import io
import json

from ddtrace.appsec._asm_request_context import _CALLBACKS
from ddtrace.appsec._asm_request_context import _call_waf_first
from ddtrace.appsec._asm_request_context import _on_context_ended
from ddtrace.appsec._asm_request_context import _set_headers_and_response
from ddtrace.appsec._asm_request_context import call_waf_callback
from ddtrace.appsec._asm_request_context import get_blocked
from ddtrace.appsec._asm_request_context import get_value
from ddtrace.appsec._asm_request_context import in_asm_context
from ddtrace.appsec._asm_request_context import is_blocked
from ddtrace.appsec._asm_request_context import set_block_request_callable
from ddtrace.appsec._asm_request_context import set_value
from ddtrace.appsec._asm_request_context import set_waf_address
from ddtrace.appsec._utils import Block_config
from ddtrace.contrib import trace_utils
from ddtrace.contrib.internal.trace_utils_base import _get_request_header_user_agent
from ddtrace.contrib.internal.trace_utils_base import _set_url_tag
from ddtrace.ext import http
from ddtrace.internal import core
from ddtrace.internal.constants import REQUEST_PATH_PARAMS
from ddtrace.internal.constants import RESPONSE_HEADERS
from ddtrace.internal.logger import get_logger
from ddtrace.internal.settings.asm import config as asm_config
from ddtrace.internal.utils import http as http_utils
import ddtrace.vendor.xmltodict as xmltodict


logger = get_logger(__name__)

_BODY_METHODS = {"POST", "PUT", "DELETE", "PATCH"}


def _get_content_length(environ):
    content_length = environ.get("CONTENT_LENGTH")
    transfer_encoding = environ.get("HTTP_TRANSFER_ENCODING")

    if transfer_encoding == "chunked" or content_length is None:
        return None

    try:
        return max(0, int(content_length))
    except Exception:
        return 0


def _on_request_span_modifier(
    ctx, flask_config, request, environ, _HAS_JSON_MIXIN, flask_version, flask_version_str, exception_type
):
    req_body = None
    if asm_config._asm_enabled and request.method in _BODY_METHODS:
        content_type = request.content_type
        wsgi_input = environ.get("wsgi.input", "")

        # Copy wsgi input if not seekable
        if wsgi_input:
            try:
                seekable = wsgi_input.seekable()
            # expect AttributeError in normal error cases
            except Exception:
                seekable = False
            if not seekable:
                # https://gist.github.com/mitsuhiko/5721547
                # Provide wsgi.input as an end-of-file terminated stream.
                # In that case wsgi.input_terminated is set to True
                # and an app is required to read to the end of the file and disregard CONTENT_LENGTH for reading.
                if environ.get("wsgi.input_terminated"):
                    body = wsgi_input.read()
                else:
                    content_length = _get_content_length(environ)
                    body = wsgi_input.read(content_length) if content_length else b""
                environ["wsgi.input"] = io.BytesIO(body)

        try:
            if content_type in ("application/json", "text/json"):
                if _HAS_JSON_MIXIN and hasattr(request, "json") and request.json:
                    req_body = request.json
                elif request.data is None or request.data == b"":
                    req_body = None
                else:
                    req_body = json.loads(request.data.decode("UTF-8"))
            elif content_type in ("application/xml", "text/xml"):
                req_body = xmltodict.parse(request.get_data())
            elif hasattr(request, "form"):
                req_body = request.form.to_dict()
            else:
                # no raw body
                req_body = None
        except Exception:
            logger.debug("Failed to parse request body", exc_info=True)
        finally:
            # Reset wsgi input to the beginning
            if wsgi_input:
                if seekable:
                    wsgi_input.seek(0)
                else:
                    environ["wsgi.input"] = io.BytesIO(body)
    return req_body


