# this module must not load any other unsafe appsec module directly

import collections
import contextlib
import json
import logging
import typing
from typing import Any
from typing import Optional
from typing import Union

from ddtrace._trace._inferred_proxy import SUPPORTED_PROXY_SPAN_NAMES
from ddtrace._trace.span import Span
from ddtrace.appsec._constants import API_SECURITY
from ddtrace.appsec._constants import APPSEC
from ddtrace.contrib.internal.trace_utils_base import _get_header_value_case_insensitive
from ddtrace.internal._unpatched import unpatched_json_loads
from ddtrace.internal.logger import get_logger
from ddtrace.internal.settings.asm import config as asm_config


log = get_logger(__name__)

_TRUNC_STRING_LENGTH = 1
_TRUNC_CONTAINER_DEPTH = 4
_TRUNC_CONTAINER_SIZE = 2


class _observator:
    def __init__(self) -> None:
        self.string_length: Optional[int] = None
        self.container_size: Optional[int] = None
        self.container_depth: Optional[int] = None

    def set_string_length(self, length: int) -> None:
        if self.string_length is None:
            self.string_length = length
        else:
            self.string_length = max(self.string_length, length)

    def set_container_size(self, size: int) -> None:
        if self.container_size is None:
            self.container_size = size
        else:
            self.container_size = max(self.container_size, size)

    def set_container_depth(self, depth: int) -> None:
        if self.container_depth is None:
            self.container_depth = depth
        else:
            self.container_depth = max(self.container_depth, depth)

    def __repr__(self) -> str:
        return f"_observator(length={self.string_length}, size={self.container_size}, depth={self.container_depth})"


class DDWaf_result:
    __slots__ = [
        "return_code",
        "data",
        "actions",
        "runtime",
        "total_runtime",
        "timeout",
        "truncation",
        "meta_tags",
        "metrics",
        "api_security",
        "keep",
    ]

    def __init__(
        self,
        return_code: int,
        data: list[dict[str, Any]],
        actions: dict[str, Any],
        runtime: float,
        total_runtime: float,
        timeout: bool,
        truncation: _observator,
        derivatives: dict[str, Any],
        keep: bool = False,
    ) -> None:
        self.return_code = return_code
        self.data = data
        self.actions = actions
        self.runtime = runtime
        self.total_runtime = total_runtime
        self.timeout = timeout
        self.truncation = truncation
        self.metrics: dict[str, Union[int, float]] = {}
        self.meta_tags: dict[str, str] = {}
        self.api_security: dict[str, str] = {}
        for k, v in derivatives.items():
            if k.startswith("_dd.appsec.s."):
                self.api_security[k] = v
            elif isinstance(v, str):
                self.meta_tags[k] = v
            elif isinstance(v, bool):
                self.metrics[k] = int(v)
            else:
                self.metrics[k] = v
        self.keep = keep

    def __repr__(self) -> str:
        return (
            f"DDWaf_result(return_code: {self.return_code} data: {self.data},"
            f" actions: {self.actions}, runtime: {self.runtime},"
            f" total_runtime: {self.total_runtime}, timeout: {self.timeout},"
            f" truncation: {self.truncation}, meta_tags: {self.meta_tags})"
            f" metrics: {self.metrics}, api_security: {self.api_security}, keep: {self.keep}"
        )


Binding_error = DDWaf_result(-127, [], {}, 0.0, 0.0, False, _observator(), {})


class DDWaf_info:
    __slots__ = ["loaded", "failed", "errors", "version"]

    def __init__(self, loaded: int, failed: int, errors: str, version: str) -> None:
        self.loaded = loaded
        self.failed = failed
        self.errors = errors
        self.version = version

    def __repr__(self) -> str:
        return "{loaded: %d, failed: %d, errors: %s, version: %s}" % (
            self.loaded,
            self.failed,
            self.errors,
            self.version,
        )


class Truncation_result:
    __slots__ = ["string_length", "container_size", "container_depth"]

    def __init__(self) -> None:
        self.string_length: list[int] = []
        self.container_size: list[int] = []
        self.container_depth: list[int] = []


