# MIT License
#
# Copyright (c) 2024-2025 CNRS
# Copyright (c) 2025- pyannoteAI
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


import os
from pathlib import Path
from typing import Optional

import numpy as np
import safetensors.numpy
import scipy.interpolate
from pyannote.audio.utils.hf_hub import AssetFileName, download_from_hf_hub
from sklearn.isotonic import IsotonicRegression
from sklearn.utils.validation import NotFittedError, check_is_fitted


class Calibration(IsotonicRegression):
    """Logit/distance calibration"""

    def __init__(self):
        super().__init__(y_min=0.0, y_max=1.0, increasing="auto", out_of_bounds="clip")

    def safe_transform(
        self,
        values: np.ndarray,
        nan_value: float = 2.0,
    ) -> np.ndarray:
        """Apply calibration handling NaN values and any shape gracefully
        
        Parameters
        ----------
        values : np.ndarray
            Values to calibrate
        nan_value : float, optional
            Value to use in place of NaN values during calibration. Default is 2.0.

        Returns
        -------
        calibrated_values : np.ndarray
            Calibrated values
        """
        # temporarily replace NaN values with `nan_value` so `transform()` does not fail
        transformed = np.nan_to_num(values.reshape(-1), nan=nan_value)

        # apply calibration
        transformed: np.ndarray = self.transform(transformed)

        # recover original shape
        return transformed.reshape(values.shape)

    def save(self, path: str):
        """Save fitted calibration to disk

        Parameters
        ----------
        path : str
            Path to the file where the calibration should be saved

        Raises
        ------
        NotFittedError
            If the calibration is not fitted yet

        """
        try:
            check_is_fitted(self)
        except NotFittedError:
            raise NotFittedError("Cannot save an unfitted model.")

        tensor_dict = {
            "X_min_": self.X_min_,
            "X_max_": self.X_max_,
            "X_thresholds_": self.X_thresholds_,
            "y_thresholds_": self.y_thresholds_,
            "increasing_": self.increasing_,
        }

        safetensors.numpy.save_file(tensor_dict, path)

    @classmethod
    def from_tensor_dict(cls, tensor_dict: dict) -> "Calibration":
        """Load calibration from a dictionary of tensors

        Parameters
        ----------
        tensor_dict : dict
            Dictionary containing the calibration tensors

        Returns
        -------
        calibration : Calibration
            Fitted calibration
        """
        calibration = cls()

        for key, value in tensor_dict.items():
            setattr(calibration, key, value)

        calibration.f_ = scipy.interpolate.interp1d(
            np.hstack(
                [
                    [np.min(calibration.X_thresholds_) - 1.0],
                    calibration.X_thresholds_,
                    [np.max(calibration.X_thresholds_) + 1.0],
                ]
            ),
            np.hstack(
                [
                    [1.0 - calibration.increasing_],
                    calibration.y_thresholds_,
                    [1.0 * calibration.increasing_],
                ]
            ),
            kind="linear",
            bounds_error=False,
        )

        return calibration

    @classmethod
    def from_file(cls, path: str) -> "Calibration":
        """Load calibration from disk

        Parameters
        ----------
        path : str
            Path to the file where the calibration is saved

        Returns
        -------
        calibration : Calibration
            Fitted calibration
        """
        tensor_dict = safetensors.numpy.load_file(path)
        return cls.from_tensor_dict(tensor_dict)

    @classmethod
    def from_pretrained(
        cls,
        checkpoint: Path | str,
        subfolder: str | None = None,
        revision: str | None = None,
        token: str | bool | None = None,
        cache_dir: Path | str | None = None,
        **kwargs,
    ) -> Optional["Calibration"]:
        """Load calibration from disk or Huggingface Hub

        Parameters
        ----------
        checkpoint : Path or str
            Path to checkpoint or a model identifier from the hf.co model hub.
        subfolder : str, optional
            Folder inside the hf.co model repo.
        revision : str, optional
            Revision when loading from the huggingface.co model hub.
        token : str, optional
            When loading a private hf.co model, set `token`
            to True or to a string containing your hugginface.co authentication
            token that can be obtained by running `huggingface-cli login`
        cache_dir: Path or str, optional
            Path to model cache directory.
        """

        # if checkpoint is a directory, look for the calibration checkpoint
        # inside this directory (or inside a subfolder if specified)
        if os.path.isdir(checkpoint):
            if revision is not None:
                raise ValueError("Revisions cannot be used with local checkpoints.")

            if subfolder:
                path_to_calibration_checkpoint = (
                    Path(checkpoint) / subfolder / AssetFileName.Calibration.value
                )
            else:
                path_to_calibration_checkpoint = (
                    Path(checkpoint) / AssetFileName.Calibration.value
                )

        # if checkpoint is a file, use it as is
        elif os.path.isfile(checkpoint):
            if revision is not None:
                raise ValueError("Revisions cannot be used with local checkpoints.")

            path_to_calibration_checkpoint = checkpoint

        # otherwise, assume that the checkpoint is hosted on HF model hub
        else:
            checkpoint = str(checkpoint)
            if "@" in checkpoint:
                raise ValueError(
                    "Revisions must be passed with `revision` keyword argument."
                )

            path_to_calibration_checkpoint = download_from_hf_hub(
                checkpoint,
                AssetFileName.Calibration,
                subfolder=subfolder,
                revision=revision,
                cache_dir=cache_dir,
                token=token,
            )

            if path_to_calibration_checkpoint is None:
                return None

        return cls.from_file(path_to_calibration_checkpoint)
