import sys
import json
import time
from functools import wraps
from collections.abc import Iterable

import sentry_sdk
from sentry_sdk import consts
from sentry_sdk.ai.monitoring import record_token_usage
from sentry_sdk.ai.utils import (
    set_data_normalized,
    normalize_message_roles,
    truncate_and_annotate_messages,
    truncate_and_annotate_embedding_inputs,
)
from sentry_sdk.ai._openai_completions_api import (
    _is_system_instruction as _is_system_instruction_completions,
    _get_system_instructions as _get_system_instructions_completions,
    _transform_system_instructions,
    _get_text_items,
)
from sentry_sdk.ai._openai_responses_api import (
    _is_system_instruction as _is_system_instruction_responses,
    _get_system_instructions as _get_system_instructions_responses,
)
from sentry_sdk.consts import SPANDATA
from sentry_sdk.integrations import DidNotEnable, Integration
from sentry_sdk.scope import should_send_default_pii
from sentry_sdk.tracing_utils import set_span_errored
from sentry_sdk.utils import (
    capture_internal_exceptions,
    event_from_exception,
    safe_serialize,
    reraise,
)

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from typing import (
        Any,
        List,
        Optional,
        Callable,
        AsyncIterator,
        Iterator,
        Union,
        Iterable,
    )
    from sentry_sdk.tracing import Span
    from sentry_sdk._types import TextPart

    from openai.types.responses import ResponseInputParam, SequenceNotStr
    from openai.types.responses import ResponseStreamEvent
    from openai import Omit

try:
    try:
        from openai import NotGiven
    except ImportError:
        NotGiven = None

    try:
        from openai import Omit
    except ImportError:
        Omit = None

    from openai.resources.chat.completions import Completions, AsyncCompletions
    from openai.resources import Embeddings, AsyncEmbeddings

    from openai import Stream, AsyncStream

    if TYPE_CHECKING:
        from openai.types.chat import (
            ChatCompletionMessageParam,
            ChatCompletionChunk,
        )
except ImportError:
    raise DidNotEnable("OpenAI not installed")

RESPONSES_API_ENABLED = True
try:
    # responses API support was introduced in v1.66.0
    from openai.resources.responses import Responses, AsyncResponses
    from openai.types.responses.response_completed_event import ResponseCompletedEvent
except ImportError:
    RESPONSES_API_ENABLED = False


class OpenAIIntegration(Integration):
    identifier = "openai"
    origin = f"auto.ai.{identifier}"

    def __init__(
        self: "OpenAIIntegration",
        include_prompts: bool = True,
        tiktoken_encoding_name: "Optional[str]" = None,
    ) -> None:
        self.include_prompts = include_prompts

        self.tiktoken_encoding = None
        if tiktoken_encoding_name is not None:
            import tiktoken  # type: ignore

            self.tiktoken_encoding = tiktoken.get_encoding(tiktoken_encoding_name)

    @staticmethod
    def setup_once() -> None:
        Completions.create = _wrap_chat_completion_create(Completions.create)
        AsyncCompletions.create = _wrap_async_chat_completion_create(
            AsyncCompletions.create
        )

        Embeddings.create = _wrap_embeddings_create(Embeddings.create)
        AsyncEmbeddings.create = _wrap_async_embeddings_create(AsyncEmbeddings.create)

        if RESPONSES_API_ENABLED:
            Responses.create = _wrap_responses_create(Responses.create)
            AsyncResponses.create = _wrap_async_responses_create(AsyncResponses.create)

    def count_tokens(self: "OpenAIIntegration", s: str) -> int:
        if self.tiktoken_encoding is None:
            return 0
        try:
            return len(self.tiktoken_encoding.encode_ordinary(s))
        except Exception:
            return 0


def _capture_exception(exc: "Any", manual_span_cleanup: bool = True) -> None:
    # Close an eventually open span
    # We need to do this by hand because we are not using the start_span context manager
    current_span = sentry_sdk.get_current_span()
    set_span_errored(current_span)

    if manual_span_cleanup and current_span is not None:
        current_span.__exit__(None, None, None)

    event, hint = event_from_exception(
        exc,
        client_options=sentry_sdk.get_client().options,
        mechanism={"type": "openai", "handled": False},
    )
    sentry_sdk.capture_event(event, hint=hint)


