import logging
import os
import re
import signal
import threading
from subprocess import call
from types import FrameType
from typing import Any, Callable, Union

import torch
import torch.distributed as dist

import pytorch_lightning as pl
from lightning_fabric.plugins.environments import SLURMEnvironment
from lightning_fabric.utilities.imports import _IS_WINDOWS
from pytorch_lightning.utilities.rank_zero import rank_prefixed_message, rank_zero_info

# copied from signal.pyi
_SIGNUM = Union[int, signal.Signals]
_HANDLER = Union[Callable[[_SIGNUM, FrameType], Any], int, signal.Handlers, None]

log = logging.getLogger(__name__)


class _HandlersCompose:
    def __init__(self, signal_handlers: Union[list[_HANDLER], _HANDLER]) -> None:
        if not isinstance(signal_handlers, list):
            signal_handlers = [signal_handlers]
        self.signal_handlers = signal_handlers

    def __call__(self, signum: _SIGNUM, frame: FrameType) -> None:
        for signal_handler in self.signal_handlers:
            if isinstance(signal_handler, int):
                signal_handler = signal.getsignal(signal_handler)
            if callable(signal_handler):
                signal_handler(signum, frame)


class _SignalConnector:
    def __init__(self, trainer: "pl.Trainer") -> None:
        self.received_sigterm = False
        self.trainer = trainer
        self._original_handlers: dict[_SIGNUM, _HANDLER] = {}

    def register_signal_handlers(self) -> None:
        self.received_sigterm = False
        self._original_handlers = self._get_current_signal_handlers()

        sigusr_handlers: list[_HANDLER] = []
        sigterm_handlers: list[_HANDLER] = [self._sigterm_notifier_fn]

        environment = self.trainer._accelerator_connector.cluster_environment
        if isinstance(environment, SLURMEnvironment) and environment.auto_requeue:
            log.info("SLURM auto-requeueing enabled. Setting signal handlers.")
            sigusr_handlers.append(self._slurm_sigusr_handler_fn)
            sigterm_handlers.append(self._sigterm_handler_fn)

        # Windows seems to have signal incompatibilities
        if not _IS_WINDOWS:
            sigusr = environment.requeue_signal if isinstance(environment, SLURMEnvironment) else signal.SIGUSR1
            assert sigusr is not None
            if sigusr_handlers and not self._has_already_handler(sigusr):
                self._register_signal(sigusr, _HandlersCompose(sigusr_handlers))

            # we have our own handler, but include existing ones too
            if self._has_already_handler(signal.SIGTERM):
                sigterm_handlers.append(signal.getsignal(signal.SIGTERM))
            self._register_signal(signal.SIGTERM, _HandlersCompose(sigterm_handlers))

    def _slurm_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:
        rank_zero_info(f"Handling auto-requeue signal: {signum}")

        # save logger to make sure we get all the metrics
        for logger in self.trainer.loggers:
            logger.finalize("finished")

        hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(self.trainer.default_root_dir)
        self.trainer.save_checkpoint(hpc_save_path)

        if self.trainer.is_global_zero:
            # find job id
            array_job_id = os.getenv("SLURM_ARRAY_JOB_ID")
            if array_job_id is not None:
                array_task_id = os.environ["SLURM_ARRAY_TASK_ID"]
                job_id = f"{array_job_id}_{array_task_id}"
            else:
                job_id = os.environ["SLURM_JOB_ID"]

            assert re.match("[0-9_-]+", job_id)
            cmd = ["scontrol", "requeue", job_id]

            # requeue job
            log.info(f"requeing job {job_id}...")
            try:
                result = call(cmd)
            except FileNotFoundError:
                # This can occur if a subprocess call to `scontrol` is run outside a shell context
                # Re-attempt call (now with shell context). If any error is raised, propagate to user.
                # When running a shell command, it should be passed as a single string.
                result = call(" ".join(cmd), shell=True)

            # print result text
            if result == 0:
                log.info(f"Requeued SLURM job: {job_id}")
            else:
                log.warning(f"Requeuing SLURM job {job_id} failed with error code {result}")

    def _sigterm_notifier_fn(self, signum: _SIGNUM, _: FrameType) -> None:
        log.info(rank_prefixed_message(f"Received SIGTERM: {signum}", self.trainer.local_rank))
        if not self.received_sigterm:
            launcher = self.trainer.strategy.launcher
            if launcher is not None:
                launcher.kill(signum)

        # New broadcast logic
        if dist.is_available() and dist.is_initialized() and self.trainer.world_size > 1:
            sigterm_tensor = torch.tensor([1], device=self.trainer.strategy.root_device)
            dist.broadcast(sigterm_tensor, src=0)

        self.received_sigterm = True

    def _sigterm_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:
        log.info(f"Bypassing SIGTERM: {signum}")

    def teardown(self) -> None:
        """Restores the signals that were previously configured before :class:`_SignalConnector` replaced them."""
        for signum, handler in self._original_handlers.items():
            if handler is not None:
                self._register_signal(signum, handler)
        self._original_handlers = {}

    @staticmethod
    def _get_current_signal_handlers() -> dict[_SIGNUM, _HANDLER]:
        """Collects the currently assigned signal handlers."""
        valid_signals = _SignalConnector._valid_signals()
        if not _IS_WINDOWS:
            # SIGKILL and SIGSTOP are not allowed to be modified by the user
            valid_signals -= {signal.SIGKILL, signal.SIGSTOP}
        return {signum: signal.getsignal(signum) for signum in valid_signals}

    @staticmethod
    def _valid_signals() -> set[signal.Signals]:
        """Returns all valid signals supported on the current platform."""
        return signal.valid_signals()

    @staticmethod
    def _has_already_handler(signum: _SIGNUM) -> bool:
        return signal.getsignal(signum) not in (None, signal.SIG_DFL)

    @staticmethod
    def _register_signal(signum: _SIGNUM, handlers: _HANDLER) -> None:
        if threading.current_thread() is threading.main_thread():
            signal.signal(signum, handlers)  # type: ignore[arg-type]

    def __getstate__(self) -> dict:
        state = self.__dict__.copy()
        state["_original_handlers"] = {}
        return state


def _get_sigkill_signal() -> _SIGNUM:
    return signal.SIGTERM if _IS_WINDOWS else signal.SIGKILL
