o
    }oirP                     @   s   d dl mZ d dlmZmZmZmZ d dlZd dl	m
Z d dlmZmZ d dlmZ d dlmZmZmZ d dlmZ d dlmZ d d	lmZ d d
lmZ d dlmZmZm Z  d dl!m"Z" d dl#m$Z$ G dd dej%eZ&G dd de"Z'dS )    deepcopy)AnyDictLiteralOptionalN)EVAL_DATALOADERSTRAIN_DATALOADERS)parallel_state)WorkerConfigget_savable_loaderget_train_dataset)
DataLoader)Self)MultiModalSampleConfig)MultiModalTaskEncoder)IOMixinserializationtrack_io)MegatronDataSampler)loggingc                       s  e Zd ZdZdddddddde ddddfdeded	ed
edededB dedededB dee dee	 dee dee dee	 ddf fddZ
deje fddZd*ded fddZdefdd Zdefd!d"Zd+d#d$Zdeeef fd%d&Zd'eeef ddfd(d)Z  ZS ),EnergonMultiModalDataModulea  
    A PyTorch Lightning DataModule for handling multimodal datasets with images and text.

    This data module is designed to work with multimodal datasets that involve both images and text.
    It provides a seamless interface to load training and validation data, manage batching, and handle
    the state of the data pipeline across training epochs. The module integrates with the Megatron-Energon
    framework for efficient data handling in large-scale distributed training.

    Attributes:
    path (str): Path to the energon dataset.
    tokenizer (Tokenizer): The tokenizer used for processing text.
    image_processor (ImageProcessor): The image processor used for preprocessing images.
    seq_length (int): The maximum sequence length for tokenized text.
    micro_batch_size (int): The batch size for training and validation.
    num_workers (int): Number of workers for data loading.
    pin_memory (bool): Whether to pin memory in the DataLoader.
    multimodal_sample_config (MultiModalSampleConfig): Configuration object for multimodal samples.
    task_encoder (MultiModalTaskEncoder): Encoder responsible for encoding and batching samples.
    init_global_step (int): The initial global step for the trainer, used for resuming training.
    data_sampler (SequentialMegatronSampler): Sampler responsible for generating sequential samples.
    train_dataloader_object (Optional): The DataLoader object for training data.
    val_dataloader_object (Optional): The DataLoader object for validation data.
    i      NTd   path
seq_lengthmicro_batch_sizeglobal_batch_sizenum_workersnum_val_workers
pin_memoryshuffle_buffer_sizemax_samples_per_sequencemultimodal_sample_configtask_encoderdecoder_seq_lengthpacking_buffer_sizevalidation_task_encoderreturnc                    s   t    || _|| _|| _|| _|| _|| _|| _|| _	|	| _
|| _|
| _|| _|p3t| j| j|d| _d| _t| j| j| j| jd| _d| _d| _|| _|pR| j| _|pX| j	| _|| _dS )a  
        Initialize the EnergonMultiModalDataModule.

        Parameters:
        path (str): Path to the dataset.
        tokenizer (Tokenizer): The tokenizer used for processing text.
        image_processor (ImageProcessor): The image processor used for preprocessing images.
        seq_length (int, optional): The maximum sequence length for tokenized text. Defaults to 2048.
        micro_batch_size (int, optional): The batch size for training and validation. Defaults to 1.
        num_workers (int, optional): Number of workers for data loading. Defaults to 1.
        num_val_workers (int, optional): Number of workers for validation data loading. Defaults to num_workers.
        pin_memory (bool, optional): Whether to pin memory in the DataLoader. Defaults to True.
        multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples.
        Defaults to MultiModalSampleConfig().
        shuffle_buffer_size (int, optional): Size of the shuffle buffer. Defaults to 100.
        max_samples_per_sequence (int, optional): Maximum number of samples per sequence to load from memory.
        Defaults to None (loads the whole tar file at once).
        task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding and batching samples.
        If not provided, a default (MultimodalTaskEncoder) encoder will be created. Defaults to None.
        decoder_seq_length (int, optional): The max sequence length for the decoder. Used in encoder-decoder models
        packing_buffer_size (int, optional): Size of the packing buffer for batched samples. Defaults to None.
        validation_task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding
        and batching samples for validation. Defaults to None and will be the same as task_encoder.
        **kwargs: Additional keyword arguments. Will be passed to get_train_dataset() of Energon
        )	tokenizerimage_processorr#   r   )seq_lendecoder_seq_lenr   r   N)super__init__r   r)   r*   r   r%   r   r   r   r    r#   r!   r"   r   r$   init_global_stepSequentialMegatronSamplerdata_samplertrain_dataloader_objectval_dataloader_objectr&   r'   r   kwargs)selfr   r)   r*   r   r   r   r   r   r    r!   r"   r#   r$   r%   r&   r'   r4   	__class__ a/home/ubuntu/.local/lib/python3.10/site-packages/nemo/collections/multimodal/data/energon/base.pyr.   :   s>   
