from openai.version import VERSION as OPENAI_VERSION

from ddtrace.contrib.internal.openai.utils import OpenAIAsyncStreamHandler
from ddtrace.contrib.internal.openai.utils import OpenAIStreamHandler
from ddtrace.contrib.internal.openai.utils import _is_async_generator
from ddtrace.contrib.internal.openai.utils import _is_generator
from ddtrace.contrib.internal.openai.utils import _loop_handler
from ddtrace.contrib.internal.openai.utils import _process_finished_stream
from ddtrace.internal.utils.version import parse_version
from ddtrace.llmobs._constants import OAI_HANDOFF_TOOL_ARG
from ddtrace.llmobs._integrations.base_stream_handler import make_traced_stream
from ddtrace.llmobs._utils import _get_attr
from ddtrace.llmobs._utils import safe_load_json


API_VERSION = "v1"


class _EndpointHook:
    """
    Base class for all OpenAI endpoint hooks.
    Each new endpoint hook should declare `_request_arg_params` and `_request_kwarg_params`,
    which will be tagged automatically by _EndpointHook._record_request().
    For endpoint-specific request/response parameters that requires special casing, add that logic to
    the endpoint hook's `_record_request()` after a super call to the base `_EndpointHook._record_request()`.
    """

    # _request_arg_params must include the names of arg parameters in order.
    # If a given arg requires special casing, replace with `None` to avoid automatic tagging.
    _request_arg_params = ()
    # _request_kwarg_params must include the names of kwarg parameters to tag automatically.
    # If a given kwarg requires special casing, remove from this tuple to avoid automatic tagging.
    _request_kwarg_params = ()
    # _response_attrs is used to automatically tag specific response attributes.
    _response_attrs = ()
    _base_level_tag_args = ("api_base", "api_type", "api_version")
    ENDPOINT_NAME = "openai"
    HTTP_METHOD_TYPE = ""
    OPERATION_ID = ""  # Each endpoint hook must provide an operationID as specified in the OpenAI API specs:
    # https://raw.githubusercontent.com/openai/openai-openapi/master/openapi.yaml

    def _record_request(self, pin, integration, instance, span, args, kwargs):
        """
        Set base-level openai tags, as well as request params from args and kwargs.
        All inherited EndpointHook classes should include a super call to this method before performing
        endpoint-specific request tagging logic.
        """
        endpoint = self.ENDPOINT_NAME
        if endpoint is None:
            endpoint = "%s" % getattr(instance, "OBJECT_NAME", "")
        span._set_tag_str("openai.request.endpoint", "/%s/%s" % (API_VERSION, endpoint))
        span._set_tag_str("openai.request.method", self.HTTP_METHOD_TYPE)

        if self._request_arg_params and len(self._request_arg_params) > 1:
            for idx, arg in enumerate(self._request_arg_params):
                if idx >= len(args):
                    break
                if arg is None or args[idx] is None:
                    continue
                if arg in self._base_level_tag_args:
                    span._set_tag_str("openai.%s" % arg, str(args[idx]))
        for kw_attr in self._request_kwarg_params:
            if kw_attr not in kwargs:
                continue

            if isinstance(kwargs[kw_attr], dict):
                for k, v in kwargs[kw_attr].items():
                    span._set_tag_str("openai.request.%s.%s" % (kw_attr, k), str(v))
            elif (kw_attr == "engine" or kw_attr == "model") and kwargs[
                kw_attr
            ] is not None:  # Azure OpenAI requires using "engine" instead of "model"
                span._set_tag_str("openai.request.model", str(kwargs[kw_attr]))

    def handle_request(self, pin, integration, instance, span, args, kwargs):
        self._record_request(pin, integration, instance, span, args, kwargs)
        resp, error = yield
        if hasattr(resp, "parse"):
            # Users can request the raw response, in which case we need to process on the parsed response
            # and return the original raw APIResponse.
            self._record_response(pin, integration, span, args, kwargs, resp.parse(), error)
            return resp
        return self._record_response(pin, integration, span, args, kwargs, resp, error)

    def _record_response(self, pin, integration, span, args, kwargs, resp, error):
        for resp_attr in self._response_attrs:
            if hasattr(resp, resp_attr):
                span._set_tag_str("openai.response.%s" % resp_attr, str(getattr(resp, resp_attr, "")))
        return resp