def _on_flask_blocked_request(span):
    span._set_tag_str(http.STATUS_CODE, "403")
    request = core.find_item("flask_request")
    try:
        base_url = getattr(request, "base_url", None)
        query_string = getattr(request, "query_string", None)
        if base_url and query_string:
            _set_url_tag(core.find_item("flask_config"), span, base_url, query_string)
        if query_string and core.find_item("flask_config").trace_query_string:
            span._set_tag_str(http.QUERY_STRING, query_string)
        if request.method is not None:
            span._set_tag_str(http.METHOD, request.method)
        user_agent = _get_request_header_user_agent(request.headers)
        if user_agent:
            span._set_tag_str(http.USER_AGENT, user_agent)
    except Exception as e:
        logger.warning("Could not set some span tags on blocked request: %s", str(e))


def _on_start_response_blocked(ctx, flask_config, response_headers, status):
    trace_utils.set_http_meta(ctx["req_span"], flask_config, status_code=status, response_headers=response_headers)


def _on_wrapped_view(kwargs):
    callback_block = None
    # if Appsec is enabled, we can try to block as we have the path parameters at that point
    if asm_config._asm_enabled and in_asm_context():
        logger.debug("asm_context::flask::srb_on_request_param")
        if kwargs:
            set_waf_address(REQUEST_PATH_PARAMS, kwargs)
        call_waf_callback()
        if is_blocked():
            callback_block = get_value(_CALLBACKS, "flask_block")
    return callback_block


def _on_pre_tracedrequest(ctx):
    import functools

    current_span = ctx.span
    block_request_callable = ctx.get_item("block_request_callable")
    if asm_config._asm_enabled:
        from ddtrace.appsec._asm_request_context import block_request

        set_block_request_callable(functools.partial(block_request_callable, current_span))
        if get_blocked():
            block_request()


def _on_block_decided(callback):
    if not asm_config._asm_enabled:
        return

    set_value(_CALLBACKS, "flask_block", callback)
    core.on("flask.block.request.content", callback, "block_requested")


def _wsgi_make_block_content(ctx, construct_url):
    middleware = ctx.get_item("middleware")
    req_span = ctx.get_item("req_span")
    headers = ctx.get_item("headers")
    environ = ctx.get_item("environ")
    if req_span is None:
        raise ValueError("request span not found")
    block_config = get_blocked() or Block_config()
    ctype = None
    if block_config.type == "none":
        content = b""
        resp_headers = [("content-type", "text/plain; charset=utf-8"), ("location", block_config.location)]
    else:
        ctype = block_config.content_type
        content = http_utils._get_blocked_template(ctype, block_config.block_id).encode("UTF-8")
        resp_headers = [("content-type", ctype)]
    status = block_config.status_code
    try:
        req_span._set_tag_str(RESPONSE_HEADERS + ".content-length", str(len(content)))
        if ctype is not None:
            req_span._set_tag_str(RESPONSE_HEADERS + ".content-type", ctype)
        req_span._set_tag_str(http.STATUS_CODE, str(status))
        url = construct_url(environ)
        query_string = environ.get("QUERY_STRING")
        _set_url_tag(middleware._config, req_span, url, query_string)
        if query_string and middleware._config.trace_query_string:
            req_span._set_tag_str(http.QUERY_STRING, query_string)
        method = environ.get("REQUEST_METHOD")
        if method:
            req_span._set_tag_str(http.METHOD, method)
        user_agent = _get_request_header_user_agent(headers, headers_are_case_sensitive=True)
        if user_agent:
            req_span._set_tag_str(http.USER_AGENT, user_agent)
    except Exception as e:
        logger.warning("Could not set some span tags on blocked request: %s", str(e))
    resp_headers.append(("Content-Length", str(len(content))))
    return status, resp_headers, content


def listen():
    core.on("flask.request_call_modifier", _on_request_span_modifier, "request_body")
    core.on("flask.blocked_request_callable", _on_flask_blocked_request)
    core.on("flask.start_response.blocked", _on_start_response_blocked)
    core.on("wsgi.block.started", _wsgi_make_block_content, "status_headers_content")

    core.on("flask.finalize_request.post", _set_headers_and_response)
    core.on("flask.wrapped_view", _on_wrapped_view, "callbacks")
    core.on("flask._patched_request", _on_pre_tracedrequest)
    core.on("wsgi.block_decided", _on_block_decided)
    core.on("flask.start_response", _call_waf_first)

    core.on("context.ended.wsgi.__call__", _on_context_ended)
