"""
This module contains utility functions for writing ddtrace integrations.
"""

from collections import deque
import ipaddress
import re
from typing import TYPE_CHECKING  # noqa:F401
from typing import Any  # noqa:F401
from typing import Callable  # noqa:F401
from typing import Generator  # noqa:F401
from typing import Iterator  # noqa:F401
from typing import Mapping  # noqa:F401
from typing import Optional  # noqa:F401
from typing import Union  # noqa:F401
from typing import cast  # noqa:F401
from urllib import parse

import wrapt

from ddtrace._trace.pin import Pin
from ddtrace._trace.span import Span
from ddtrace.constants import _ORIGIN_KEY
from ddtrace.contrib.internal.trace_utils_base import USER_AGENT_PATTERNS  # noqa:F401
from ddtrace.contrib.internal.trace_utils_base import _get_header_value_case_insensitive
from ddtrace.contrib.internal.trace_utils_base import _get_request_header_user_agent
from ddtrace.contrib.internal.trace_utils_base import _normalize_tag_name
from ddtrace.contrib.internal.trace_utils_base import _set_url_tag
from ddtrace.contrib.internal.trace_utils_base import set_user  # noqa:F401
from ddtrace.ext import http
from ddtrace.ext import net
from ddtrace.internal import core
from ddtrace.internal.compat import ensure_text
from ddtrace.internal.compat import ip_is_global
from ddtrace.internal.constants import SAMPLING_DECISION_TRACE_TAG_KEY
from ddtrace.internal.core.event_hub import dispatch
from ddtrace.internal.logger import get_logger
from ddtrace.internal.settings._config import config
from ddtrace.internal.settings.asm import config as asm_config
import ddtrace.internal.utils.wrappers
from ddtrace.propagation.http import HTTPPropagator


if TYPE_CHECKING:  # pragma: no cover
    from ddtrace.internal.settings.integration import IntegrationConfig  # noqa:F401
    from ddtrace.trace import Span  # noqa:F401
    from ddtrace.trace import Tracer  # noqa:F401


log = get_logger(__name__)

wrap = wrapt.wrap_function_wrapper
unwrap = ddtrace.internal.utils.wrappers.unwrap
iswrapped = ddtrace.internal.utils.wrappers.iswrapped

REQUEST = "request"
RESPONSE = "response"

# Tag normalization based on: https://docs.datadoghq.com/tagging/#defining-tags
# With the exception of '.' in header names which are replaced with '_' to avoid
# starting a "new object" on the UI.
NORMALIZE_PATTERN = re.compile(r"([^a-z0-9_\-:/]){1}")


IP_PATTERNS = (
    "x-forwarded-for",
    "x-real-ip",
    "true-client-ip",
    "x-client-ip",
    "forwarded",
    "forwarded-for",
    "x-cluster-client-ip",
    "fastly-client-ip",
    "cf-connecting-ip",
    "cf-connecting-ipv6",
)


def _store_headers(
    headers: dict[str, str], span: Span, integration_config: "IntegrationConfig", request_or_response: str
) -> None:
    """
    :param headers: A dict of http headers to be stored in the span
    :type headers: dict or list
    :param span: The Span instance where tags will be stored
    :type span: ddtrace.trace.Span
    :param integration_config: An integration specific config object.
    :type integration_config: ddtrace.settings.IntegrationConfig
    """
    if not isinstance(headers, dict):
        try:
            headers = dict(headers)
        except Exception:
            return

    if integration_config is None:
        log.debug("Skipping headers tracing as no integration config was provided")
        return

    for header_name, header_value in headers.items():
        # config._header_tag_name gets an element of the dictionary in config._trace_http_header_tags
        # which gets the value from DD_TRACE_HEADER_TAGS environment variable."""
        tag_name = integration_config._header_tag_name(header_name)
        if tag_name is None:
            continue
        # An empty tag defaults to a http.<request or response>.headers.<header name> tag
        span._set_tag_str(tag_name or _normalize_tag_name(request_or_response, header_name), header_value)


def _get_request_header_referrer_host(headers: Mapping[str, str], headers_are_case_sensitive: bool = False) -> str:
    """Get referrer host from request headers
    :param headers: A dict of http headers to be stored in the span
    :type headers: dict or list
    :param headers_are_case_sensitive: Whether the headers are case sensitive
    :type headers_are_case_sensitive: bool
    :return: The referrer host if found, empty string otherwise
    :rtype: str
    """
    if headers_are_case_sensitive:
        referrer = _get_header_value_case_insensitive(headers, http.REFERER_HEADER)
    else:
        referrer = headers.get(http.REFERER_HEADER)
    if referrer:
        try:
            parsed_url = parse.urlparse(referrer)
            if parsed_url.hostname:
                return parsed_url.hostname
        except (ValueError, AttributeError):
            return ""
    return ""