class _BaseCompletionHook(_EndpointHook):
    def _handle_streamed_response(self, integration, span, kwargs, resp, operation_type=""):
        """Handle streamed response objects returned from completions/chat/response endpoint calls.

        This method returns a wrapped version of the OpenAIStream/OpenAIAsyncStream objects
        to trace the response while it is read by the user.
        """
        if parse_version(OPENAI_VERSION) >= (1, 6, 0):
            if _is_async_generator(resp):
                return make_traced_stream(
                    resp,
                    OpenAIAsyncStreamHandler(integration, span, None, kwargs, operation_type=operation_type),
                )
            elif _is_generator(resp):
                return make_traced_stream(
                    resp, OpenAIStreamHandler(integration, span, None, kwargs, operation_type=operation_type)
                )

        def shared_gen():
            try:
                streamed_chunks = yield
                _process_finished_stream(integration, span, kwargs, streamed_chunks, operation_type=operation_type)
            finally:
                span.finish()

        if _is_async_generator(resp):

            async def traced_streamed_response():
                g = shared_gen()
                g.send(None)
                n = kwargs.get("n", 1) or 1
                if operation_type == "completion":
                    prompts = kwargs.get("prompt", "")
                    if isinstance(prompts, list) and not isinstance(prompts[0], int):
                        n *= len(prompts)
                streamed_chunks = [[] for _ in range(n)]
                try:
                    async for chunk in resp:
                        _loop_handler(span, chunk, streamed_chunks)
                        yield chunk
                finally:
                    try:
                        g.send(streamed_chunks)
                    except StopIteration:
                        pass

            return traced_streamed_response()

        elif _is_generator(resp):

            def traced_streamed_response():
                g = shared_gen()
                g.send(None)
                n = kwargs.get("n", 1) or 1
                if operation_type == "completion":
                    prompts = kwargs.get("prompt", "")
                    if isinstance(prompts, list) and not isinstance(prompts[0], int):
                        n *= len(prompts)
                streamed_chunks = [[] for _ in range(n)]
                try:
                    for chunk in resp:
                        _loop_handler(span, chunk, streamed_chunks)
                        yield chunk
                finally:
                    try:
                        g.send(streamed_chunks)
                    except StopIteration:
                        pass

            return traced_streamed_response()
        return resp


class _CompletionHook(_BaseCompletionHook):
    _request_kwarg_params = (
        "model",
        "engine",
        "suffix",
    )
    _response_attrs = ("model",)
    ENDPOINT_NAME = "completions"
    HTTP_METHOD_TYPE = "POST"
    OPERATION_ID = "createCompletion"

    def _record_response(self, pin, integration, span, args, kwargs, resp, error):
        resp = super()._record_response(pin, integration, span, args, kwargs, resp, error)
        if not resp:
            integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=resp, operation="completion")
            return
        if kwargs.get("stream") and error is None:
            return self._handle_streamed_response(integration, span, kwargs, resp, operation_type="completion")
        integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=resp, operation="completion")
        return resp


class _CompletionWithRawResponseHook(_CompletionHook):
    pass


class _ChatCompletionHook(_BaseCompletionHook):
    _request_kwarg_params = (
        "model",
        "engine",
    )
    _response_attrs = ("model",)
    ENDPOINT_NAME = "chat/completions"
    HTTP_METHOD_TYPE = "POST"
    OPERATION_ID = "createChatCompletion"

    def _record_request(self, pin, integration, instance, span, args, kwargs):
        super()._record_request(pin, integration, instance, span, args, kwargs)
        if parse_version(OPENAI_VERSION) >= (1, 26) and kwargs.get("stream"):
            stream_options = kwargs.get("stream_options", {})
            if not isinstance(stream_options, dict):
                stream_options = {}
            if stream_options.get("include_usage", None) is not None:
                # Only perform token chunk auto-extraction if this option is not explicitly set
                return
            span._set_ctx_item("_dd.auto_extract_token_chunk", True)
            stream_options["include_usage"] = True
            kwargs["stream_options"] = stream_options

    def _record_response(self, pin, integration, span, args, kwargs, resp, error):
        resp = super()._record_response(pin, integration, span, args, kwargs, resp, error)
        if not resp:
            integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=resp, operation="chat")
            return
        if kwargs.get("stream") and error is None:
            return self._handle_streamed_response(integration, span, kwargs, resp, operation_type="chat")
        integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=resp, operation="chat")
        return resp


