# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Optional

import einops
import hydra
import torch
from lightning.pytorch import Trainer
from omegaconf import DictConfig

from nemo.collections.audio.models.audio_to_audio import AudioToAudioModel
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.core.neural_types import AudioSignal, LengthsType, LossType, NeuralType
from nemo.utils import logging

__all__ = [
    'EncMaskDecAudioToAudioModel',
    'ScoreBasedGenerativeAudioToAudioModel',
    'PredictiveAudioToAudioModel',
    'SchroedingerBridgeAudioToAudioModel',
    'FlowMatchingAudioToAudioModel',
]


class EncMaskDecAudioToAudioModel(AudioToAudioModel):
    """Class for encoder-mask-decoder audio processing models.

    The model consists of the following blocks:
        - encoder: transforms input multi-channel audio signal into an encoded representation (analysis transform)
        - mask_estimator: estimates a mask used by signal processor
        - mask_processor: mask-based signal processor, combines the encoded input and the estimated mask
        - decoder: transforms processor output into the time domain (synthesis transform)
    """

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0
        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.world_size

        super().__init__(cfg=cfg, trainer=trainer)
        self.sample_rate = self._cfg.sample_rate

        # Setup processing modules
        self.encoder = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.encoder)
        self.mask_estimator = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mask_estimator)
        self.mask_processor = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mask_processor)
        self.decoder = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.decoder)

        if 'mixture_consistency' in self._cfg:
            logging.debug('Using mixture consistency')
            self.mixture_consistency = EncMaskDecAudioToAudioModel.from_config_dict(self._cfg.mixture_consistency)
        else:
            logging.debug('Mixture consistency not used')
            self.mixture_consistency = None

        # Setup augmentation
        if hasattr(self.cfg, 'channel_augment') and self.cfg.channel_augment is not None:
            logging.debug('Using channel augmentation')
            self.channel_augmentation = EncMaskDecAudioToAudioModel.from_config_dict(self.cfg.channel_augment)
        else:
            logging.debug('Channel augmentation not used')
            self.channel_augmentation = None

        # Setup optional Optimization flags
        self.setup_optimization_flags()

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        return {
            "input_signal": NeuralType(
                ('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)
            ),  # multi-channel format, channel dimension can be 1 for single-channel audio
            "input_length": NeuralType(tuple('B'), LengthsType(), optional=True),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        return {
            "output_signal": NeuralType(
                ('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)
            ),  # multi-channel format, channel dimension can be 1 for single-channel audio
            "output_length": NeuralType(tuple('B'), LengthsType(), optional=True),
        }

    @typecheck()
    def forward(self, input_signal, input_length=None):
        """
        Forward pass of the model.

        Args:
            input_signal: Tensor that represents a batch of raw audio signals,
                of shape [B, T] or [B, T, C]. T here represents timesteps, with 1 second of audio represented as
                `self.sample_rate` number of floating point values.
            input_signal_length: Vector of length B, that contains the individual lengths of the audio
                sequences.

        Returns:
            Output signal `output` in the time domain and the length of the output signal `output_length`.
        """
        batch_length = input_signal.size(-1)

        # Encoder
        encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length)

        # Mask estimator
        mask, _ = self.mask_estimator(input=encoded, input_length=encoded_length)

        # Mask-based processor in the encoded domain
        processed, processed_length = self.mask_processor(input=encoded, input_length=encoded_length, mask=mask)

        # Mixture consistency
        if self.mixture_consistency is not None:
            processed = self.mixture_consistency(mixture=encoded, estimate=processed)

        # Decoder
        processed, processed_length = self.decoder(input=processed, input_length=processed_length)

        # Trim or pad the estimated signal to match input length
        processed = self.match_batch_length(input=processed, batch_length=batch_length)
        return processed, processed_length

    # PTL-specific methods
    def training_step(self, batch, batch_idx):

        if isinstance(batch, dict):
            # lhotse batches are dictionaries
            input_signal = batch['input_signal']
            input_length = batch['input_length']
            target_signal = batch['target_signal']
        else:
            input_signal, input_length, target_signal, _ = batch

        # For consistency, the model uses multi-channel format, even if the channel dimension is 1
        if input_signal.ndim == 2:
            input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
        if target_signal.ndim == 2:
            target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')

        # Apply channel augmentation
        if self.training and self.channel_augmentation is not None:
            input_signal = self.channel_augmentation(input=input_signal)

        # Process input
        processed_signal, _ = self.forward(input_signal=input_signal, input_length=input_length)

        # Calculate the loss
        loss = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length)

        # Logs
        self.log('train_loss', loss)
        self.log('learning_rate', self._optimizer.param_groups[0]['lr'])
        self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))

        # Return loss
        return loss

    def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):

        if isinstance(batch, dict):
            # lhotse batches are dictionaries
            input_signal = batch['input_signal']
            input_length = batch['input_length']
            target_signal = batch['target_signal']
        else:
            input_signal, input_length, target_signal, _ = batch

        # For consistency, the model uses multi-channel format, even if the channel dimension is 1
        if input_signal.ndim == 2:
            input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
        if target_signal.ndim == 2:
            target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')

        # Process input
        processed_signal, _ = self.forward(input_signal=input_signal, input_length=input_length)

        # Calculate the loss
        loss = self.loss(estimate=processed_signal, target=target_signal, input_length=input_length)

        # Update metrics
        if hasattr(self, 'metrics') and tag in self.metrics:
            # Update metrics for this (tag, dataloader_idx)
            for name, metric in self.metrics[tag][dataloader_idx].items():
                metric.update(preds=processed_signal, target=target_signal, input_length=input_length)

        # Log global step
        self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))

        # Return loss
        return {f'{tag}_loss': loss}

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        """
        results = []

        return results


class PredictiveAudioToAudioModel(AudioToAudioModel):
    """This models aims to directly estimate the coefficients
    in the encoded domain by applying a neural model.
    """

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        super().__init__(cfg=cfg, trainer=trainer)
        self.sample_rate = self._cfg.sample_rate

        # Setup processing modules
        self.encoder = self.from_config_dict(self._cfg.encoder)
        self.decoder = self.from_config_dict(self._cfg.decoder)

        # Neural estimator
        self.estimator = self.from_config_dict(self._cfg.estimator)

        # Normalization
        self.normalize_input = self._cfg.get('normalize_input', False)

        # Term added to the denominator to improve numerical stability
        self.eps = self._cfg.get('eps', 1e-8)

        # Setup optional Optimization flags
        self.setup_optimization_flags()

        logging.debug('Initialized %s', self.__class__.__name__)
        logging.debug('\tnormalize_input: %s', self.normalize_input)
        logging.debug('\teps:             %s', self.eps)

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        return {
            "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
            "input_length": NeuralType(tuple('B'), LengthsType(), optional=True),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        return {
            "output_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
            "output_length": NeuralType(tuple('B'), LengthsType(), optional=True),
        }

    @typecheck()
    def forward(self, input_signal, input_length=None):
        """Forward pass of the model.

        Args:
            input_signal: time-domain signal
            input_length: valid length of each example in the batch

        Returns:
            Output signal `output` in the time domain and the length of the output signal `output_length`.
        """
        batch_length = input_signal.size(-1)

        if self.normalize_input:
            # max for each example in the batch
            norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True)
            # scale input signal
            input_signal = input_signal / (norm_scale + self.eps)

        # Encoder
        encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length)

        # Backbone
        estimated, estimated_length = self.estimator(input=encoded, input_length=encoded_length)

        # Decoder
        output, output_length = self.decoder(input=estimated, input_length=estimated_length)

        if self.normalize_input:
            # rescale to the original scale
            output = output * norm_scale

        # Trim or pad the estimated signal to match input length
        output = self.match_batch_length(input=output, batch_length=batch_length)
        return output, output_length

    # PTL-specific methods
    def training_step(self, batch, batch_idx):

        if isinstance(batch, dict):
            # lhotse batches are dictionaries
            input_signal = batch['input_signal']
            input_length = batch['input_length']
            target_signal = batch['target_signal']
        else:
            input_signal, input_length, target_signal, _ = batch

        # For consistency, the model uses multi-channel format, even if the channel dimension is 1
        if input_signal.ndim == 2:
            input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
        if target_signal.ndim == 2:
            target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')

        # Estimate the signal
        output_signal, _ = self.forward(input_signal=input_signal, input_length=input_length)

        # Calculate the loss
        loss = self.loss(estimate=output_signal, target=target_signal, input_length=input_length)

        # Logs
        self.log('train_loss', loss)
        self.log('learning_rate', self._optimizer.param_groups[0]['lr'])
        self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))

        return loss

    def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):

        if isinstance(batch, dict):
            # lhotse batches are dictionaries
            input_signal = batch['input_signal']
            input_length = batch['input_length']
            target_signal = batch['target_signal']
        else:
            input_signal, input_length, target_signal, _ = batch

        # For consistency, the model uses multi-channel format, even if the channel dimension is 1
        if input_signal.ndim == 2:
            input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
        if target_signal.ndim == 2:
            target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')

        # Estimate the signal
        output_signal, _ = self.forward(input_signal=input_signal, input_length=input_length)

        # Prepare output
        loss = self.loss(estimate=output_signal, target=target_signal, input_length=input_length)

        # Update metrics
        if hasattr(self, 'metrics') and tag in self.metrics:
            # Update metrics for this (tag, dataloader_idx)
            for name, metric in self.metrics[tag][dataloader_idx].items():
                metric.update(preds=output_signal, target=target_signal, input_length=input_length)

        # Log global step
        self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))

        return {f'{tag}_loss': loss}


class ScoreBasedGenerativeAudioToAudioModel(AudioToAudioModel):
    """This models is using a score-based diffusion process to generate
    an encoded representation of the enhanced signal.

    The model consists of the following blocks:
        - encoder: transforms input multi-channel audio signal into an encoded representation (analysis transform)
        - estimator: neural model, estimates a score for the diffusion process
        - sde: stochastic differential equation (SDE) defining the forward and reverse diffusion process
        - sampler: sampler for the reverse diffusion process, estimates coefficients of the target signal
        - decoder: transforms sampler output into the time domain (synthesis transform)
    """

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        super().__init__(cfg=cfg, trainer=trainer)
        self.sample_rate = self._cfg.sample_rate

        # Setup processing modules
        self.encoder = self.from_config_dict(self._cfg.encoder)
        self.decoder = self.from_config_dict(self._cfg.decoder)

        # Neural score estimator
        self.estimator = self.from_config_dict(self._cfg.estimator)

        # SDE
        self.sde = self.from_config_dict(self._cfg.sde)

        # Sampler
        if 'sde' in self._cfg.sampler:
            raise ValueError('SDE should be defined in the model config, not in the sampler config')
        if 'score_estimator' in self._cfg.sampler:
            raise ValueError('Score estimator should be defined in the model config, not in the sampler config')

        self.sampler = hydra.utils.instantiate(self._cfg.sampler, sde=self.sde, score_estimator=self.estimator)

        # Normalization
        self.normalize_input = self._cfg.get('normalize_input', False)

        # Metric evaluation
        self.max_utts_evaluation_metrics = self._cfg.get('max_utts_evaluation_metrics')

        if self.max_utts_evaluation_metrics is not None:
            logging.warning(
                'Metrics will be evaluated on first %d examples of the evaluation datasets.',
                self.max_utts_evaluation_metrics,
            )

        # Term added to the denominator to improve numerical stability
        self.eps = self._cfg.get('eps', 1e-8)

        # Setup optional Optimization flags
        self.setup_optimization_flags()

        logging.debug('Initialized %s', self.__class__.__name__)
        logging.debug('\tnormalize_input: %s', self.normalize_input)
        logging.debug('\teps:             %s', self.eps)

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        return {
            "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
            "input_length": NeuralType(tuple('B'), LengthsType(), optional=True),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        return {
            "output_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
            "output_length": NeuralType(tuple('B'), LengthsType(), optional=True),
        }

    @typecheck()
    @torch.inference_mode()
    def forward(self, input_signal, input_length=None):
        """Forward pass of the model.

        Forward pass of the model aplies the following steps:
            - encoder to obtain the encoded representation of the input signal
            - sampler to generate the estimated coefficients of the target signal
            - decoder to transform the sampler output into the time domain

        Args:
            input_signal: Tensor that represents a batch of time-domain audio signals,
                of shape [B, C, T]. T here represents timesteps, with 1 second of audio represented as
                `self.sample_rate` number of floating point values.
            input_signal_length: Vector of length B, contains the individual lengths of the audio sequences.

        Returns:
            Output `output_signal` in the time domain and the length of the output signal `output_length`.
        """
        batch_length = input_signal.size(-1)

        if self.normalize_input:
            # max for each example in the batch
            norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True)
            # scale input signal
            input_signal = input_signal / (norm_scale + self.eps)

        # Encoder
        encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length)

        # Sampler
        generated, generated_length = self.sampler(
            prior_mean=encoded, score_condition=encoded, state_length=encoded_length
        )

        # Decoder
        output, output_length = self.decoder(input=generated, input_length=generated_length)

        if self.normalize_input:
            # rescale to the original scale
            output = output * norm_scale

        # Trim or pad the estimated signal to match input length
        output = self.match_batch_length(input=output, batch_length=batch_length)
        return output, output_length

    @typecheck(
        input_types={
            "target_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
            "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
            "input_length": NeuralType(tuple('B'), LengthsType()),
        },
        output_types={
            "loss": NeuralType(None, LossType()),
        },
    )
    def _step(self, target_signal, input_signal, input_length=None):
        """Randomly generate a time step for each example in the batch, estimate
        the score and calculate the loss value.

        Note that this step does not include sampler.
        """
        batch_size = target_signal.size(0)

        if self.normalize_input:
            # max for each example in the batch
            norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True)
            # scale input signal
            input_signal = input_signal / (norm_scale + self.eps)
            # scale the target signal
            target_signal = target_signal / (norm_scale + self.eps)

        # Apply encoder to both target and the input
        input_enc, input_enc_len = self.encoder(input=input_signal, input_length=input_length)
        target_enc, _ = self.encoder(input=target_signal, input_length=input_length)

        # Generate random time steps
        sde_time = self.sde.generate_time(size=batch_size, device=input_enc.device)

        # Get the mean and the variance of the perturbation kernel
        pk_mean, pk_std = self.sde.perturb_kernel_params(state=target_enc, prior_mean=input_enc, time=sde_time)

        # Generate a random sample from a standard normal distribution
        z_norm = torch.randn_like(input_enc)

        # Prepare perturbed data
        perturbed_enc = pk_mean + pk_std * z_norm

        # Score is conditioned on the perturbed data and the input
        estimator_input = torch.cat([perturbed_enc, input_enc], dim=-3)

        # Estimate the score using the neural estimator
        # SDE time is used to inform the estimator about the current time step
        # Note:
        # - some implementations use `score = -self._raw_dnn_output(x, t, y)`
        # - this seems to be unimportant, and is an artifact of transfering code from the original Song's repo
        score_est, score_len = self.estimator(input=estimator_input, input_length=input_enc_len, condition=sde_time)

        # Score loss weighting as in Section 4.2 in http://arxiv.org/abs/1907.05600
        score_est = score_est * pk_std
        score_ref = -z_norm

        # Score matching loss on the normalized scores
        loss = self.loss(estimate=score_est, target=score_ref, input_length=score_len)

        return loss

    # PTL-specific methods
    def training_step(self, batch, batch_idx):

        if isinstance(batch, dict):
            # lhotse batches are dictionaries
            input_signal = batch['input_signal']
            input_length = batch['input_length']
            target_signal = batch['target_signal']
        else:
            input_signal, input_length, target_signal, _ = batch

        # For consistency, the model uses multi-channel format, even if the channel dimension is 1
        if input_signal.ndim == 2:
            input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
        if target_signal.ndim == 2:
            target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')

        # Calculate the loss
        loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length)

        # Logs
        self.log('train_loss', loss)
        self.log('learning_rate', self._optimizer.param_groups[0]['lr'])
        self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))

        return loss

    def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):

        if isinstance(batch, dict):
            # lhotse batches are dictionaries
            input_signal = batch['input_signal']
            input_length = batch['input_length']
            target_signal = batch['target_signal']
        else:
            input_signal, input_length, target_signal, _ = batch

        # For consistency, the model uses multi-channel format, even if the channel dimension is 1
        if input_signal.ndim == 2:
            input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
        if target_signal.ndim == 2:
            target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')

        # Calculate loss
        loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length)

        # Update metrics
        update_metrics = False
        if self.max_utts_evaluation_metrics is None:
            # Always update if max is not configured
            update_metrics = True
            # Number of examples to process
            num_examples = input_signal.size(0)  # batch size
        else:
            # Check how many examples have been used for metric calculation
            first_metric_name = next(iter(self.metrics[tag][dataloader_idx]))
            num_examples_evaluated = self.metrics[tag][dataloader_idx][first_metric_name].num_examples
            # Update metrics if some examples were not processed
            update_metrics = num_examples_evaluated < self.max_utts_evaluation_metrics
            # Number of examples to process
            num_examples = min(self.max_utts_evaluation_metrics - num_examples_evaluated, input_signal.size(0))

        if update_metrics:
            # Generate output signal
            output_signal, _ = self.forward(
                input_signal=input_signal[:num_examples, ...], input_length=input_length[:num_examples]
            )

            # Update metrics
            if hasattr(self, 'metrics') and tag in self.metrics:
                # Update metrics for this (tag, dataloader_idx)
                for name, metric in self.metrics[tag][dataloader_idx].items():
                    metric.update(
                        preds=output_signal,
                        target=target_signal[:num_examples, ...],
                        input_length=input_length[:num_examples],
                    )

        # Log global step
        self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))

        return {f'{tag}_loss': loss}


class FlowMatchingAudioToAudioModel(AudioToAudioModel):
    """This models uses a flow matching process to generate
    an encoded representation of the enhanced signal.

    The model consists of the following blocks:
        - encoder: transforms input multi-channel audio signal into an encoded representation (analysis transform)
        - estimator: neural model, estimates a score for the diffusion process
        - flow: ordinary differential equation (ODE) defining a flow and a vector field.
        - sampler: sampler for the inference process, estimates coefficients of the target signal
        - decoder: transforms sampler output into the time domain (synthesis transform)
        - ssl_pretrain_masking: if it is defined, perform the ssl pretrain masking for self reconstruction in the training process
    """

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        super().__init__(cfg=cfg, trainer=trainer)
        self.sample_rate = self._cfg.sample_rate

        # Setup processing modules
        self.encoder = self.from_config_dict(self._cfg.encoder)
        self.decoder = self.from_config_dict(self._cfg.decoder)

        # Neural estimator
        self.estimator = self.from_config_dict(self._cfg.estimator)

        # Flow
        self.flow = self.from_config_dict(self._cfg.flow)

        # Sampler
        self.sampler = hydra.utils.instantiate(self._cfg.sampler, estimator=self.estimator)

        # probability that the conditional input will be feed into the
        # estimator in the training stage
        self.p_cond = self._cfg.get('p_cond', 1.0)

        # Self-Supervised Pretraining
        if self._cfg.get('ssl_pretrain_masking') is not None:
            logging.debug('SSL-pretrain_masking is found and will be initialized')
            self.ssl_pretrain_masking = self.from_config_dict(self._cfg.ssl_pretrain_masking)
        else:
            self.ssl_pretrain_masking = None

        # Normalization
        self.normalize_input = self._cfg.get('normalize_input', False)

        # Metric evaluation
        self.max_utts_evaluation_metrics = self._cfg.get('max_utts_evaluation_metrics')

        if self.max_utts_evaluation_metrics is not None:
            logging.warning(
                'Metrics will be evaluated on first %d examples of the evaluation datasets.',
                self.max_utts_evaluation_metrics,
            )

        # Regularization
        self.eps = self._cfg.get('eps', 1e-8)

        # Setup optional Optimization flags
        self.setup_optimization_flags()

        logging.debug('Initialized              %s', self.__class__.__name__)
        logging.debug('\tdoing SSL-pretraining: %s', (self.ssl_pretrain_masking is not None))
        logging.debug('\tp_cond:                %s', self.p_cond)
        logging.debug('\tnormalize_input:       %s', self.normalize_input)
        logging.debug('\tloss:                  %s', self.loss)
        logging.debug('\teps:                   %s', self.eps)

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        return {
            "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
            "input_length": NeuralType(tuple('B'), LengthsType(), optional=True),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        return {
            "output_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
            "output_length": NeuralType(tuple('B'), LengthsType(), optional=True),
        }

    @typecheck()
    @torch.inference_mode()
    def forward(self, input_signal, input_length=None):
        """Forward pass of the model to generate samples from the target distribution.
        This is used for inference mode only, and it explicitly disables SSL masking to the input.

        Args:
            input_signal: Tensor that represents a batch of raw audio signals,
                of shape [B, T] or [B, T, C]. T here represents timesteps, with 1 second of audio represented as
                `self.sample_rate` number of floating point values.
            input_signal_length: Vector of length B, that contains the individual lengths of the audio
                sequences.

        Returns:
            Output signal `output` in the time domain and the length of the output signal `output_length`.
        """
        return self.forward_internal(input_signal=input_signal, input_length=input_length, enable_ssl_masking=False)

    @typecheck(
        input_types={
            "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
            "input_length": NeuralType(tuple('B'), LengthsType(), optional=True),
        },
        output_types={
            "output_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
            "output_length": NeuralType(tuple('B'), LengthsType(), optional=True),
        },
    )
    @torch.inference_mode()
    def forward_eval(self, input_signal, input_length=None):
        """Forward pass of the model to generate samples from the target distribution.
        This is used for eval mode only, and it enables SSL masking to the input.

        Args:
            input_signal: Tensor that represents a batch of raw audio signals,
                of shape [B, T] or [B, T, C]. T here represents timesteps, with 1 second of audio represented as
                `self.sample_rate` number of floating point values.
            input_signal_length: Vector of length B, that contains the individual lengths of the audio
                sequences.

        Returns:
            Output signal `output` in the time domain and the length of the output signal `output_length`.
        """
        return self.forward_internal(input_signal=input_signal, input_length=input_length, enable_ssl_masking=True)

    @torch.inference_mode()
    def forward_internal(self, input_signal, input_length=None, enable_ssl_masking=False):
        """Internal forward pass of the model.

        Args:
            input_signal: Tensor that represents a batch of raw audio signals,
                of shape [B, T] or [B, T, C]. T here represents timesteps, with 1 second of audio represented as
                `self.sample_rate` number of floating point values.
            input_signal_length: Vector of length B, that contains the individual lengths of the audio
                sequences.
            enable_ssl_masking: Whether to enable SSL masking of the input. If using SSL pretraining, masking
                is applied to the input signal. If not using SSL pretraining, masking is not applied.

        Returns:
            Output signal `output` in the time domain and the length of the output signal `output_length`.
        """
        batch_length = input_signal.size(-1)

        if self.normalize_input:
            # max for each example in the batch
            norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True)
            # scale input signal
            input_signal = input_signal / (norm_scale + self.eps)

        # Encoder
        encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length)

        # Conditional input
        if self.p_cond == 0:
            # The model is trained without the conditional input
            encoded = torch.zeros_like(encoded)
        elif enable_ssl_masking and self.ssl_pretrain_masking is not None:
            # Masking for self-supervised pretraining
            encoded = self.ssl_pretrain_masking(input_spec=encoded, length=encoded_length)

        # Initial process state
        init_state = torch.randn_like(encoded) * self.flow.sigma_start

        # Sampler
        generated, generated_length = self.sampler(
            state=init_state, estimator_condition=encoded, state_length=encoded_length
        )

        # Decoder
        output, output_length = self.decoder(input=generated, input_length=generated_length)

        if self.normalize_input:
            # rescale to the original scale
            output = output * norm_scale

        # Trim or pad the estimated signal to match input length
        output = self.match_batch_length(input=output, batch_length=batch_length)

        return output, output_length

    @typecheck(
        input_types={
            "target_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
            "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
            "input_length": NeuralType(tuple('B'), LengthsType()),
        },
        output_types={
            "loss": NeuralType(None, LossType()),
        },
    )
    def _step(self, target_signal, input_signal, input_length=None):
        batch_size = target_signal.size(0)

        if self.normalize_input:
            # max for each example in the batch
            norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True)
            # scale input signal
            input_signal = input_signal / (norm_scale + self.eps)
            # scale the target signal
            target_signal = target_signal / (norm_scale + self.eps)

        # Apply encoder to both target and the input
        input_enc, input_enc_len = self.encoder(input=input_signal, input_length=input_length)
        target_enc, _ = self.encoder(input=target_signal, input_length=input_length)

        # Self-Supervised Pretraining
        if self.ssl_pretrain_masking is not None:
            input_enc = self.ssl_pretrain_masking(input_spec=input_enc, length=input_enc_len)

        # Drop off conditional inputs (input_enc) with (1 - p_cond) probability.
        # The dropped conditions will be set to zeros
        keep_conditions = einops.rearrange((torch.rand(batch_size) < self.p_cond).float(), 'B -> B 1 1 1')
        input_enc = input_enc * keep_conditions.to(input_enc.device)

        x_start = torch.zeros_like(input_enc)

        time = self.flow.generate_time(batch_size=batch_size).to(device=input_enc.device)
        sample = self.flow.sample(time=time, x_start=x_start, x_end=target_enc)

        # we want to get a vector field estimate given current state
        # at training time, current state is sampled from the conditional path
        #   the vector field model is also conditioned on input signal
        estimator_input = torch.cat([sample, input_enc], dim=-3)

        # Estimate the vector  using the neural estimator
        estimate, estimate_len = self.estimator(input=estimator_input, input_length=input_enc_len, condition=time)

        conditional_vector_field = self.flow.vector_field(time=time, x_start=x_start, x_end=target_enc, point=sample)

        return self.loss(estimate=estimate, target=conditional_vector_field, input_length=input_enc_len)

    # PTL-specific methods
    def training_step(self, batch, batch_idx):
        if isinstance(batch, dict):
            # lhotse batches are dictionaries
            input_signal = batch['input_signal']
            input_length = batch['input_length']
            target_signal = batch.get('target_signal', input_signal.clone())
        else:
            input_signal, input_length, target_signal, _ = batch

        # For consistency, the model uses multi-channel format, even if the channel dimension is 1
        if input_signal.ndim == 2:
            input_signal = einops.rearrange(input_signal, "B T -> B 1 T")
        if target_signal.ndim == 2:
            target_signal = einops.rearrange(target_signal, "B T -> B 1 T")

        # Calculate the loss
        loss = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length)

        # Logs
        self.log('train_loss', loss)
        self.log('learning_rate', self._optimizer.param_groups[0]['lr'])
        self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))

        return loss

    def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):

        if isinstance(batch, dict):
            # lhotse batches are dictionaries
            input_signal = batch['input_signal']
            input_length = batch['input_length']
            target_signal = batch.get('target_signal', input_signal.clone())
        else:
            input_signal, input_length, target_signal, _ = batch

        # For consistency, the model uses multi-channel format, even if the channel dimension is 1
        if input_signal.ndim == 2:
            input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
        if target_signal.ndim == 2:
            target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')

        # Calculate loss
        loss = self._step(
            target_signal=target_signal,
            input_signal=input_signal,
            input_length=input_length,
        )

        # Update metrics
        update_metrics = False
        if self.max_utts_evaluation_metrics is None:
            # Always update if max is not configured
            update_metrics = True
            # Number of examples to process
            num_examples = input_signal.size(0)  # batch size
        else:
            # Check how many examples have been used for metric calculation
            first_metric_name = next(iter(self.metrics[tag][dataloader_idx]))
            num_examples_evaluated = self.metrics[tag][dataloader_idx][first_metric_name].num_examples
            # Update metrics if some examples were not processed
            update_metrics = num_examples_evaluated < self.max_utts_evaluation_metrics
            # Number of examples to process
            num_examples = min(self.max_utts_evaluation_metrics - num_examples_evaluated, input_signal.size(0))

        if update_metrics:
            # Generate output signal
            output_signal, _ = self.forward_eval(
                input_signal=input_signal[:num_examples, ...], input_length=input_length[:num_examples]
            )

            # Update metrics
            if hasattr(self, 'metrics') and tag in self.metrics:
                # Update metrics for this (tag, dataloader_idx)
                for name, metric in self.metrics[tag][dataloader_idx].items():
                    metric.update(
                        preds=output_signal,
                        target=target_signal[:num_examples, ...],
                        input_length=input_length[:num_examples],
                    )

        # Log global step
        self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))

        return {f'{tag}_loss': loss}


class SchroedingerBridgeAudioToAudioModel(AudioToAudioModel):
    """This models is using a Schrödinger Bridge process to generate
    an encoded representation of the enhanced signal.

    The model consists of the following blocks:
        - encoder: transforms input audio signal into an encoded representation (analysis transform)
        - estimator: neural model, estimates the coefficients for the SB process
        - noise_schedule: defines the path between the clean and noisy signals
        - sampler: sampler for the reverse process, estimates coefficients of the target signal
        - decoder: transforms sampler output into the time domain (synthesis transform)

    References:
        Schrödinger Bridge for Generative Speech Enhancement, https://arxiv.org/abs/2407.16074
    """

    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        super().__init__(cfg=cfg, trainer=trainer)
        self.sample_rate = self._cfg.sample_rate

        # Setup processing modules
        self.encoder = self.from_config_dict(self._cfg.encoder)
        self.decoder = self.from_config_dict(self._cfg.decoder)

        # Neural estimator
        self.estimator = self.from_config_dict(self._cfg.estimator)
        self.estimator_output = self._cfg.estimator_output

        # Noise schedule
        self.noise_schedule = self.from_config_dict(self._cfg.noise_schedule)

        # Sampler
        self.sampler = hydra.utils.instantiate(
            self._cfg.sampler,
            noise_schedule=self.noise_schedule,
            estimator=self.estimator,
            estimator_output=self.estimator_output,
        )

        # Normalization
        self.normalize_input = self._cfg.get('normalize_input', False)

        # Metric evaluation
        self.max_utts_evaluation_metrics = self._cfg.get('max_utts_evaluation_metrics')

        if self.max_utts_evaluation_metrics is not None:
            logging.warning(
                'Metrics will be evaluated on first %d examples of the evaluation datasets.',
                self.max_utts_evaluation_metrics,
            )

        # Loss in the encoded domain
        if 'loss_encoded' in self._cfg:
            self.loss_encoded = self.from_config_dict(self._cfg.loss_encoded)
            self.loss_encoded_weight = self._cfg.get('loss_encoded_weight', 1.0)
        else:
            self.loss_encoded = None
            self.loss_encoded_weight = 0.0

        # Loss in the time domain
        if 'loss_time' in self._cfg:
            self.loss_time = self.from_config_dict(self._cfg.loss_time)
            self.loss_time_weight = self._cfg.get('loss_time_weight', 1.0)
        else:
            self.loss_time = None
            self.loss_time_weight = 0.0

        if self.loss is not None and (self.loss_encoded is not None or self.loss_time is not None):
            raise ValueError('Either ``loss`` or ``loss_encoded`` and ``loss_time`` should be defined, not both.')

        # Term added to the denominator to improve numerical stability
        self.eps = self._cfg.get('eps', 1e-8)

        # Setup optional optimization flags
        self.setup_optimization_flags()

        logging.debug('Initialized %s', self.__class__.__name__)
        logging.debug('\testimator_output:    %s', self.estimator_output)
        logging.debug('\tnormalize_input:     %s', self.normalize_input)
        logging.debug('\tloss:                %s', self.loss)
        logging.debug('\tloss_encoded:        %s', self.loss_encoded)
        logging.debug('\tloss_encoded_weight: %s', self.loss_encoded_weight)
        logging.debug('\tloss_time:           %s', self.loss_time)
        logging.debug('\tloss_time_weight:    %s', self.loss_time_weight)
        logging.debug('\teps:                 %s', self.eps)

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        # time-domain input
        return {
            "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
            "input_length": NeuralType(tuple('B'), LengthsType(), optional=True),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        # time-domain output
        return {
            "output_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)),
            "output_length": NeuralType(tuple('B'), LengthsType(), optional=True),
        }

    @typecheck()
    @torch.inference_mode()
    def forward(self, input_signal, input_length=None):
        """Forward pass of the model.

        Forward pass of the model consists of the following steps
            - encoder to obtain the encoded representation of the input signal
            - sampler to generate the estimated coefficients of the target signal
            - decoder to transform the estimated output into the time domain

        Args:
            input_signal: Tensor that represents a batch of time-domain audio signals,
                of shape [B, C, T]. T here represents timesteps, with 1 second of audio represented as
                `self.sample_rate` number of floating point values.
            input_signal_length: Vector of length B, contains the individual lengths of the audio sequences.

        Returns:
            Output `output_signal` in the time domain and the length of the output signal `output_length`.
        """
        batch_length = input_signal.size(-1)

        if self.normalize_input:
            # max for each example in the batch
            norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True)
            # scale input signal
            input_signal = input_signal / (norm_scale + self.eps)

        # Encoder
        encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length)

        # Sampler
        generated, generated_length = self.sampler(
            prior_mean=encoded, estimator_condition=encoded, state_length=encoded_length
        )

        # Decoder
        output, output_length = self.decoder(input=generated, input_length=generated_length)

        if self.normalize_input:
            # rescale to the original scale
            output = output * norm_scale

        # Trim or pad the estimated signal to match input length
        output = self.match_batch_length(input=output, batch_length=batch_length)

        return output, output_length

    @typecheck(
        input_types={
            "target_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
            "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal()),
            "input_length": NeuralType(tuple('B'), LengthsType()),
        },
        output_types={
            "loss": NeuralType(None, LossType()),
            "loss_encoded": NeuralType(None, LossType()),
            "loss_time": NeuralType(None, LossType()),
        },
    )
    def _step(self, target_signal, input_signal, input_length=None):
        """Randomly generate time step for each example in the batch, run neural estimator
        to estimate the target and calculate the loss.
        """
        batch_size = target_signal.size(0)

        if self.normalize_input:
            # max for each example in the batch
            norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True)
            # scale input signal
            input_signal = input_signal / (norm_scale + self.eps)
            # scale the target signal
            target_signal = target_signal / (norm_scale + self.eps)

        # Apply encoder to both target and the input
        # For example, if the encoder is STFT, then _enc is the complex-valued STFT of the corresponding signal
        input_enc, input_enc_len = self.encoder(input=input_signal, input_length=input_length)
        target_enc, _ = self.encoder(input=target_signal, input_length=input_length)

        # Generate random time steps
        process_time = self.noise_schedule.generate_time(size=batch_size, device=input_enc.device)

        # Prepare necessary info from the noise schedule
        alpha_t, alpha_bar_t, alpha_t_max = self.noise_schedule.get_alphas(time=process_time)
        sigma_t, sigma_bar_t, sigma_t_max = self.noise_schedule.get_sigmas(time=process_time)

        # Marginal distribution
        weight_target = alpha_t * sigma_bar_t**2 / (sigma_t_max**2 + self.eps)
        weight_input = alpha_bar_t * sigma_t**2 / (sigma_t_max**2 + self.eps)
        # view weights as [B, C, D, T]
        weight_target = weight_target.view(-1, 1, 1, 1)
        weight_input = weight_input.view(-1, 1, 1, 1)
        # mean
        mean_x = weight_target * target_enc + weight_input * input_enc
        # standard deviation
        std_x = alpha_t * sigma_bar_t * sigma_t / (sigma_t_max + self.eps)
        # view as [B, C, D, T]
        std_x = std_x.view(-1, 1, 1, 1)

        # Generate a random sample from a standard normal distribution
        z_norm = torch.randn_like(input_enc)

        # Generate a random sample from the marginal distribution
        x_t = mean_x + std_x * z_norm

        # Estimator is conditioned on the generated sample and the original input (prior)
        estimator_input = torch.cat([x_t, input_enc], dim=-3)

        # Neural estimator
        # Estimator input is the same data type as the encoder output
        # For example, if the encoder is STFT, then the estimator input and output are complex-valued coefficients
        estimate, estimate_len = self.estimator(
            input=estimator_input, input_length=input_enc_len, condition=process_time
        )

        # Prepare output target and calculate loss
        if self.estimator_output == 'data_prediction':
            if self.loss is not None:
                # Single loss in the encoded domain
                loss = self.loss(estimate=estimate, target=target_enc, input_length=estimate_len)
                loss_encoded = loss_time = None
            else:
                # Weighted loss between encoded and time domain
                loss = 0.0

                # Loss in the encoded domain
                if self.loss_encoded is not None:
                    # Loss between the estimate and the target in the encoded domain
                    loss_encoded = self.loss_encoded(estimate=estimate, target=target_enc, input_length=estimate_len)
                    # Weighting
                    loss += self.loss_encoded_weight * loss_encoded
                else:
                    loss_encoded = None

                # Loss in the time domain
                if self.loss_time is not None:
                    # Convert the estimate to the time domain
                    with typecheck.disable_checks():
                        # Note: stimate is FloatType, decoder requires SpectrogramType
                        estimate_signal, _ = self.decoder(input=estimate, input_length=estimate_len)

                    # Match estimate length
                    batch_length = input_signal.size(-1)
                    estimate_signal = self.match_batch_length(input=estimate_signal, batch_length=batch_length)

                    # Loss between the estimate and the target in the time domain
                    loss_time = self.loss_time(
                        estimate=estimate_signal, target=target_signal, input_length=input_length
                    )
                    # Weighting
                    loss += self.loss_time_weight * loss_time
                else:
                    loss_time = None
        else:
            raise NotImplementedError(f'Output type {self.estimator_output} is not implemented')

        return loss, loss_encoded, loss_time

    # PTL-specific methods
    def training_step(self, batch, batch_idx):

        if isinstance(batch, dict):
            # lhotse batches are dictionaries
            input_signal = batch['input_signal']
            input_length = batch['input_length']
            target_signal = batch['target_signal']
        else:
            input_signal, input_length, target_signal, _ = batch

        # For consistency, the model uses multi-channel format, even if the channel dimension is 1
        if input_signal.ndim == 2:
            input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
        if target_signal.ndim == 2:
            target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')

        # Calculate the loss
        loss, loss_encoded, loss_time = self._step(
            target_signal=target_signal, input_signal=input_signal, input_length=input_length
        )

        # Logs
        self.log('train_loss', loss)
        self.log('learning_rate', self._optimizer.param_groups[0]['lr'])
        self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))

        if loss_encoded is not None:
            self.log('train_loss_encoded', loss_encoded)

        if loss_time is not None:
            self.log('train_loss_time', loss_time)

        return loss

    def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):

        if isinstance(batch, dict):
            # lhotse batches are dictionaries
            input_signal = batch['input_signal']
            input_length = batch['input_length']
            target_signal = batch['target_signal']
        else:
            input_signal, input_length, target_signal, _ = batch

        # For consistency, the model uses multi-channel format, even if the channel dimension is 1
        if input_signal.ndim == 2:
            input_signal = einops.rearrange(input_signal, 'B T -> B 1 T')
        if target_signal.ndim == 2:
            target_signal = einops.rearrange(target_signal, 'B T -> B 1 T')

        # Calculate loss
        loss, *_ = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length)

        # Update metrics
        update_metrics = False
        if self.max_utts_evaluation_metrics is None:
            # Always update if max is not configured
            update_metrics = True
            # Number of examples to process
            num_examples = input_signal.size(0)  # batch size
        else:
            # Check how many examples have been used for metric calculation
            first_metric_name = next(iter(self.metrics[tag][dataloader_idx]))
            num_examples_evaluated = self.metrics[tag][dataloader_idx][first_metric_name].num_examples
            # Update metrics if some examples were not processed
            update_metrics = num_examples_evaluated < self.max_utts_evaluation_metrics
            # Number of examples to process
            num_examples = min(self.max_utts_evaluation_metrics - num_examples_evaluated, input_signal.size(0))

        if update_metrics:
            # Generate output signal
            output_signal, _ = self.forward(
                input_signal=input_signal[:num_examples, ...], input_length=input_length[:num_examples]
            )

            # Update metrics
            if hasattr(self, 'metrics') and tag in self.metrics:
                # Update metrics for this (tag, dataloader_idx)
                for name, metric in self.metrics[tag][dataloader_idx].items():
                    metric.update(
                        preds=output_signal,
                        target=target_signal[:num_examples, ...],
                        input_length=input_length[:num_examples],
                    )

        # Log global step
        self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))

        return {f'{tag}_loss': loss}