.
z$EnergonMultiModalDataModule.__init__c                 K   sT   dd |  D }| D ]}tt|stt| qtjt| fi |}|S )Nc                 S   s"   i | ]\}}|d vr|t |qS ))r*   r$   r'   r   ).0kvr8   r8   r9   
<dictcomp>   s
    z7EnergonMultiModalDataModule.io_init.<locals>.<dictcomp>)itemsvaluesr   find_node_traversertyper   fdlConfig)r5   r4   
cfg_kwargsvalcfgr8   r8   r9   io_init   s   z#EnergonMultiModalDataModule.io_initrE   split)trainrE   c              
   C   sV   |dvrt d|dkr| j}n| j}t| jf| j||| j|| j| jd| j	}|S )a  
        Provide the dataset for training or validation.

        This method retrieves the dataset for the specified split (either 'train' or 'val') and configures
        it according to the worker configuration.

        Parameters:
        worker_config: Configuration for the data loader workers.
        split (Literal['train', 'val'], optional): The data split to retrieve ('train' or 'val'). Defaults to 'val'.

        Returns:
        Dataset: The dataset configured for the specified split.
        >   rE   rI   z=Invalid value for split. Allowed values are 'train' or 'val'.rI   )
batch_sizer$   worker_configr&   
split_partr!   r"   )

ValueErrorr$   r'   r   r   r   r&   r!   r"   r4   )r5   rK   rH   r$   _datasetr8   r8   r9   datasets_provider   s&   	z-EnergonMultiModalDataModule.datasets_providerc              	   C   s   | j r| j j| _| j| j_td| j  | jr| jS t s0td| j	  t
| j	}n&t }t }t }td| d| d| d t
||| j	|ddd	}| j|d
d}t||d}|| _| jS )a  
        Initialize and return the training DataLoader.

        This method initializes the DataLoader for the training dataset. It uses the global step
        from the trainer to configure the data sampler and ensures that the parallel state is initialized
        correctly for distributed training.

        Returns:
        TRAIN_DATALOADERS: The DataLoader for the training dataset.
        z?Multimodal train dataloader initializing with init_global_step zfMuiltimodal data loader parallel state is not initialized,using default worker config with no_workers z4 Multimodal  train dataloader initializing withrank  world_size  data_parallel_group z ****** Nr   rank
world_sizer   data_parallel_groupworker_debug_pathworker_log_levelrI   rH   rK   )trainerglobal_stepr/   r1   r   infor2   r
   is_initializedr   r   default_worker_configget_data_parallel_rankget_data_parallel_world_sizeget_data_parallel_grouprO   r   )r5   rK   rS   rT   rU   train_datasetenergon_dataloaderr8   r8   r9   train_dataloader   sJ   