class _ChatCompletionWithRawResponseHook(_ChatCompletionHook):
    pass


class _ChatCompletionParseHook(_ChatCompletionHook):
    OPERATION_ID = "parseChatCompletion"


class _EmbeddingHook(_EndpointHook):
    _request_kwarg_params = ("model", "engine")
    _response_attrs = ("model",)
    ENDPOINT_NAME = "embeddings"
    HTTP_METHOD_TYPE = "POST"
    OPERATION_ID = "createEmbedding"

    def _record_response(self, pin, integration, span, args, kwargs, resp, error):
        resp = super()._record_response(pin, integration, span, args, kwargs, resp, error)
        integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=resp, operation="embedding")
        return resp


class _ListHook(_EndpointHook):
    """
    Hook for openai.ListableAPIResource, which is used by Model.list, File.list, and FineTune.list.
    """

    _request_arg_params = ("api_base", "api_version")
    _request_kwarg_params = ("user",)
    ENDPOINT_NAME = None
    HTTP_METHOD_TYPE = "GET"
    OPERATION_ID = "list"

    def _record_request(self, pin, integration, instance, span, args, kwargs):
        super()._record_request(pin, integration, instance, span, args, kwargs)
        endpoint = span.get_tag("openai.request.endpoint")
        if endpoint.endswith("/models"):
            span.resource = "listModels"
        elif endpoint.endswith("/files"):
            span.resource = "listFiles"

    def _record_response(self, pin, integration, span, args, kwargs, resp, error):
        resp = super()._record_response(pin, integration, span, args, kwargs, resp, error)
        if not resp:
            return
        if hasattr(resp, "data"):
            span.set_metric("openai.response.count", len(resp.data or []))
        return resp


class _ModelListHook(_ListHook):
    """
    Hook for openai.resources.models.Models.list (v1)
    """

    ENDPOINT_NAME = "models"
    OPERATION_ID = "listModels"


class _FileListHook(_ListHook):
    """
    Hook for openai.resources.files.Files.list (v1)
    """

    ENDPOINT_NAME = "files"
    OPERATION_ID = "listFiles"


class _RetrieveHook(_EndpointHook):
    """Hook for openai.APIResource, which is used by Model.retrieve, File.retrieve, and FineTune.retrieve."""

    _request_arg_params = (None, "request_id", "request_timeout")
    _request_kwarg_params = ("user",)
    _response_attrs = (
        "id",
        "owned_by",
        "model",
        "parent",
        "root",
        "bytes",
        "created",
        "created_at",
        "purpose",
        "filename",
        "fine_tuned_model",
        "status",
        "status_details",
        "updated_at",
    )
    ENDPOINT_NAME = None
    HTTP_METHOD_TYPE = "GET"
    OPERATION_ID = "retrieve"

    def _record_request(self, pin, integration, instance, span, args, kwargs):
        super()._record_request(pin, integration, instance, span, args, kwargs)
        endpoint = span.get_tag("openai.request.endpoint")
        if endpoint.endswith("/models"):
            span.resource = "retrieveModel"
            if len(args) >= 1:
                span._set_tag_str("openai.request.model", args[0])
            else:
                span._set_tag_str("openai.request.model", kwargs.get("model", kwargs.get("id", "")))
        elif endpoint.endswith("/files"):
            span.resource = "retrieveFile"
            if len(args) >= 1:
                span._set_tag_str("openai.request.file_id", args[0])
            else:
                span._set_tag_str("openai.request.file_id", kwargs.get("file_id", kwargs.get("id", "")))
        span._set_tag_str("openai.request.endpoint", "%s/*" % endpoint)

    def _record_response(self, pin, integration, span, args, kwargs, resp, error):
        resp = super()._record_response(pin, integration, span, args, kwargs, resp, error)
        return resp


