# -*- encoding: utf-8 -*-
from functools import partial
import sys
from types import ModuleType
import typing


if typing.TYPE_CHECKING:
    import asyncio
    import asyncio as aio

from ddtrace.internal._unpatched import _threading as ddtrace_threading
from ddtrace.internal.datadog.profiling import stack
from ddtrace.internal.module import ModuleWatchdog
from ddtrace.internal.settings.profiling import config
from ddtrace.internal.utils import get_argument_value
from ddtrace.internal.wrapping import wrap


ASYNCIO_IMPORTED: bool = False


def current_task(
    loop: typing.Optional["asyncio.AbstractEventLoop"] = None,
) -> typing.Optional["asyncio.Task[typing.Any]"]:
    return None


def get_running_loop() -> typing.Optional["asyncio.AbstractEventLoop"]:
    return None


def _task_get_name(task: "asyncio.Task[typing.Any]") -> str:
    return "Task-%d" % id(task)


def _call_init_asyncio(asyncio: ModuleType) -> None:
    from asyncio import tasks as asyncio_tasks

    if sys.hexversion >= 0x030C0000:
        scheduled_tasks = asyncio_tasks._scheduled_tasks.data  # type: ignore[attr-defined]
        eager_tasks = asyncio_tasks._eager_tasks  # type: ignore[attr-defined]
    else:
        scheduled_tasks = asyncio_tasks._all_tasks.data  # type: ignore[attr-defined]
        eager_tasks = None

    stack.init_asyncio(scheduled_tasks, eager_tasks)


def link_existing_loop_to_current_thread() -> None:
    global ASYNCIO_IMPORTED

    # Only proceed if asyncio is actually imported and available
    # Don't rely solely on ASYNCIO_IMPORTED global since it persists across forks
    if not ASYNCIO_IMPORTED or "asyncio" not in sys.modules:
        return

    import asyncio

    # Only track if there's actually a running loop
    running_loop: typing.Optional["asyncio.AbstractEventLoop"] = None
    try:
        running_loop = asyncio.get_running_loop()
    except RuntimeError:
        # No existing loop to track, nothing to do
        return

    # We have a running loop, track it
    stack.track_asyncio_loop(typing.cast(int, ddtrace_threading.current_thread().ident), running_loop)
    _call_init_asyncio(asyncio)


