import json
from typing import Any
from typing import Optional
from typing import TypedDict

from ddtrace._trace.sampling_rule import SamplingRule
from ddtrace._trace.span import Span
from ddtrace.constants import _SAMPLING_AGENT_DECISION
from ddtrace.constants import _SAMPLING_RULE_DECISION
from ddtrace.constants import _SINGLE_SPAN_SAMPLING_MAX_PER_SEC
from ddtrace.constants import _SINGLE_SPAN_SAMPLING_MAX_PER_SEC_NO_LIMIT
from ddtrace.constants import _SINGLE_SPAN_SAMPLING_MECHANISM
from ddtrace.constants import _SINGLE_SPAN_SAMPLING_RATE
from ddtrace.internal.constants import _KEEP_PRIORITY_INDEX
from ddtrace.internal.constants import _REJECT_PRIORITY_INDEX
from ddtrace.internal.constants import MAX_UINT_64BITS
from ddtrace.internal.constants import SAMPLING_DECISION_MAKER_INHERITED
from ddtrace.internal.constants import SAMPLING_DECISION_MAKER_RESOURCE
from ddtrace.internal.constants import SAMPLING_DECISION_MAKER_SERVICE
from ddtrace.internal.constants import SAMPLING_DECISION_TRACE_TAG_KEY
from ddtrace.internal.constants import SAMPLING_HASH_MODULO
from ddtrace.internal.constants import SAMPLING_KNUTH_FACTOR
from ddtrace.internal.constants import SAMPLING_MECHANISM_TO_PRIORITIES
from ddtrace.internal.constants import SamplingMechanism
from ddtrace.internal.glob_matching import GlobMatcher
from ddtrace.internal.logger import get_logger
from ddtrace.internal.settings._config import config

from .rate_limiter import RateLimiter


log = get_logger(__name__)


class PriorityCategory(object):
    DEFAULT = "default"
    AUTO = "auto"
    RULE_DEFAULT = "rule_default"
    RULE_CUSTOMER = "rule_customer"
    RULE_DYNAMIC = "rule_dynamic"


SAMPLING_MECHANISM_CONSTANTS = {
    "-{}".format(value) for name, value in vars(SamplingMechanism).items() if name.isupper()
}

KNUTH_SAMPLE_RATE_KEY = "_dd.p.ksr"

SpanSamplingRules = TypedDict(
    "SpanSamplingRules",
    {
        "name": str,
        "service": str,
        "sample_rate": float,
        "max_per_second": int,
    },
    total=False,
)


def validate_sampling_decision(
    meta: dict[str, str],
) -> dict[str, str]:
    value = meta.get(SAMPLING_DECISION_TRACE_TAG_KEY)
    if value:
        # Skip propagating invalid sampling mechanism trace tag
        if value not in SAMPLING_MECHANISM_CONSTANTS:
            del meta[SAMPLING_DECISION_TRACE_TAG_KEY]
            meta["_dd.propagation_error"] = "decoding_error"
            log.warning("failed to decode _dd.p.dm: %r", value)
    return meta


class SpanSamplingRule:
    """A span sampling rule to evaluate and potentially tag each span upon finish."""

    __slots__ = (
        "_service_matcher",
        "_name_matcher",
        "_sample_rate",
        "_max_per_second",
        "_sampling_id_threshold",
        "_limiter",
        "_matcher",
    )

    def __init__(
        self,
        sample_rate: float,
        max_per_second: int,
        service: Optional[str] = None,
        name: Optional[str] = None,
    ):
        self._sample_rate = sample_rate
        self._sampling_id_threshold = self._sample_rate * MAX_UINT_64BITS

        self._max_per_second = max_per_second
        self._limiter = RateLimiter(max_per_second)

        # we need to create matchers for the service and/or name pattern provided
        self._service_matcher = GlobMatcher(service) if service is not None else None
        self._name_matcher = GlobMatcher(name) if name is not None else None

    def sample(self, span: Span) -> bool:
        if self._sample(span):
            if self._limiter.is_allowed():
                self.apply_span_sampling_tags(span)
                return True
        return False

    def _sample(self, span: Span) -> bool:
        if self._sample_rate == 1:
            return True
        elif self._sample_rate == 0:
            return False

        return ((span.span_id * SAMPLING_KNUTH_FACTOR) % SAMPLING_HASH_MODULO) <= self._sampling_id_threshold

    def match(self, span: Span) -> bool:
        """Determines if the span's service and name match the configured patterns"""
        name = span.name
        service = span.service
        # If a span lacks a name and service, we can't match on it
        if service is None and name is None:
            return False

        # Default to True, as the rule may not have a name or service rule
        # For whichever rules it does have, it will attempt to match on them
        service_match = True
        name_match = True

        if self._service_matcher:
            if service is None:
                return False
            else:
                service_match = self._service_matcher.match(service)
        if self._name_matcher:
            if name is None:
                return False
            else:
                name_match = self._name_matcher.match(name)
        return service_match and name_match

    def apply_span_sampling_tags(self, span: Span) -> None:
        span.set_metric(_SINGLE_SPAN_SAMPLING_MECHANISM, SamplingMechanism.SPAN_SAMPLING_RULE)
        span.set_metric(_SINGLE_SPAN_SAMPLING_RATE, self._sample_rate)
        # Only set this tag if it's not the default -1
        if self._max_per_second != _SINGLE_SPAN_SAMPLING_MAX_PER_SEC_NO_LIMIT:
            span.set_metric(_SINGLE_SPAN_SAMPLING_MAX_PER_SEC, self._max_per_second)


