import collections

from ddtrace.appsec._iast._iast_env import _get_iast_env
from ddtrace.internal.logger import get_logger
from ddtrace.internal.settings.asm import config as asm_config


log = get_logger(__name__)

# Global vulnerabilities map per endpoint - Map<String, Map<VulnerabilityType, int>>
# Key is a string combining the request method (GET, POST, ...) and HTTP route (endpoint)
# Value is another Map where the key is vulnerability type, and value is count
# Using LRU cache to limit memory usage
MAX_ENDPOINTS = 4096

RequestDictType = dict[str, int]
GlobalDictType = dict[str, RequestDictType]
GLOBAL_VULNERABILITIES_LIMIT: GlobalDictType = collections.OrderedDict()


def init_request_vulnerability_maps(context) -> RequestDictType:
    """
    Initialize vulnerability maps for a new request.

    Creates a copy of the stored counts from the global map and a new empty map for
    tracking vulnerabilities in the current request.
    """
    return GLOBAL_VULNERABILITIES_LIMIT.get(context.endpoint_key, {}).copy()


def should_process_vulnerability(vulnerability_type: str) -> bool:
    """
    Determine if a vulnerability should be processed based on the optimization algorithm.
    """
    context = _get_iast_env()
    if not context:
        return False

    if context.vulnerability_budget >= asm_config._iast_max_vulnerabilities_per_requests:
        return False
    # Get maps from context if not provided or if this is the first vulnerability
    if context.is_first_vulnerability:
        context.vulnerability_copy_global_limit = init_request_vulnerability_maps(context)
        context.vulnerabilities_request_limit = {}
        context.is_first_vulnerability = False
    # Get current count for this vulnerability type, default to 0 if not found
    current_count = context.vulnerabilities_request_limit.get(vulnerability_type, 0)
    # Get stored count for this vulnerability type from copy_map
    stored_count = context.vulnerability_copy_global_limit.get(vulnerability_type, 0)

    context.vulnerabilities_request_limit[vulnerability_type] = current_count + 1
    # If current count (before increment) is lower than stored count, skip processing
    if current_count < stored_count:
        return False

    # Vulnerability should be processed, increment budget counter
    context.vulnerability_budget += 1
    return True


def rollback_quota(vulnerability_type: str) -> bool:
    context = _get_iast_env()
    if not context:
        return False

    context.vulnerabilities_request_limit[vulnerability_type] -= 1
    if context.vulnerabilities_request_limit[vulnerability_type] <= 0:
        context.vulnerabilities_request_limit.pop(vulnerability_type)
    context.vulnerability_budget -= 1
    return True


def update_global_vulnerability_limit(context=None) -> None:
    """
    Update the global vulnerability map at the end of a request.

    Args:
        budget_used (bool): Whether the vulnerability detection budget was fully used
    """
    if context is None:
        context = _get_iast_env()
        if context is None:
            log.debug("No request context found when updating global vulnerability map")
            return

    # Check if budget was used based on actual budget count
    if context.vulnerability_budget >= asm_config._iast_max_vulnerabilities_per_requests:
        global_map_entry = GLOBAL_VULNERABILITIES_LIMIT.get(context.endpoint_key, {})
        # Update each vulnerability type with max value between request_map and global_map
        for vuln_type, count in context.vulnerabilities_request_limit.items():
            global_map_entry[vuln_type] = count
        GLOBAL_VULNERABILITIES_LIMIT[context.endpoint_key] = global_map_entry

    # Enforce LRU cache limit
    while len(GLOBAL_VULNERABILITIES_LIMIT) > MAX_ENDPOINTS:
        GLOBAL_VULNERABILITIES_LIMIT.popitem(last=False)  # type: ignore[call-arg]


def reset_request_vulnerabilities(context=None) -> None:
    if context is None:
        context = _get_iast_env()
        if context is None:
            log.debug("No request context found when updating global vulnerability map")
            return
    context.vulnerability_copy_global_limit = {}
    context.vulnerabilities_request_limit = {}
    context.is_first_vulnerability = True
    context.vulnerability_budget = 0


def _reset_global_limit():
    global GLOBAL_VULNERABILITIES_LIMIT
    GLOBAL_VULNERABILITIES_LIMIT = collections.OrderedDict()
