import abc
from typing import Any  # noqa:F401
from typing import Optional  # noqa:F401

from ddtrace import config
from ddtrace._trace.sampler import RateSampler
from ddtrace.constants import _SPAN_MEASURED_KEY
from ddtrace.contrib.internal.trace_utils import int_service
from ddtrace.ext import SpanTypes
from ddtrace.internal.logger import get_logger
from ddtrace.internal.settings.integration import IntegrationConfig
from ddtrace.llmobs._constants import INTEGRATION
from ddtrace.llmobs._constants import PROXY_REQUEST
from ddtrace.llmobs._llmobs import LLMObs
from ddtrace.trace import Span
from ddtrace.trace import tracer


log = get_logger(__name__)


class BaseLLMIntegration:
    _integration_name = "baseLLM"

    def __init__(self, integration_config: IntegrationConfig) -> None:
        self.integration_config = integration_config
        self._llmobs_pc_sampler = RateSampler(sample_rate=config._llmobs_sample_rate)

    @property
    def llmobs_enabled(self) -> bool:
        """Return whether submitting llmobs payloads is enabled."""
        return LLMObs.enabled

    def is_pc_sampled_llmobs(self, span: Span) -> bool:
        # Sampling of llmobs payloads is independent of spans, but we're using a RateSampler for consistency.
        if not self.llmobs_enabled:
            return False
        return self._llmobs_pc_sampler.sample(span)

    @abc.abstractmethod
    def _set_base_span_tags(self, span: Span, **kwargs) -> None:
        """Set default LLM span attributes when possible."""
        pass

    def trace(self, operation_id: str, submit_to_llmobs: bool = False, **kwargs) -> Span:
        """
        Start a LLM request span.
        Reuse the service of the application since we'll tag downstream request spans with the LLM name.
        Eventually those should also be internal service spans once peer.service is implemented.
        """
        span_name = kwargs.get("span_name", None) or "{}.request".format(self._integration_name)
        span_type = SpanTypes.LLM if (submit_to_llmobs and self.llmobs_enabled) else None
        parent_context = kwargs.get("parent_context") or tracer.context_provider.active()

        span = tracer.start_span(
            span_name,
            child_of=parent_context,
            service=int_service(None, self.integration_config),
            resource=operation_id,
            span_type=span_type,
            activate=True,
        )

        log.debug("Creating LLM span with type %s", span.span_type)
        # determine if the span represents a proxy request
        base_url = self._get_base_url(**kwargs)
        if self._is_instrumented_proxy_url(base_url):
            span._set_ctx_item(PROXY_REQUEST, True)
        # Enable trace metrics for these spans so users can see per-service openai usage in APM.
        # PERF: avoid setting via Span.set_tag
        span.set_metric(_SPAN_MEASURED_KEY, 1)
        self._set_base_span_tags(span, **kwargs)
        if self.llmobs_enabled:
            span._set_ctx_item(INTEGRATION, self._integration_name)
        return span

    def llmobs_set_tags(
        self,
        span: Span,
        args: list[Any],
        kwargs: dict[str, Any],
        response: Optional[Any] = None,
        operation: str = "",
    ) -> None:
        """Extract input/output information from the request and response to be submitted to LLMObs."""
        if not self.llmobs_enabled or not self.is_pc_sampled_llmobs(span):
            return
        try:
            self._llmobs_set_tags(span, args, kwargs, response, operation)
        except Exception:
            log.error("Error extracting LLMObs fields for span %s, likely due to malformed data", span, exc_info=True)

    @abc.abstractmethod
    def _llmobs_set_tags(
        self,
        span: Span,
        args: list[Any],
        kwargs: dict[str, Any],
        response: Optional[Any] = None,
        operation: str = "",
    ) -> None:
        raise NotImplementedError()

    def _get_base_url(self, **kwargs: dict[str, Any]) -> Optional[str]:
        return None

    def _is_instrumented_proxy_url(self, base_url: Optional[str] = None) -> bool:
        if not base_url:
            return False
        instrumented_proxy_urls = config._llmobs_instrumented_proxy_urls or set()
        return base_url in instrumented_proxy_urls
