# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Blendable dataset."""

import os
import subprocess
import time

import numpy as np
import torch

from nemo.utils import logging
from nemo.utils.app_state import AppState


class BlendableDataset(torch.utils.data.Dataset):
    """ """

    def __init__(self, datasets, weights, size):
        self.datasets = datasets
        num_datasets = len(datasets)
        assert num_datasets == len(weights)

        self.size = size

        # Normalize weights.
        weights = np.array(weights, dtype=np.float64)
        sum_weights = np.sum(weights)
        assert sum_weights > 0.0
        weights /= sum_weights

        # Build indecies.
        start_time = time.time()
        assert num_datasets < 255
        self.dataset_index = np.zeros(self.size, dtype=np.uint8)
        self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)

        app_state = AppState()
        try:
            if app_state.local_rank == 0:
                compile_helper()
            torch.distributed.barrier()
            from nemo.collections.common.data import helpers
        except ImportError:
            raise ImportError(
                'Could not compile megatron dataset C++ helper functions and therefore '
                'cannot import helpers python file.'
            )

        helpers.build_blending_indices(
            self.dataset_index,
            self.dataset_sample_index,
            weights,
            num_datasets,
            self.size,
            torch.distributed.get_rank() == 0,
        )
        logging.info(
            '> elapsed time for building blendable dataset indices: ' '{:.2f} (sec)'.format(time.time() - start_time)
        )

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        dataset_idx = self.dataset_index[idx]
        sample_idx = self.dataset_sample_index[idx]
        dataset_size = len(self.datasets[dataset_idx])
        # Ensure the sample index doesn't exceed the dataset size
        if sample_idx >= dataset_size:
            logging.warning(f"Index {sample_idx} out of bounds for dataset {dataset_idx}. Reusing existing examples.")
            sample_idx = sample_idx % dataset_size
            logging.warning(f"Reusing index {sample_idx} for dataset {dataset_idx}.")

        return self.datasets[dataset_idx][sample_idx]

    def create_data_mmap(self):
        """ """
        for dataset in self.datasets:
            dataset.create_data_mmap()


class MemoryEfficientBlendableDataset(torch.utils.data.Dataset):
    """
    A BlendableDataset implementation that uses less memory than the original implementation.
    Indices are computed algorithmically instead of storing them in memory.

    To test call: MemoryEfficientBlendableDataset.test_index_blending()
    """

    def __init__(self, datasets, weights, size, weight_bins=100):
        self.datasets = datasets
        num_datasets = len(datasets)
        assert num_datasets == len(weights)

        weight_bins = min(weight_bins, size)

        self.size = size
        self.weight_bins = weight_bins

        # Normalize weights.
        weights = np.array(weights, dtype=np.float64)
        assert (weights > 0.0).all()
        sum_weights = np.sum(weights)
        assert sum_weights > 0.0
        self.weights = weights / sum_weights

        # create ds index based on weights
        ds_index = []
        ds_bias = []
        for i, w in enumerate(self.weights):
            n = int(w * weight_bins)
            ds_index.extend([i] * n)
            ds_bias.extend(range(n))
        # make sure arrays have length of weight_bins
        n = weight_bins - len(ds_index)
        ds_index.extend([i] * n)
        ds_bias.extend(range(ds_bias[-1], ds_bias[-1] + n))

        self.ds_index = np.array(ds_index, dtype=np.uint32)
        self.ds_index_size = np.array([(self.ds_index == i).sum() for i in range(num_datasets)], dtype=np.uint32)
        assert (self.ds_index_size > 0).all(), (
            "Some datasets have no samples in the blendable dataset, "
            "increase weight_bins or the offending weight. "
            f"ds_index_size = {self.ds_index_size}"
        )
        self.ds_bias = np.array(ds_bias, dtype=np.uint32)

        self.ds_size = np.array([len(ds) for ds in datasets], dtype=np.uint32)

    def get_ds_sample_idx(self, idx):
        """Returns ds index and sample index (within the ds) for the given index in the blendable dataset."""

        bin = idx % self.weight_bins
        ds_idx = self.ds_index[bin]
        sample_idx = (self.ds_bias[bin] + (idx // self.weight_bins) * self.ds_index_size[ds_idx]) % self.ds_size[
            ds_idx
        ]

        return ds_idx, sample_idx

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        ds_idx, sample_idx = self.get_ds_sample_idx(idx)

        return self.datasets[ds_idx][sample_idx]

    @classmethod
    def test_index_blending(cls):
        """Visualize indices of blended dataset"""

        import matplotlib.pyplot as plt

        plt.ion()

        class DS(torch.utils.data.Dataset):
            """ """

            def __init__(self, size, data):
                self.size = size
                self.data = data

            def __len__(self):
                return self.size

            def __getitem__(self, idx):
                return self.data[idx]

        for weight_bins in [10, 100]:
            blend_ds = MemoryEfficientBlendableDataset(
                [DS(10, "a"), DS(10, "b"), DS(10, "c")], [0.5, 0.3, 0.2], 50, weight_bins=weight_bins
            )

            ds_sample_idx_list = [blend_ds.get_ds_sample_idx(i) for i in range(50)]
            ds_list = list(zip(*ds_sample_idx_list))[0]
            sample_list = list(zip(*ds_sample_idx_list))[1]

            plt.figure()
            plt.plot(ds_list, label="ds idx")
            plt.plot(sample_list, label="sample")
            plt.legend()
            plt.grid()
            plt.title(f"weight_bins={weight_bins}")


def compile_helper():
    """Compile helper function ar runtime. Make sure this
    is invoked on a single process."""

    path = os.path.abspath(os.path.dirname(__file__))
    ret = subprocess.run(['make', '-C', path])
    if ret.returncode != 0:
        logging.error("Making C++ dataset helpers module failed, exiting.")
        import sys

        sys.exit(1)