def _get_usage(usage: "Any", names: "List[str]") -> int:
    for name in names:
        if hasattr(usage, name) and isinstance(getattr(usage, name), int):
            return getattr(usage, name)
    return 0


def _calculate_token_usage(
    messages: "Optional[Iterable[ChatCompletionMessageParam]]",
    response: "Any",
    span: "Span",
    streaming_message_responses: "Optional[List[str]]",
    count_tokens: "Callable[..., Any]",
) -> None:
    input_tokens: "Optional[int]" = 0
    input_tokens_cached: "Optional[int]" = 0
    output_tokens: "Optional[int]" = 0
    output_tokens_reasoning: "Optional[int]" = 0
    total_tokens: "Optional[int]" = 0

    if hasattr(response, "usage"):
        input_tokens = _get_usage(response.usage, ["input_tokens", "prompt_tokens"])
        if hasattr(response.usage, "input_tokens_details"):
            input_tokens_cached = _get_usage(
                response.usage.input_tokens_details, ["cached_tokens"]
            )

        output_tokens = _get_usage(
            response.usage, ["output_tokens", "completion_tokens"]
        )
        if hasattr(response.usage, "output_tokens_details"):
            output_tokens_reasoning = _get_usage(
                response.usage.output_tokens_details, ["reasoning_tokens"]
            )

        total_tokens = _get_usage(response.usage, ["total_tokens"])

    # Manually count tokens
    if input_tokens == 0:
        for message in messages or []:
            if isinstance(message, str):
                input_tokens += count_tokens(message)
                continue
            elif isinstance(message, dict):
                message_content = message.get("content")
                if message_content is None:
                    continue
                # Deliberate use of Completions function for both Completions and Responses input format.
                text_items = _get_text_items(message_content)
                input_tokens += sum(count_tokens(text) for text in text_items)
                continue

    if output_tokens == 0:
        if streaming_message_responses is not None:
            for message in streaming_message_responses:
                output_tokens += count_tokens(message)
        elif hasattr(response, "choices"):
            for choice in response.choices:
                if hasattr(choice, "message") and hasattr(choice.message, "content"):
                    output_tokens += count_tokens(choice.message.content)

    # Do not set token data if it is 0
    input_tokens = input_tokens or None
    input_tokens_cached = input_tokens_cached or None
    output_tokens = output_tokens or None
    output_tokens_reasoning = output_tokens_reasoning or None
    total_tokens = total_tokens or None

    record_token_usage(
        span,
        input_tokens=input_tokens,
        input_tokens_cached=input_tokens_cached,
        output_tokens=output_tokens,
        output_tokens_reasoning=output_tokens_reasoning,
        total_tokens=total_tokens,
    )


def _set_responses_api_input_data(
    span: "Span",
    kwargs: "dict[str, Any]",
    integration: "OpenAIIntegration",
) -> None:
    explicit_instructions: "Union[Optional[str], Omit]" = kwargs.get("instructions")
    messages: "Optional[Union[str, ResponseInputParam]]" = kwargs.get("input")

    tools = kwargs.get("tools")
    if tools is not None and _is_given(tools) and len(tools) > 0:
        set_data_normalized(
            span, SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, safe_serialize(tools)
        )

    model = kwargs.get("model")
    if model is not None:
        span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)

    max_tokens = kwargs.get("max_output_tokens")
    if max_tokens is not None and _is_given(max_tokens):
        span.set_data(SPANDATA.GEN_AI_REQUEST_MAX_TOKENS, max_tokens)

    temperature = kwargs.get("temperature")
    if temperature is not None and _is_given(temperature):
        span.set_data(SPANDATA.GEN_AI_REQUEST_TEMPERATURE, temperature)

    top_p = kwargs.get("top_p")
    if top_p is not None and _is_given(top_p):
        span.set_data(SPANDATA.GEN_AI_REQUEST_TOP_P, top_p)

    if not should_send_default_pii() or not integration.include_prompts:
        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "responses")
        return

    if (
        messages is None
        and explicit_instructions is not None
        and _is_given(explicit_instructions)
    ):
        span.set_data(
            SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS,
            json.dumps(
                [
                    {
                        "type": "text",
                        "content": explicit_instructions,
                    }
                ]
            ),
        )

        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "responses")
        return

    if messages is None:
        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "responses")
        return

    instructions_text_parts: "list[TextPart]" = []
    if explicit_instructions is not None and _is_given(explicit_instructions):
        instructions_text_parts.append(
            {
                "type": "text",
                "content": explicit_instructions,
            }
        )

    system_instructions = _get_system_instructions_responses(messages)
    # Deliberate use of function accepting completions API type because
    # of shared structure FOR THIS PURPOSE ONLY.
    instructions_text_parts += _transform_system_instructions(system_instructions)

    if len(instructions_text_parts) > 0:
        span.set_data(
            SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS,
            json.dumps(instructions_text_parts),
        )

    if isinstance(messages, str):
        normalized_messages = normalize_message_roles([messages])  # type: ignore
        scope = sentry_sdk.get_current_scope()
        messages_data = truncate_and_annotate_messages(normalized_messages, span, scope)
        if messages_data is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages_data, unpack=False
            )

        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "responses")
        return

    non_system_messages = [
        message for message in messages if not _is_system_instruction_responses(message)
    ]
    if len(non_system_messages) > 0:
        normalized_messages = normalize_message_roles(non_system_messages)
        scope = sentry_sdk.get_current_scope()
        messages_data = truncate_and_annotate_messages(normalized_messages, span, scope)
        if messages_data is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages_data, unpack=False
            )

    set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "responses")


