"""
An API to provide fork-safe functions.
"""

import functools
import logging
import os
import typing
import weakref

import wrapt

from ddtrace.internal import _unpatched


log = logging.getLogger(__name__)

# IMPORTANT: Do not change typing.List to list until minimum Python version is 3.11+
# Module-level list[...] in Python 3.10 affects import timing. See packages.py for details.
_registry: typing.List[typing.Callable[[], None]] = []  # noqa: UP006
_registry_before_fork: typing.List[typing.Callable[[], None]] = []  # noqa: UP006
_registry_after_parent: typing.List[typing.Callable[[], None]] = []  # noqa: UP006

# Some integrations might require after-fork hooks to be executed after the
# actual call to os.fork with earlier versions of Python (<= 3.6), else issues
# like SIGSEGV will occur. Setting this to True will cause the after-fork hooks
# to be executed after the actual fork, which seems to prevent the issue.
_soft = True


# Flag to determine, from the parent process, if fork has been called
_forked = False

# Flag to determine if the current process is a fork child
_fork_child = False


def set_forked():
    global _forked

    _forked = True


def has_forked():
    return _forked


def set_fork_child() -> None:
    global _fork_child

    _fork_child = True


def is_fork_child() -> bool:
    return _fork_child


def run_hooks(registry: list[typing.Callable[[], None]]) -> None:
    for hook in list(registry):
        try:
            hook()
        except Exception:
            # Mimic the behaviour of Python's fork hooks.
            log.exception("Exception ignored in forksafe hook %r", hook)


ddtrace_before_fork = functools.partial(run_hooks, _registry_before_fork)
ddtrace_after_in_child = functools.partial(run_hooks, _registry)
ddtrace_after_in_parent = functools.partial(run_hooks, _registry_after_parent)


def register_hook(registry, hook):
    registry.append(hook)
    return hook


register_before_fork = functools.partial(register_hook, _registry_before_fork)
register = functools.partial(register_hook, _registry)
register_after_parent = functools.partial(register_hook, _registry_after_parent)

register(set_fork_child)
register_after_parent(set_forked)


def unregister(after_in_child: typing.Callable[[], None]) -> None:
    try:
        _registry.remove(after_in_child)
    except ValueError:
        log.info("after_in_child hook %s was unregistered without first being registered", after_in_child.__name__)


def unregister_parent(after_in_parent: typing.Callable[[], None]) -> None:
    try:
        _registry_after_parent.remove(after_in_parent)
    except ValueError:
        log.info("after_in_parent hook %s was unregistered without first being registered", after_in_parent.__name__)


def unregister_before_fork(before_fork: typing.Callable[[], None]) -> None:
    try:
        _registry_before_fork.remove(before_fork)
    except ValueError:
        log.info("before_in_child hook %s was unregistered without first being registered", before_fork.__name__)


# Availability: Unix, not WASI, not Android, not iOS.
# Added in version 3.7.
if hasattr(os, "register_at_fork"):
    os.register_at_fork(
        before=ddtrace_before_fork, after_in_child=ddtrace_after_in_child, after_in_parent=ddtrace_after_in_parent
    )


_T = typing.TypeVar("_T")


class ResetObject(wrapt.ObjectProxy, typing.Generic[_T]):
    """An object wrapper object that is fork-safe and resets itself after a fork.

    When a Python process forks, a Lock can be in any state, locked or not, by any thread. Since after fork all threads
    are gone, Lock objects needs to be reset. CPython does this with an internal `threading._after_fork` function. We
    use the same mechanism here.

    """

    def __init__(
        self,
        wrapped_class: type[_T],
    ) -> None:
        super(ResetObject, self).__init__(wrapped_class())
        self._self_wrapped_class = wrapped_class
        _resetable_objects.add(self)

    def _reset_object(self) -> None:
        self.__wrapped__ = self._self_wrapped_class()


_resetable_objects: weakref.WeakSet[ResetObject] = weakref.WeakSet()


def _reset_objects() -> None:
    for obj in list(_resetable_objects):
        try:
            obj._reset_object()
        except Exception:
            log.exception("Exception ignored in object reset forksafe hook %r", obj)


register(_reset_objects)


def Event() -> _unpatched.threading_Event:
    return ResetObject(_unpatched.threading_Event)
