# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, Tuple

from kornia.core import Module, Tensor
from kornia.metrics.average_meter import AverageMeter

# import yaml


class TrainerState(Enum):
    STARTING = 0
    TRAINING = 1
    VALIDATE = 2
    TERMINATE = 3


# NOTE: this class needs to be redefined according to the needed parameters.
@dataclass
class Configuration:
    data_path: str = field(default="./", metadata={"help": "The input data directory."})
    batch_size: int = field(default=1, metadata={"help": "The number of batches for the training dataloader."})
    num_epochs: int = field(default=1, metadata={"help": "The number of epochs to run the training."})
    lr: float = field(default=1e-3, metadata={"help": "The learning rate to be used for the optimize."})
    output_path: str = field(default="./output", metadata={"help": "The output data directory."})
    image_size: Tuple[int, int] = field(default=(224, 224), metadata={"help": "The input image size."})

    # TODO: possibly remove because hydra already do this
    # def __init__(self, **entries):
    #     for k, v in entries.items():
    #         self.__dict__[k] = Configuration(**v) if isinstance(v, dict) else v

    # @classmethod
    # def from_yaml(cls, config_file: str):
    #     """Create an instance of the configuration from a yaml file."""
    #     with open(config_file) as f:
    #         data = yaml.safe_load(f)
    #     return cls(**data)


class Lambda(Module):
    """Module to create a lambda function as Module.

    Args:
        fcn: a pointer to any function.

    Example:
        >>> import torch
        >>> import kornia as K
        >>> fcn = Lambda(lambda x: K.geometry.resize(x, (32, 16)))
        >>> fcn(torch.rand(1, 4, 64, 32)).shape
        torch.Size([1, 4, 32, 16])

    """

    def __init__(self, fcn: Callable[..., Any]) -> None:
        super().__init__()
        self.fcn = fcn

    def forward(self, x: Tensor) -> Any:
        return self.fcn(x)


class StatsTracker:
    """Stats tracker for computing metrics on the fly."""

    def __init__(self) -> None:
        self._stats: Dict[str, AverageMeter] = {}

    @property
    def stats(self) -> Dict[str, AverageMeter]:
        return self._stats

    def update(self, key: str, val: float, batch_size: int) -> None:
        """Update the stats by the key value pair."""
        if key not in self._stats:
            self._stats[key] = AverageMeter()
        self._stats[key].update(val, batch_size)

    def update_from_dict(self, dic: Dict[str, float], batch_size: int) -> None:
        """Update the stats by the dict."""
        for k, v in dic.items():
            self.update(k, v, batch_size)

    def __repr__(self) -> str:
        return " ".join([f"{k.upper()}: {v.val:.2f} {v.val:.2f} " for k, v in self._stats.items()])

    def as_dict(self) -> Dict[str, AverageMeter]:
        """Return the dict format."""
        return self._stats