def _set_completions_api_input_data(
    span: "Span",
    kwargs: "dict[str, Any]",
    integration: "OpenAIIntegration",
) -> None:
    messages: "Optional[Union[str, Iterable[ChatCompletionMessageParam]]]" = kwargs.get(
        "messages"
    )

    tools = kwargs.get("tools")
    if tools is not None and _is_given(tools) and len(tools) > 0:
        set_data_normalized(
            span, SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, safe_serialize(tools)
        )

    model = kwargs.get("model")
    if model is not None:
        span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)

    max_tokens = kwargs.get("max_tokens")
    if max_tokens is not None and _is_given(max_tokens):
        span.set_data(SPANDATA.GEN_AI_REQUEST_MAX_TOKENS, max_tokens)

    presence_penalty = kwargs.get("presence_penalty")
    if presence_penalty is not None and _is_given(presence_penalty):
        span.set_data(SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY, presence_penalty)

    frequency_penalty = kwargs.get("frequency_penalty")
    if frequency_penalty is not None and _is_given(frequency_penalty):
        span.set_data(SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY, frequency_penalty)

    temperature = kwargs.get("temperature")
    if temperature is not None and _is_given(temperature):
        span.set_data(SPANDATA.GEN_AI_REQUEST_TEMPERATURE, temperature)

    top_p = kwargs.get("top_p")
    if top_p is not None and _is_given(top_p):
        span.set_data(SPANDATA.GEN_AI_REQUEST_TOP_P, top_p)

    if (
        not should_send_default_pii()
        or not integration.include_prompts
        or messages is None
    ):
        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat")
        return

    if isinstance(messages, str):
        normalized_messages = normalize_message_roles([messages])  # type: ignore
        scope = sentry_sdk.get_current_scope()
        messages_data = truncate_and_annotate_messages(normalized_messages, span, scope)
        if messages_data is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages_data, unpack=False
            )
        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat")
        return

    # dict special case following https://github.com/openai/openai-python/blob/3e0c05b84a2056870abf3bd6a5e7849020209cc3/src/openai/_utils/_transform.py#L194-L197
    if not isinstance(messages, Iterable) or isinstance(messages, dict):
        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat")
        return

    messages = list(messages)
    kwargs["messages"] = messages

    system_instructions = _get_system_instructions_completions(messages)
    if len(system_instructions) > 0:
        span.set_data(
            SPANDATA.GEN_AI_SYSTEM_INSTRUCTIONS,
            json.dumps(_transform_system_instructions(system_instructions)),
        )

    non_system_messages = [
        message
        for message in messages
        if not _is_system_instruction_completions(message)
    ]
    if len(non_system_messages) > 0:
        normalized_messages = normalize_message_roles(non_system_messages)
        scope = sentry_sdk.get_current_scope()
        messages_data = truncate_and_annotate_messages(normalized_messages, span, scope)
        if messages_data is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages_data, unpack=False
            )

    set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "chat")