def _parse_ip_header(ip_header_value: str) -> str:
    """Parse the ip header, either in Forwarded-For format or Forwarded format.

    references: https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Forwarded
    """
    IP_EXTRACTIONS = [
        r"^\s*(?P<ip>[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)$",  # ipv4 simple format
        r'(?:^|;)\s*for="?(?P<ip>[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)',  # ipv4 forwarded format
        r"^\s*(?P<ip>[0-9a-fA-F:]+)$",  # ipv6 simple format
        r'(?:^|;)\s*for="\[(?P<ip>[0-9a-fA-F:]+)\]',  # ipv6 forwarded format
    ]
    for pattern in IP_EXTRACTIONS:
        if m := re.search(pattern, ip_header_value, re.IGNORECASE):
            return m.group("ip")
    return ""


def _get_request_header_client_ip(
    headers: Optional[Mapping[str, str]], peer_ip: Optional[str] = None, headers_are_case_sensitive: bool = False
) -> str:
    def get_header_value(key: str) -> Optional[str]:
        if not headers_are_case_sensitive:
            return headers.get(key)

        return _get_header_value_case_insensitive(headers, key)

    if not headers:
        try:
            _ = ipaddress.ip_address(str(peer_ip))
        except ValueError:
            return ""
        return peer_ip

    ip_header_value = ""
    user_configured_ip_header = config._client_ip_header
    if user_configured_ip_header:
        # Used selected the header to use to get the IP
        ip_header_value = get_header_value(
            user_configured_ip_header.lower().replace("_", "-")
            if headers_are_case_sensitive
            else user_configured_ip_header
        )
        if not ip_header_value:
            log.debug("DD_TRACE_CLIENT_IP_HEADER configured but '%s' header missing", user_configured_ip_header)
            return ""

        try:
            _ = ipaddress.ip_address(str(ip_header_value))
        except ValueError:
            log.debug("Invalid IP address from configured %s header: %s", user_configured_ip_header, ip_header_value)
            return ""

    else:
        if headers_are_case_sensitive:
            new_headers = {k.lower().replace("_", "-"): v for k, v in headers.items()}
            for ip_header in IP_PATTERNS:
                if ip_header in new_headers:
                    ip_header_value = new_headers[ip_header]
                    break
        else:
            for ip_header in IP_PATTERNS:
                if ip_header in headers:
                    ip_header_value = headers[ip_header]
                    break

    private_ip_from_headers = ""

    if ip_header_value:
        # At this point, we have one IP header, check its value and retrieve the first public IP

        ip_list = ip_header_value.split(",")
        for ip in ip_list:
            ip = _parse_ip_header(ip)
            if not ip:
                continue
            try:
                if ip_is_global(ip):
                    return ip
                elif not private_ip_from_headers:
                    # IP is private, store it just in case we don't find a public one later
                    private_ip_from_headers = ip
            except ValueError:  # invalid IP
                continue

    # At this point we have none or maybe one private ip from the headers: check the peer ip in
    # case it's public and, if not, return either the private_ip from the headers (if we have one)
    # or the peer private ip
    try:
        if ip_is_global(peer_ip) or not private_ip_from_headers:
            return peer_ip
    except ValueError:
        pass

    return private_ip_from_headers


def _store_request_headers(headers: dict[str, str], span: Span, integration_config: "IntegrationConfig") -> None:
    """
    Store request headers as a span's tags
    :param headers: All the request's http headers, will be filtered through the whitelist
    :type headers: dict or list
    :param span: The Span instance where tags will be stored
    :type span: ddtrace.trace.Span
    :param integration_config: An integration specific config object.
    :type integration_config: ddtrace.settings.IntegrationConfig
    """
    _store_headers(headers, span, integration_config, REQUEST)


def _store_response_headers(headers: dict[str, str], span: Span, integration_config: "IntegrationConfig") -> None:
    """
    Store response headers as a span's tags
    :param headers: All the response's http headers, will be filtered through the whitelist
    :type headers: dict or list
    :param span: The Span instance where tags will be stored
    :type span: ddtrace.trace.Span
    :param integration_config: An integration specific config object.
    :type integration_config: ddtrace.settings.IntegrationConfig
    """
    _store_headers(headers, span, integration_config, RESPONSE)