@ModuleWatchdog.after_module_imported("asyncio")
def _(asyncio: ModuleType) -> None:
    global ASYNCIO_IMPORTED

    ASYNCIO_IMPORTED = True

    if hasattr(asyncio, "current_task"):
        globals()["current_task"] = asyncio.current_task
    elif hasattr(asyncio.Task, "current_task"):
        globals()["current_task"] = asyncio.Task.current_task

    def _get_running_loop() -> typing.Optional["aio.AbstractEventLoop"]:
        try:
            return asyncio.get_running_loop()
        except RuntimeError:
            return None

    globals()["get_running_loop"] = _get_running_loop
    globals()["_task_get_name"] = lambda task: task.get_name()

    init_stack: bool = config.stack.enabled and stack.is_available

    # Python 3.14+: BaseDefaultEventLoopPolicy was renamed to _BaseDefaultEventLoopPolicy
    # Try both names for compatibility
    events_module = sys.modules["asyncio.events"]
    if sys.hexversion >= 0x030E0000:
        # Python 3.14+: Use _BaseDefaultEventLoopPolicy
        policy_class = getattr(events_module, "_BaseDefaultEventLoopPolicy", None)
    else:
        # Python < 3.14: Use BaseDefaultEventLoopPolicy
        policy_class = getattr(events_module, "BaseDefaultEventLoopPolicy", None)

    if policy_class is not None:

        @partial(wrap, policy_class.set_event_loop)
        def _(
            f: typing.Callable[..., typing.Any], args: tuple[typing.Any, ...], kwargs: dict[str, typing.Any]
        ) -> typing.Any:
            loop: typing.Optional["aio.AbstractEventLoop"] = get_argument_value(args, kwargs, 1, "loop")
            if init_stack:
                stack.track_asyncio_loop(typing.cast(int, ddtrace_threading.current_thread().ident), loop)
            return f(*args, **kwargs)

    if init_stack:

        @partial(wrap, sys.modules["asyncio"].tasks._GatheringFuture.__init__)
        def _(f: typing.Callable[..., None], args: tuple[typing.Any, ...], kwargs: dict[str, typing.Any]) -> None:
            try:
                return f(*args, **kwargs)
            finally:
                children = get_argument_value(args, kwargs, 1, "children")
                assert children is not None  # nosec: assert is used for typing

                # Pass an invalid positional index for 'loop'
                loop = get_argument_value(args, kwargs, -1, "loop")

                # Link the parent gathering task to the gathered children
                parent = globals()["current_task"](loop)

                for child in children:
                    stack.link_tasks(parent, child)

        @partial(wrap, sys.modules["asyncio"].tasks._wait)
        def _(
            f: typing.Callable[..., tuple[set["aio.Future[typing.Any]"], set["aio.Future[typing.Any]"]]],
            args: tuple[typing.Any, ...],
            kwargs: dict[str, typing.Any],
        ) -> typing.Any:
            try:
                return f(*args, **kwargs)
            finally:
                futures = typing.cast(set["aio.Future[typing.Any]"], get_argument_value(args, kwargs, 0, "fs"))
                loop = typing.cast("aio.AbstractEventLoop", get_argument_value(args, kwargs, 3, "loop"))

                # Link the parent gathering task to the gathered children
                parent = typing.cast("aio.Task[typing.Any]", globals()["current_task"](loop))
                for future in futures:
                    stack.link_tasks(parent, future)

        @partial(wrap, sys.modules["asyncio"].tasks.as_completed)
        def _(
            f: typing.Callable[..., typing.Generator["aio.Future[typing.Any]", typing.Any, None]],
            args: tuple[typing.Any, ...],
            kwargs: dict[str, typing.Any],
        ) -> typing.Any:
            loop = typing.cast(typing.Optional["aio.AbstractEventLoop"], kwargs.get("loop"))
            parent: typing.Optional["aio.Task[typing.Any]"] = globals()["current_task"](loop)

            if parent is not None:
                fs = typing.cast(typing.Iterable["aio.Future[typing.Any]"], get_argument_value(args, kwargs, 0, "fs"))
                futures: set["aio.Future"] = {asyncio.ensure_future(f, loop=loop) for f in set(fs)}
                for future in futures:
                    stack.link_tasks(parent, future)

                # Replace fs with the ensured futures to avoid double-wrapping
                args = (futures,) + args[1:]

            return f(*args, **kwargs)

        # Wrap asyncio.shield to link parent task to shielded future
        @partial(wrap, sys.modules["asyncio"].tasks.shield)
        def _(
            f: typing.Callable[..., "aio.Future[typing.Any]"],
            args: tuple[typing.Any, ...],
            kwargs: dict[str, typing.Any],
        ) -> typing.Any:
            loop = typing.cast(typing.Optional["aio.AbstractEventLoop"], kwargs.get("loop"))
            awaitable = typing.cast("aio.Future[typing.Any]", get_argument_value(args, kwargs, 0, "arg"))
            future = asyncio.ensure_future(awaitable, loop=loop)

            parent = globals()["current_task"]()
            if parent is not None:
                stack.link_tasks(parent, future)

            args = (future,) + args[1:]

            return f(*args, **kwargs)

        # Wrap asyncio.TaskGroup.create_task to link parent task to created tasks (Python 3.11+)
        if sys.hexversion >= 0x030B0000:  # Python 3.11+
            taskgroups_module = sys.modules.get("asyncio.taskgroups")
            if taskgroups_module is not None:
                taskgroup_class = getattr(taskgroups_module, "TaskGroup", None)
                if taskgroup_class is not None and hasattr(taskgroup_class, "create_task"):

                    @partial(wrap, taskgroup_class.create_task)
                    def _(
                        f: typing.Callable[..., "aio.Task[typing.Any]"],
                        args: tuple[typing.Any, ...],
                        kwargs: dict[str, typing.Any],
                    ) -> typing.Any:
                        result = f(*args, **kwargs)

                        parent = globals()["current_task"]()
                        if parent is not None and result is not None:
                            # Link parent task to the task created by TaskGroup
                            stack.link_tasks(parent, result)

                        return result

        # Note: asyncio.timeout and asyncio.timeout_at don't create child tasks.
        # They are context managers that schedule a callback to cancel the current task
        # if it times out. The timeout._task is the same as the current task, so there's
        # no parent-child relationship to link. The timeout mechanism is handled by the
        # event loop's timeout handler, not by creating new tasks.
        @partial(wrap, sys.modules["asyncio"].tasks.create_task)
        def _(
            f: typing.Callable[..., "aio.Task[typing.Any]"],
            args: tuple[typing.Any, ...],
            kwargs: dict[str, typing.Any],
        ) -> "aio.Task[typing.Any]":
            # kwargs will typically contain context (Python 3.11+ only) and eager_start (Python 3.14+ only)
            task: "aio.Task[typing.Any]" = f(*args, **kwargs)
            parent: typing.Optional["aio.Task[typing.Any]"] = globals()["current_task"]()

            if parent is not None:
                stack.weak_link_tasks(parent, task)

            return task

        _call_init_asyncio(asyncio)