def _set_embeddings_input_data(
    span: "Span",
    kwargs: "dict[str, Any]",
    integration: "OpenAIIntegration",
) -> None:
    messages: "Union[str, SequenceNotStr[str], Iterable[int], Iterable[Iterable[int]]]" = kwargs.get(
        "input"
    )

    model = kwargs.get("model")
    if model is not None:
        span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model)

    if (
        not should_send_default_pii()
        or not integration.include_prompts
        or messages is None
    ):
        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")

        return

    if isinstance(messages, str):
        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")

        normalized_messages = normalize_message_roles([messages])  # type: ignore
        scope = sentry_sdk.get_current_scope()
        messages_data = truncate_and_annotate_embedding_inputs(
            normalized_messages, span, scope
        )
        if messages_data is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, messages_data, unpack=False
            )

        return

    # dict special case following https://github.com/openai/openai-python/blob/3e0c05b84a2056870abf3bd6a5e7849020209cc3/src/openai/_utils/_transform.py#L194-L197
    if not isinstance(messages, Iterable) or isinstance(messages, dict):
        set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")
        return

    messages = list(messages)
    kwargs["input"] = messages

    if len(messages) > 0:
        normalized_messages = normalize_message_roles(messages)
        scope = sentry_sdk.get_current_scope()
        messages_data = truncate_and_annotate_embedding_inputs(
            normalized_messages, span, scope
        )
        if messages_data is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_EMBEDDINGS_INPUT, messages_data, unpack=False
            )

    set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, "embeddings")


def _set_common_output_data(
    span: "Span",
    response: "Any",
    input: "Any",
    integration: "OpenAIIntegration",
    finish_span: bool = True,
) -> None:
    if hasattr(response, "model"):
        set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_MODEL, response.model)

    if hasattr(response, "choices"):
        if should_send_default_pii() and integration.include_prompts:
            response_text = [
                choice.message.model_dump()
                for choice in response.choices
                if choice.message is not None
            ]
            if len(response_text) > 0:
                set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, response_text)

        _calculate_token_usage(input, response, span, None, integration.count_tokens)

        if finish_span:
            span.__exit__(None, None, None)

    elif hasattr(response, "output"):
        if should_send_default_pii() and integration.include_prompts:
            output_messages: "dict[str, list[Any]]" = {
                "response": [],
                "tool": [],
            }

            for output in response.output:
                if output.type == "function_call":
                    output_messages["tool"].append(output.dict())
                elif output.type == "message":
                    for output_message in output.content:
                        try:
                            output_messages["response"].append(output_message.text)
                        except AttributeError:
                            # Unknown output message type, just return the json
                            output_messages["response"].append(output_message.dict())

            if len(output_messages["tool"]) > 0:
                set_data_normalized(
                    span,
                    SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS,
                    output_messages["tool"],
                    unpack=False,
                )

            if len(output_messages["response"]) > 0:
                set_data_normalized(
                    span, SPANDATA.GEN_AI_RESPONSE_TEXT, output_messages["response"]
                )

        _calculate_token_usage(input, response, span, None, integration.count_tokens)

        if finish_span:
            span.__exit__(None, None, None)
    else:
        _calculate_token_usage(input, response, span, None, integration.count_tokens)
        if finish_span:
            span.__exit__(None, None, None)


def _new_chat_completion_common(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
    integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
    if integration is None:
        return f(*args, **kwargs)

    if "messages" not in kwargs:
        # invalid call (in all versions of openai), let it return error
        return f(*args, **kwargs)

    try:
        iter(kwargs["messages"])
    except TypeError:
        # invalid call (in all versions), messages must be iterable
        return f(*args, **kwargs)

    model = kwargs.get("model")

    span = sentry_sdk.start_span(
        op=consts.OP.GEN_AI_CHAT,
        name=f"chat {model}",
        origin=OpenAIIntegration.origin,
    )
    span.__enter__()

    span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai")

    # Same bool handling as in https://github.com/openai/openai-python/blob/acd0c54d8a68efeedde0e5b4e6c310eef1ce7867/src/openai/resources/completions.py#L585
    is_streaming_response = kwargs.get("stream", False) or False
    span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, is_streaming_response)

    _set_completions_api_input_data(span, kwargs, integration)

    start_time = time.perf_counter()
    response = yield f, args, kwargs

    # Attribute check to fail gracefully if the attribute is not present in future `openai` versions.
    if isinstance(response, Stream) and hasattr(response, "_iterator"):
        messages = kwargs.get("messages")

        if messages is not None and isinstance(messages, str):
            messages = [messages]

        response._iterator = _wrap_synchronous_completions_chunk_iterator(
            span=span,
            integration=integration,
            start_time=start_time,
            messages=messages,
            response=response,
            old_iterator=response._iterator,
            finish_span=True,
        )

    # Attribute check to fail gracefully if the attribute is not present in future `openai` versions.
    elif isinstance(response, AsyncStream) and hasattr(response, "_iterator"):
        messages = kwargs.get("messages")

        if messages is not None and isinstance(messages, str):
            messages = [messages]

        response._iterator = _wrap_asynchronous_completions_chunk_iterator(
            span=span,
            integration=integration,
            start_time=start_time,
            messages=messages,
            response=response,
            old_iterator=response._iterator,
            finish_span=True,
        )
    else:
        _set_completions_api_output_data(
            span, response, kwargs, integration, finish_span=True
        )

    return response


