from collections.abc import Mapping
import json
import re
from types import TracebackType
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Literal
from typing import Optional
from typing import Protocol
from typing import Union
from urllib import parse

from ddtrace._trace.span import Span
from ddtrace.appsec._constants import APPSEC
from ddtrace.appsec._constants import SPAN_DATA_NAMES
from ddtrace.appsec._constants import Constant_Class
from ddtrace.appsec._metrics import _set_waf_request_metrics
from ddtrace.appsec._utils import Block_config
from ddtrace.appsec._utils import Telemetry_result
from ddtrace.appsec._utils import get_triggers
from ddtrace.appsec._utils import is_inferred_span
from ddtrace.contrib.internal.trace_utils_base import _normalize_tag_name
from ddtrace.internal import core
from ddtrace.internal._exceptions import BlockingException
import ddtrace.internal.logger as ddlogger
from ddtrace.internal.settings.asm import config as asm_config


if TYPE_CHECKING:
    from ddtrace.appsec._utils import DDWaf_info
    from ddtrace.appsec._utils import DDWaf_result

logger = ddlogger.get_logger(__name__)


class WafCallable(Protocol):
    def __call__(
        self,
        custom_data: Optional[dict[str, Any]] = None,
        crop_trace: Optional[str] = None,
        rule_type: Optional[str] = None,
        force_sent: bool = False,
    ) -> Optional["DDWaf_result"]: ...


class WARNING_TAGS(metaclass=Constant_Class):
    ASM_ENV_NO_SPAN = "asm_context::ASM_Environment::no_span"
    SET_BLOCKED_NO_ASM_CONTEXT = "asm_context::set_blocked::no_active_context"
    SET_WAF_INFO_NO_ASM_CONTEXT = "asm_context::set_waf_info::no_active_context"
    CALL_WAF_CALLBACK_NOT_SET = "asm_context::call_waf_callback::not_set"
    BLOCK_REQUEST_NOT_CALLABLE = "asm_context::block_request::not_callable"
    GET_DATA_SENT_NO_ASM_CONTEXT = "asm_context::get_data_sent::no_active_context"
    STORE_WAF_RESULTS_NO_ASM_CONTEXT = "asm_context::store_waf_results_data::no_active_context"


# limit warnings to one per day per process
for _, tag in WARNING_TAGS:
    ddlogger.set_tag_rate_limit(tag, ddlogger.DAY)

log_extra = {"product": "appsec", "stack_limit": 4, "exec_limit": 4}


# Stopgap module for providing ASM context for the blocking features wrapping some contextvars.


_ASM_CONTEXT: Literal["_asm_env"] = "_asm_env"
_WAF_ADDRESSES: Literal["waf_addresses"] = "waf_addresses"
_CALLBACKS: Literal["callbacks"] = "callbacks"
_TELEMETRY: Literal["telemetry"] = "telemetry"
_CONTEXT_CALL: Literal["context"] = "context"
_BLOCK_CALL: Literal["block"] = "block"


GLOBAL_CALLBACKS: dict[str, list[Callable]] = {_CONTEXT_CALL: []}


def report_error_on_entry_span(error: str, message: str) -> None:
    entry_span = get_entry_span()
    if not entry_span:
        return
    entry_span._set_tag_str(APPSEC.ERROR_TYPE, error)
    entry_span._set_tag_str(APPSEC.ERROR_MESSAGE, message)


class ASM_Environment:
    """
    an object of this class contains all asm data (waf and telemetry)
    for a single request. It is bound to a single asm request context.
    It is contained into a ContextVar.
    """

    def __init__(
        self,
        waf_callable: Optional[WafCallable],
        span: Optional[Span] = None,
        rc_products: str = "",
    ):
        self.root = not in_asm_context()
        if self.root:
            core.add_suppress_exception(BlockingException)
        # add several layers of fallbacks to get a span, but normal span should be the first or the second one
        context_span = span or core.get_root_span()
        if context_span is None:
            logger.warning(WARNING_TAGS.ASM_ENV_NO_SPAN, extra=log_extra, stack_info=True)
            raise TypeError("ASM_Environment requires a span")
        self.span: Span = context_span
        self.entry_span: Span = self.span._service_entry_span
        if self.span.name.endswith(".request"):
            self.framework = self.span.name[:-8]
        else:
            self.framework = self.span.name
        self.framework = self.framework.lower().replace(" ", "_")
        self.waf_info: Optional[Callable[[], "DDWaf_info"]] = None
        self.waf_addresses: dict[str, Any] = {}
        self.waf_callable: Optional[WafCallable] = waf_callable
        self.callbacks: dict[str, Any] = {}
        self.telemetry: Telemetry_result = Telemetry_result()
        self.addresses_sent: set[str] = set()
        self.waf_triggers: list[dict[str, Any]] = []
        self.blocked: Optional[Block_config] = None
        self.finalized: bool = False
        self.api_security_reported: int = 0
        self.rc_products: str = rc_products
        self.downstream_requests: int = 0


