import typing

from ddtrace.appsec import _constants
from ddtrace.appsec._deduplications import deduplication
from ddtrace.appsec._utils import DDWaf_info
from ddtrace.appsec._utils import Telemetry_result
from ddtrace.appsec._utils import _observator
from ddtrace.internal import telemetry
import ddtrace.internal.logger as ddlogger
from ddtrace.internal.telemetry.constants import TELEMETRY_LOG_LEVEL
from ddtrace.internal.telemetry.constants import TELEMETRY_NAMESPACE


UNKNOWN_VERSION = "unknown"
ddwaf_version = UNKNOWN_VERSION

bool_str = ("false", "true")

logger = ddlogger.get_logger(__name__)


class WARNING_TAGS(metaclass=_constants.Constant_Class):
    TELEMETRY_LOGS = "telemetry_logs"
    TELEMETRY_METRICS = "telemetry_metrics"


# limit warnings to one per day per process
for _, tag in WARNING_TAGS:
    ddlogger.set_tag_rate_limit(tag, ddlogger.HOUR)

log_extra = {"product": "appsec", "stack_limit": 4, "exec_limit": 4}


@deduplication
def _set_waf_error_log(msg: str, version: str, action: str, error_level: bool = True) -> None:
    """used for waf configuration errors"""
    try:
        log_tags = {
            "waf_version": ddwaf_version,
            "event_rules_version": version or UNKNOWN_VERSION,
        }
        level = TELEMETRY_LOG_LEVEL.ERROR if error_level else TELEMETRY_LOG_LEVEL.WARNING
        telemetry.telemetry_writer.add_log(level, msg, tags=log_tags)
    except Exception:
        extra = {"product": "appsec", "exec_limit": 6, "more_info": ":waf:error"}
        logger.warning(WARNING_TAGS.TELEMETRY_LOGS, extra=extra, exc_info=True)
    try:
        tags = (
            ("waf_version", ddwaf_version),
            ("event_rules_version", version or UNKNOWN_VERSION),
            ("action", action),
        )
        telemetry.telemetry_writer.add_count_metric(TELEMETRY_NAMESPACE.APPSEC, "waf.config_errors", 1, tags=tags)
    except Exception:
        extra = {"product": "appsec", "exec_limit": 6, "more_info": ":waf:config_errors"}
        logger.warning(WARNING_TAGS.TELEMETRY_METRICS, extra=extra, exc_info=True)


def _set_waf_updates_metric(info: DDWaf_info, success: bool) -> None:
    try:
        tags: tuple[tuple[str, str], ...] = (
            ("event_rules_version", info.version or UNKNOWN_VERSION),
            ("waf_version", ddwaf_version),
        )

        telemetry.telemetry_writer.add_count_metric(
            TELEMETRY_NAMESPACE.APPSEC, "waf.updates", 1, tags=tags + (("success", bool_str[success]),)
        )
    except Exception:
        extra = {"product": "appsec", "exec_limit": 6, "more_info": ":waf:updates"}
        logger.warning(WARNING_TAGS.TELEMETRY_METRICS, extra=extra, exc_info=True)


def _set_waf_init_metric(info: DDWaf_info, success: bool) -> None:
    try:
        tags: tuple[tuple[str, str], ...] = (
            ("event_rules_version", info.version or UNKNOWN_VERSION),
            ("waf_version", ddwaf_version),
        )

        telemetry.telemetry_writer.add_count_metric(
            TELEMETRY_NAMESPACE.APPSEC, "waf.init", 1, tags=tags + (("success", bool_str[success]),)
        )
        if not success:
            telemetry.telemetry_writer.add_count_metric(
                TELEMETRY_NAMESPACE.APPSEC, "waf.config_errors", 1, tags=tags + (("action", "init"),)
            )
    except Exception:
        extra = {"product": "appsec", "exec_limit": 6, "more_info": ":waf:init"}
        logger.warning(WARNING_TAGS.TELEMETRY_METRICS, extra=extra, exc_info=True)