class _ModelRetrieveHook(_RetrieveHook):
    """
    Hook for openai.resources.models.Models.retrieve
    """

    ENDPOINT_NAME = "models"
    OPERATION_ID = "retrieveModel"


class _FileRetrieveHook(_RetrieveHook):
    """
    Hook for openai.resources.files.Files.retrieve
    """

    ENDPOINT_NAME = "files"
    OPERATION_ID = "retrieveFile"


class _DeleteHook(_EndpointHook):
    """Hook for openai.DeletableAPIResource, which is used by File.delete, and Model.delete."""

    _request_arg_params = (None, "api_type", "api_version")
    _request_kwarg_params = ("user",)
    ENDPOINT_NAME = None
    HTTP_METHOD_TYPE = "DELETE"
    OPERATION_ID = "delete"

    def _record_request(self, pin, integration, instance, span, args, kwargs):
        super()._record_request(pin, integration, instance, span, args, kwargs)
        endpoint = span.get_tag("openai.request.endpoint")
        if endpoint.endswith("/models"):
            span.resource = "deleteModel"
            if len(args) >= 1:
                span._set_tag_str("openai.request.model", args[0])
            else:
                span._set_tag_str("openai.request.model", kwargs.get("model", kwargs.get("sid", "")))
        elif endpoint.endswith("/files"):
            span.resource = "deleteFile"
            if len(args) >= 1:
                span._set_tag_str("openai.request.file_id", args[0])
            else:
                span._set_tag_str("openai.request.file_id", kwargs.get("file_id", kwargs.get("sid", "")))
        span._set_tag_str("openai.request.endpoint", "%s/*" % endpoint)

    def _record_response(self, pin, integration, span, args, kwargs, resp, error):
        resp = super()._record_response(pin, integration, span, args, kwargs, resp, error)
        if not resp:
            return
        if hasattr(resp, "data"):
            if resp._headers.get("openai-organization"):
                span._set_tag_str("openai.organization.name", resp._headers.get("openai-organization"))
            span._set_tag_str("openai.response.id", resp.data.get("id", ""))
            span._set_tag_str("openai.response.deleted", str(resp.data.get("deleted", "")))
        else:
            span._set_tag_str("openai.response.id", str(resp.id))
            span._set_tag_str("openai.response.deleted", str(resp.deleted))
        return resp


class _FileDeleteHook(_DeleteHook):
    """
    Hook for openai.resources.files.Files.delete
    """

    ENDPOINT_NAME = "files"


class _ModelDeleteHook(_DeleteHook):
    """
    Hook for openai.resources.models.Models.delete
    """

    ENDPOINT_NAME = "models"


class _ImageHook(_EndpointHook):
    _response_attrs = ("created",)
    ENDPOINT_NAME = "images"
    HTTP_METHOD_TYPE = "POST"

    def _record_request(self, pin, integration, instance, span, args, kwargs):
        super()._record_request(pin, integration, instance, span, args, kwargs)
        span._set_tag_str("openai.request.model", "dall-e")


class _ImageCreateHook(_ImageHook):
    ENDPOINT_NAME = "images/generations"
    OPERATION_ID = "createImage"


class _ImageEditHook(_ImageHook):
    ENDPOINT_NAME = "images/edits"
    OPERATION_ID = "createImageEdit"


class _ImageVariationHook(_ImageHook):
    ENDPOINT_NAME = "images/variations"
    OPERATION_ID = "createImageVariation"


class _BaseAudioHook(_EndpointHook):
    _request_arg_params = ("model",)
    ENDPOINT_NAME = "audio"
    HTTP_METHOD_TYPE = "POST"


class _AudioTranscriptionHook(_BaseAudioHook):
    ENDPOINT_NAME = "audio/transcriptions"
    OPERATION_ID = "createTranscription"


class _AudioTranslationHook(_BaseAudioHook):
    ENDPOINT_NAME = "audio/translations"
    OPERATION_ID = "createTranslation"


class _ModerationHook(_EndpointHook):
    _request_arg_params = ("model",)
    _request_kwarg_params = ("model",)
    _response_attrs = ("id", "model")
    ENDPOINT_NAME = "moderations"
    HTTP_METHOD_TYPE = "POST"
    OPERATION_ID = "createModeration"


class _BaseFileHook(_EndpointHook):
    ENDPOINT_NAME = "files"