def _set_completions_api_output_data(
    span: "Span",
    response: "Any",
    kwargs: "dict[str, Any]",
    integration: "OpenAIIntegration",
    finish_span: bool = True,
) -> None:
    messages = kwargs.get("messages")

    if messages is not None and isinstance(messages, str):
        messages = [messages]

    _set_common_output_data(
        span,
        response,
        messages,
        integration,
        finish_span,
    )


def _wrap_synchronous_completions_chunk_iterator(
    span: "Span",
    integration: "OpenAIIntegration",
    start_time: "Optional[float]",
    messages: "Optional[Iterable[ChatCompletionMessageParam]]",
    response: "Stream[ChatCompletionChunk]",
    old_iterator: "Iterator[ChatCompletionChunk]",
    finish_span: "bool",
) -> "Iterator[ChatCompletionChunk]":
    """
    Sets information received while iterating the response stream on the AI Client Span.
    Compute token count based on inputs and outputs using tiktoken if token counts are not in the model response.
    Responsible for closing the AI Client Span if instructed to by the `finish_span` argument.
    """
    ttft = None
    data_buf: "list[list[str]]" = []  # one for each choice

    for x in old_iterator:
        span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, x.model)

        with capture_internal_exceptions():
            if hasattr(x, "choices"):
                choice_index = 0
                for choice in x.choices:
                    if hasattr(choice, "delta") and hasattr(choice.delta, "content"):
                        if start_time is not None and ttft is None:
                            ttft = time.perf_counter() - start_time
                        content = choice.delta.content
                        if len(data_buf) <= choice_index:
                            data_buf.append([])
                        data_buf[choice_index].append(content or "")
                    choice_index += 1

        yield x

    with capture_internal_exceptions():
        if ttft is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN, ttft
            )
        if len(data_buf) > 0:
            all_responses = ["".join(chunk) for chunk in data_buf]
            if should_send_default_pii() and integration.include_prompts:
                set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses)
            _calculate_token_usage(
                messages,
                response,
                span,
                all_responses,
                integration.count_tokens,
            )

    if finish_span:
        span.__exit__(None, None, None)


async def _wrap_asynchronous_completions_chunk_iterator(
    span: "Span",
    integration: "OpenAIIntegration",
    start_time: "Optional[float]",
    messages: "Optional[Iterable[ChatCompletionMessageParam]]",
    response: "AsyncStream[ChatCompletionChunk]",
    old_iterator: "AsyncIterator[ChatCompletionChunk]",
    finish_span: "bool",
) -> "AsyncIterator[ChatCompletionChunk]":
    """
    Sets information received while iterating the response stream on the AI Client Span.
    Compute token count based on inputs and outputs using tiktoken if token counts are not in the model response.
    Responsible for closing the AI Client Span if instructed to by the `finish_span` argument.
    """
    ttft = None
    data_buf: "list[list[str]]" = []  # one for each choice

    async for x in old_iterator:
        span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, x.model)

        with capture_internal_exceptions():
            if hasattr(x, "choices"):
                choice_index = 0
                for choice in x.choices:
                    if hasattr(choice, "delta") and hasattr(choice.delta, "content"):
                        if start_time is not None and ttft is None:
                            ttft = time.perf_counter() - start_time
                        content = choice.delta.content
                        if len(data_buf) <= choice_index:
                            data_buf.append([])
                        data_buf[choice_index].append(content or "")
                    choice_index += 1

        yield x

    with capture_internal_exceptions():
        if ttft is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN, ttft
            )
        if len(data_buf) > 0:
            all_responses = ["".join(chunk) for chunk in data_buf]
            if should_send_default_pii() and integration.include_prompts:
                set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses)
            _calculate_token_usage(
                messages,
                response,
                span,
                all_responses,
                integration.count_tokens,
            )

    if finish_span:
        span.__exit__(None, None, None)


