import importlib
import sys

import openai
from openai import version

from ddtrace import config
from ddtrace.contrib.internal.openai import _endpoint_hooks
from ddtrace.contrib.trace_utils import unwrap
from ddtrace.contrib.trace_utils import wrap
from ddtrace.internal.logger import get_logger
from ddtrace.internal.utils.formats import deep_getattr
from ddtrace.internal.utils.version import parse_version
from ddtrace.llmobs._integrations import OpenAIIntegration
from ddtrace.trace import tracer


log = get_logger(__name__)


config._add("openai", {})


def get_version() -> str:
    return version.VERSION


def _supported_versions() -> dict[str, str]:
    return {"openai": ">=1.0"}


OPENAI_VERSION = parse_version(get_version())


_RESOURCES = {
    "models.Models": {
        "list": _endpoint_hooks._ModelListHook,
        "retrieve": _endpoint_hooks._ModelRetrieveHook,
        "delete": _endpoint_hooks._ModelDeleteHook,
    },
    "completions.Completions": {
        "create": _endpoint_hooks._CompletionHook,
    },
    "chat.Completions": {
        "create": _endpoint_hooks._ChatCompletionHook,
        "parse": _endpoint_hooks._ChatCompletionParseHook,
    },
    "images.Images": {
        "generate": _endpoint_hooks._ImageCreateHook,
        "edit": _endpoint_hooks._ImageEditHook,
        "create_variation": _endpoint_hooks._ImageVariationHook,
    },
    "audio.Transcriptions": {
        "create": _endpoint_hooks._AudioTranscriptionHook,
    },
    "audio.Translations": {
        "create": _endpoint_hooks._AudioTranslationHook,
    },
    "embeddings.Embeddings": {
        "create": _endpoint_hooks._EmbeddingHook,
    },
    "moderations.Moderations": {
        "create": _endpoint_hooks._ModerationHook,
    },
    "files.Files": {
        "create": _endpoint_hooks._FileCreateHook,
        "retrieve": _endpoint_hooks._FileRetrieveHook,
        "list": _endpoint_hooks._FileListHook,
        "delete": _endpoint_hooks._FileDeleteHook,
        "retrieve_content": _endpoint_hooks._FileDownloadHook,
    },
    "responses.Responses": {
        "create": _endpoint_hooks._ResponseHook,
        "parse": _endpoint_hooks._ResponseParseHook,
    },
}

OPENAI_WITH_RAW_RESPONSE_ARG = "_dd.with_raw_response"


def patch():
    if getattr(openai, "__datadog_patch", False):
        return

    if OPENAI_VERSION < (1, 0, 0):
        log.warning("openai version %s is not supported, please upgrade to openai version 1.0 or later", OPENAI_VERSION)
        return

    integration = OpenAIIntegration(integration_config=config.openai, openai=openai)
    openai._datadog_integration = integration

    if OPENAI_VERSION >= (1, 8, 0):
        wrap(openai, "_base_client.SyncAPIClient._process_response", traced_convert)
        wrap(openai, "_base_client.AsyncAPIClient._process_response", traced_convert)
    else:
        wrap(openai, "_base_client.BaseClient._process_response", traced_convert)
    wrap(openai, "OpenAI.__init__", traced_client_init)
    wrap(openai, "AsyncOpenAI.__init__", traced_client_init)
    wrap(openai, "AzureOpenAI.__init__", traced_client_init)
    wrap(openai, "AsyncAzureOpenAI.__init__", traced_client_init)
    wrap(openai, "resources.chat.CompletionsWithRawResponse.__init__", traced_completions_with_raw_response_init)
    wrap(openai, "resources.CompletionsWithRawResponse.__init__", traced_completions_with_raw_response_init)
    wrap(
        openai,
        "resources.chat.AsyncCompletionsWithRawResponse.__init__",
        traced_completions_with_raw_response_init,
    )
    wrap(openai, "resources.AsyncCompletionsWithRawResponse.__init__", traced_completions_with_raw_response_init)

    # HACK: openai.resources.responses is not imported by default in openai 1.78.0 and later, so we need to import it
    #       to detect and patch it below.
    try:
        importlib.import_module("openai.resources.responses")
    except ImportError:
        pass

    for resource, method_hook_dict in _RESOURCES.items():
        if deep_getattr(openai.resources, resource) is None:
            log.debug("WARNING: resource %s is not found", resource)
            continue
        for method_name, endpoint_hook in method_hook_dict.items():
            sync_method = "resources.{}.{}".format(resource, method_name)
            async_method = "resources.{}.{}".format(".Async".join(resource.split(".")), method_name)
            if deep_getattr(openai, sync_method) is not None:
                wrap(openai, sync_method, _patched_endpoint(endpoint_hook))
            if deep_getattr(openai, async_method) is not None:
                wrap(openai, async_method, _patched_endpoint_async(endpoint_hook))

    openai.__datadog_patch = True