def _get_asm_context() -> Optional[ASM_Environment]:
    return core.find_item(_ASM_CONTEXT)


def get_active_asm_context() -> Optional[ASM_Environment]:
    env = _get_asm_context()
    if env is None or env.finalized:
        extra = {"product": "appsec", "stack_limit": 4}
        logger.debug("asm_context::get_value::no_active_context", extra=extra, stack_info=True)
        return None
    return env


def in_asm_context() -> bool:
    return core.find_item(_ASM_CONTEXT) is not None


def is_blocked() -> bool:
    env = _get_asm_context()
    if env is None:
        return False
    return env.blocked is not None


def get_blocked() -> Optional[Block_config]:
    env = _get_asm_context()
    if env is None:
        return None
    return env.blocked or None


def get_entry_span() -> Optional[Span]:
    env = _get_asm_context()
    if env is None:
        span = core.get_span()
        if span:
            return span._service_entry_span
        else:
            return core.get_root_span()
    return env.entry_span


KNUTH_FACTOR: int = 11400714819323199488
UINT64_MAX: int = (1 << 64) - 1


class DownstreamRequests:
    counter: int = 0
    sampling_rate: int = int(asm_config._dr_sample_rate * UINT64_MAX)


def should_analyze_body_response(env: ASM_Environment) -> bool:
    """Check if we should analyze body for API10."""
    DownstreamRequests.counter += 1
    return (
        env.downstream_requests < asm_config._dr_body_limit_per_request
        and (DownstreamRequests.counter * KNUTH_FACTOR) % UINT64_MAX <= DownstreamRequests.sampling_rate
    )


def get_framework() -> str:
    env = _get_asm_context()
    if env is None:
        return ""
    return env.framework


def _use_html(headers: Mapping) -> bool:
    """decide if the response should be html or json.

    Add support for quality values in the Accept header.
    """
    ctype = headers.get("Accept", headers.get("accept", ""))
    if not ctype:
        return False
    html_score = 0.0
    json_score = 0.0
    ctypes = ctype.split(",")
    for ct in ctypes:
        if len(ct) > 128:
            # ignore long (and probably malicious) headers to avoid performances issues
            continue
        m = re.match(r"([^/;]+/[^/;]+)(?:;q=([01](?:\.\d*)?))?", ct.strip())
        if m:
            if m.group(1) == "text/html":
                html_score = max(html_score, min(1.0, float(1.0 if m.group(2) is None else m.group(2))))
            elif m.group(1) == "text/*":
                html_score = max(html_score, min(1.0, float(0.2 if m.group(2) is None else m.group(2))))
            elif m.group(1) == "application/json":
                json_score = max(json_score, min(1.0, float(1.0 if m.group(2) is None else m.group(2))))
            elif m.group(1) == "application/*":
                json_score = max(json_score, min(1.0, float(0.2 if m.group(2) is None else m.group(2))))
    return html_score > json_score


def _ctype_from_headers(block_config: Block_config) -> None:
    """compute MIME type of the blocked response and store it in the block config"""
    headers = get_headers()
    if headers is None:
        return
    if (block_config.type == "auto" and _use_html(headers)) or block_config.type == "html":
        block_config.content_type = "text/html"


def set_blocked(blocked: Block_config) -> None:
    env = _get_asm_context()
    if env is None:
        logger.warning(WARNING_TAGS.SET_BLOCKED_NO_ASM_CONTEXT, extra=log_extra, stack_info=True)
        return
    _ctype_from_headers(blocked)
    env.blocked = blocked


def set_blocked_dict(block: Union[dict[str, Any], Block_config, None]) -> None:
    if isinstance(block, dict):
        blocked = Block_config(**block)
    elif block is None:
        blocked = Block_config()
    else:
        blocked = block
    set_blocked(blocked)


def update_span_metrics(span: Span, name: str, value: Union[float, int]) -> None:
    span.set_metric(name, value + (span.get_metric(name) or 0.0))


