# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
from dataclasses import dataclass
from typing import Any, List, Optional

import torch
import torch.nn.functional as F
from einops import rearrange
from megatron.energon import DefaultTaskEncoder, Sample, SkipSample
from megatron.energon.task_encoder.base import stateless
from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys

from nemo.lightning.io.mixin import IOMixin
from nemo.utils.sequence_packing_utils import first_fit_decreasing


@dataclass
class DiffusionSample(Sample):
    """
    Data class representing a sample for diffusion tasks.

    Attributes:
        video (torch.Tensor): Video latents (C T H W).
        t5_text_embeddings (torch.Tensor): Text embeddings (S D).
        t5_text_mask (torch.Tensor): Mask for text embeddings.
        loss_mask (torch.Tensor): Mask indicating valid positions for loss computation.
        image_size (Optional[torch.Tensor]): Tensor containing image dimensions.
        fps (Optional[torch.Tensor]): Frame rate of the video.
        num_frames (Optional[torch.Tensor]): Number of frames in the video.
        padding_mask (Optional[torch.Tensor]): Mask indicating padding positions.
        seq_len_q (Optional[torch.Tensor]): Sequence length for query embeddings.
        seq_len_kv (Optional[torch.Tensor]): Sequence length for key/value embeddings.
        pos_ids (Optional[torch.Tensor]): Positional IDs.
        latent_shape (Optional[torch.Tensor]): Shape of the latent tensor.
    """

    video: torch.Tensor  # video latents (C T H W)
    t5_text_embeddings: torch.Tensor  # (S D)
    t5_text_mask: torch.Tensor  # 1
    loss_mask: torch.Tensor
    image_size: Optional[torch.Tensor] = None
    fps: Optional[torch.Tensor] = None
    num_frames: Optional[torch.Tensor] = None
    padding_mask: Optional[torch.Tensor] = None
    seq_len_q: Optional[torch.Tensor] = None
    seq_len_kv: Optional[torch.Tensor] = None
    pos_ids: Optional[torch.Tensor] = None
    latent_shape: Optional[torch.Tensor] = None

    def to_dict(self) -> dict:
        """Converts the sample to a dictionary."""
        return dict(
            video=self.video,
            t5_text_embeddings=self.t5_text_embeddings,
            t5_text_mask=self.t5_text_mask,
            loss_mask=self.loss_mask,
            image_size=self.image_size,
            fps=self.fps,
            num_frames=self.num_frames,
            padding_mask=self.padding_mask,
            seq_len_q=self.seq_len_q,
            seq_len_kv=self.seq_len_kv,
            pos_ids=self.pos_ids,
            latent_shape=self.latent_shape,
        )

    def __add__(self, other: Any) -> int:
        """Adds the sequence length of this sample with another sample or integer."""
        if isinstance(other, DiffusionSample):
            # Combine the values of the two instances
            return self.seq_len_q.item() + other.seq_len_q.item()
        elif isinstance(other, int):
            # Add an integer to the value
            return self.seq_len_q.item() + other
        raise NotImplementedError

    def __radd__(self, other: Any) -> int:
        """Handles reverse addition for summing with integers."""
        # This is called if sum or other operations start with a non-DiffusionSample object.
        # e.g., sum([DiffusionSample(1), DiffusionSample(2)]) -> the 0 + DiffusionSample(1) calls __radd__.
        if isinstance(other, int):
            return self.seq_len_q.item() + other
        raise NotImplementedError

    def __lt__(self, other: Any) -> bool:
        """Compares this sample's sequence length with another sample or integer."""
        if isinstance(other, DiffusionSample):
            return self.seq_len_q.item() < other.seq_len_q.item()
        elif isinstance(other, int):
            return self.seq_len_q.item() < other
        raise NotImplementedError


def cook(sample: dict) -> dict:
    """
    Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys.

    Args:
        sample (dict): The input dictionary containing the raw sample data.

    Returns:
        dict: A new dictionary containing the processed sample data with the following keys:
            - All keys from the result of `basic_sample_keys(sample)`
            - 'json': The contains meta data like resolution, aspect ratio, fps, etc.
            - 'pth': contains video latent tensor
            - 'pickle': contains text embeddings
    """
    return dict(
        **basic_sample_keys(sample),
        json=sample['.json'],
        pth=sample['.pth'],
        pickle=sample['.pickle'],
    )