@ModuleWatchdog.after_module_imported("uvloop")
def _(uvloop: ModuleType) -> None:
    """Hook uvloop to track event loops.

    uvloop doesn't inherit from BaseDefaultEventLoopPolicy, and on Python 3.11+
    uvloop.run() uses asyncio.Runner which bypasses set_event_loop entirely.
    We hook new_event_loop to catch all uvloop loop creations.

    We also hook EventLoopPolicy.set_event_loop for the deprecated uvloop.install()
    + asyncio.run() pattern.
    """
    # Check if uvloop support is disabled via configuration
    if not config.stack.uvloop:  # pyright: ignore[reportAttributeAccessIssue]
        return

    import asyncio

    init_stack: bool = config.stack.enabled and stack.is_available

    # Wrap uvloop.new_event_loop to track loops when they're created
    new_event_loop_func = getattr(uvloop, "new_event_loop", None)
    if new_event_loop_func is not None:

        @partial(wrap, new_event_loop_func)
        def _(
            f: typing.Callable[..., "asyncio.AbstractEventLoop"],
            args: tuple[typing.Any, ...],
            kwargs: dict[str, typing.Any],
        ) -> "asyncio.AbstractEventLoop":
            loop = f(*args, **kwargs)
            if init_stack:
                thread_id = typing.cast(int, ddtrace_threading.current_thread().ident)
                stack.set_uvloop_mode(thread_id, True)

                stack.track_asyncio_loop(thread_id, loop)
                # Ensure asyncio task tracking is initialized
                _call_init_asyncio(asyncio)

            return loop

    # Wrap uvloop.EventLoopPolicy.set_event_loop for uvloop.install() + asyncio.run() pattern
    policy_class = getattr(uvloop, "EventLoopPolicy", None)
    if policy_class is not None and hasattr(policy_class, "set_event_loop"):

        @partial(wrap, policy_class.set_event_loop)
        def _(
            f: typing.Callable[..., typing.Any], args: tuple[typing.Any, ...], kwargs: dict[str, typing.Any]
        ) -> typing.Any:
            thread_id = typing.cast(int, ddtrace_threading.current_thread().ident)
            if init_stack:
                stack.set_uvloop_mode(thread_id, True)

            loop: typing.Optional["asyncio.AbstractEventLoop"] = get_argument_value(args, kwargs, 1, "loop")
            if init_stack and loop is not None:
                stack.track_asyncio_loop(typing.cast(int, ddtrace_threading.current_thread().ident), loop)
                _call_init_asyncio(asyncio)

            return f(*args, **kwargs)
