# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Dataloaders."""

import abc
import warnings
from itertools import chain
from typing import Optional, Tuple

import torch

from nemo.utils import logging
from nemo.utils.decorators import experimental

__all__ = [
    "MegatronPretrainingBatchSampler",
    "MegatronPretrainingRandomBatchSampler",
]


class BaseMegatronSampler:
    """ """

    def __init__(
        self,
        total_samples: int,
        consumed_samples: int,
        micro_batch_size: int,
        data_parallel_rank: int,
        data_parallel_size: int,
        drop_last: bool = True,
        global_batch_size: Optional[int] = None,
        rampup_batch_size: Optional[list] = None,
        pad_samples_to_global_batch_size: Optional[bool] = False,
    ) -> None:
        # Sanity checks.
        if total_samples <= 0:
            raise RuntimeError("no sample to consume: {}".format(total_samples))
        if micro_batch_size <= 0:
            raise RuntimeError(f"micro_batch_size size must be greater than 0, but {micro_batch_size}")
        if data_parallel_size <= 0:
            raise RuntimeError(f"data parallel size must be greater than 0, but {data_parallel_size}")
        if data_parallel_rank >= data_parallel_size:
            raise RuntimeError(
                "data_parallel_rank should be smaller than data size, but {} >= {}".format(
                    data_parallel_rank, data_parallel_size
                )
            )
        if global_batch_size is not None and rampup_batch_size is None:
            if global_batch_size % (micro_batch_size * data_parallel_size) != 0:
                raise RuntimeError(
                    f"`global_batch_size` ({global_batch_size}) is not divisible by "
                    f"`micro_batch_size ({micro_batch_size}) x data_parallel_size "
                    f"({data_parallel_size})`"
                )
        if pad_samples_to_global_batch_size and global_batch_size is None:
            raise RuntimeError(
                "`pad_samples_to_global_batch_size` can be `True` only when "
                "`global_batch_size` is set to an integer value"
            )

        # Keep a copy of input params for later use.
        self.total_samples = total_samples
        self.consumed_samples = consumed_samples
        self.micro_batch_size = micro_batch_size
        self.data_parallel_rank = data_parallel_rank
        self.data_parallel_size = data_parallel_size
        self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
        self.drop_last = drop_last
        self.global_batch_size = global_batch_size
        self.pad_samples_to_global_batch_size = pad_samples_to_global_batch_size

        logging.info(
            f'Instantiating MegatronPretrainingSampler with total_samples: {total_samples} '
            f'and consumed_samples: {consumed_samples}'
        )

    def __len__(self):
        num_available_samples: int = self.total_samples - self.consumed_samples
        if self.global_batch_size is not None:
            if self.drop_last:
                num_global_batches = num_available_samples // self.global_batch_size
            else:
                num_global_batches = (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
            # return len of dataloader in terms of micro batches to avoid discrepancy between len of dataloader and
            # num of batches fetched (as training step fetches in terms of micro batches)
            return num_global_batches * (self.global_batch_size // self.micro_batch_times_data_parallel_size)
        else:
            return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size + 1

    @abc.abstractmethod
    def __iter__(self): ...


class MegatronPretrainingSampler(BaseMegatronSampler):
    """ """

    def get_start_end_idx(self):
        """ """
        start_idx = self.data_parallel_rank * self.micro_batch_size
        end_idx = start_idx + self.micro_batch_size
        return start_idx, end_idx

    def _get_padding_indices(self, pad_samples_num):
        """ """
        return range(-1, -pad_samples_num - 1, -1)

    def __iter__(self):
        batch = []
        # Last batch will be dropped if drop_last is not set False
        indices = range(self.consumed_samples, self.total_samples)
        if (not self.drop_last) and self.pad_samples_to_global_batch_size:
            pad_samples_num = -len(indices) % self.global_batch_size
            pad_indices = self._get_padding_indices(pad_samples_num)
            indices = chain(indices, pad_indices)

        for idx in indices:
            batch.append(idx)
            if len(batch) == self.micro_batch_times_data_parallel_size:
                start_idx, end_idx = self.get_start_end_idx()
                yield batch[start_idx:end_idx]
                batch = []

        # Check the last partial batch and see drop_last is set
        if len(batch) > 0 and not self.drop_last:
            assert (
                not self.pad_samples_to_global_batch_size
            ), 'with pad_samples_to_global_batch_size all batches should be complete'
            start_idx, end_idx = self.get_start_end_idx()
            yield batch[start_idx:end_idx]


class MegatronCorePretrainingSampler(MegatronPretrainingSampler):
    """ """

    def _get_padding_indices(self, pad_samples_num):
        """ """
        return [None] * pad_samples_num


class MegatronPretrainingRandomSampler(BaseMegatronSampler):
    """ """

    def __init__(
        self,
        total_samples: int,
        consumed_samples: int,
        micro_batch_size: int,
        data_parallel_rank: int,
        data_parallel_size: int,
        drop_last: bool = True,
        global_batch_size: Optional[int] = None,
        pad_samples_to_global_batch_size: Optional[bool] = False,
        seed: int = 0,
    ) -> None:
        super().__init__(
            total_samples=total_samples,
            consumed_samples=consumed_samples,
            micro_batch_size=micro_batch_size,
            data_parallel_rank=data_parallel_rank,
            data_parallel_size=data_parallel_size,
            drop_last=drop_last,
            global_batch_size=global_batch_size,
            pad_samples_to_global_batch_size=pad_samples_to_global_batch_size,
        )
        assert (
            not pad_samples_to_global_batch_size
        ), "`MegatronPretrainingRandomSampler` does not support sample padding"
        if (not drop_last) and self.micro_batch_times_data_parallel_size > 1:
            raise RuntimeError(
                "`MegatronPretrainingRandomSampler` does not support drop_last=False when \
                  micro_batch_size * data_parallel_size > 1. Please reduce your MBS and data parallelism to 1 \
                  if you want to use drop_last=False, or switch to drop_last=True to avoid this error"
            )
        self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size
        self.seed = seed

    def __len__(self):
        active_total_samples = self.total_samples - (self.last_batch_size if self.drop_last else 0)
        num_available_samples = active_total_samples - self.consumed_samples % active_total_samples
        if self.global_batch_size is not None:
            if self.drop_last:
                num_global_batches = num_available_samples // self.global_batch_size
            else:
                num_global_batches = (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
            # return len of dataloader in terms of micro batches to avoid discrepancy between len of dataloader and
            # num of batches fetched (as training step fetches in terms of micro batches)
            return num_global_batches * (self.global_batch_size // self.micro_batch_times_data_parallel_size)
        else:
            if self.drop_last:
                return num_available_samples // self.micro_batch_times_data_parallel_size
            else:
                return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size

    def __iter__(self):
        active_total_samples = self.total_samples - self.last_batch_size
        self.epoch = self.consumed_samples // active_total_samples
        current_epoch_samples = self.consumed_samples % active_total_samples
        assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0

        # data sharding and random sampling
        bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size
        bucket_offset = current_epoch_samples // self.data_parallel_size
        start_idx = self.data_parallel_rank * bucket_size

        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)
        random_idx = torch.randperm(bucket_size, generator=g).tolist()
        idx_range = [start_idx + x for x in random_idx[bucket_offset:]]

        batch = []
        # Last batch if not complete will be dropped.
        for idx in idx_range:
            batch.append(idx)
            if len(batch) == self.micro_batch_size:
                self.consumed_samples += self.micro_batch_times_data_parallel_size
                yield batch
                batch = []

        # Check the last partial batch and see drop_last is set
        if len(batch) > 0 and not self.drop_last:
            yield batch


class BaseMegatronBatchSampler:
    """Megatron style BatchSampler.

    Let mbs, gbs, tp, pp, and dp stand for "micro batch size", "global batch size",
    "tensor model parallel world size", "pipeline model parallel world size", and
    "data parallel world size", the number of micro batches (hereafter, nmb) is defined as
    :math:`nmb = gbs \\div (mbs \\times dp)`.

    See `apex/transformer/microbatches.py#L91-L98 <https://github.com/NVIDIA/apex/blob/
    44c3043685b6115e7b81b3458a6c76601b1e55b4/apex/transformer/microbatches.py#L91-L98>`_
    for the initial settings of the number of micro batches and
    `apex/transformer/microbatches.py#L160-L177 <https://github.com/NVIDIA/apex/blob/
    44c3043685b6115e7b81b3458a6c76601b1e55b4/apex/transformer/microbatches.py#L160-L177>_`.
    for warming up of global batch size.

    e.g.) `(mbs, gbs, tp, pp, dp) = (1, 16, 1, 1, 2)`, then the number of micro batches is
    :math:`gbs \\div (mbs \\times dp) = 16 \\div (1 \\times 2) = 8`.
    In this case, an instance of Megatron Batch Sampler on each data parallel rank is expected
    returns :math:`nmb \\times mbs = 8` indices.
    """

    _global_batch_size: int
    _num_micro_batches: int
    _global_batch_size_on_this_data_parallel_rank: int

    def __init__(
        self,
        total_samples: int,
        consumed_samples: int,
        micro_batch_size: int,
        global_batch_size: int,
        data_parallel_rank: int,
        data_parallel_size: int,
        drop_last: bool,
        pad_samples_to_global_batch_size=False,
    ) -> None:
        """Constructor of Megatron-LM style Batch Sampler.

        Args:
            total_samples: The size of dataset.
            consumed_samples: The number of samples that have been used.
            micro_batch_size: The size of each micro batch.
            global_batch_size: The size of global batch.
            data_parallel_rank: The value you can obtain via
                `parallel_state.get_data_parallel_rank()` of megatron.core.
            data_parallel_size: The value you can obtain via
                `parallel_state.get_data_parallel_world_size()` of megatron.core.
        """
        # Sanity checks.
        if total_samples <= 0:
            raise RuntimeError("no sample to consume: {}".format(total_samples))
        if micro_batch_size <= 0:
            raise RuntimeError(f"micro_batch_size size must be greater than 0, but {micro_batch_size}")
        if data_parallel_size <= 0:
            raise RuntimeError(f"data parallel size must be greater than 0, but {data_parallel_size}")
        if data_parallel_rank >= data_parallel_size:
            raise RuntimeError(
                "data_parallel_rank should be smaller than data size, but {} >= {}".format(
                    data_parallel_rank, data_parallel_size
                )
            )
        # Keep a copy of input params for later use.
        self.total_samples: int = total_samples
        self.consumed_samples: int = consumed_samples
        self.micro_batch_size: int = micro_batch_size
        self.data_parallel_rank: int = data_parallel_rank
        self.data_parallel_size: int = data_parallel_size
        self.drop_last: bool = drop_last
        self.pad_samples_to_global_batch_size = pad_samples_to_global_batch_size
        self.micro_batch_times_data_parallel_size = self.micro_batch_size * self.data_parallel_size

        self.update_global_batch_size(global_batch_size)

    def update_global_batch_size(self, new_global_batch_size: int) -> None:
        """Update the global batch size."""
        self._global_batch_size = new_global_batch_size
        if self._global_batch_size % self.micro_batch_times_data_parallel_size != 0:
            raise RuntimeError(
                f"`global_batch_size` ({self._global_batch_size}) is not divisible by "
                f"`micro_batch_size ({self.micro_batch_size}) x data_parallel_size "
                f"({self.data_parallel_size})`"
            )
        self._num_micro_batches = self._global_batch_size // self.micro_batch_times_data_parallel_size
        self._global_batch_size_on_this_data_parallel_rank = self._num_micro_batches * self.micro_batch_size

    @property
    def global_batch_size(self) -> int:
        """ """
        return self._global_batch_size

    @global_batch_size.setter
    def global_batch_size(self, new_global_batch_size: int) -> None:
        """ """
        warnings.warn("`self.update_global_batch_size(new_global_batch_size)` is recommended.")
        self.update_global_batch_size(new_global_batch_size=new_global_batch_size)

    def __len__(self) -> int:
        """Length of Batch Sampler.

        ..note::
            When `rampup_batch_size` is enabled, the return value can be not exactly precise.

        """
        num_available_samples: int = self.total_samples - self.consumed_samples % self.total_samples
        if self.drop_last:
            return num_available_samples // self.global_batch_size
        else:
            return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size

    @abc.abstractmethod
    def __iter__(self): ...


class MegatronPretrainingBatchSampler(BaseMegatronBatchSampler):
    """ """

    def get_start_end_idx(self) -> Tuple[int, int]:
        """ """
        start_idx = self.data_parallel_rank * self._global_batch_size_on_this_data_parallel_rank
        end_idx = start_idx + self._global_batch_size_on_this_data_parallel_rank
        return start_idx, end_idx

    def __iter__(self):
        batch = []
        # Last batch will be dropped if drop_last is not set False
        for idx in range(self.consumed_samples % self.total_samples, self.total_samples):
            batch.append(idx)
            if len(batch) == self._global_batch_size:
                # start_idx, end_idx = self.get_start_end_idx()
                indices = [
                    batch[i]
                    for i in range(
                        self.data_parallel_rank,
                        self._global_batch_size,
                        self.data_parallel_size,
                    )
                ]
                assert len(indices) == self._global_batch_size_on_this_data_parallel_rank
                yield indices
                # yield batch[start_idx:end_idx]
                batch = []

        # Check the last partial batch and see drop_last is set
        if len(batch) > 0 and not self.drop_last:
            # start_idx, end_idx = self.get_start_end_idx()
            indices = [batch[i] for i in range(self.data_parallel_rank, len(batch), self.data_parallel_size)]
            if self.pad_samples_to_global_batch_size:
                num_pad = self._global_batch_size // self.data_parallel_size - len(indices)
                indices = indices + [-1] * num_pad
            yield indices


@experimental
class MegatronPretrainingRandomBatchSampler(BaseMegatronBatchSampler):
    """ """

    # NOTE (mkozuki): [[Argument of `dataset` and `data_sharding`]]
    # From the commit below, it seems like `dataset` argument and `data_sharding` argument
    # are necessary for ViT training. However, to keep this simple,
    # I omit those two arguments.
    # commit: https://github.com/NVIDIA/Megatron-LM/commit/7a77abd9b6267dc0020a60b424b4748fc22790bb
    #
    # NOTE (degert): I have re-written this class somewhat to give the length correctly when consumed_samples
    # are larger than total_samples, which happens with epochs > 1 training when using this Sampler
    # I have also added an explicit seed which allows us to remove Dataset-side shuffling in Nemo-Aligner
    #
    # This class does not currently work with pad_samples_to_global_batch_size=True
    def __init__(
        self,
        total_samples: int,
        consumed_samples: int,
        micro_batch_size: int,
        global_batch_size: int,
        data_parallel_rank: int,
        data_parallel_size: int,
        drop_last: bool,
        pad_samples_to_global_batch_size: bool = False,
        seed: int = 0,
    ) -> None:
        super().__init__(
            total_samples=total_samples,
            consumed_samples=consumed_samples,
            micro_batch_size=micro_batch_size,
            data_parallel_rank=data_parallel_rank,
            data_parallel_size=data_parallel_size,
            drop_last=drop_last,
            global_batch_size=global_batch_size,
            pad_samples_to_global_batch_size=pad_samples_to_global_batch_size,
        )
        assert (
            not pad_samples_to_global_batch_size
        ), "`MegatronPretrainingRandomBatchSampler` does not support sample padding"
        if (not drop_last) and self.micro_batch_times_data_parallel_size > 1:
            raise RuntimeError(
                "`MegatronPretrainingRandomBatchSampler` does not support drop_last=False \
                  when micro_batch_size * data_parallel_size > 1. Please reduce your MBS and data parallelism to 1 \
                  if you want to use drop_last=False, or switch to drop_last=True to avoid this error"
            )
        self.last_batch_size = self.total_samples % self._global_batch_size
        self.seed = seed

    def __len__(self) -> int:
        """Length of Random Batch Sampler.

        ..note::
            When `rampup_batch_size` is enabled, the return value can be not exactly precise.

        """
        active_total_samples = self.total_samples - (self.last_batch_size if self.drop_last else 0)
        num_available_samples = active_total_samples - self.consumed_samples % active_total_samples
        if self.drop_last:
            return num_available_samples // self.global_batch_size
        else:
            return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size

    def __iter__(self):
        active_total_samples = self.total_samples - self.last_batch_size
        self.epoch = self.consumed_samples // active_total_samples
        current_epoch_samples = self.consumed_samples % active_total_samples
        assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0

        # data sharding and random sampling
        bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size
        bucket_offset = current_epoch_samples // self.data_parallel_size
        start_idx = self.data_parallel_rank * bucket_size

        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)
        random_idx = torch.randperm(bucket_size, generator=g).tolist()
        idx_range = [start_idx + x for x in random_idx[bucket_offset:]]

        batch = []
        # Last batch if not complete will be dropped.
        for idx in idx_range:
            batch.append(idx)
            if len(batch) == self._global_batch_size_on_this_data_parallel_rank:
                self.consumed_samples += self._global_batch_size
                yield batch
                batch = []
        # Check the last partial batch and see drop_last is set
        if len(batch) > 0 and not self.drop_last:
            yield batch