class BasicDiffusionTaskEncoder(DefaultTaskEncoder, IOMixin):
    """
    BasicDiffusionTaskEncoder is a class that encodes image/video samples for diffusion tasks.
    Attributes:
        cookers (list): A list of Cooker objects used for processing.
        max_frames (int, optional): The maximum number of frames to consider from the video. Defaults to None.
        text_embedding_padding_size (int): The padding size for text embeddings. Defaults to 512.
    Methods:
        __init__(*args, max_frames=None, text_embedding_padding_size=512, **kwargs):
            Initializes the BasicDiffusionTaskEncoder with optional maximum frames and text embedding padding size.
        encode_sample(sample: dict) -> dict:
            Encodes a given sample dictionary containing video and text data.
            Args:
                sample (dict): A dictionary containing 'pth' for video latent and 'json' for additional info.
            Returns:
                dict: A dictionary containing encoded video, text embeddings, text mask, and loss mask.
            Raises:
                SkipSample: If the video latent contains NaNs, Infs, or is not divisible by the tensor parallel size.
    """

    cookers = [
        Cooker(cook),
    ]

    def __init__(
        self,
        *args,
        max_frames: int = None,
        text_embedding_padding_size: int = 512,
        seq_length: int = None,
        max_seq_length: int = None,
        patch_spatial: int = 2,
        patch_temporal: int = 1,
        aesthetic_score: float = 0.0,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.max_frames = max_frames
        self.text_embedding_padding_size = text_embedding_padding_size
        self.seq_length = seq_length
        self.max_seq_length = max_seq_length
        self.patch_spatial = patch_spatial
        self.patch_temporal = patch_temporal
        self.aesthetic_score = aesthetic_score

    @stateless(restore_seeds=True)
    def encode_sample(self, sample: dict) -> dict:
        """
        Encodes video / text sample.
        """
        video_latent = sample['pth']

        if torch.isnan(video_latent).any() or torch.isinf(video_latent).any():
            raise SkipSample()
        if torch.max(torch.abs(video_latent)) > 1e3:
            raise SkipSample()

        info = sample['json']
        if info['aesthetic_score'] < self.aesthetic_score:
            raise SkipSample()

        C, T, H, W = video_latent.shape
        seq_len = (
            video_latent.shape[-1]
            * video_latent.shape[-2]
            * video_latent.shape[-3]
            // self.patch_spatial**2
            // self.patch_temporal
        )
        is_image = T == 1

        if self.seq_length is not None and seq_len > self.seq_length:
            raise SkipSample()
        if self.max_seq_length is not None and seq_len > self.max_seq_length:
            raise SkipSample()

        if self.max_frames is not None:
            video_latent = video_latent[:, : self.max_frames, :, :]

        video_latent = rearrange(
            video_latent,
            'C (T pt) (H ph) (W pw) -> (T H W) (ph pw pt C)',
            ph=self.patch_spatial,
            pw=self.patch_spatial,
            pt=self.patch_temporal,
        )

        if is_image:
            t5_text_embeddings = torch.from_numpy(sample['pickle']).to(torch.bfloat16)
        else:
            t5_text_embeddings = torch.from_numpy(sample['pickle'][0]).to(torch.bfloat16)
        t5_text_embeddings_seq_length = t5_text_embeddings.shape[0]

        if t5_text_embeddings_seq_length > self.text_embedding_padding_size:
            t5_text_embeddings = t5_text_embeddings[: self.text_embedding_padding_size]
        else:
            t5_text_embeddings = F.pad(
                t5_text_embeddings,
                (
                    0,
                    0,
                    0,
                    self.text_embedding_padding_size - t5_text_embeddings_seq_length,
                ),
            )
        t5_text_mask = torch.ones(t5_text_embeddings_seq_length, dtype=torch.bfloat16)

        if is_image:
            h, w = info['image_height'], info['image_width']
            fps = torch.tensor([30] * 1, dtype=torch.bfloat16)
            num_frames = torch.tensor([1] * 1, dtype=torch.bfloat16)
        else:
            h, w = info['height'], info['width']
            fps = torch.tensor([info['framerate']] * 1, dtype=torch.bfloat16)
            num_frames = torch.tensor([info['num_frames']] * 1, dtype=torch.bfloat16)
        image_size = torch.tensor([[h, w, h, w]] * 1, dtype=torch.bfloat16)

        pos_ids = rearrange(
            pos_id_3d.get_pos_id_3d(t=T // self.patch_temporal, h=H // self.patch_spatial, w=W // self.patch_spatial),
            'T H W d -> (T H W) d',
        )

        if self.seq_length is not None and self.max_seq_length is None:
            pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len))
            loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16)
            loss_mask[:seq_len] = 1
            video_latent = F.pad(video_latent, (0, 0, 0, self.seq_length - seq_len))
        else:
            loss_mask = torch.ones(seq_len, dtype=torch.bfloat16)

        return DiffusionSample(
            __key__=sample['__key__'],
            __restore_key__=sample['__restore_key__'],
            __subflavor__=None,
            __subflavors__=sample['__subflavors__'],
            video=video_latent,
            t5_text_embeddings=t5_text_embeddings,
            t5_text_mask=t5_text_mask,
            image_size=image_size,
            fps=fps,
            num_frames=num_frames,
            loss_mask=loss_mask,
            seq_len_q=torch.tensor(seq_len, dtype=torch.int32),
            seq_len_kv=torch.tensor(self.text_embedding_padding_size, dtype=torch.int32),
            pos_ids=pos_ids,
            latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32),
        )

    def select_samples_to_pack(self, samples: List[DiffusionSample]) -> List[List[DiffusionSample]]:
        """
        Selects sequences to pack for mixed image-video training.
        """
        results = first_fit_decreasing(samples, self.max_seq_length)
        random.shuffle(results)
        return results

    @stateless
    def pack_selected_samples(self, samples: List[DiffusionSample]) -> DiffusionSample:
        """Construct a new Diffusion sample by concatenating the sequences."""

        def stack(attr):
            return torch.stack([getattr(sample, attr) for sample in samples], dim=0)

        def cat(attr):
            return torch.cat([getattr(sample, attr) for sample in samples], dim=0)

        video = concat_pad([i.video for i in samples], self.max_seq_length)
        loss_mask = concat_pad([i.loss_mask for i in samples], self.max_seq_length)
        pos_ids = concat_pad([i.pos_ids for i in samples], self.max_seq_length)

        return DiffusionSample(
            __key__=",".join([s.__key__ for s in samples]),
            __restore_key__=(),  # Will be set by energon based on `samples`
            __subflavor__=None,
            __subflavors__=samples[0].__subflavors__,
            video=video,
            t5_text_embeddings=cat('t5_text_embeddings'),
            t5_text_mask=cat('t5_text_mask'),
            # image_size=stack('image_size'),
            # fps=stack('fps'),
            # num_frames=stack('num_frames'),
            loss_mask=loss_mask,
            seq_len_q=stack('seq_len_q'),
            seq_len_kv=stack('seq_len_kv'),
            pos_ids=pos_ids,
            latent_shape=stack('latent_shape'),
        )

    @stateless
    def batch(self, samples: List[DiffusionSample]) -> dict:
        """Return dictionary with data for batch."""
        if self.max_seq_length is None:
            # no packing
            return super().batch(samples).to_dict()

        # packing
        sample = samples[0]
        return dict(
            video=sample.video.unsqueeze_(0),
            t5_text_embeddings=sample.t5_text_embeddings.unsqueeze_(0),
            t5_text_mask=sample.t5_text_mask.unsqueeze_(0),
            loss_mask=sample.loss_mask.unsqueeze_(0),
            # image_size=sample.image_size,
            # fps=sample.fps,
            # num_frames=sample.num_frames,
            # padding_mask=sample.padding_mask.unsqueeze_(0),
            seq_len_q=sample.seq_len_q,
            seq_len_kv=sample.seq_len_kv,
            pos_ids=sample.pos_ids.unsqueeze_(0),
            latent_shape=sample.latent_shape,
        )