def flush_waf_triggers(env: ASM_Environment) -> None:
    from ddtrace.appsec._metrics import ddwaf_version

    entry_span = env.entry_span
    if env.waf_triggers:
        report_list = get_triggers(entry_span)
        if report_list is not None:
            report_list.extend(env.waf_triggers)
        else:
            report_list = env.waf_triggers
        if asm_config._use_metastruct_for_triggers:
            entry_span._set_struct_tag(APPSEC.STRUCT, {"triggers": report_list})
        else:
            entry_span.set_tag(APPSEC.JSON, json.dumps({"triggers": report_list}, separators=(",", ":")))

        parent = entry_span._parent
        if parent is not None and is_inferred_span(parent):
            if asm_config._use_metastruct_for_triggers:
                parent._set_struct_tag(APPSEC.STRUCT, {"triggers": report_list})
            else:
                parent.set_tag(APPSEC.JSON, json.dumps({"triggers": report_list}, separators=(",", ":")))

        env.waf_triggers = []
    telemetry_results: Telemetry_result = env.telemetry

    entry_span._set_tag_str(APPSEC.WAF_VERSION, ddwaf_version)
    if env.downstream_requests:
        update_span_metrics(entry_span, APPSEC.DOWNSTREAM_REQUESTS, env.downstream_requests)
    if telemetry_results.total_duration:
        update_span_metrics(entry_span, APPSEC.WAF_DURATION, telemetry_results.duration)
        telemetry_results.duration = 0.0
        update_span_metrics(entry_span, APPSEC.WAF_DURATION_EXT, telemetry_results.total_duration)
        telemetry_results.total_duration = 0.0
    if telemetry_results.timeout:
        update_span_metrics(entry_span, APPSEC.WAF_TIMEOUTS, telemetry_results.timeout)
    rasp_timeouts = sum(telemetry_results.rasp.timeout.values())
    if rasp_timeouts:
        update_span_metrics(entry_span, APPSEC.RASP_TIMEOUTS, rasp_timeouts)
    if telemetry_results.rasp.sum_eval:
        update_span_metrics(entry_span, APPSEC.RASP_DURATION, telemetry_results.rasp.duration)
        update_span_metrics(entry_span, APPSEC.RASP_DURATION_EXT, telemetry_results.rasp.total_duration)
        update_span_metrics(entry_span, APPSEC.RASP_RULE_EVAL, telemetry_results.rasp.sum_eval)
    if telemetry_results.truncation.string_length:
        entry_span.set_metric(APPSEC.TRUNCATION_STRING_LENGTH, max(telemetry_results.truncation.string_length))
    if telemetry_results.truncation.container_size:
        entry_span.set_metric(APPSEC.TRUNCATION_CONTAINER_SIZE, max(telemetry_results.truncation.container_size))
    if telemetry_results.truncation.container_depth:
        entry_span.set_metric(APPSEC.TRUNCATION_CONTAINER_DEPTH, max(telemetry_results.truncation.container_depth))


def finalize_asm_env(env: ASM_Environment) -> None:
    if env.finalized:
        return
    env.finalized = True
    for function in GLOBAL_CALLBACKS[_CONTEXT_CALL]:
        function(env)
    flush_waf_triggers(env)
    _set_waf_request_metrics(env.telemetry)
    entry_span = env.entry_span
    if entry_span:
        if env.waf_info:
            info = env.waf_info()
            try:
                if info.errors:
                    entry_span._set_tag_str(APPSEC.EVENT_RULE_ERRORS, info.errors)
                    extra = {"product": "appsec", "more_info": info.errors, "stack_limit": 4}
                    logger.debug("asm_context::finalize_asm_env::waf_errors", extra=extra, stack_info=True)
                entry_span._set_tag_str(APPSEC.EVENT_RULE_VERSION, info.version)
                entry_span.set_metric(APPSEC.EVENT_RULE_LOADED, info.loaded)
                entry_span.set_metric(APPSEC.EVENT_RULE_ERROR_COUNT, info.failed)
            except Exception:
                logger.debug("asm_context::finalize_asm_env::exception", extra=log_extra, exc_info=True)
        if asm_config._rc_client_id is not None:
            entry_span.set_tag(APPSEC.RC_CLIENT_ID, asm_config._rc_client_id)
        waf_adresses = env.waf_addresses
        req_headers = waf_adresses.get(SPAN_DATA_NAMES.REQUEST_HEADERS_NO_COOKIES, {})
        if req_headers:
            _set_headers(entry_span, req_headers, kind="request")
        res_headers = waf_adresses.get(SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES, {})
        if res_headers:
            _set_headers(entry_span, res_headers, kind="response")
        if env.rc_products:
            entry_span._set_tag_str(APPSEC.RC_PRODUCTS, env.rc_products)

    # Manually clear reference cycles to simplify the work for the GC
    env.callbacks.clear()
    env.waf_callable = None
    core.discard_local_item(_ASM_CONTEXT)