def _sanitized_url(url: str) -> str:
    """
    Sanitize url by removing parts with potential auth info
    """
    if "@" in url:
        parsed = parse.urlparse(url)
        netloc = parsed.netloc

        if "@" not in netloc:
            # Safe url, `@` not in netloc
            return url

        netloc = netloc[netloc.index("@") + 1 :]
        return parse.urlunparse(
            (
                parsed.scheme,
                netloc,
                parsed.path,
                "",
                parsed.query,
                "",
            )
        )

    return url


def with_traced_module(func):
    """Helper for providing tracing essentials (module and pin) for tracing
    wrappers.

    This helper enables tracing wrappers to dynamically be disabled when the
    corresponding pin is disabled.

    Usage::

        @with_traced_module
        def my_traced_wrapper(django, pin, func, instance, args, kwargs):
            # Do tracing stuff
            pass

        def patch():
            import django
            wrap(django.somefunc, my_traced_wrapper(django))
    """

    def with_mod(mod):
        def wrapper(wrapped, instance, args, kwargs):
            pin = Pin._find(instance, mod)
            if pin and not pin.enabled():
                return wrapped(*args, **kwargs)
            elif not pin:
                log.debug("Pin not found for traced method %r", wrapped)
                return wrapped(*args, **kwargs)
            return func(mod, pin, wrapped, instance, args, kwargs)

        return wrapper

    return with_mod


def distributed_tracing_enabled(int_config: "IntegrationConfig", default: bool = False) -> bool:
    """Returns whether distributed tracing is enabled for this integration config"""
    if "distributed_tracing_enabled" in int_config and int_config.distributed_tracing_enabled is not None:
        return int_config.distributed_tracing_enabled
    elif "distributed_tracing" in int_config and int_config.distributed_tracing is not None:
        return int_config.distributed_tracing
    return default


def int_service(pin: Optional[Pin], int_config: "IntegrationConfig", default: Optional[str] = None) -> Optional[str]:
    """Returns the service name for an integration which is internal
    to the application. Internal meaning that the work belongs to the
    user's application. Eg. Web framework, sqlalchemy, web servers.

    For internal integrations we prioritize overrides, then global defaults and
    lastly the default provided by the integration.
    """
    # Pin has top priority since it is user defined in code
    if pin is not None and pin.service:
        return pin.service

    # Config is next since it is also configured via code
    # Note that both service and service_name are used by
    # integrations.
    if "service" in int_config and int_config.service is not None:
        return cast(str, int_config.service)
    if "service_name" in int_config and int_config.service_name is not None:
        return cast(str, int_config.service_name)

    global_service = int_config.global_config._get_service()
    # We check if global_service != _inferred_base_service since global service (config.service)
    # defaults to _inferred_base_service when no DD_SERVICE is set. In this case, we want to not
    # use the inferred base service value, and instead use the integration default service. If we
    # didn't do this, we would have a massive breaking change from adding inferred_base_service.
    if global_service and global_service != int_config.global_config._inferred_base_service:
        return cast(str, global_service)

    if "_default_service" in int_config and int_config._default_service is not None:
        return cast(str, int_config._default_service)

    if default is None and global_service:
        return cast(str, global_service)

    return default


def ext_service(pin: Optional[Pin], int_config: "IntegrationConfig", default: Optional[str] = None) -> Optional[str]:
    """Returns the service name for an integration which is external
    to the application. External meaning that the integration generates
    spans wrapping code that is outside the scope of the user's application. Eg. A database, RPC, cache, etc.
    """
    if pin is not None and pin.service:
        return pin.service

    if "service" in int_config and int_config.service is not None:
        return cast(str, int_config.service)
    if "service_name" in int_config and int_config.service_name is not None:
        return cast(str, int_config.service_name)

    if "_default_service" in int_config and int_config._default_service is not None:
        return cast(str, int_config._default_service)

    # A default is required since it's an external service.
    return default


