# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import datetime
import logging
import os
from typing import List, Optional, Union

from kornia.core import Module, Tensor, stack
from kornia.core.external import PILImage as Image
from kornia.core.external import numpy as np
from kornia.core.mixin.onnx import ONNXExportMixin
from kornia.io import write_image
from kornia.utils.image import tensor_to_image

logger = logging.getLogger(__name__)


class ModelBaseMixin:
    name: str = "model"

    def _tensor_to_type(
        self, output: Union[Tensor, List[Tensor]], output_type: str, is_batch: bool = False
    ) -> Union[Tensor, List[Tensor], List["Image.Image"]]:  # type: ignore
        """Convert the output tensor to the desired type.

        Args:
            output: The output tensor or list of tensors.
            output_type: The desired output type. Accepted values are "torch" and "pil".
            is_batch: If True, the output is expected to be a batch of tensors.

        Returns:
            The converted output tensor or list of tensors.

        Raises:
            RuntimeError: If the output type is not supported.

        """
        if output_type == "torch":
            if is_batch and not isinstance(output, Tensor):
                return stack(output)
            elif is_batch and isinstance(output, Tensor):
                return output
            elif not is_batch and isinstance(output, Tensor):
                return list(output)
            elif not is_batch and not isinstance(output, Tensor):
                return output
            return output
        elif output_type == "pil":
            out = [Image.fromarray((tensor_to_image(out_img) * 255).astype(np.uint8)) for out_img in output]  # type: ignore
            return list(out)

        raise RuntimeError(f"Unsupported output type `{output_type}`.")

    def _save_outputs(
        self, outputs: Union[Tensor, List[Tensor]], directory: Optional[str] = None, suffix: str = ""
    ) -> None:
        """Save the output image(s) to a directory.

        Args:
            outputs: output tensor.
            directory: directory to save the images.
            suffix: filename suffix.

        """
        if directory is None:
            name = f"{self.name}_{datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y%m%d%H%M%S')!s}"
            directory = os.path.join("kornia_outputs", name)

        os.makedirs(directory, exist_ok=True)
        for i, out_image in enumerate(outputs):
            write_image(
                os.path.join(directory, f"{str(i).zfill(6)}{suffix}.jpg"),
                out_image.mul(255.0).byte(),
            )
        logger.info(f"Outputs are saved in {directory}")


class ModelBase(Module, ONNXExportMixin, ModelBaseMixin):
    """Wrap a model and perform pre-processing and post-processing."""

    def __init__(
        self, model: Module, pre_processor: Module, post_processor: Module, name: Optional[str] = None
    ) -> None:
        """Construct an Object Detector object.

        Args:
            model: an object detection model.
            pre_processor: a pre-processing module
            post_processor: a post-processing module.
            name: name of a model.

        """
        super().__init__()
        self.model = model.eval()
        self.pre_processor = pre_processor.eval()
        self.post_processor = post_processor.eval()
        if name is not None:
            self.name = name