class Rasp_result:
    __slots__ = ["blocked", "sum_eval", "duration", "total_duration", "eval", "match", "timeout", "durations"]

    def __init__(self) -> None:
        self.blocked = False
        self.sum_eval = 0
        self.duration = 0.0
        self.total_duration = 0.0
        self.eval: dict[str, int] = collections.defaultdict(int)
        self.match: dict[str, int] = collections.defaultdict(int)
        self.timeout: dict[str, int] = collections.defaultdict(int)
        self.durations: dict[str, float] = collections.defaultdict(float)


class Block_config:
    __slots__ = ["block_id", "grpc_status_code", "status_code", "type", "content_type", "location"]

    def __init__(
        self,
        type: str = "auto",  # noqa: A002
        status_code: int = 403,
        grpc_status_code: int = 10,
        security_response_id: str = "default",
        location: str = "",
        **_kwargs: Any,
    ) -> None:
        self.block_id: str = security_response_id
        self.grpc_status_code: int = grpc_status_code
        self.status_code: int = status_code
        self.type: str = type
        self.location = location.replace(APPSEC.SECURITY_RESPONSE_ID, security_response_id)
        self.content_type: str = "application/json"

    def get(self, key: str, default: Any = None) -> Union[str, int]:
        """
        Dictionary-like get method for backward compatibility with Lambda integration.

        Returns the attribute value if it exists, otherwise returns the default value.
        This allows Block_config to be used in contexts that expect dictionary-like access.
        """
        if key == "content-type":
            key = "content_type"
        return getattr(self, key, default)

    def __getitem__(self, key: str) -> Optional[Union[str, int]]:
        if key == "content-type":
            key = "content_type"
        return getattr(self, key, None)

    def __contains__(self, key: str) -> bool:
        if key == "content-type":
            key = "content_type"
        return bool(getattr(self, key, None))


class Telemetry_result:
    __slots__ = [
        "blocked",
        "triggered",
        "timeout",
        "version",
        "duration",
        "total_duration",
        "truncation",
        "rasp",
        "rate_limited",
        "error",
    ]

    def __init__(self) -> None:
        self.blocked = False
        self.triggered = False
        self.timeout = 0
        self.version: str = ""
        self.duration = 0.0
        self.total_duration = 0.0
        self.truncation = Truncation_result()
        self.rasp = Rasp_result()
        self.rate_limited = False
        self.error = 0


def parse_response_body(raw_body: Any, headers: Any) -> Optional[Any]:
    if not raw_body:
        return None

    if isinstance(raw_body, dict):
        return raw_body

    if not headers:
        return None
    content_type = _get_header_value_case_insensitive(
        {str(k): str(v) for k, v in dict(headers).items()},
        "content-type",
    )
    if not content_type:
        return None

    def access_body(bd: Any) -> Any:
        if isinstance(bd, list) and isinstance(bd[0], (str, bytes)):
            bd = bd[0][:0].join(bd)
        if getattr(bd, "decode", False):
            bd = bd.decode("UTF-8", errors="ignore")
        if len(bd) >= API_SECURITY.MAX_PAYLOAD_SIZE:
            raise ValueError("response body larger than 16MB")
        return bd

    req_body = None
    try:
        # TODO handle charset
        if "json" in content_type:
            req_body = unpatched_json_loads(access_body(raw_body))
        elif "xml" in content_type:
            import ddtrace.vendor.xmltodict as xmltodict

            req_body = xmltodict.parse(access_body(raw_body))
        else:
            return None
    except Exception:
        log.debug("Failed to parse response body", exc_info=True)
    else:
        return req_body
    return None


def _hash_user_id(user_id: str) -> str:
    import hashlib

    return f"anon_{hashlib.sha256(user_id.encode()).hexdigest()[:32]}"


def _safe_userid(user_id: Any) -> Optional[Any]:
    try:
        _ = int(user_id)
        return user_id
    except ValueError:
        try:
            # Import uuid lazily because this also imports threading via the
            # platform module
            import uuid

            _ = uuid.UUID(user_id)
            return user_id
        except ValueError:
            pass

    return None