def set_value(category: str, address: str, value: Any) -> None:
    env = _get_asm_context()
    if env is None:
        extra = {"product": "appsec", "more_info": f"::{category}::{address}", "stack_limit": 4}
        logger.debug("asm_context::set_value::no_active_context", extra=extra, stack_info=True)
        return
    asm_context_attr = getattr(env, category, None)
    if asm_context_attr is not None:
        asm_context_attr[address] = value


def set_headers_response(headers: Any) -> None:
    if headers is not None:
        set_waf_address(SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES, headers)


def set_body_response(body_response: Any) -> None:
    # local import to avoid circular import
    from ddtrace.appsec._utils import parse_response_body

    env = _get_asm_context()
    if env is None:
        extra = {"product": "appsec", "more_info": "::set_body_response", "stack_limit": 4}
        logger.debug("asm_context::set_body_response::no_active_context", extra=extra, stack_info=True)
        return

    set_waf_address(
        SPAN_DATA_NAMES.RESPONSE_BODY,
        lambda: parse_response_body(
            body_response,
            env.waf_addresses.get(SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES, None),
        ),
    )


def set_waf_address(address: str, value: Any) -> None:
    if address == SPAN_DATA_NAMES.REQUEST_URI_RAW:
        parse_address = parse.urlparse(value)
        no_scheme = parse.ParseResult("", "", *parse_address[2:])
        waf_value = parse.urlunparse(no_scheme)
        set_value(_WAF_ADDRESSES, address, waf_value)
    else:
        set_value(_WAF_ADDRESSES, address, value)
    if address in (SPAN_DATA_NAMES.REQUEST_HTTP_IP, SPAN_DATA_NAMES.REQUEST_HEADERS_NO_COOKIES_CASE):
        core.set_item(address, value)


def get_value(category: str, address: str, default: Any = None) -> Any:
    env = get_active_asm_context()
    if env is None:
        return default
    asm_context_attr = getattr(env, category, None)
    if asm_context_attr is not None:
        return asm_context_attr.get(address, default)
    return default


def get_waf_address(address: str, default: Any = None) -> Any:
    return get_value(_WAF_ADDRESSES, address, default=default)


def add_context_callback(function: Callable[[ASM_Environment], Any], global_callback: bool = False) -> None:
    if global_callback:
        callbacks = GLOBAL_CALLBACKS.setdefault(_CONTEXT_CALL, [])
        callbacks.append(function)


def remove_context_callback(function: Callable[[ASM_Environment], Any], global_callback: bool = False) -> None:
    if global_callback:
        callbacks = GLOBAL_CALLBACKS.get(_CONTEXT_CALL)
        if callbacks:
            callbacks[:] = list([cb for cb in callbacks if cb != function])


def set_waf_info(info: Callable[[], "DDWaf_info"]) -> None:
    env = _get_asm_context()
    if env is None:
        logger.warning(WARNING_TAGS.SET_WAF_INFO_NO_ASM_CONTEXT, extra=log_extra, stack_info=True)
        return
    env.waf_info = info


def call_waf_callback(
    custom_data: Optional[dict[str, Any]] = None,
    crop_trace: Optional[str] = None,
    rule_type: Optional[str] = None,
    force_sent: bool = False,
) -> Optional["DDWaf_result"]:
    if not asm_config._asm_enabled:
        return None
    env = get_active_asm_context()
    if env is not None and env.waf_callable is not None:
        return env.waf_callable(custom_data, crop_trace, rule_type, force_sent)
    else:
        logger.warning(WARNING_TAGS.CALL_WAF_CALLBACK_NOT_SET, extra=log_extra, stack_info=True)
        report_error_on_entry_span("appsec::instrumentation::diagnostic", WARNING_TAGS.CALL_WAF_CALLBACK_NOT_SET)
        return None


def call_waf_callback_no_instrumentation() -> None:
    """call the waf once if it was not already called"""
    if asm_config._asm_enabled:
        env = _get_asm_context()
        if env and not env.telemetry.triggered:
            waf_callable = env.waf_callable
            if waf_callable:
                waf_callable()