def get_span_sampling_rules() -> list[SpanSamplingRule]:
    json_rules = _get_span_sampling_json()
    sampling_rules = []
    for rule in json_rules:
        # If sample_rate not specified default to 100%
        sample_rate = rule.get("sample_rate", 1.0)
        service = rule.get("service")
        name = rule.get("name")

        if not service and not name:
            log.warning("Sampling rules must supply at least 'service' or 'name', got %s", json.dumps(rule))
            return []

        # If max_per_second not specified default to no limit
        max_per_second = rule.get("max_per_second", _SINGLE_SPAN_SAMPLING_MAX_PER_SEC_NO_LIMIT)

        try:
            if service:
                _check_unsupported_pattern(service)
            if name:
                _check_unsupported_pattern(name)
            sampling_rule = SpanSamplingRule(
                sample_rate=sample_rate, service=service, name=name, max_per_second=max_per_second
            )
        except Exception as e:
            log.warning("Error creating single span sampling rule %s: %s", json.dumps(rule), e)
        else:
            sampling_rules.append(sampling_rule)
    return sampling_rules


def _get_span_sampling_json() -> list[dict[str, Any]]:
    env_json_rules = _get_env_json()
    file_json_rules = _get_file_json()

    if env_json_rules and file_json_rules:
        log.warning(
            (
                "DD_SPAN_SAMPLING_RULES and DD_SPAN_SAMPLING_RULES_FILE detected. "
                "Defaulting to DD_SPAN_SAMPLING_RULES value."
            )
        )
        return env_json_rules
    return env_json_rules or file_json_rules or []


def _get_file_json() -> Optional[list[dict[str, Any]]]:
    file_json_raw = config._sampling_rules_file
    if file_json_raw:
        with open(file_json_raw) as f:
            return _load_span_sampling_json(f.read())
    return None


def _get_env_json() -> Optional[list[dict[str, Any]]]:
    env_json_raw = config._sampling_rules
    if env_json_raw:
        return _load_span_sampling_json(env_json_raw)
    return None


def _load_span_sampling_json(raw_json_rules: str) -> list[dict[str, Any]]:
    try:
        json_rules = json.loads(raw_json_rules)
        if not isinstance(json_rules, list):
            log.warning("DD_SPAN_SAMPLING_RULES is not list, got %r", json_rules)
            return []
    except json.JSONDecodeError:
        log.warning("Unable to parse DD_SPAN_SAMPLING_RULES=%r", raw_json_rules)
        return []

    return json_rules


def _check_unsupported_pattern(string: str) -> None:
    # We don't support pattern bracket expansion or escape character
    unsupported_chars = {"[", "]", "\\"}
    for char in string:
        if char in unsupported_chars and config._raise:
            raise ValueError("Unsupported Glob pattern found, character:%r is not supported" % char)


def _set_sampling_tags(span: Span, sampled: bool, sample_rate: float, mechanism: int) -> None:
    # Set the sampling mechanism once but never overwrite an existing tag
    if not span.context._meta.get(SAMPLING_DECISION_TRACE_TAG_KEY):
        span._set_sampling_decision_maker(mechanism)

    # Set the sampling psr rate
    if mechanism in (
        SamplingMechanism.LOCAL_USER_TRACE_SAMPLING_RULE,
        SamplingMechanism.REMOTE_USER_TRACE_SAMPLING_RULE,
        SamplingMechanism.REMOTE_DYNAMIC_TRACE_SAMPLING_RULE,
    ):
        span.set_metric(_SAMPLING_RULE_DECISION, sample_rate)
        span._set_tag_str(KNUTH_SAMPLE_RATE_KEY, f"{sample_rate:.6g}")
    elif mechanism == SamplingMechanism.AGENT_RATE_BY_SERVICE:
        span.set_metric(_SAMPLING_AGENT_DECISION, sample_rate)
        span._set_tag_str(KNUTH_SAMPLE_RATE_KEY, f"{sample_rate:.6g}")
    # Set the sampling priority
    priorities = SAMPLING_MECHANISM_TO_PRIORITIES[mechanism]
    priority_index = _KEEP_PRIORITY_INDEX if sampled else _REJECT_PRIORITY_INDEX

    span.context.sampling_priority = priorities[priority_index]


def _inherit_sampling_tags(target: Span, source: Span):
    """Set sampling tags from source span on target span."""
    target.set_metric(SAMPLING_DECISION_MAKER_INHERITED, 1)
    target._set_tag_str(SAMPLING_DECISION_MAKER_SERVICE, source.service)  # type: ignore[arg-type]
    target._set_tag_str(SAMPLING_DECISION_MAKER_RESOURCE, source.resource)


def _get_highest_precedence_rule_matching(span: Span, rules: list[SamplingRule]) -> Optional[SamplingRule]:
    if not rules:
        return None

    for rule in rules:
        if rule.matches(span):
            return rule
    return None
