# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import lightning.pytorch as pl
import torch
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data import DataLoader

from nemo.collections.diffusion.models.model import DiTConfig
from nemo.lightning.pytorch.plugins import MegatronDataSampler

from .diffusion_taskencoder import pos_id_3d


class PosEmb3D:
    """Generates and provides 3D positional embeddings for video data."""

    def __init__(self, *, max_t=96, max_h=960, max_w=960):
        self.max_t = max_t
        self.max_h = max_h
        self.max_w = max_w
        self.generate_pos_id()

    def generate_pos_id(self):
        """Generates the positional ID grid based on max_t, max_h, and max_w."""
        self.grid = torch.stack(
            torch.meshgrid(
                torch.arange(self.max_t, device='cpu'),
                torch.arange(self.max_h, device='cpu'),
                torch.arange(self.max_w, device='cpu'),
            ),
            dim=-1,
        )

    def get_pos_id_3d(self, *, t, h, w):
        """Retrieves a subset of the positional IDs for the specified dimensions.

        Parameters:
            t (int): Number of time frames.
            h (int): Height dimension.
            w (int): Width dimension.

        Returns:
            torch.Tensor: The positional IDs tensor with shape (t, h, w, 3).
        """
        if t > self.max_t or h > self.max_h or w > self.max_w:
            self.max_t = max(self.max_t, t)
            self.max_h = max(self.max_h, h)
            self.max_w = max(self.max_w, w)
            self.generate_pos_id()
        return self.grid[:t, :h, :w]


class DiTVideoLatentFakeDataset(torch.utils.data.Dataset):
    """A fake dataset for generating synthetic video latent data."""

    def __init__(
        self,
        n_frames,
        max_h,
        max_w,
        patch_size,
        in_channels,
        crossattn_emb_size,
        max_text_seqlen=512,
        seq_length=8192,
    ):
        self.max_t = n_frames
        self.max_height = max_h
        self.max_width = max_w
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.text_dim = crossattn_emb_size
        self.text_seqlen = max_text_seqlen
        self.seq_length = seq_length

    def __len__(self):
        """Returns the total number of samples."""
        return 100000000

    def __getitem__(self, idx):
        """Generates a single sample of data.

        Parameters:
            idx (int): Index of the data sample.

        Returns:
            dict: A dictionary containing video latent data and related information.
        """
        t = self.max_t
        h = self.max_height
        w = self.max_width
        p = self.patch_size
        c = self.in_channels

        video_latent = torch.ones(self.seq_length, c * p**2, dtype=torch.bfloat16) * 0.5
        text_embedding = torch.randn(self.text_seqlen, self.text_dim, dtype=torch.bfloat16)
        pos_emb = pos_id_3d.get_pos_id_3d(t=t, h=h // p, w=w // p).reshape(-1, 3)

        return {
            'video': video_latent,
            't5_text_embeddings': text_embedding,
            'seq_len_q': torch.tensor([video_latent.shape[0]], dtype=torch.int32).squeeze(),
            'seq_len_kv': torch.tensor([self.text_seqlen], dtype=torch.int32).squeeze(),
            'pos_ids': torch.zeros((self.seq_length, 3), dtype=torch.int32),
            'loss_mask': torch.ones(video_latent.shape[0], dtype=torch.bfloat16),
        }

    def _collate_fn(self, batch):
        """A default implementation of a collation function.

        Users should override this method to define custom data loaders.
        """
        return torch.utils.data.dataloader.default_collate(batch)

    def collate_fn(self, batch):
        """Method that user passes as a functor to DataLoader.

        The method optionally performs neural type checking and adds types to the outputs.

        Please note, subclasses of Dataset should not implement `input_types`.

        Usage:
            dataloader = torch.utils.data.DataLoader(
                    ....,
                    collate_fn=dataset.collate_fn,
                    ....
            )

        Returns:
            Collated batch, with or without types.
        """
        return self._collate_fn(batch)


class VideoLatentFakeDataModule(pl.LightningDataModule):
    """A LightningDataModule for generating fake video latent data for training."""

    def __init__(
        self,
        model_config: DiTConfig,
        seq_length: int = 2048,
        micro_batch_size: int = 1,
        global_batch_size: int = 8,
        num_workers: int = 1,
        pin_memory: bool = True,
        task_encoder=None,
        use_train_split_for_val: bool = False,
    ) -> None:
        super().__init__()
        self.seq_length = seq_length
        self.micro_batch_size = micro_batch_size
        self.global_batch_size = global_batch_size
        self.num_workers = num_workers
        self.model_config = model_config

        self.data_sampler = MegatronDataSampler(
            seq_len=self.seq_length,
            micro_batch_size=micro_batch_size,
            global_batch_size=global_batch_size,
        )

    def setup(self, stage: str = "") -> None:
        """Sets up the dataset for training and validation.

        Parameters:
            stage (str): Optional stage argument (unused).
        """
        self._train_ds = DiTVideoLatentFakeDataset(
            n_frames=self.model_config.max_frames,
            max_h=self.model_config.max_img_h,
            max_w=self.model_config.max_img_w,
            patch_size=self.model_config.patch_spatial,
            in_channels=self.model_config.in_channels,
            crossattn_emb_size=self.model_config.crossattn_emb_size,
        )

    def train_dataloader(self) -> TRAIN_DATALOADERS:
        """Returns the training DataLoader."""
        if not hasattr(self, "_train_ds"):
            self.setup()
        return self._create_dataloader(self._train_ds)

    def val_dataloader(self) -> EVAL_DATALOADERS:
        """Returns the validation DataLoader."""
        if not hasattr(self, "_train_ds"):
            self.setup()
        return self._create_dataloader(self._train_ds)

    def _create_dataloader(self, dataset, **kwargs) -> DataLoader:
        """Creates a DataLoader for the given dataset.

        Parameters:
            dataset (Dataset): The dataset to load.
            **kwargs: Additional arguments for DataLoader.

        Returns:
            DataLoader: The DataLoader instance.
        """
        return DataLoader(
            dataset,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=True,
            collate_fn=dataset.collate_fn,
            **kwargs,
        )
