import torch
from typing import Optional
from torch import Tensor

from ..core.transforms_interface import BaseWaveformTransform
from ..utils.object_dict import ObjectDict


class Padding(BaseWaveformTransform):
    supported_modes = {"per_batch", "per_example", "per_channel"}
    supports_multichannel = True
    requires_sample_rate = False

    supports_target = True
    requires_target = False

    def __init__(
        self,
        min_fraction=0.1,
        max_fraction=0.5,
        pad_section="end",
        mode="per_batch",
        p=0.5,
        p_mode: Optional[str] = None,
        sample_rate: Optional[int] = None,
        target_rate: Optional[int] = None,
        output_type: Optional[str] = None,
    ):
        super().__init__(
            mode=mode,
            p=p,
            p_mode=p_mode,
            sample_rate=sample_rate,
            target_rate=target_rate,
            output_type=output_type,
        )
        self.min_fraction = min_fraction
        self.max_fraction = max_fraction
        self.pad_section = pad_section
        if not self.min_fraction >= 0.0:
            raise ValueError("minimum fraction should be greater than zero.")
        if self.min_fraction > self.max_fraction:
            raise ValueError(
                "minimum fraction should be less than or equal to maximum fraction."
            )
        assert self.pad_section in (
            "start",
            "end",
        ), 'pad_section must be "start" or "end"'

    def randomize_parameters(
        self,
        samples: Tensor = None,
        sample_rate: Optional[int] = None,
        targets: Optional[Tensor] = None,
        target_rate: Optional[int] = None,
    ):
        input_length = samples.shape[-1]
        self.transform_parameters["pad_length"] = torch.randint(
            int(input_length * self.min_fraction),
            int(input_length * self.max_fraction),
            (samples.shape[0],),
        )

    def apply_transform(
        self,
        samples: Tensor,
        sample_rate: Optional[int] = None,
        targets: Optional[int] = None,
        target_rate: Optional[int] = None,
    ) -> ObjectDict:
        for i, index in enumerate(self.transform_parameters["pad_length"]):
            if self.pad_section == "start":
                samples[i, :, :index] = 0.0
            else:
                samples[i, :, -index:] = 0.0

        return ObjectDict(
            samples=samples,
            sample_rate=sample_rate,
            targets=targets,
            target_rate=target_rate,
        )
