# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from math import inf
from pathlib import Path
from typing import Callable, Dict, Optional, Union

import torch

from kornia.core import Module
from kornia.metrics import AverageMeter

from .utils import TrainerState


def default_filename_fcn(epoch: Union[str, int], metric: Union[str, float]) -> str:
    """Generate the filename in the model checkpoint."""
    return f"model_epoch={epoch}_metricValue={metric}.pt"


class EarlyStopping:
    """Callback that evaluates whether there is improvement in the loss function.

    The module track the losses and in case of finish patience sends a termination signal to the trainer.

    Args:
        monitor: the name of the value to track.
        min_delta: the minimum difference between losses to increase the patience counter.
        patience: the number of times to wait until the trainer does not terminate.
        max_mode: if true metric will be multiply by -1,
                  turn this flag when increasing metric value is expected for example Accuracy

    **Usage example:**

    .. code:: python

        early_stop = EarlyStopping(
            monitor="loss", patience=10
        )

        trainer = ImageClassifierTrainer(
            callbacks={"on_epoch_end", early_stop}
        )

    """

    def __init__(
        self,
        monitor: str,
        min_delta: float = 0.0,
        patience: int = 8,
        max_mode: bool = False,
    ) -> None:
        self.monitor = monitor
        self.min_delta = min_delta
        self.patience = patience
        # flag to reverse metric, for example in case of accuracy metric where bigger value is better
        # In classical loss functions smaller value = better,
        # in case of max_mode training end with metric stable/decreasing
        self.max_mode = max_mode

        self.counter: int = 0
        self.best_score: float = -inf if max_mode else inf
        self.early_stop: bool = False

    def __call__(self, model: Module, epoch: int, valid_metric: Dict[str, AverageMeter]) -> TrainerState:
        score: float = valid_metric[self.monitor].avg
        is_best: bool = score > self.best_score if self.max_mode else score < self.best_score
        if is_best:
            self.best_score = score
            self.counter = 0
        else:
            # Example score = 1.9 best_score = 2.0 min_delta = 0.15
            # with max_mode (1.9 > (2.0 - 0.15)) == True
            # with min_mode (1.9 < (2.0 + 0.15)) == True
            is_within_delta: bool = (
                score > (self.best_score - self.min_delta)
                if self.max_mode
                else score < (self.best_score + self.min_delta)
            )
            if not is_within_delta:
                self.counter += 1
                if self.counter >= self.patience:
                    self.early_stop = True

        if self.early_stop:
            print(f"[INFO] Early-Stopping the training process. Epoch: {epoch}.")
            return TrainerState.TERMINATE

        return TrainerState.TRAINING


class ModelCheckpoint:
    """Callback that save the model at the end of every epoch.

    Args:
        filepath: the where to save the mode.
        monitor: the name of the value to track.
        max_mode: if true metric will be multiply by -1
                  turn this flag when increasing metric value is expected for example Accuracy
    **Usage example:**

    .. code:: python

        model_checkpoint = ModelCheckpoint(
            filepath="./outputs", monitor="loss",
        )

        trainer = ImageClassifierTrainer(...,
            callbacks={"on_checkpoint", model_checkpoint}
        )

    """

    def __init__(
        self,
        filepath: str,
        monitor: str,
        filename_fcn: Optional[Callable[..., str]] = None,
        max_mode: bool = False,
    ) -> None:
        self.filepath = filepath
        self.monitor = monitor
        self._filename_fcn = filename_fcn or default_filename_fcn
        # track best model
        self.best_metric: float = -inf if max_mode else inf
        # flag to reverse metric, for example in case of accuracy metric where bigger value is better
        # In classical loss functions smaller value = better,
        # In case of max_mode checkpoints are saved if new metric value > old metric value
        self.max_mode = max_mode

        # create directory
        Path(self.filepath).mkdir(parents=True, exist_ok=True)

    def __call__(self, model: Module, epoch: int, valid_metric: Dict[str, AverageMeter]) -> None:
        valid_metric_value: float = valid_metric[self.monitor].avg
        is_best: bool = (
            valid_metric_value > self.best_metric if self.max_mode else valid_metric_value < self.best_metric
        )
        if is_best:
            self.best_metric = valid_metric_value
            # store old metric and save new model
            filename = Path(self.filepath) / self._filename_fcn(epoch, valid_metric_value)
            torch.save(model, filename)
