import inspect
from itertools import chain
from typing import Any
from typing import Iterable
from typing import Optional

from ddtrace._trace.span import Span
from ddtrace.appsec import _asm_request_context
from ddtrace.appsec._constants import STACK_TRACE
from ddtrace.internal import core
from ddtrace.internal.settings.asm import config as asm_config


def report_stack(
    message: Optional[str] = None,
    span: Optional[Span] = None,
    crop_stack: Optional[str] = None,
    stack_id: Optional[str] = None,
    namespace: str = STACK_TRACE.RASP,
) -> bool:
    """
    Report a stack trace to the current span.
    This is used to report stack traces for exploit prevention.
    Return the stack id for the reported stack trace to link it in triggers.
    """
    if not asm_config._ep_stack_trace_enabled:
        # stack trace report disabled
        return False
    if namespace == STACK_TRACE.RASP and not (asm_config._asm_enabled and asm_config._ep_enabled):
        # exploit prevention stack trace with ep disabled
        return False
    if namespace == STACK_TRACE.IAST and not (asm_config._iast_enabled):
        # iast stack trace with iast disabled
        return False

    if namespace == STACK_TRACE.IAST and asm_config._iast_use_root_span:
        span = core.get_root_span()

    if span is None:
        span = _asm_request_context.get_entry_span()

    if span is None or stack_id is None:
        return False
    appsec_traces = span._get_struct_tag(STACK_TRACE.TAG) or {}
    current_list = appsec_traces.get(namespace, [])
    total_length = len(current_list)

    # Do not report more than the maximum number of stack traces
    if asm_config._ep_max_stack_traces and total_length >= asm_config._ep_max_stack_traces:
        return False

    stack = inspect.stack()
    if crop_stack is not None:
        for i, frame in enumerate(stack):
            if stack[i].frame.f_code.co_name == crop_stack:
                stack = stack[i + 1 :]
                break
    res: dict[str, Any] = {
        "language": "python",
        "id": stack_id,
    }
    if message is not None:
        res["message"] = message
    if len(stack) > asm_config._ep_max_stack_trace_depth > 0:
        top_stack = int(asm_config._ep_max_stack_trace_depth * asm_config._ep_stack_top_percent / 100)
        bottom_stack = asm_config._ep_max_stack_trace_depth - top_stack
        iterator: Iterable[int] = chain(range(top_stack), range(len(stack) - bottom_stack, len(stack)))
    else:
        iterator = range(len(stack))
    frames = [
        {
            "id": i,
            "function": getattr(stack[i].frame.f_code, "co_qualname", stack[i].frame.f_code.co_name),
            "file": stack[i].filename,
            "line": stack[i].lineno,
        }
        for i in iterator
    ]
    res["frames"] = frames
    current_list.append(res)
    appsec_traces[namespace] = current_list
    span._set_struct_tag(STACK_TRACE.TAG, appsec_traces)
    return True
