import os
import random

import lightning.fabric as lightning_fabric
import numpy as np
import torch
import wandb
from absl import logging

from tts.core import constants
from tts.utils import configuration, custom_logging

_BEST_CHECKPOINT_NAME = "best_model.pt"


def _capture_rng_state() -> dict:
    """Captures the full RNG state so exact resume is possible without fast-forward."""
    state = {
        "torch_rng": torch.random.get_rng_state(),
        "python_rng": random.getstate(),
        "numpy_rng": np.random.get_state(),
    }
    if torch.cuda.is_available():
        state["cuda_rng"] = torch.cuda.get_rng_state_all()
    return state


def _restore_rng_state(rng_state: dict) -> None:
    """Restores RNG state captured by _capture_rng_state."""
    torch.random.set_rng_state(rng_state["torch_rng"])
    random.setstate(rng_state["python_rng"])
    np.random.set_state(rng_state["numpy_rng"])
    if "cuda_rng" in rng_state and torch.cuda.is_available():
        torch.cuda.set_rng_state_all(rng_state["cuda_rng"])


def load_from_checkpoint(
    fabric: lightning_fabric.Fabric,
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    checkpoint_file_to_resume_from: str,
    load_full_checkpoint: bool = True,
) -> tuple[torch.nn.Module, custom_logging.Statistics | None, torch.optim.Optimizer, dict | None]:
    """Loads checkpoint and updates model state. Returns RNG state if available."""
    checkpoint = {"model": model}
    if load_full_checkpoint:
        checkpoint.update({"optimizer": optimizer, "loss_statistics": {}, "rng_state": {}})

    fabric.load(checkpoint_file_to_resume_from, checkpoint, strict=True)
    statistics = None
    rng_state = None
    if load_full_checkpoint:
        statistics = custom_logging.Statistics.from_dict(checkpoint["loss_statistics"])
        rng_state = checkpoint.get("rng_state")
        if rng_state and isinstance(rng_state, dict) and "torch_rng" in rng_state:
            logging.info("RNG state found in checkpoint; will restore for exact resume.")
        else:
            rng_state = None

    return model, statistics, optimizer, rng_state


def save_to_checkpoint(
    fabric: lightning_fabric.Fabric,
    model: torch.nn.Module,
    config: configuration.ExperimentConfig,
    optimizer: torch.optim.Optimizer,
    statistics: custom_logging.Statistics,
    checkpoint_name: str | None = None,
) -> str:
    """Saves the model, optimizer, RNG state, and training stats to a checkpoint."""
    checkpoint_name = checkpoint_name or f"checkpoint_{statistics.step}.pt"
    checkpoint_file = os.path.join(config.checkpointing.directory, checkpoint_name)

    checkpoint = {
        "model": model,
        "loss_statistics": statistics.as_dict(),
        "optimizer": optimizer,
        "config": config.to_dict(),
        "rng_state": _capture_rng_state(),
    }
    fabric.save(path=checkpoint_file, state=checkpoint)

    if fabric.is_global_zero:
        keep_only_last_n_checkpoints = config.checkpointing.keep_only_last_n_checkpoints
        if keep_only_last_n_checkpoints is not None:
            checkpoint_files = [
                f
                for f in os.listdir(config.checkpointing.directory)
                if f.startswith("checkpoint_") and f.endswith(".pt")
                and f != _BEST_CHECKPOINT_NAME
            ]
            checkpoint_files.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
            for f in checkpoint_files[:-keep_only_last_n_checkpoints]:
                logging.info("Removing too old checkpoint %s...", f)
                os.remove(os.path.join(config.checkpointing.directory, f))

    return checkpoint_file


def maybe_save_best_checkpoint(
    fabric: lightning_fabric.Fabric,
    model: torch.nn.Module,
    config: configuration.ExperimentConfig,
    optimizer: torch.optim.Optimizer,
    statistics: custom_logging.Statistics,
    eval_loss: float,
    best_eval_loss: float,
) -> float:
    """Saves a 'best_model.pt' checkpoint when eval_loss improves."""
    if eval_loss < best_eval_loss:
        logging.info(
            "New best eval loss %.5f < %.5f, saving best checkpoint.",
            eval_loss,
            best_eval_loss,
        )
        save_to_checkpoint(
            fabric, model, config, optimizer, statistics,
            checkpoint_name=_BEST_CHECKPOINT_NAME,
        )
        return eval_loss
    return best_eval_loss


def save_config(
    experiment_config: configuration.ExperimentConfig,
    checkpoint_dir: str,
    use_wandb: bool,
):
    """Saves model config to a file."""
    config_file = os.path.join(checkpoint_dir, constants.CONFIG_FILE_NAME)
    with open(config_file, "w") as f:
        f.write(str(experiment_config))

    # Config might be with new values after first initialized, to ensure
    # consistency, the config here should be updated with wandb.
    #
    # TODO: sweep run set value can be overriden by python's training code
    #       leading to config here and one shown in the W&B UI being different.
    if use_wandb:
        wandb.config.update(experiment_config.to_dict(), allow_val_change=True)
