import copy
import time
from functools import wraps

from sentry_sdk.integrations import DidNotEnable

from ..spans import ai_client_span, update_ai_client_span

import sentry_sdk
from sentry_sdk.consts import SPANDATA
from sentry_sdk.utils import logger
from sentry_sdk.tracing import BAGGAGE_HEADER_NAME
from sentry_sdk.tracing_utils import (
    should_propagate_trace,
    add_sentry_baggage_to_headers,
)

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from typing import Any, Callable, Optional
    from sentry_sdk.tracing import Span

try:
    import agents
    from agents.tool import HostedMCPTool
except ImportError:
    raise DidNotEnable("OpenAI Agents not installed")


def _set_response_model_on_agent_span(
    agent: "agents.Agent", response_model: "Optional[str]"
) -> None:
    """Set the response model on the agent's invoke_agent span if available."""
    if response_model:
        agent_span = getattr(agent, "_sentry_agent_span", None)
        if agent_span:
            agent_span.set_data(SPANDATA.GEN_AI_RESPONSE_MODEL, response_model)


def _inject_trace_propagation_headers(
    hosted_tool: "HostedMCPTool", span: "Span"
) -> None:
    headers = hosted_tool.tool_config.get("headers")
    if headers is None:
        headers = {}
        hosted_tool.tool_config["headers"] = headers

    mcp_url = hosted_tool.tool_config.get("server_url")
    if not mcp_url:
        return

    if should_propagate_trace(sentry_sdk.get_client(), mcp_url):
        for (
            key,
            value,
        ) in sentry_sdk.get_current_scope().iter_trace_propagation_headers(span=span):
            logger.debug(
                "[Tracing] Adding `{key}` header {value} to outgoing request to {mcp_url}.".format(
                    key=key, value=value, mcp_url=mcp_url
                )
            )
            if key == BAGGAGE_HEADER_NAME:
                add_sentry_baggage_to_headers(headers, value)
            else:
                headers[key] = value


def _get_model(
    original_get_model: "Callable[..., agents.Model]",
    agent: "agents.Agent",
    run_config: "agents.RunConfig",
) -> "agents.Model":
    """
    Responsible for
    - creating and managing AI client spans.
    - adding trace propagation headers to tools with type HostedMCPTool.
    - setting the response model on agent invocation spans.
    """
    # copy the model to double patching its methods. We use copy on purpose here (instead of deepcopy)
    # because we only patch its direct methods, all underlying data can remain unchanged.
    model = copy.copy(original_get_model(agent, run_config))

    # Capture the request model name for spans (agent.model can be None when using defaults)
    request_model_name = model.model if hasattr(model, "model") else str(model)
    agent._sentry_request_model = request_model_name

    # Wrap _fetch_response if it exists (for OpenAI models) to capture response model
    if hasattr(model, "_fetch_response"):
        original_fetch_response = model._fetch_response

        @wraps(original_fetch_response)
        async def wrapped_fetch_response(*args: "Any", **kwargs: "Any") -> "Any":
            response = await original_fetch_response(*args, **kwargs)
            if hasattr(response, "model") and response.model:
                agent._sentry_response_model = str(response.model)
            return response

        model._fetch_response = wrapped_fetch_response

    original_get_response = model.get_response

    @wraps(original_get_response)
    async def wrapped_get_response(*args: "Any", **kwargs: "Any") -> "Any":
        mcp_tools = kwargs.get("tools")
        hosted_tools = []
        if mcp_tools is not None:
            hosted_tools = [
                tool for tool in mcp_tools if isinstance(tool, HostedMCPTool)
            ]

        with ai_client_span(agent, kwargs) as span:
            for hosted_tool in hosted_tools:
                _inject_trace_propagation_headers(hosted_tool, span=span)

            result = await original_get_response(*args, **kwargs)

            # Get response model captured from _fetch_response and clean up
            response_model = getattr(agent, "_sentry_response_model", None)
            if response_model:
                delattr(agent, "_sentry_response_model")

            _set_response_model_on_agent_span(agent, response_model)
            update_ai_client_span(span, result, response_model, agent)

        return result

    model.get_response = wrapped_get_response

    # Also wrap stream_response for streaming support
    if hasattr(model, "stream_response"):
        original_stream_response = model.stream_response

        @wraps(original_stream_response)
        async def wrapped_stream_response(*args: "Any", **kwargs: "Any") -> "Any":
            # Uses explicit try/finally instead of context manager to ensure cleanup
            # even if the consumer abandons the stream (GeneratorExit).
            span_kwargs = dict(kwargs)
            if len(args) > 0:
                span_kwargs["system_instructions"] = args[0]
            if len(args) > 1:
                span_kwargs["input"] = args[1]

            hosted_tools = []
            if len(args) > 3:
                mcp_tools = args[3]

                if mcp_tools is not None:
                    hosted_tools = [
                        tool for tool in mcp_tools if isinstance(tool, HostedMCPTool)
                    ]

            with ai_client_span(agent, span_kwargs) as span:
                for hosted_tool in hosted_tools:
                    _inject_trace_propagation_headers(hosted_tool, span=span)

                span.set_data(SPANDATA.GEN_AI_RESPONSE_STREAMING, True)

                streaming_response = None
                ttft_recorded = False
                # Capture start time locally to avoid race conditions with concurrent requests
                start_time = time.perf_counter()

                async for event in original_stream_response(*args, **kwargs):
                    # Detect first content token (text delta event)
                    if not ttft_recorded and hasattr(event, "delta"):
                        ttft = time.perf_counter() - start_time
                        span.set_data(
                            SPANDATA.GEN_AI_RESPONSE_TIME_TO_FIRST_TOKEN, ttft
                        )
                        ttft_recorded = True

                    # Capture the full response from ResponseCompletedEvent
                    if hasattr(event, "response"):
                        streaming_response = event.response
                    yield event

                # Update span with response data (usage, output, model)
                if streaming_response:
                    response_model = (
                        str(streaming_response.model)
                        if hasattr(streaming_response, "model")
                        and streaming_response.model
                        else None
                    )
                    _set_response_model_on_agent_span(agent, response_model)
                    update_ai_client_span(
                        span, streaming_response, response_model, agent
                    )

        model.stream_response = wrapped_stream_response

    return model