_TYPES_AND_TAGS = {
    _constants.EXPLOIT_PREVENTION.TYPE.CMDI: (("rule_type", "command_injection"), ("rule_variant", "exec")),
    _constants.EXPLOIT_PREVENTION.TYPE.SHI: (("rule_type", "command_injection"), ("rule_variant", "shell")),
    _constants.EXPLOIT_PREVENTION.TYPE.LFI: (("rule_type", "lfi"),),
    _constants.EXPLOIT_PREVENTION.TYPE.SSRF: (("rule_type", "ssrf"),),
    _constants.EXPLOIT_PREVENTION.TYPE.SSRF_REQ: (("rule_type", "ssrf"), ("rule_variant", "request")),
    _constants.EXPLOIT_PREVENTION.TYPE.SSRF_RES: (("rule_type", "ssrf"), ("rule_variant", "response")),
    _constants.EXPLOIT_PREVENTION.TYPE.SQLI: (("rule_type", "sql_injection"),),
}

TAGS_STRING_LENGTH = (("truncation_reason", "1"),)
TAGS_CONTAINER_SIZE = (("truncation_reason", "2"),)
TAGS_CONTAINER_DEPTH = (("truncation_reason", "4"),)


def _report_waf_truncations(observator: _observator) -> None:
    try:
        bitfield = 0
        if observator.string_length is not None:
            bitfield |= 1
            telemetry.telemetry_writer.add_distribution_metric(
                TELEMETRY_NAMESPACE.APPSEC, "waf.truncated_value_size", observator.string_length, TAGS_STRING_LENGTH
            )
        if observator.container_size is not None:
            bitfield |= 2
            telemetry.telemetry_writer.add_distribution_metric(
                TELEMETRY_NAMESPACE.APPSEC, "waf.truncated_value_size", observator.container_size, TAGS_CONTAINER_SIZE
            )
        if observator.container_depth is not None:
            bitfield |= 4
            telemetry.telemetry_writer.add_distribution_metric(
                TELEMETRY_NAMESPACE.APPSEC, "waf.truncated_value_size", observator.container_depth, TAGS_CONTAINER_DEPTH
            )
        if bitfield:
            telemetry.telemetry_writer.add_count_metric(
                TELEMETRY_NAMESPACE.APPSEC,
                "waf.input_truncated",
                1,
                tags=(("truncation_reason", str(bitfield)),),
            )
    except Exception:
        extra = {"product": "appsec", "exec_limit": 6, "more_info": ":waf:truncations"}
        logger.warning(WARNING_TAGS.TELEMETRY_METRICS, extra=extra, exc_info=True)


def _report_waf_run_error(error: int, rule_version: str, rule_type: typing.Optional[str]) -> None:
    """used for waf run errors"""
    try:
        if rule_type is None:
            waf_tags = (
                ("waf_version", ddwaf_version),
                ("event_rules_version", rule_version or UNKNOWN_VERSION),
                ("waf_error", str(error)),
            )
            telemetry.telemetry_writer.add_count_metric(TELEMETRY_NAMESPACE.APPSEC, "waf.error", 1, tags=waf_tags)
        else:
            rasp_tags = (
                ("waf_version", ddwaf_version),
                ("event_rules_version", rule_version or UNKNOWN_VERSION),
                ("waf_error", str(error)),
            ) + _TYPES_AND_TAGS.get(rule_type, ())
            telemetry.telemetry_writer.add_count_metric(TELEMETRY_NAMESPACE.APPSEC, "rasp.error", 1, tags=rasp_tags)
    except Exception:
        extra = {"product": "appsec", "exec_limit": 6, "more_info": f":waf:run_error:{rule_type or 'srb'}"}
        logger.warning(WARNING_TAGS.TELEMETRY_METRICS, extra=extra, exc_info=True)
        raise