def unpatch():
    if not getattr(openai, "__datadog_patch", False):
        return

    if OPENAI_VERSION < (1, 0, 0):
        log.warning("openai version %s is not supported, please upgrade to openai version 1.0 or later", OPENAI_VERSION)
        return

    openai.__datadog_patch = False

    if OPENAI_VERSION >= (1, 8, 0):
        unwrap(openai._base_client.SyncAPIClient, "_process_response")
        unwrap(openai._base_client.AsyncAPIClient, "_process_response")
    else:
        unwrap(openai._base_client.BaseClient, "_process_response")
    unwrap(openai.OpenAI, "__init__")
    unwrap(openai.AsyncOpenAI, "__init__")
    unwrap(openai.AzureOpenAI, "__init__")
    unwrap(openai.AsyncAzureOpenAI, "__init__")
    unwrap(openai.resources.chat.CompletionsWithRawResponse, "__init__")
    unwrap(openai.resources.CompletionsWithRawResponse, "__init__")
    unwrap(openai.resources.chat.AsyncCompletionsWithRawResponse, "__init__")
    unwrap(openai.resources.AsyncCompletionsWithRawResponse, "__init__")

    for resource, method_hook_dict in _RESOURCES.items():
        if deep_getattr(openai.resources, resource) is None:
            continue
        for method_name, _ in method_hook_dict.items():
            sync_resource = deep_getattr(openai.resources, resource)
            async_resource = deep_getattr(openai.resources, ".Async".join(resource.split(".")))
            if sync_resource is not None and hasattr(sync_resource, method_name):
                unwrap(sync_resource, method_name)
            if async_resource is not None and hasattr(async_resource, method_name):
                unwrap(async_resource, method_name)

    delattr(openai, "_datadog_integration")


def traced_client_init(func, instance, args, kwargs):
    """
    Patch for `openai.OpenAI/AsyncOpenAI` client init methods to add
    the client object to the OpenAIIntegration object.
    """
    func(*args, **kwargs)
    integration = openai._datadog_integration
    integration._client = instance
    return


def traced_completions_with_raw_response_init(func, instance, args, kwargs):
    """
    Patch create method of CompletionsWithRawResponse classes to catch requests that use with_raw_response wrapper
    since the response for these streamed requests cannot be traced and we therefore need to avoid creating
    spans for these cases.
    """
    func(*args, **kwargs)
    if hasattr(instance, "create"):
        if isinstance(instance, openai.resources.completions.CompletionsWithRawResponse):
            wrap(instance, "create", _patched_endpoint(_endpoint_hooks._CompletionWithRawResponseHook))
        elif isinstance(instance, openai.resources.chat.CompletionsWithRawResponse):
            wrap(instance, "create", _patched_endpoint(_endpoint_hooks._ChatCompletionWithRawResponseHook))
        elif isinstance(instance, openai.resources.completions.AsyncCompletionsWithRawResponse):
            wrap(instance, "create", _patched_endpoint_async(_endpoint_hooks._CompletionWithRawResponseHook))
        elif isinstance(instance, openai.resources.chat.AsyncCompletionsWithRawResponse):
            wrap(instance, "create", _patched_endpoint_async(_endpoint_hooks._ChatCompletionWithRawResponseHook))
        return


def _traced_endpoint(endpoint_hook, integration, instance, args, kwargs):
    span = integration.trace(endpoint_hook.OPERATION_ID, instance=instance)
    resp, err = None, None
    try:
        # Start the hook
        hook = endpoint_hook().handle_request(None, integration, instance, span, args, kwargs)
        hook.send(None)

        resp, err = yield

        # Record any error information
        if err is not None:
            span.set_exc_info(*sys.exc_info())

        # Pass the response and the error to the hook
        try:
            hook.send((resp, err))
        except StopIteration as e:
            if err is None:
                return e.value
    finally:
        # Streamed responses will be finished when the generator exits, so finish non-streamed spans here.
        # Streamed responses with error will need to be finished manually as well.
        if not kwargs.get("stream") or err is not None or resp is None:
            span.finish()


def _patched_endpoint(patch_hook):
    def patched_endpoint(func, instance, args, kwargs):
        if (
            patch_hook is _endpoint_hooks._ChatCompletionWithRawResponseHook
            or patch_hook is _endpoint_hooks._CompletionWithRawResponseHook
        ):
            kwargs[OPENAI_WITH_RAW_RESPONSE_ARG] = True
            return func(*args, **kwargs)
        if kwargs.pop(OPENAI_WITH_RAW_RESPONSE_ARG, False) and kwargs.get("stream", False):
            return func(*args, **kwargs)

        integration = openai._datadog_integration
        g = _traced_endpoint(patch_hook, integration, instance, args, kwargs)
        g.send(None)
        resp, err = None, None
        override_return = None
        try:
            resp = func(*args, **kwargs)
        except BaseException as e:
            err = e
            raise
        finally:
            try:
                g.send((resp, err))
            except StopIteration as e:
                if err is None:
                    # This return takes priority over the implicit None return
                    override_return = e.value

        if override_return is not None:
            return override_return

    return patched_endpoint