class _FileCreateHook(_BaseFileHook):
    _request_arg_params = (
        None,
        "purpose",
        "model",
        "api_key",
        "api_base",
        "api_type",
        "api_version",
        "organization",
        "user_provided_filename",
    )
    _request_kwarg_params = ("purpose", "user_provided_filename")
    _response_attrs = ("id", "bytes", "created_at", "filename", "purpose", "status", "status_details")
    HTTP_METHOD_TYPE = "POST"
    OPERATION_ID = "createFile"

    def _record_request(self, pin, integration, instance, span, args, kwargs):
        super()._record_request(pin, integration, instance, span, args, kwargs)
        fp = args[0] if len(args) >= 1 else kwargs.get("file", "")
        if fp and hasattr(fp, "name"):
            span._set_tag_str("openai.request.filename", fp.name.split("/")[-1])
        else:
            span._set_tag_str("openai.request.filename", "")

    def _record_response(self, pin, integration, span, args, kwargs, resp, error):
        resp = super()._record_response(pin, integration, span, args, kwargs, resp, error)
        return resp


class _FileDownloadHook(_BaseFileHook):
    _request_arg_params = (None, "api_key", "api_base", "api_type", "api_version", "organization")
    HTTP_METHOD_TYPE = "GET"
    OPERATION_ID = "downloadFile"
    ENDPOINT_NAME = "files/*/content"

    def _record_request(self, pin, integration, instance, span, args, kwargs):
        super()._record_request(pin, integration, instance, span, args, kwargs)
        span._set_tag_str("openai.request.file_id", args[0] if len(args) >= 1 else kwargs.get("file_id", ""))

    def _record_response(self, pin, integration, span, args, kwargs, resp, error):
        resp = super()._record_response(pin, integration, span, args, kwargs, resp, error)
        if not resp:
            return
        if isinstance(resp, bytes) or isinstance(resp, str):
            span.set_metric("openai.response.total_bytes", len(resp))
        else:
            span.set_metric("openai.response.total_bytes", getattr(resp, "total_bytes", 0))
        return resp


class _ResponseHook(_BaseCompletionHook):
    # Collecting all kwargs for responses
    _request_kwarg_params = ("model",)
    _response_attrs = ("model",)
    ENDPOINT_NAME = "responses"
    HTTP_METHOD_TYPE = "POST"
    OPERATION_ID = "createResponse"

    def _record_response(self, pin, integration, span, args, kwargs, resp, error):
        resp = super()._record_response(pin, integration, span, args, kwargs, resp, error)
        self._trace_mcp_tool_usage(pin, integration, resp)
        if not resp:
            integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=resp, operation="response")
            return resp
        if kwargs.get("stream") and error is None:
            return self._handle_streamed_response(integration, span, kwargs, resp, operation_type="response")
        integration.llmobs_set_tags(span, args=[], kwargs=kwargs, response=resp, operation="response")
        return resp

    def _trace_mcp_tool_usage(self, pin, integration, resp):
        """Detect and trace server-side MCP tool usage in the response."""
        if not resp:
            return

        messages = _get_attr(resp, "output", [])

        if messages and isinstance(messages, list):
            for item in messages:
                message_type = _get_attr(item, "type", "")
                if message_type == "mcp_call":
                    self._create_mcp_tool_span(item, integration, pin)

    def _create_mcp_tool_span(self, item, integration, pin):
        """Creates and submits a tool span to LLMObs to represent a server-side MCP tool call."""
        with integration.trace("client_tool_call", submit_to_llmobs=True, kind="tool") as span:
            tool_id = str(_get_attr(item, "id", ""))
            tool_name = str(_get_attr(item, "name", ""))
            raw_arguments = _get_attr(item, "arguments", OAI_HANDOFF_TOOL_ARG)
            tool_arguments = safe_load_json(str(raw_arguments))
            tool_output = str(_get_attr(item, "output", ""))
            integration.llmobs_set_tags(
                span,
                args=[],
                kwargs={"name": tool_name, "arguments": tool_arguments, "tool_id": tool_id},
                response=tool_output,
                operation="tool",
            )


class _ResponseParseHook(_ResponseHook):
    OPERATION_ID = "parseResponse"