def _set_waf_request_metrics(result: Telemetry_result) -> None:
    try:
        truncation = result.truncation
        input_truncated = bool(truncation.string_length or truncation.container_size or truncation.container_depth)
        tags_request = (
            ("event_rules_version", result.version or UNKNOWN_VERSION),
            ("waf_version", ddwaf_version),
            ("rule_triggered", bool_str[result.triggered]),
            ("request_blocked", bool_str[result.blocked]),
            ("waf_timeout", bool_str[bool(result.timeout)]),
            ("input_truncated", bool_str[input_truncated]),
            ("waf_error", bool_str[result.error < 0]),  # waf_error is a boolean in waf.requests
            ("rate_limited", bool_str[result.rate_limited]),
        )

        telemetry.telemetry_writer.add_count_metric(TELEMETRY_NAMESPACE.APPSEC, "waf.requests", 1, tags=tags_request)
        rasp = result.rasp
        if rasp.sum_eval:
            for t, n in [("eval", "rasp.rule.eval"), ("match", "rasp.rule.match"), ("timeout", "rasp.timeout")]:
                for rule_type, value in getattr(rasp, t).items():
                    if value:
                        tags = _TYPES_AND_TAGS.get(rule_type, ()) + (
                            ("waf_version", ddwaf_version),
                            ("event_rules_version", result.version or UNKNOWN_VERSION),
                        )
                        if t == "match":
                            tags = tags + (("block", ["irrelevant", "success"][rasp.blocked]),)
                        telemetry.telemetry_writer.add_count_metric(TELEMETRY_NAMESPACE.APPSEC, n, value, tags=tags)
    except Exception:
        extra = {"product": "appsec", "exec_limit": 6, "more_info": ":waf:request"}
        logger.warning(WARNING_TAGS.TELEMETRY_METRICS, extra=extra, exc_info=True)


def _report_api_security(route: bool, schemas: int, framework: str = "") -> None:
    try:
        if route:
            metric_name = "api_security.request.schema" if schemas > 0 else "api_security.request.no_schema"
            telemetry.telemetry_writer.add_count_metric(
                TELEMETRY_NAMESPACE.APPSEC, metric_name, 1, tags=(("framework", framework),)
            )
        else:
            telemetry.telemetry_writer.add_count_metric(
                TELEMETRY_NAMESPACE.APPSEC, "api_security.missing_route", 1, tags=(("framework", framework),)
            )
    except Exception:
        extra = {"product": "appsec", "exec_limit": 6, "more_info": ":api_security"}
        logger.warning(WARNING_TAGS.TELEMETRY_METRICS, extra=extra, exc_info=True)


def _report_rasp_skipped(rule_type: str, import_error: bool) -> None:
    try:
        tags = _TYPES_AND_TAGS.get(rule_type, ()) + (("reason", "app-startup" if import_error else "out-of-request"),)
        telemetry.telemetry_writer.add_count_metric(TELEMETRY_NAMESPACE.APPSEC, "rasp.rule.skipped", 1, tags=tags)
    except Exception:
        extra = {
            "product": "appsec",
            "exec_limit": 6,
            "more_info": f":waf:rasp_rule_skipped:{rule_type}:{import_error}",
        }
        logger.warning(WARNING_TAGS.TELEMETRY_METRICS, extra=extra, exc_info=True)


def _report_ato_sdk_usage(event_type: str, v2: bool = True) -> None:
    version = "v2" if v2 else "v1"
    try:
        tags = (("event_type", event_type), ("sdk_version", version))
        telemetry.telemetry_writer.add_count_metric(TELEMETRY_NAMESPACE.APPSEC, "sdk.event", 1, tags=tags)
    except Exception:
        extra = {
            "product": "appsec",
            "exec_limit": 6,
            "more_info": f":waf:sdk.event:{event_type}:{version}",
        }
        logger.warning(WARNING_TAGS.TELEMETRY_METRICS, extra=extra, exc_info=True)