class _TracedAsyncPaginator:
    """Wrapper for AsyncPaginator objects to enable tracing for both await and async for usage."""

    def __init__(self, paginator, integration, patch_hook, instance, args, kwargs):
        self._paginator = paginator
        self._integration = integration
        self._patch_hook = patch_hook
        self._instance = instance
        self._args = args
        self._kwargs = kwargs

    def __aiter__(self):
        async def _traced_aiter():
            g = _traced_endpoint(self._patch_hook, self._integration, self._instance, self._args, self._kwargs)
            g.send(None)
            err = None
            completed = False
            try:
                iterator = self._paginator.__aiter__()
                # Fetch first item to trigger trace completion before iteration starts.
                # This ensures the span is recorded even if iteration stops early.
                first_item = await iterator.__anext__()
                try:
                    g.send((None, None))
                    completed = True
                except StopIteration:
                    completed = True
                yield first_item
                async for item in iterator:
                    yield item
            except StopAsyncIteration:
                pass
            except BaseException as e:
                err = e
                raise
            finally:
                if not completed:
                    try:
                        g.send((None, err))
                    except StopIteration:
                        pass

        return _traced_aiter()

    def __await__(self):
        async def _trace_and_await():
            g = _traced_endpoint(self._patch_hook, self._integration, self._instance, self._args, self._kwargs)
            g.send(None)
            resp, err = None, None
            try:
                resp = await self._paginator
            except BaseException as e:
                err = e
                raise
            finally:
                try:
                    g.send((resp, err))
                except StopIteration as e:
                    if err is None:
                        resp = e.value
            return resp

        return _trace_and_await().__await__()


def _patched_endpoint_async(patch_hook):
    def patched_endpoint(func, instance, args, kwargs):
        if (
            patch_hook is _endpoint_hooks._ChatCompletionWithRawResponseHook
            or patch_hook is _endpoint_hooks._CompletionWithRawResponseHook
        ):
            kwargs[OPENAI_WITH_RAW_RESPONSE_ARG] = True
            return func(*args, **kwargs)
        if kwargs.pop(OPENAI_WITH_RAW_RESPONSE_ARG, False) and kwargs.get("stream", False):
            return func(*args, **kwargs)

        result = func(*args, **kwargs)
        # Detect AsyncPaginator objects (have both __aiter__ and __await__).
        # These must be returned directly (not awaited) to preserve iteration behavior.
        if hasattr(result, "__aiter__") and hasattr(result, "__await__"):
            return _TracedAsyncPaginator(result, openai._datadog_integration, patch_hook, instance, args, kwargs)

        async def async_wrapper():
            integration = openai._datadog_integration
            g = _traced_endpoint(patch_hook, integration, instance, args, kwargs)
            g.send(None)
            resp, err = None, None
            override_return = None
            try:
                resp = await result
            except BaseException as e:
                err = e
                raise
            finally:
                try:
                    g.send((resp, err))
                except StopIteration as e:
                    if err is None:
                        override_return = e.value

            if override_return is not None:
                return override_return
            return resp

        return async_wrapper()

    return patched_endpoint


def traced_convert(func, instance, args, kwargs):
    """Patch convert captures header information in the openai response"""
    span = tracer.current_span()
    if not span:
        return func(*args, **kwargs)

    if OPENAI_VERSION < (1, 0, 0):
        resp = args[0]
        if not isinstance(resp, openai.openai_response.OpenAIResponse):
            return func(*args, **kwargs)
        headers = resp._headers
    else:
        resp = kwargs.get("response", {})
        headers = resp.headers
    # This function is called for each chunk in the stream.
    # To prevent needlessly setting the same tags for each chunk, short-circuit here.
    if span.get_tag("openai.organization.name") is not None:
        return func(*args, **kwargs)
    if headers.get("openai-organization"):
        org_name = headers.get("openai-organization")
        span._set_tag_str("openai.organization.name", org_name)

    # Gauge total rate limit
    if headers.get("x-ratelimit-limit-requests"):
        v = headers.get("x-ratelimit-limit-requests")
        if v is not None:
            span.set_metric("openai.organization.ratelimit.requests.limit", int(v))
    if headers.get("x-ratelimit-limit-tokens"):
        v = headers.get("x-ratelimit-limit-tokens")
        if v is not None:
            span.set_metric("openai.organization.ratelimit.tokens.limit", int(v))
    # Gauge and set span info for remaining requests and tokens
    if headers.get("x-ratelimit-remaining-requests"):
        v = headers.get("x-ratelimit-remaining-requests")
        if v is not None:
            span.set_metric("openai.organization.ratelimit.requests.remaining", int(v))
    if headers.get("x-ratelimit-remaining-tokens"):
        v = headers.get("x-ratelimit-remaining-tokens")
        if v is not None:
            span.set_metric("openai.organization.ratelimit.tokens.remaining", int(v))

    return func(*args, **kwargs)