class PosID3D:
    """
    Generates 3D positional IDs for video data.

    Attributes:
        max_t (int): Maximum number of time frames.
        max_h (int): Maximum height dimension.
        max_w (int): Maximum width dimension.
    """

    def __init__(self, *, max_t=32, max_h=128, max_w=128):
        self.max_t = max_t
        self.max_h = max_h
        self.max_w = max_w
        self.generate_pos_id()

    def generate_pos_id(self):
        """Generates a grid of positional IDs based on max_t, max_h, and max_w."""
        self.grid = torch.stack(
            torch.meshgrid(
                torch.arange(self.max_t, device='cpu'),
                torch.arange(self.max_h, device='cpu'),
                torch.arange(self.max_w, device='cpu'),
            ),
            dim=-1,
        )

    def get_pos_id_3d(self, *, t, h, w):
        """Retrieves positional IDs for specified dimensions."""
        if t > self.max_t or h > self.max_h or w > self.max_w:
            self.max_t = max(self.max_t, t)
            self.max_h = max(self.max_h, h)
            self.max_w = max(self.max_w, w)
            self.generate_pos_id()
        return self.grid[:t, :h, :w]


def pad_divisible(x, padding_value=0):
    """
    Pads the input tensor to make its size divisible by a specified value.

    Args:
        x (torch.Tensor): Input tensor.
        padding_value (int): The value to make the tensor size divisible by.

    Returns:
        torch.Tensor: Padded tensor.
    """
    if padding_value == 0:
        return x
    # Get the size of the first dimension
    n = x.size(0)

    # Compute the padding needed to make the first dimension divisible by 16
    padding_needed = (padding_value - n % padding_value) % padding_value

    if padding_needed <= 0:
        return x

    # Create a new shape with the padded first dimension
    new_shape = list(x.shape)
    new_shape[0] += padding_needed

    # Create a new tensor filled with zeros
    x_padded = torch.zeros(new_shape, dtype=x.dtype, device=x.device)

    # Assign the original tensor to the beginning of the new tensor
    x_padded[:n] = x
    return x_padded