def set_ip(ip: Optional[str]) -> None:
    if ip is not None:
        set_waf_address(SPAN_DATA_NAMES.REQUEST_HTTP_IP, ip)


def get_ip() -> Optional[str]:
    return get_value(_WAF_ADDRESSES, SPAN_DATA_NAMES.REQUEST_HTTP_IP)


# Note: get/set headers use Any since we just carry the headers here without changing or using them
# and different frameworks use different types that we don't want to force it into a Mapping at the
# early point set_headers is usually called


def set_headers(headers: Mapping) -> None:
    if headers is not None:
        set_waf_address(SPAN_DATA_NAMES.REQUEST_HEADERS_NO_COOKIES, headers)


def get_headers() -> Optional[Mapping]:
    return get_value(_WAF_ADDRESSES, SPAN_DATA_NAMES.REQUEST_HEADERS_NO_COOKIES, {})


def set_headers_case_sensitive(case_sensitive: bool) -> None:
    set_waf_address(SPAN_DATA_NAMES.REQUEST_HEADERS_NO_COOKIES_CASE, case_sensitive)


def get_headers_case_sensitive() -> bool:
    return get_value(_WAF_ADDRESSES, SPAN_DATA_NAMES.REQUEST_HEADERS_NO_COOKIES_CASE, False)  # type : ignore


def set_block_request_callable(_callable: Optional[Callable[[], Any]], *_: Any) -> None:
    """
    Sets a callable that could be use to do a best-effort to block the request. If
    the callable need any params, like headers, they should be curried with
    functools.partial.
    """
    if asm_config._asm_enabled and _callable:
        set_value(_CALLBACKS, _BLOCK_CALL, _callable)


def block_request() -> None:
    """
    Calls or returns the stored block request callable, if set.
    """
    _callable = get_value(_CALLBACKS, _BLOCK_CALL)
    if _callable:
        _callable()
    else:
        logger.warning(WARNING_TAGS.BLOCK_REQUEST_NOT_CALLABLE, extra=log_extra, stack_info=True)


def get_data_sent() -> set[str]:
    env = _get_asm_context()
    if env is None:
        logger.warning(WARNING_TAGS.GET_DATA_SENT_NO_ASM_CONTEXT, extra=log_extra, stack_info=True)
        return set()
    return env.addresses_sent


def asm_request_context_set(
    remote_ip: Optional[str] = None,
    headers: Any = None,
    headers_case_sensitive: bool = False,
    block_request_callable: Optional[Callable] = None,
) -> None:
    set_ip(remote_ip)
    set_headers(headers)
    set_headers_case_sensitive(headers_case_sensitive)
    set_block_request_callable(block_request_callable)


def set_waf_telemetry_results(
    rules_version: str,
    is_blocked: bool,
    waf_results: "DDWaf_result",
    rule_type: Optional[str],
    is_sampled: bool,
) -> None:
    env = _get_asm_context()
    if env is None:
        return
    result: Telemetry_result = env.telemetry
    is_triggered = bool(waf_results.data)
    from ddtrace.appsec._metrics import _report_waf_truncations

    result.rate_limited |= is_sampled
    if waf_results.return_code < 0:
        if result.error:
            result.error = max(result.error, waf_results.return_code)
        else:
            result.error = waf_results.return_code
        from ddtrace.appsec._metrics import _report_waf_run_error

        _report_waf_run_error(waf_results.return_code, rules_version, rule_type)
    _report_waf_truncations(waf_results.truncation)
    for key in ["container_size", "container_depth", "string_length"]:
        res = getattr(waf_results.truncation, key)
        if isinstance(res, int):
            getattr(result.truncation, key).append(res)
    if rule_type is None:
        # Request Blocking telemetry
        result.triggered |= is_triggered
        result.blocked |= is_blocked
        result.timeout += waf_results.timeout
        if rules_version:
            result.version = rules_version
        result.duration += waf_results.runtime
        result.total_duration += waf_results.total_runtime
    else:
        # Exploit Prevention telemetry
        result.rasp.blocked |= is_blocked
        result.rasp.sum_eval += 1
        result.rasp.eval[rule_type] += 1
        result.rasp.match[rule_type] += int(is_triggered)
        result.rasp.timeout[rule_type] += int(waf_results.timeout)
        result.rasp.durations[rule_type] += waf_results.runtime
        result.rasp.duration += waf_results.runtime
        result.rasp.total_duration += waf_results.total_runtime