def _wrap_synchronous_responses_event_iterator(
    span: "Span",
    integration: "OpenAIIntegration",
    start_time: "Optional[float]",
    input: "Optional[Union[str, ResponseInputParam]]",
    response: "Stream[ResponseStreamEvent]",
    old_iterator: "Iterator[ResponseStreamEvent]",
    finish_span: "bool",
) -> "Iterator[ResponseStreamEvent]":
    """
    Sets information received while iterating the response stream on the AI Client Span.
    Compute token count based on inputs and outputs using tiktoken if token counts are not in the model response.
    Responsible for closing the AI Client Span if instructed to by the `finish_span` argument.
    """
    ttft = None
    data_buf: "list[list[str]]" = []  # one for each choice

    count_tokens_manually = True
    for x in old_iterator:
        with capture_internal_exceptions():
            if hasattr(x, "delta"):
                if start_time is not None and ttft is None:
                    ttft = time.perf_counter() - start_time
                if len(data_buf) == 0:
                    data_buf.append([])
                data_buf[0].append(x.delta or "")

            if isinstance(x, ResponseCompletedEvent):
                span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, x.response.model)

                _calculate_token_usage(
                    input,
                    x.response,
                    span,
                    None,
                    integration.count_tokens,
                )
                count_tokens_manually = False

        yield x

    with capture_internal_exceptions():
        if ttft is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN, ttft
            )
        if len(data_buf) > 0:
            all_responses = ["".join(chunk) for chunk in data_buf]
            if should_send_default_pii() and integration.include_prompts:
                set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses)
            if count_tokens_manually:
                _calculate_token_usage(
                    input,
                    response,
                    span,
                    all_responses,
                    integration.count_tokens,
                )

    if finish_span:
        span.__exit__(None, None, None)


async def _wrap_asynchronous_responses_event_iterator(
    span: "Span",
    integration: "OpenAIIntegration",
    start_time: "Optional[float]",
    input: "Optional[Union[str, ResponseInputParam]]",
    response: "AsyncStream[ResponseStreamEvent]",
    old_iterator: "AsyncIterator[ResponseStreamEvent]",
    finish_span: "bool",
) -> "AsyncIterator[ResponseStreamEvent]":
    """
    Sets information received while iterating the response stream on the AI Client Span.
    Compute token count based on inputs and outputs using tiktoken if token counts are not in the model response.
    Responsible for closing the AI Client Span if instructed to by the `finish_span` argument.
    """
    ttft: "Optional[float]" = None
    data_buf: "list[list[str]]" = []  # one for each choice

    count_tokens_manually = True
    async for x in old_iterator:
        with capture_internal_exceptions():
            if hasattr(x, "delta"):
                if start_time is not None and ttft is None:
                    ttft = time.perf_counter() - start_time
                if len(data_buf) == 0:
                    data_buf.append([])
                data_buf[0].append(x.delta or "")

            if isinstance(x, ResponseCompletedEvent):
                span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, x.response.model)

                _calculate_token_usage(
                    input,
                    x.response,
                    span,
                    None,
                    integration.count_tokens,
                )
                count_tokens_manually = False

        yield x

    with capture_internal_exceptions():
        if ttft is not None:
            set_data_normalized(
                span, SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN, ttft
            )
        if len(data_buf) > 0:
            all_responses = ["".join(chunk) for chunk in data_buf]
            if should_send_default_pii() and integration.include_prompts:
                set_data_normalized(span, SPANDATA.GEN_AI_RESPONSE_TEXT, all_responses)
            if count_tokens_manually:
                _calculate_token_usage(
                    input,
                    response,
                    span,
                    all_responses,
                    integration.count_tokens,
                )
    if finish_span:
        span.__exit__(None, None, None)