def concat_pad(tensor_list, max_seq_length):
    """
    Efficiently concatenates a list of tensors along the first dimension and pads with zeros
    to reach max_seq_length.

    Args:
        tensor_list (list of torch.Tensor): List of tensors to concatenate and pad.
        max_seq_length (int): The desired size of the first dimension of the output tensor.

    Returns:
        torch.Tensor: A tensor of shape [max_seq_length, ...], where ... represents the remaining dimensions.
    """
    import torch

    # Get common properties from the first tensor
    other_shape = tensor_list[0].shape[1:]
    dtype = tensor_list[0].dtype
    device = tensor_list[0].device

    # Initialize the result tensor with zeros
    result = torch.zeros((max_seq_length, *other_shape), dtype=dtype, device=device)

    current_index = 0
    for tensor in tensor_list:
        length = tensor.shape[0]
        # Directly assign the tensor to the result tensor without checks
        result[current_index : current_index + length] = tensor
        current_index += length

    return result


pos_id_3d = PosID3D()


def cook_raw_images(sample: dict) -> dict:
    """
    Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys.

    Args:
        sample (dict): The input dictionary containing the raw sample data.

    Returns:
        dict: A new dictionary containing the processed sample data with the following keys:
            - All keys from the result of `basic_sample_keys(sample)`
            - 'jpg': original images
            - 'png': contains control images
            - 'txt': contains raw text
    """
    return dict(
        **basic_sample_keys(sample),
        images=sample['jpg'],
        hint=sample['png'],
        txt=sample['txt'],
    )


class RawImageDiffusionTaskEncoder(DefaultTaskEncoder, IOMixin):
    '''
    Dummy task encoder takes raw image input on CrudeDataset.
    '''

    cookers = [
        # Cooker(cook),
        Cooker(cook_raw_images),
    ]
