# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Any, Callable, Dict, Optional, Sequence, Tuple

import nemo_run as run
import torch
import torch.distributed
import torch.utils.checkpoint
import torchvision
import wandb
from autovae import VAEGenerator
from contperceptual_loss import LPIPSWithDiscriminator
from diffusers import AutoencoderKL
from megatron.core import parallel_state
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.energon import DefaultTaskEncoder, ImageSample
from torch import Tensor, nn

from nemo import lightning as nl
from nemo.collections import llm
from nemo.collections.diffusion.data.diffusion_energon_datamodule import DiffusionDataModule
from nemo.collections.diffusion.train import pretrain
from nemo.collections.llm.gpt.model.base import GPTModel
from nemo.lightning.io.mixin import IOMixin
from nemo.lightning.megatron_parallel import DataT, MegatronLossReduction, ReductionT
from nemo.lightning.pytorch.optim import OptimizerModule


class AvgLossReduction(MegatronLossReduction):
    """Performs average loss reduction across micro-batches."""

    def forward(self, batch: DataT, forward_out: Tensor) -> Tuple[Tensor, ReductionT]:
        """
        Forward pass for loss reduction.

        Args:
            batch: The batch of data.
            forward_out: The output tensor from forward computation.

        Returns:
            A tuple of (loss, reduction dictionary).
        """
        loss = forward_out.mean()
        return loss, {"avg": loss}

    def reduce(self, losses_reduced_per_micro_batch: Sequence[ReductionT]) -> Tensor:
        """
        Reduce losses across multiple micro-batches by averaging them.

        Args:
            losses_reduced_per_micro_batch: A sequence of loss dictionaries.

        Returns:
            The averaged loss tensor.
        """
        losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
        return losses.mean()


class VAE(MegatronModule):
    """Variational Autoencoder (VAE) module."""

    def __init__(self, config, pretrained_model_name_or_path, search_vae=False):
        """
        Initialize the VAE model.

        Args:
            config: Transformer configuration.
            pretrained_model_name_or_path: Path or name of the pretrained model.
            search_vae: Flag to indicate whether to search for a target VAE using AutoVAE.
        """
        super().__init__(config)
        if search_vae:
            # Get VAE automatically from AutoVAE
            self.vae = VAEGenerator(input_resolution=1024, compression_ratio=16)
            # Below line is commented out due to an undefined 'generator' variable in original code snippet.
            # self.vae = generator.search_for_target_vae(parameters_budget=895.178707, cuda_max_mem=0)
        else:
            self.vae = AutoencoderKL.from_config(pretrained_model_name_or_path, weight_dtype=torch.bfloat16)

        sdxl_vae = AutoencoderKL.from_pretrained(
            'stabilityai/stable-diffusion-xl-base-1.0', subfolder="vae", weight_dtype=torch.bfloat16
        )
        sd_dict = sdxl_vae.state_dict()
        vae_dict = self.vae.state_dict()
        pre_dict = {k: v for k, v in sd_dict.items() if (k in vae_dict) and (vae_dict[k].numel() == v.numel())}
        self.vae.load_state_dict(pre_dict, strict=False)
        del sdxl_vae

        self.vae_loss = LPIPSWithDiscriminator(
            disc_start=50001,
            logvar_init=0.0,
            kl_weight=0.000001,
            pixelloss_weight=1.0,
            disc_num_layers=3,
            disc_in_channels=3,
            disc_factor=1.0,
            disc_weight=0.5,
            perceptual_weight=1.0,
            use_actnorm=False,
            disc_conditional=False,
            disc_loss="hinge",
        )

    def forward(self, target, global_step):
        """
        Forward pass through the VAE.

        Args:
            target: Target images.
            global_step: Current global step.

        Returns:
            A tuple (aeloss, log_dict_ae, pred) containing the loss, log dictionary, and predictions.
        """
        posterior = self.vae.encode(target).latent_dist
        z = posterior.sample()
        pred = self.vae.decode(z).sample
        aeloss, log_dict_ae = self.vae_loss(
            inputs=target,
            reconstructions=pred,
            posteriors=posterior,
            optimizer_idx=0,
            global_step=global_step,
            last_layer=self.vae.decoder.conv_out.weight,
        )
        return aeloss, log_dict_ae, pred

    def set_input_tensor(self, input_tensor: Tensor) -> None:
        """
        Set input tensor.

        Args:
            input_tensor: The input tensor to the model.
        """
        pass