def set_http_meta(
    span: Span,
    integration_config: "IntegrationConfig",
    method: Optional[str] = None,
    url: Optional[str] = None,
    target_host: Optional[str] = None,
    server_address: Optional[str] = None,
    status_code: Optional[Union[int, str]] = None,
    status_msg: Optional[str] = None,
    query: Optional[str] = None,
    parsed_query: Optional[Mapping[str, str]] = None,
    request_headers: Optional[Mapping[str, str]] = None,
    response_headers: Optional[Mapping[str, str]] = None,
    retries_remain: Optional[Union[int, str]] = None,
    raw_uri: Optional[str] = None,
    request_cookies: Optional[dict[str, str]] = None,
    request_path_params: Optional[dict[str, str]] = None,
    request_body: Optional[Union[str, dict[str, list[str]]]] = None,
    peer_ip: Optional[str] = None,
    headers_are_case_sensitive: bool = False,
    route: Optional[str] = None,
    response_cookies: Optional[dict[str, str]] = None,
) -> None:
    """
    Set HTTP metas on the span

    :param method: the HTTP method
    :param url: the HTTP URL
    :param status_code: the HTTP status code
    :param status_msg: the HTTP status message
    :param query: the HTTP query part of the URI as a string
    :param parsed_query: the HTTP query part of the URI as parsed by the framework and forwarded to the user code
    :param request_headers: the HTTP request headers
    :param response_headers: the HTTP response headers
    :param raw_uri: the full raw HTTP URI (including ports and query)
    :param request_cookies: the HTTP request cookies as a dict
    :param request_path_params: the parameters of the HTTP URL as set by the framework: /posts/<id:int> would give us
         { "id": <int_value> }
    """
    if method is not None:
        span._set_tag_str(http.METHOD, method)

    if url is not None:
        url = _sanitized_url(url)
        _set_url_tag(integration_config, span, url, query)

    if target_host is not None:
        span._set_tag_str(net.TARGET_HOST, target_host)

    if server_address is not None:
        span._set_tag_str(net.SERVER_ADDRESS, server_address)

    if status_code is not None:
        try:
            int_status_code = int(status_code)
        except (TypeError, ValueError):
            log.debug("failed to convert http status code %r to int", status_code)
        else:
            span._set_tag_str(http.STATUS_CODE, str(status_code))
            if config._http_server.is_error_code(int_status_code):
                span.error = 1

    if status_msg is not None:
        span._set_tag_str(http.STATUS_MSG, status_msg)

    if query is not None and integration_config.trace_query_string:
        span._set_tag_str(http.QUERY_STRING, query)

    request_ip = peer_ip
    if request_headers:
        user_agent = _get_request_header_user_agent(request_headers, headers_are_case_sensitive)
        if user_agent:
            span._set_tag_str(http.USER_AGENT, user_agent)

        # Extract referrer host if referer header is present
        referrer_host = _get_request_header_referrer_host(request_headers, headers_are_case_sensitive)
        if referrer_host:
            span._set_tag_str(http.REFERRER_HOSTNAME, referrer_host)

        # We always collect the IP if appsec is enabled to report it on potential vulnerabilities.
        # https://datadoghq.atlassian.net/wiki/spaces/APS/pages/2118779066/Client+IP+addresses+resolution
        if asm_config._asm_enabled or config._retrieve_client_ip:
            # Retrieve the IP if it was calculated on AppSecProcessor.on_span_start
            request_ip = core.find_item("http.request.remote_ip")

            if not request_ip:
                # Not calculated: framework does not support IP blocking or testing env
                request_ip = (
                    _get_request_header_client_ip(request_headers, peer_ip, headers_are_case_sensitive) or peer_ip
                )

            if request_ip:
                span._set_tag_str(http.CLIENT_IP, request_ip)
                span._set_tag_str("network.client.ip", request_ip)

        if integration_config.is_header_tracing_configured:
            """We should store both http.<request_or_response>.headers.<header_name> and
            http.<key>. The last one
            is the DD standardized tag for user-agent"""
            _store_request_headers(dict(request_headers), span, integration_config)

    if response_headers is not None and integration_config.is_header_tracing_configured:
        _store_response_headers(dict(response_headers), span, integration_config)

    if retries_remain is not None:
        span._set_tag_str(http.RETRIES_REMAIN, str(retries_remain))

    core.dispatch(
        "set_http_meta_for_asm",
        [
            span,
            request_ip,
            raw_uri,
            route,
            method,
            request_headers,
            request_cookies,
            parsed_query,
            request_path_params,
            request_body,
            status_code,
            response_headers,
            response_cookies,
        ],
    )

    if route is not None:
        span._set_tag_str(http.ROUTE, route)