def _set_responses_api_output_data(
    span: "Span",
    response: "Any",
    kwargs: "dict[str, Any]",
    integration: "OpenAIIntegration",
    finish_span: bool = True,
) -> None:
    input = kwargs.get("input")

    if input is not None and isinstance(input, str):
        input = [input]

    _set_common_output_data(
        span,
        response,
        input,
        integration,
        finish_span,
    )


def _set_embeddings_output_data(
    span: "Span",
    response: "Any",
    kwargs: "dict[str, Any]",
    integration: "OpenAIIntegration",
    finish_span: bool = True,
) -> None:
    input = kwargs.get("input")

    if input is not None and isinstance(input, str):
        input = [input]

    _set_common_output_data(
        span,
        response,
        input,
        integration,
        finish_span,
    )


def _wrap_chat_completion_create(f: "Callable[..., Any]") -> "Callable[..., Any]":
    def _execute_sync(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
        gen = _new_chat_completion_common(f, *args, **kwargs)

        try:
            f, args, kwargs = next(gen)
        except StopIteration as e:
            return e.value

        try:
            try:
                result = f(*args, **kwargs)
            except Exception as e:
                exc_info = sys.exc_info()
                with capture_internal_exceptions():
                    _capture_exception(e)
                reraise(*exc_info)

            return gen.send(result)
        except StopIteration as e:
            return e.value

    @wraps(f)
    def _sentry_patched_create_sync(*args: "Any", **kwargs: "Any") -> "Any":
        integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
        if integration is None or "messages" not in kwargs:
            # no "messages" means invalid call (in all versions of openai), let it return error
            return f(*args, **kwargs)

        return _execute_sync(f, *args, **kwargs)

    return _sentry_patched_create_sync


def _wrap_async_chat_completion_create(f: "Callable[..., Any]") -> "Callable[..., Any]":
    async def _execute_async(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
        gen = _new_chat_completion_common(f, *args, **kwargs)

        try:
            f, args, kwargs = next(gen)
        except StopIteration as e:
            return await e.value

        try:
            try:
                result = await f(*args, **kwargs)
            except Exception as e:
                exc_info = sys.exc_info()
                with capture_internal_exceptions():
                    _capture_exception(e)
                reraise(*exc_info)

            return gen.send(result)
        except StopIteration as e:
            return e.value

    @wraps(f)
    async def _sentry_patched_create_async(*args: "Any", **kwargs: "Any") -> "Any":
        integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
        if integration is None or "messages" not in kwargs:
            # no "messages" means invalid call (in all versions of openai), let it return error
            return await f(*args, **kwargs)

        return await _execute_async(f, *args, **kwargs)

    return _sentry_patched_create_async


def _new_embeddings_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
    integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
    if integration is None:
        return f(*args, **kwargs)

    model = kwargs.get("model")

    with sentry_sdk.start_span(
        op=consts.OP.GEN_AI_EMBEDDINGS,
        name=f"embeddings {model}",
        origin=OpenAIIntegration.origin,
    ) as span:
        span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai")
        _set_embeddings_input_data(span, kwargs, integration)

        response = yield f, args, kwargs

        _set_embeddings_output_data(
            span, response, kwargs, integration, finish_span=False
        )

        return response


def _wrap_embeddings_create(f: "Any") -> "Any":
    def _execute_sync(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
        gen = _new_embeddings_create_common(f, *args, **kwargs)

        try:
            f, args, kwargs = next(gen)
        except StopIteration as e:
            return e.value

        try:
            try:
                result = f(*args, **kwargs)
            except Exception as e:
                exc_info = sys.exc_info()
                with capture_internal_exceptions():
                    _capture_exception(e, manual_span_cleanup=False)
                reraise(*exc_info)

            return gen.send(result)
        except StopIteration as e:
            return e.value

    @wraps(f)
    def _sentry_patched_create_sync(*args: "Any", **kwargs: "Any") -> "Any":
        integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
        if integration is None:
            return f(*args, **kwargs)

        return _execute_sync(f, *args, **kwargs)

    return _sentry_patched_create_sync


def _wrap_async_embeddings_create(f: "Any") -> "Any":
    async def _execute_async(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
        gen = _new_embeddings_create_common(f, *args, **kwargs)

        try:
            f, args, kwargs = next(gen)
        except StopIteration as e:
            return await e.value

        try:
            try:
                result = await f(*args, **kwargs)
            except Exception as e:
                exc_info = sys.exc_info()
                with capture_internal_exceptions():
                    _capture_exception(e, manual_span_cleanup=False)
                reraise(*exc_info)

            return gen.send(result)
        except StopIteration as e:
            return e.value

    @wraps(f)
    async def _sentry_patched_create_async(*args: "Any", **kwargs: "Any") -> "Any":
        integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
        if integration is None:
            return await f(*args, **kwargs)

        return await _execute_async(f, *args, **kwargs)

    return _sentry_patched_create_async


def _new_responses_create_common(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
    integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
    if integration is None:
        return f(*args, **kwargs)

    model = kwargs.get("model")

    span = sentry_sdk.start_span(
        op=consts.OP.GEN_AI_RESPONSES,
        name=f"responses {model}",
        origin=OpenAIIntegration.origin,
    )
    span.__enter__()

    span.set_data(SPANDATA.GEN_AI_SYSTEM, "openai")

    # Same bool handling as in https://github.com/openai/openai-python/blob/acd0c54d8a68efeedde0e5b4e6c310eef1ce7867/src/openai/resources/responses/responses.py#L940
    is_streaming_response = kwargs.get("stream", False) or False
    span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, is_streaming_response)

    _set_responses_api_input_data(span, kwargs, integration)

    start_time = time.perf_counter()
    response = yield f, args, kwargs

    # Attribute check to fail gracefully if the attribute is not present in future `openai` versions.
    if isinstance(response, Stream) and hasattr(response, "_iterator"):
        input = kwargs.get("input")

        if input is not None and isinstance(input, str):
            input = [input]

        response._iterator = _wrap_synchronous_responses_event_iterator(
            span=span,
            integration=integration,
            start_time=start_time,
            input=input,
            response=response,
            old_iterator=response._iterator,
            finish_span=True,
        )

    # Attribute check to fail gracefully if the attribute is not present in future `openai` versions.
    elif isinstance(response, AsyncStream) and hasattr(response, "_iterator"):
        input = kwargs.get("input")

        if input is not None and isinstance(input, str):
            input = [input]

        response._iterator = _wrap_asynchronous_responses_event_iterator(
            span=span,
            integration=integration,
            start_time=start_time,
            input=input,
            response=response,
            old_iterator=response._iterator,
            finish_span=True,
        )
    else:
        _set_responses_api_output_data(
            span, response, kwargs, integration, finish_span=True
        )

    return response


def _wrap_responses_create(f: "Any") -> "Any":
    def _execute_sync(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
        gen = _new_responses_create_common(f, *args, **kwargs)

        try:
            f, args, kwargs = next(gen)
        except StopIteration as e:
            return e.value

        try:
            try:
                result = f(*args, **kwargs)
            except Exception as e:
                exc_info = sys.exc_info()
                with capture_internal_exceptions():
                    _capture_exception(e)
                reraise(*exc_info)

            return gen.send(result)
        except StopIteration as e:
            return e.value

    @wraps(f)
    def _sentry_patched_create_sync(*args: "Any", **kwargs: "Any") -> "Any":
        integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
        if integration is None:
            return f(*args, **kwargs)

        return _execute_sync(f, *args, **kwargs)

    return _sentry_patched_create_sync


def _wrap_async_responses_create(f: "Any") -> "Any":
    async def _execute_async(f: "Any", *args: "Any", **kwargs: "Any") -> "Any":
        gen = _new_responses_create_common(f, *args, **kwargs)

        try:
            f, args, kwargs = next(gen)
        except StopIteration as e:
            return await e.value

        try:
            try:
                result = await f(*args, **kwargs)
            except Exception as e:
                exc_info = sys.exc_info()
                with capture_internal_exceptions():
                    _capture_exception(e)
                reraise(*exc_info)

            return gen.send(result)
        except StopIteration as e:
            return e.value

    @wraps(f)
    async def _sentry_patched_responses_async(*args: "Any", **kwargs: "Any") -> "Any":
        integration = sentry_sdk.get_client().get_integration(OpenAIIntegration)
        if integration is None:
            return await f(*args, **kwargs)

        return await _execute_async(f, *args, **kwargs)

    return _sentry_patched_responses_async


def _is_given(obj: "Any") -> bool:
    """
    Check for givenness safely across different openai versions.
    """
    if NotGiven is not None and isinstance(obj, NotGiven):
        return False
    if Omit is not None and isinstance(obj, Omit):
        return False
    return True
