# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations

from typing import Any, Optional

from torch import nn

import kornia
from kornia.core import Module, tensor
from kornia.core.external import segmentation_models_pytorch as smp

from .base import SemanticSegmentation

__all__ = ["SegmentationModelsBuilder"]


class SegmentationModelsBuilder:
    @staticmethod
    def build(
        model_name: str = "Unet",
        encoder_name: str = "resnet34",
        encoder_weights: Optional[str] = "imagenet",
        in_channels: int = 3,
        classes: int = 1,
        activation: str = "softmax",
        **kwargs: Any,
    ) -> SemanticSegmentation:
        """SegmentationModel is a module that wraps a segmentation model.

        This module uses SegmentationModel library for segmentation.

        Args:
            model_name: Name of the model to use. Valid options are:
                "Unet", "UnetPlusPlus", "MAnet", "LinkNet", "FPN", "PSPNet", "PAN", "DeepLabV3", "DeepLabV3Plus".
            encoder_name: Name of the encoder to use.
            encoder_depth: Depth of the encoder.
            encoder_weights: Weights of the encoder.
            decoder_channels: Number of channels in the decoder.
            in_channels: Number of channels in the input.
            classes: Number of classes to predict.
            activation: Type of activation layer.
            **kwargs: Additional arguments to pass to the model. Detailed arguments can be found at:
                https://github.com/qubvel-org/segmentation_models.pytorch/tree/main/segmentation_models_pytorch/decoders

        Note:
            Only encoder weights are available.
            Pretrained weights for the whole model are not available.

        """
        preproc_params = smp.encoders.get_preprocessing_params(encoder_name)  # type: ignore
        preprocessor = SegmentationModelsBuilder.get_preprocessing_pipeline(preproc_params)
        segmentation_model = getattr(smp, model_name)(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,
            in_channels=in_channels,
            classes=classes,
            activation=activation,
            **kwargs,
        )

        return SemanticSegmentation(
            model=segmentation_model,
            pre_processor=preprocessor,
            post_processor=nn.Identity(),
            name=f"{model_name}_{encoder_name}",
        )

    @staticmethod
    def get_preprocessing_pipeline(preproc_params: dict[str, Any]) -> kornia.augmentation.container.ImageSequential:
        # Ensure the color space transformation is ONNX-friendly
        proc_sequence: list[Module] = []
        input_space = preproc_params["input_space"]
        if input_space == "BGR":
            proc_sequence.append(kornia.color.BgrToRgb())
        elif input_space == "RGB":
            pass
        else:
            raise ValueError(f"Unsupported input space: {input_space}")

        # Normalize input range if needed
        input_range = preproc_params["input_range"]
        if input_range[1] == 255:
            proc_sequence.append(kornia.enhance.Normalize(mean=0.0, std=1 / 255.0))
        elif input_range[1] == 1:
            pass
        else:
            raise ValueError(f"Unsupported input range: {input_range}")

        # Handle mean and std normalization
        if preproc_params["mean"] is not None:
            mean = tensor([preproc_params["mean"]])
        else:
            mean = tensor(0.0)

        if preproc_params["std"] is not None:
            std = tensor([preproc_params["std"]])
        else:
            std = tensor(1.0)
        proc_sequence.append(kornia.enhance.Normalize(mean=mean, std=std))

        return kornia.augmentation.container.ImageSequential(*proc_sequence)
