import copyreg
import os

import fastapi
import fastapi.routing
import starlette
import wrapt
from wrapt import wrap_function_wrapper as _w

from ddtrace import config
from ddtrace._trace.pin import Pin
from ddtrace.contrib.internal.asgi.middleware import TraceMiddleware
from ddtrace.contrib.internal.starlette.patch import _trace_background_tasks
from ddtrace.contrib.internal.starlette.patch import traced_handler
from ddtrace.contrib.internal.starlette.patch import traced_route_init
from ddtrace.internal.compat import is_wrapted
from ddtrace.internal.logger import get_logger
from ddtrace.internal.schema import schematize_service_name
from ddtrace.internal.settings.asm import config as asm_config
from ddtrace.internal.telemetry import get_config as _get_config
from ddtrace.internal.utils.formats import asbool
from ddtrace.internal.utils.version import parse_version
from ddtrace.internal.utils.wrappers import unwrap as _u
from ddtrace.trace import tracer


log = get_logger(__name__)

_WRAPT_REDUCERS_REGISTERED = False


def _identity(x):
    """Identity function for pickle reconstruction - returns unwrapped object."""
    return x


def _reduce_wrapt_proxy(proxy):
    """Pickle reducer for wrapt proxies.

    Returns (callable, args) tuple for pickle reconstruction.
    Using _identity(proxy.__wrapped__) strips the wrapper.
    """
    return (_identity, (proxy.__wrapped__,))


def _register_wrapt_pickle_reducers():
    """Register pickle reducers for wrapt proxy types.

    Must be called before FastAPI app is pickled (e.g., by Ray Serve/vLLM).
    """
    global _WRAPT_REDUCERS_REGISTERED
    if _WRAPT_REDUCERS_REGISTERED:
        return

    # Only register for Starlette >= 0.24.0 (lazy middleware initialization)
    # Required for copyreg.dispatch_table to work with wrapt types
    if parse_version(starlette.__version__) < parse_version("0.24.0"):
        _WRAPT_REDUCERS_REGISTERED = True  # Mark as "handled" to avoid re-checking
        return

    for cls in [wrapt.ObjectProxy, wrapt.FunctionWrapper, wrapt.BoundFunctionWrapper]:
        if cls not in copyreg.dispatch_table:
            copyreg.dispatch_table[cls] = _reduce_wrapt_proxy
    _WRAPT_REDUCERS_REGISTERED = True


config._add(
    "fastapi",
    dict(
        _default_service=schematize_service_name("fastapi"),
        request_span_name="fastapi.request",
        distributed_tracing=True,
        trace_query_string=None,  # Default to global config
        obfuscate_404_resource=os.getenv("DD_ASGI_OBFUSCATE_404_RESOURCE", default=False),
        trace_asgi_websocket_messages=_get_config(
            "DD_TRACE_WEBSOCKET_MESSAGES_ENABLED",
            default=_get_config("DD_ASGI_TRACE_WEBSOCKET", default=True, modifier=asbool),
            modifier=asbool,
        ),
        asgi_websocket_messages_inherit_sampling=asbool(
            _get_config("DD_TRACE_WEBSOCKET_MESSAGES_INHERIT_SAMPLING", default=True)
        )
        and asbool(_get_config("DD_TRACE_WEBSOCKET_MESSAGES_SEPARATE_TRACES", default=True)),
        websocket_messages_separate_traces=asbool(
            _get_config("DD_TRACE_WEBSOCKET_MESSAGES_SEPARATE_TRACES", default=True)
        ),
    ),
)


def get_version() -> str:
    return getattr(fastapi, "__version__", "")


def _supported_versions() -> dict[str, str]:
    return {"fastapi": ">=0.64.0"}


def wrap_middleware_stack(wrapped, instance, args, kwargs):
    return TraceMiddleware(app=wrapped(*args, **kwargs), integration_config=config.fastapi)


async def traced_serialize_response(wrapped, instance, args, kwargs):
    """Wrapper for fastapi.routing.serialize_response function.

    This function is called on all non-Response objects to
    convert them to a serializable form.

    This is the wrapper which calls ``jsonable_encoder``.

    This function does not do the actual encoding from
    obj -> json string  (e.g. json.dumps()). That is handled
    by the Response.render function.

    DEV: We do not wrap ``jsonable_encoder`` because it calls
    itself recursively, so there is a chance the overhead
    added by creating spans will be higher than desired for
    the result.
    """
    pin = Pin.get_from(fastapi)
    if not pin or not pin.enabled():
        return await wrapped(*args, **kwargs)

    with tracer.trace("fastapi.serialize_response"):
        return await wrapped(*args, **kwargs)


def patch():
    if getattr(fastapi, "_datadog_patch", False):
        return

    _register_wrapt_pickle_reducers()

    fastapi._datadog_patch = True
    Pin().onto(fastapi)
    _w("fastapi.applications", "FastAPI.build_middleware_stack", wrap_middleware_stack)
    _w("fastapi.routing", "serialize_response", traced_serialize_response)

    if not is_wrapted(fastapi.BackgroundTasks.add_task):
        _w("fastapi", "BackgroundTasks.add_task", _trace_background_tasks(fastapi))
    # We need to check that Starlette instrumentation hasn't already patched these
    if not is_wrapted(fastapi.routing.APIRoute.__init__):
        _w("fastapi.routing", "APIRoute.__init__", traced_route_init)

    if not is_wrapted(fastapi.routing.APIRoute.handle):
        _w("fastapi.routing", "APIRoute.handle", traced_handler)

    if not is_wrapted(fastapi.routing.Mount.handle):
        _w("starlette.routing", "Mount.handle", traced_handler)

    if asm_config._iast_enabled:
        from ddtrace.appsec._iast._handlers import _on_iast_fastapi_patch

        _on_iast_fastapi_patch()


def unpatch():
    if not getattr(fastapi, "_datadog_patch", False):
        return

    fastapi._datadog_patch = False

    _u(fastapi.applications.FastAPI, "build_middleware_stack")
    _u(fastapi.routing, "serialize_response")

    # We need to check that Starlette instrumentation hasn't already unpatched these
    if is_wrapted(fastapi.routing.APIRoute.handle):
        _u(fastapi.routing.APIRoute, "handle")

    if is_wrapted(fastapi.routing.Mount.handle):
        _u(fastapi.routing.Mount, "handle")

    if is_wrapted(fastapi.BackgroundTasks.add_task):
        _u(fastapi.BackgroundTasks, "add_task")