def activate_distributed_headers(
    tracer: "Tracer",
    int_config: Optional["IntegrationConfig"] = None,
    request_headers: Optional[dict[str, str]] = None,
    override: Optional[bool] = None,
) -> None:
    """
    Helper for activating a distributed trace headers' context if enabled in integration config.
    int_config will be used to check if distributed trace headers context will be activated, but
    override will override whatever value is set in int_config if passed any value other than None.
    """
    if override is False:
        return None

    if override or (int_config and distributed_tracing_enabled(int_config)):
        context = HTTPPropagator.extract(request_headers)
        # bail out if no context was extracted
        if context is None:
            return None
        # Only need to activate the new context if something was propagated
        # The new context must have one of these values in order for it to be activated
        if not context.trace_id and not context._baggage and not context._span_links:
            return None
        # Do not reactivate a context with the same trace id
        # DEV: An example could be nested web frameworks, when one layer already
        #      parsed request headers and activated them.
        #
        # Example::
        #
        #     app = Flask(__name__)  # Traced via Flask instrumentation
        #     app = DDWSGIMiddleware(app)  # Extra layer on top for WSGI
        current_context = tracer.current_trace_context()

        # We accept incoming contexts with only baggage or only span_links, however if we
        # already have a current_context then an incoming context not
        # containing a trace_id or containing the same trace_id
        # should not be activated.
        if current_context and (
            not context.trace_id or (context.trace_id and context.trace_id == current_context.trace_id)
        ):
            log.debug(
                "will not activate extracted Context(trace_id=%r, span_id=%r), a context with that trace id is already active",  # noqa: E501
                context.trace_id,
                context.span_id,
            )
            return None

        # We have parsed a trace id from headers, and we do not already
        # have a context with the same trace id active
        tracer.context_provider.activate(context)
        core.dispatch("http.activate_distributed_headers", (request_headers, context))

        dispatch("distributed_context.activated", (context,))


def _copy_trace_level_tags(target_span: Span, parent: Span):
    """
    Copies baggage, tags, origin, sampling decision from parent span to target span.
    """
    for key, value in parent.context._baggage.items():
        target_span.context.set_baggage_item(key, value)
        target_span._set_tag_str(f"baggage.{key}", value)

    if parent.context.sampling_priority is not None:
        target_span.context.sampling_priority = parent.context.sampling_priority

    if parent.context._meta.get(_ORIGIN_KEY):
        target_span._set_tag_str(_ORIGIN_KEY, parent.context._meta[_ORIGIN_KEY])

    if parent.context._meta.get(SAMPLING_DECISION_TRACE_TAG_KEY):
        target_span._set_tag_str(SAMPLING_DECISION_TRACE_TAG_KEY, parent.context._meta[SAMPLING_DECISION_TRACE_TAG_KEY])


def _flatten(
    obj: Any,
    sep: str = ".",
    prefix: str = "",
    exclude_policy: Optional[Callable[[str], bool]] = None,
) -> Generator[tuple[str, Any], None, None]:
    s = deque()  # type: ignore
    s.append((prefix, obj))
    while s:
        p, v = s.pop()
        if exclude_policy is not None and exclude_policy(p):
            continue
        if isinstance(v, dict):
            s.extend((sep.join((p, k)) if p else k, v) for k, v in v.items())
        else:
            yield p, v


def set_flattened_tags(
    span: Span,
    items: Iterator[tuple[str, Any]],
    sep: str = ".",
    exclude_policy: Optional[Callable[[str], bool]] = None,
    processor: Optional[Callable[[Any], Any]] = None,
) -> None:
    for prefix, value in items:
        for tag, v in _flatten(value, sep, prefix, exclude_policy):
            span.set_tag(tag, processor(v) if processor is not None else v)


def extract_netloc_and_query_info_from_url(url: str) -> tuple[str, str]:
    parse_result = parse.urlparse(url)
    query = parse_result.query

    # Relative URLs don't have a netloc, so we force them
    if not parse_result.netloc:
        parse_result = parse.urlparse("//{url}".format(url=url))

    netloc = parse_result.netloc.split("@", 1)[-1]  # Discard auth info
    netloc = netloc.split(":", 1)[0]  # Discard port information
    return netloc, query


class InterruptException(Exception):
    pass


def _convert_to_string(attr):
    # ensures attribute is converted to a string
    if attr:
        if isinstance(attr, int) or isinstance(attr, float):
            return str(attr)
        else:
            return ensure_text(attr)
    return attr


def check_module_path(module, attr_path):
    """
    Helper function to safely check if a nested attribute path exists on a module.

    Args:
        module: The root module object
        attr_path: Dot-separated path to the attribute (e.g., "flows.llm_flows.functions")

    Returns:
        bool: True if the full path exists, False otherwise

    Example:
        check_module_path(adk, "flows.llm_flows.functions.__call_tool_async")
        check_module_path(adk, "agents.llm_agent.LlmAgent")
    """
    if not module:
        return False

    try:
        current = module
        for attr in attr_path.split("."):
            if not hasattr(current, attr):
                return False
            current = getattr(current, attr)

        return True
    except (ImportError, AttributeError):
        # Some modules may raise ImportError when accessing attributes that require
        # additional dependencies (e.g., ContainerCodeExecutor requiring Docker for google-adk and an extra pkg install)
        return False