class _UserInfoRetriever:
    def __init__(self, user: Any) -> None:
        self.user = user
        self.possible_user_id_fields = ["pk", "id", "uid", "userid", "user_id", "PK", "ID", "UID", "USERID"]
        self.possible_login_fields = ["username", "user", "login", "USERNAME", "USER", "LOGIN"]
        self.possible_email_fields = ["email", "mail", "address", "EMAIL", "MAIL", "ADDRESS"]
        self.possible_name_fields = [
            "name",
            "fullname",
            "full_name",
            "first_name",
            "NAME",
            "FULLNAME",
            "FULL_NAME",
            "FIRST_NAME",
        ]

    def find_in_user_model(self, possible_fields: typing.Sequence[str]) -> typing.Optional[str]:
        for field in possible_fields:
            value = getattr(self.user, field, None)
            if value is not None:
                return value

        return None  # explicit to make clear it has a meaning

    def get_userid(self) -> Any:
        user_login = getattr(self.user, asm_config._user_model_login_field, None)
        if user_login is not None:
            return user_login

        user_login = self.find_in_user_model(self.possible_user_id_fields)
        return user_login

    def get_username(self) -> Any:
        username = getattr(self.user, asm_config._user_model_name_field, None)
        if username is not None:
            return username

        if hasattr(self.user, "get_username"):
            try:
                return self.user.get_username()
            except Exception:
                log.debug("User model get_username member produced an exception: ", exc_info=True)

        return self.find_in_user_model(self.possible_login_fields)

    def get_user_email(self) -> Any:
        email = getattr(self.user, asm_config._user_model_email_field, None)
        if email is not None:
            return email

        return self.find_in_user_model(self.possible_email_fields)

    def get_name(self) -> Any:
        name = getattr(self.user, asm_config._user_model_name_field, None)
        if name is not None:
            return name

        return self.find_in_user_model(self.possible_name_fields)

    def get_user_info(self, login: bool = False, email: bool = False, name: bool = False) -> tuple[Any, dict[str, Any]]:
        """
        In safe mode, try to get the user id from the user object.
        In extended mode, try to also get the username (which will be the returned user_id),
        email and name.
        """
        user_extra_info = {}

        user_id = self.get_userid()
        if user_id is None:
            return None, {}

        if login:
            user_extra_info["login"] = self.get_username()
        if email:
            user_extra_info["email"] = self.get_user_email()
        if name:
            user_extra_info["name"] = self.get_name()
        return user_id, user_extra_info


def has_triggers(span: Span) -> bool:
    if asm_config._use_metastruct_for_triggers:
        return (span._get_struct_tag(APPSEC.STRUCT) or {}).get("triggers", None) is not None
    return span.get_tag(APPSEC.JSON) is not None


def get_triggers(span: Span) -> Any:
    if asm_config._use_metastruct_for_triggers:
        return (span._get_struct_tag(APPSEC.STRUCT) or {}).get("triggers", None)
    json_payload = span.get_tag(APPSEC.JSON)
    if json_payload:
        try:
            return json.loads(json_payload).get("triggers", None)
        except Exception:
            log.debug("Failed to parse triggers", exc_info=True)
    return None


def add_context_log(logger: logging.Logger, msg: str, offset: int = 0) -> str:
    filename, line_number, function_name, _stack_info = logger.findCaller(False, 3 + offset)
    return f"{msg}[{filename}, line {line_number}, in {function_name}]"


@contextlib.contextmanager
def unpatching_popen() -> typing.Iterator[None]:
    """
    Context manager to temporarily unpatch `subprocess.Popen` for testing purposes.
    This is useful to ensure that the original `Popen` behavior is restored after the context.
    """
    import os
    import subprocess  # nosec B404

    from ddtrace.internal._unpatched import unpatched_close
    from ddtrace.internal._unpatched import unpatched_Popen

    original_os_close = os.close
    os.close = unpatched_close
    original_popen = subprocess.Popen
    setattr(subprocess, "Popen", unpatched_Popen)
    # Save the original bypass flag value
    original_bypass_flag = asm_config._bypass_instrumentation_for_waf
    asm_config._bypass_instrumentation_for_waf = True
    try:
        yield
    finally:
        setattr(subprocess, "Popen", original_popen)
        os.close = original_os_close
        # In tests, restore the original value to avoid corrupting test configurations
        # In production, force to False to ensure instrumentation is re-enabled
        if asm_config._is_testing_instrumentation_for_waf:
            # If it was already True, restore it (likely a test scenario)
            asm_config._bypass_instrumentation_for_waf = original_bypass_flag
        else:
            # If it was False, keep it False (normal production scenario)
            asm_config._bypass_instrumentation_for_waf = False


def is_inferred_span(span: Span) -> bool:
    return span.name in SUPPORTED_PROXY_SPAN_NAMES
