import json
import operator
import os
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union

import torch
from lightning_utilities.core.imports import compare_version

from lightning.fabric.utilities.types import _PATH

if TYPE_CHECKING:
    from lightning.fabric.fabric import Fabric

_TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0")


class SpikeDetection:
    """Spike Detection Callback.

    Terminates training with a ``TrainingSpikeException`` when a loss-spike was detected and
    saves the batches to skip when resuming to a file.

    We skip the current and the previous batch since it is unclear whether the previous batch
    altered the weights in a way that it causes the spike or just the current batch is corrupted somehow.

    Args:
        mode: Whether to minimize or maximize the tracked metric
        window: A running mean of metrics with ``window`` size. Serves as reference value for spikes.
        warmup: After how many batches spike-tracking should start
        atol: An absolute tolerance.  Every diff between the running mean and the current value,
            that's not an improvement and above ``atol`` will be considered a spike
        rtol: A relative tolerance. Every diff between the running mean and the current value,
            that's higher than ``rtol * running_mean`` is considered a spike
        exclude_batches_path: Where to save the file that contains the batches to exclude.
            Will default to current directory.
        finite_only: If set to ``False``, consider non-finite values like NaN, inf and -inf a spike as well.

    """

    def __init__(
        self,
        mode: Literal["min", "max"] = "min",
        window: int = 10,
        warmup: int = 1,
        atol: Optional[float] = None,
        rtol: Optional[float] = 2.0,
        exclude_batches_path: Optional[_PATH] = None,
        finite_only: bool = True,
    ):
        if _TORCHMETRICS_GREATER_EQUAL_1_0_0:
            from torchmetrics.aggregation import MeanMetric
            from torchmetrics.wrappers import Running
        else:
            raise RuntimeError("SpikeDetection requires torchmetrics>=1.0.0! Please upgrade your version!")
        super().__init__()

        self.last_val: Union[torch.Tensor, float] = 0.0
        # spike detection happens individually on each machine
        self.running_mean = Running(MeanMetric(dist_sync_on_step=False, sync_on_compute=False), window=window)
        # workaround for https://github.com/Lightning-AI/torchmetrics/issues/1899
        self.running_mean.dist_sync_on_step = False
        self.running_mean.sync_on_compute = False

        self.mode = mode
        self.warmup = warmup
        self.atol = atol
        self.rtol = rtol
        self.bad_batches: List[int] = []
        self.exclude_batches_path = exclude_batches_path
        self.finite_only = finite_only

    @torch.no_grad()
    def on_train_batch_end(self, fabric: "Fabric", loss: torch.Tensor, batch: Any, batch_idx: int) -> None:
        """Checks if we currently have a loss-spike."""
        if batch_idx == 0:
            self.running_mean.to(fabric.strategy.root_device)

        if self.exclude_batches_path is None:
            self.exclude_batches_path = os.getcwd()

        if not str(self.exclude_batches_path).endswith(".json"):
            self.exclude_batches_path = os.path.join(self.exclude_batches_path, "skip_batches.json")

        is_spike = bool(batch_idx >= self.warmup and self._is_spike(loss))
        fabric.strategy.barrier()

        # While spike-detection happens on a per-rank level, we need to fail all ranks if any rank detected a spike
        is_spike_global = fabric.strategy.reduce_boolean_decision(is_spike, all=False)

        if is_spike_global:
            self._handle_spike(fabric, batch_idx)
        else:
            is_finite_all = self.finite_only or fabric.strategy.reduce_boolean_decision(
                bool(torch.isfinite(loss).all()), all=True
            )
            if is_finite_all:
                self._update_stats(loss)

    def _is_spike(self, loss: torch.Tensor) -> bool:
        # we might call compute more often than update which is fine as long as the
        # metric has at least one internal value.
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            running_val = self.running_mean.compute()
        curr_diff = loss - self.last_val

        if self.finite_only and not torch.isfinite(loss):
            return True

        if self._is_better(curr_diff):
            return False

        return self._check_atol(loss, running_val) and self._check_rtol(loss, running_val)

    def _handle_spike(self, fabric: "Fabric", batch_idx: int) -> None:
        # Exclude current and last batch
        # Current batch is excluded since it could be that the data of this batch produces a high loss
        # Last batch is excluded since the previous batch could have "corrupted" the weights
        self.bad_batches.extend([batch_idx - 1, batch_idx])

        if fabric.global_rank == 0:
            assert self.exclude_batches_path is not None
            os.makedirs(os.path.dirname(self.exclude_batches_path), exist_ok=True)

            with open(self.exclude_batches_path, "w") as f:
                json.dump(self.bad_batches, f, indent=4)

        raise TrainingSpikeException(batch_idx=batch_idx)

    def _check_atol(self, val_a: Union[float, torch.Tensor], val_b: Union[float, torch.Tensor]) -> bool:
        return (self.atol is None) or bool(abs(val_a - val_b) >= abs(self.atol))

    def _check_rtol(self, val_a: Union[float, torch.Tensor], val_b: Union[float, torch.Tensor]) -> bool:
        return (self.rtol is None) or bool(abs(val_a - val_b) >= abs(self.rtol * val_b))

    def _is_better(self, diff_val: torch.Tensor) -> bool:
        if self.mode == "min":
            return bool((diff_val <= 0.0).all())
        if self.mode == "max":
            return bool((diff_val >= 0).all())

        raise ValueError(f"Invalid mode. Has to be min or max, found {self.mode}")

    def _update_stats(self, val: torch.Tensor) -> None:
        # only update if finite
        self.running_mean.update(val)
        self.last_val = val

    def state_dict(self) -> Dict[str, Any]:
        return {
            "last_val": self.last_val.item() if isinstance(self.last_val, torch.Tensor) else self.last_val,
            "mode": self.mode,
            "warmup": self.warmup,
            "atol": self.atol,
            "rtol": self.rtol,
            "bad_batches": self.bad_batches,
            "bad_batches_path": self.exclude_batches_path,
            "running": self.running_mean.state_dict(),
            "mean": self.running_mean.base_metric.state_dict(),
        }

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        self.last_val = state_dict.pop("last_val")
        self.mode = state_dict.pop("mode")
        self.warmup = state_dict.pop("warmup")
        self.atol = state_dict.pop("atol")
        self.rtol = state_dict.pop("rtol")
        self.bad_batches = state_dict.pop("bad_batches")
        self.exclude_batches_path = state_dict.pop("bad_batches_path")
        self.running.load_state_dict(state_dict.pop("running"))
        self.running_mean.base_metric.load_state_dict(state_dict.pop("mean"))


class TrainingSpikeException(RuntimeError):
    """Exception to be raised with Training Spikes."""

    def __init__(self, batch_idx: int, *args: Any, **kwargs: Any):
        super().__init__(f"Training spike detected in batch {batch_idx}", *args, **kwargs)