class VAEModel(GPTModel):
    """A GPTModel wrapper for the VAE."""

    def __init__(
        self,
        pretrained_model_name_or_path: str,
        optim: Optional[OptimizerModule] = None,
        model_transform: Optional[Callable[[nn.Module], nn.Module]] = None,
    ):
        """
        Initialize the VAEModel.

        Args:
            pretrained_model_name_or_path: Path or name of the pretrained model.
            optim: Optional optimizer module.
            model_transform: Optional function to transform the model.
        """
        self.pretrained_model_name_or_path = pretrained_model_name_or_path
        config = TransformerConfig(num_layers=1, hidden_size=1, num_attention_heads=1)
        self.model_type = ModelType.encoder_or_decoder
        super().__init__(config, optim=optim, model_transform=model_transform)

    def configure_model(self) -> None:
        """Configure the model by initializing the module."""
        if not hasattr(self, "module"):
            self.module = VAE(self.config, self.pretrained_model_name_or_path)

    def data_step(self, dataloader_iter) -> Dict[str, Any]:
        """
        Perform a single data step to fetch a batch from the iterator.

        Args:
            dataloader_iter: The dataloader iterator.

        Returns:
            A dictionary with 'pixel_values' ready for the model.
        """
        batch = next(dataloader_iter)[0]
        return {'pixel_values': batch.image.to(device='cuda', dtype=torch.bfloat16, non_blocking=True)}

    def forward(self, *args, **kwargs):
        """
        Forward pass through the underlying module.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.

        Returns:
            The result of forward pass of self.module.
        """
        return self.module(*args, **kwargs)

    def training_step(self, batch, batch_idx=None) -> torch.Tensor:
        """
        Perform a single training step.

        Args:
            batch: The input batch.
            batch_idx: Batch index.

        Returns:
            The loss tensor.
        """
        loss, log_dict_ae, pred = self(batch["pixel_values"], self.global_step)

        if torch.distributed.get_rank() == 0:
            self.log_dict(log_dict_ae)

        return loss

    def validation_step(self, batch, batch_idx=None) -> torch.Tensor:
        """
        Perform a single validation step.

        Args:
            batch: The input batch.
            batch_idx: Batch index.

        Returns:
            The loss tensor.
        """
        loss, log_dict_ae, pred = self(batch["pixel_values"], self.global_step)

        image = torch.cat([batch["pixel_values"].cpu(), pred.cpu()], axis=0)
        image = (image + 0.5).clamp(0, 1)

        # wandb is on the last rank for megatron, first rank for nemo
        wandb_rank = 0

        if parallel_state.get_data_parallel_src_rank() == wandb_rank:
            if torch.distributed.get_rank() == wandb_rank:
                gather_list = [None for _ in range(parallel_state.get_data_parallel_world_size())]
            else:
                gather_list = None
            torch.distributed.gather_object(
                image, gather_list, wandb_rank, group=parallel_state.get_data_parallel_group()
            )
            if gather_list is not None:
                self.log_dict(log_dict_ae)
                wandb.log(
                    {
                        "Original (left), Reconstruction (right)": [
                            wandb.Image(torchvision.utils.make_grid(image)) for _, image in enumerate(gather_list)
                        ]
                    },
                )

        return loss

    @property
    def training_loss_reduction(self) -> AvgLossReduction:
        """Returns the loss reduction method for training."""
        if not self._training_loss_reduction:
            self._training_loss_reduction = AvgLossReduction()
        return self._training_loss_reduction

    @property
    def validation_loss_reduction(self) -> AvgLossReduction:
        """Returns the loss reduction method for validation."""
        if not self._validation_loss_reduction:
            self._validation_loss_reduction = AvgLossReduction()
        return self._validation_loss_reduction

    def on_validation_model_zero_grad(self) -> None:
        """
        Hook to handle zero grad on validation model step.
        Used here to skip first validation on resume.
        """
        super().on_validation_model_zero_grad()
        if self.trainer.ckpt_path is not None and getattr(self, '_restarting_skip_val_flag', True):
            self.trainer.sanity_checking = True
            self._restarting_skip_val_flag = False


def crop_image(img, divisor=16):
    """
    Crop the image so that both dimensions are divisible by the given divisor.

    Args:
        img: Image tensor.
        divisor: The divisor to use for cropping.

    Returns:
        The cropped image tensor.
    """
    h, w = img.shape[-2], img.shape[-1]

    delta_h = h % divisor
    delta_w = w % divisor

    delta_h_top = delta_h // 2
    delta_h_bottom = delta_h - delta_h_top

    delta_w_left = delta_w // 2
    delta_w_right = delta_w - delta_w_left

    img_cropped = img[..., delta_h_top : h - delta_h_bottom, delta_w_left : w - delta_w_right]

    return img_cropped


class ImageTaskEncoder(DefaultTaskEncoder, IOMixin):
    """Image task encoder that crops and normalizes the image."""

    def encode_sample(self, sample: ImageSample) -> ImageSample:
        """
        Encode a single image sample by cropping and shifting its values.

        Args:
            sample: An image sample.

        Returns:
            The transformed image sample.
        """
        sample = super().encode_sample(sample)
        sample.image = crop_image(sample.image, 16)
        sample.image -= 0.5
        return sample


@run.cli.factory(target=llm.train)
def train_vae() -> run.Partial:
    """
    Training factory function for VAE.

    Returns:
        A run.Partial recipe for training.
    """
    recipe = pretrain()
    recipe.model = run.Config(
        VAEModel,
        pretrained_model_name_or_path='nemo/collections/diffusion/vae/vae16x/config.json',
    )
    recipe.data = run.Config(
        DiffusionDataModule,
        task_encoder=run.Config(ImageTaskEncoder),
        global_batch_size=24,
        num_workers=10,
    )
    recipe.optim.lr_scheduler = run.Config(nl.lr_scheduler.WarmupHoldPolicyScheduler, warmup_steps=100, hold_steps=1e9)
    recipe.optim.config.lr = 5e-6
    recipe.optim.config.weight_decay = 1e-2
    recipe.log.log_dir = 'nemo_experiments/train_vae'
    recipe.trainer.val_check_interval = 1000
    recipe.trainer.callbacks[0].every_n_train_steps = 1000

    return recipe


if __name__ == "__main__":
    run.cli.main(llm.train, default_factory=train_vae)