def get_waf_telemetry_results() -> Optional[Telemetry_result]:
    env = _get_asm_context()
    if env:
        return env.telemetry
    return None


def store_waf_results_data(data: list[dict[str, Any]]) -> None:
    if not data:
        return
    env = _get_asm_context()
    if env is None:
        logger.warning(WARNING_TAGS.STORE_WAF_RESULTS_NO_ASM_CONTEXT, extra=log_extra, stack_info=True)
        return
    for d in data:
        d["span_id"] = env.span.span_id
    env.waf_triggers.extend(data)


def start_context(waf_callable: Optional[WafCallable], span: Span, rc_products: str) -> None:
    if asm_config._asm_enabled:
        # it should only be called at start of a core context, when ASM_Env is not set yet
        core.set_item(
            _ASM_CONTEXT,
            ASM_Environment(
                waf_callable=waf_callable,
                span=span,
                rc_products=rc_products,
            ),
        )
        asm_request_context_set(
            core.get_item("remote_addr"),
            core.get_item("headers"),
            core.get_item("headers_case_sensitive"),
            core.get_item("block_request_callable"),
        )


def end_context(span: Span) -> None:
    env = _get_asm_context()
    if env is not None and env.span is span:
        finalize_asm_env(env)


def _on_context_ended(
    ctx: Any,
    _exc_info: tuple[Optional[type[BaseException]], Optional[BaseException], Optional[TracebackType]],
) -> None:
    env = ctx.get_item(_ASM_CONTEXT)
    if env is not None:
        finalize_asm_env(env)


def _set_headers_and_response(response: Any, headers: Any, *_: Any) -> None:
    if not asm_config._asm_enabled:
        return

    if asm_config._api_security_feature_active:
        if headers:
            # start_response was not called yet, set the HTTP response headers earlier
            if isinstance(headers, dict):
                list_headers = list(headers.items())
            else:
                list_headers = list(headers)
            set_headers_response(list_headers)
        if response and asm_config._api_security_parse_response_body:
            set_body_response(response)


def _call_waf_first(integration: Any, *_: Any) -> None:
    if not asm_config._asm_enabled:
        return
    info = f"{integration}::srb_on_request"
    logger.debug(info, extra=log_extra)
    call_waf_callback()


def _call_waf(integration: Any, *_: Any) -> None:
    if not asm_config._asm_enabled:
        return
    info = f"{integration}::srb_on_response"
    logger.debug(info, extra=log_extra)
    call_waf_callback()


def _get_headers_if_appsec() -> Optional[Any]:
    if asm_config._asm_enabled:
        return get_headers()
    return None


## headers tags

_COLLECTED_REQUEST_HEADERS_ASM_ENABLED = {
    "accept",
    "content-type",
    "user-agent",
    "x-amzn-trace-id",
    "cloudfront-viewer-ja3-fingerprint",
    "cf-ray",
    "x-cloud-trace-context",
    "x-appgw-trace-id",
    "akamai-user-risk",
    "x-sigsci-requestid",
    "x-sigsci-tags",
}

_COLLECTED_REQUEST_HEADERS = {
    "accept-encoding",
    "accept-language",
    "cf-connecting-ip",
    "cf-connecting-ipv6",
    "content-encoding",
    "content-language",
    "content-length",
    "fastly-client-ip",
    "forwarded",
    "forwarded-for",
    "host",
    "true-client-ip",
    "via",
    "x-client-ip",
    "x-cluster-client-ip",
    "x-forwarded",
    "x-forwarded-for",
    "x-real-ip",
}

_COLLECTED_REQUEST_HEADERS.update(_COLLECTED_REQUEST_HEADERS_ASM_ENABLED)


def _set_headers(span: Span, headers: Any, kind: str, only_asm_enabled: bool = False) -> None:
    for k in headers:
        if isinstance(k, tuple):
            key, value = k
        else:
            key, value = k, headers[k]
        if isinstance(key, bytes):
            key = key.decode()
        if isinstance(value, bytes):
            value = value.decode()
        if key.lower() in (_COLLECTED_REQUEST_HEADERS_ASM_ENABLED if only_asm_enabled else _COLLECTED_REQUEST_HEADERS):
            # since the header value can be a list, use `set_tag()` to ensure it is converted to a string
            span.set_tag(_normalize_tag_name(kind, key), value)


def asm_listen() -> None:
    core.on("asm.set_blocked", set_blocked_dict)
    core.on("asm.get_blocked", get_blocked, "block_config")