z,EnergonMultiModalDataModule.train_dataloaderc                 C   s   | j r| j S t std| j  t| j}n%t	 }t
 }t }td| d| d|  t||| j|ddd}| j|dd	}t||d
}|| _ | j S )an  
        Initialize and return the validation DataLoader.

        This method initializes the DataLoader for the validation dataset. It ensures that the parallel state
        is initialized correctly for distributed training and returns a configured DataLoader object.

        Returns:
        EVAL_DATALOADERS: The DataLoader for the validation dataset.
        zjMuiltimodal val data loader parallel state is not initialized,using default worker config with no_workers zrank rP   rQ   Nr   rR   rE   rX   rY   )r3   r
   r]   r   r\   r   r   r^   r   r_   r`   ra   rO   r   )r5   rK   rS   rT   rU   val_datasetenergon_loaderr8   r8   r9   val_dataloader   s2   
z*EnergonMultiModalDataModule.val_dataloaderc                 C   s   t d dS )z
        Return None as test dataset split does not exist.

        This method overrides the test_dataloader method and returns None since the test dataset split
        is not defined or used in this module.

        Returns:
        None
        z7Multimodal dataloader test dataset split does not existN)r   warningr5   r8   r8   r9   test_dataloader  s   

z+EnergonMultiModalDataModule.test_dataloaderc                 C   s   | j r?| j j}g }t pt pt pt dkr!|jdd}| j	| j j
| j }|du r2g }td|  ||dS td i S )aT  
        Save the state of the data module.

        This method is called when saving a checkpoint. It generates and saves the state of the data module,
        including the state of the dataloader and the number of consumed samples.

        Returns:
        Dict[str, Any]: A dictionary containing the state of the data module.
        r   )global_dst_rankNzEMultimodal data loader saving dataloader state dict consumed samples )dataloader_stateconsumed_sampleszHtrainer object not connected to data module object returning empty state)rZ   rd   r
   get_context_parallel_rank get_pipeline_model_parallel_rankget_tensor_model_parallel_rankget_expert_model_parallel_ranksave_state_globalr1   compute_consumed_samplesr[   r/   r   r\   rh   )r5   dataloader_objstaterm   r8   r8   r9   
state_dict  s*   	

z&EnergonMultiModalDataModule.state_dictrv   c              
   C   s  d|vrt d|   dS |d }z | jr'| jj | t d nt d|  t	dW n t
yN } zt d|  W Y d}~nd}~ww zdd	lm} W n ttfym   t d
 dd	lm} Y nw |d }|| j_|| j_t d|  ||dd dS )az  
        Load the state of the data module from a checkpoint.

        This method is called when loading a checkpoint. It restores the state of the data module,
        including the state of the dataloader and the number of consumed samples.

        Parameters:
        state_dict (Dict[str, Any]): The state dictionary containing the saved state of the data module.
        rl   zpData loader state cannot be resumed from state_dict, it does not have the required key dataloader_state. It has Nz$Multimodal dataloader state restoredz%Cannot restore state from state_dict zhCannot restore state from state_dict: Is the trainer object is initialized and attached to datamodule???zFailed to dataloader restore state due to [Please ensure you are using same version of energon while saving and loading, Continuing without restoring data loader] : r   )update_num_microbatcheszCMegatron num_microbatches_calculator not found, using Apex version.rm   z<Multimodal dataloader load state dict with consumed_samples F)rm   consistency_check)r   rh   keysrZ   
datamodulerd   restore_state_globalr\   errorrM   	Exception)megatron.core.num_microbatches_calculatorrw   ImportErrorModuleNotFoundError(apex.transformer.pipeline_parallel.utilsr1   init_consumed_samplesprev_consumed_samples)r5   rv   ru   erw   rm   r8   r8   r9   load_state_dictB  sN   


z+EnergonMultiModalDataModule.load_state_dict)rE   )r(   N)__name__
__module____qualname____doc__r   strintboolr   r   r.   rB   rC   r   rG   r   rO   r	   rd   r   rg   rj   r   r   rv   r   __classcell__r8   r8   r6   r9   r   !   sj    	
N%,
&"(r   c                       sr   e Zd ZdZ					ddededed	ed
ee f
 fddZdedefddZe	de
eef fddZ  ZS )r0   a  
    A data sampler for sequential sampling in Megatron, designed to handle large datasets efficiently.

    This class extends the MegatronDataSampler to support sequential sampling for large datasets.
    It includes functionality for handling micro-batches and tracking consumed samples across training steps.

    Attributes:
    seq_len (int): The sequence length for each sample.
    micro_batch_size (int): The number of samples in each micro-batch.
    init_consumed_samples (int): The initial number of samples that have been consumed (used for resuming training).
    prev_consumed_samples (int): Tracks the number of consumed samples before the current step.
    if_first_step (int): Flag to indicate if it's the first training step.
    prev_global_batch_size (Optional[int]): The global batch size from the previous step.
    init_global_step (int): The initial global step at the start of training.
          r   Nr+   r   r   r   r,   c                    s   t  j||||||d dS )a  
        Initialize the SequentialMegatronSampler.

        Parameters:
        seq_len (int): The sequence length for each sample.
        micro_batch_size (int, optional): The number of samples in each micro-batch. Defaults to 4.
        init_consumed_samples (int, optional): The initial number of samples that have been consumed. Defaults to 0.
        init_global_step (int, optional): The initial global step at the start of training. Defaults to 0.
        )r+   r,   r   r   r   r/   N)r-   r.   )r5   r+   r   r   r   r,   r/   r6   r8   r9   r.     s   
z"SequentialMegatronSampler.__init__
dataloaderr(   c                 C   s   |S )ay  
        Transform the DataLoader for sequential sampling.

        This method returns the DataLoader as is, but it can be overridden to apply specific transformations to
        the DataLoader if needed.

        Parameters:
        dataloader (DataLoader): The original DataLoader to be transformed.

        Returns:
        DataLoader: The transformed DataLoader.
        r8   )r5   r   r8   r8   r9   transform_dataloader  s   z.SequentialMegatronSampler.transform_dataloaderc                 C   s   | j | j| jdS )al  
        Return the keyword arguments required for Megatron data handling.

        This property provides the necessary arguments that Megatron uses to handle data, including sequence length,
        micro-batch size, and the number of micro-batches.

        Returns:
        Dict[str, Any]: A dictionary containing the Megatron data handling arguments.
        )r   r   num_microbatches)r+   r   r   ri   r8   r8   r9   megatron_data_kwargs  s   z.SequentialMegatronSampler.megatron_data_kwargs)r   r   r   Nr   )r   r   r   r   r   r   r.   r   r   propertyr   r   r   r   r   r8   r8   r6   r9   r0   u  s*     r0   )(copyr   typingr   r   r   r   fiddlerB   lightning.pytorchpytorchpl!lightning.pytorch.utilities.typesr   r	   megatron.corer
   megatron.energonr   r   r   torch.utils.datar   typing_extensionsr   /nemo.collections.multimodal.data.energon.configr   5nemo.collections.multimodal.data.energon.task_encoderr   nemo.lightning.io.mixinr   r   r   nemo.lightning.pytorch.pluginsr   
nemo.utilsr   LightningDataModuler   r0   r8   r8   r8   r9   <module>   s$     V