o
    wiJ                     @   s   d dl mZmZmZ d dlmZ d dlZd dl	Z	d dl
mZmZ d dlmZ d dlmZmZ d dlmZ d dlmZ G dd	 d	ejZG d
d deZdS )    )DictListOptionalN)EVAL_DATALOADERSTRAIN_DATALOADERS)data)
DataLoaderDataset)MegatronDataSampler)loggingc                       s   e Zd ZdZ														d'd	ed
ee dededededeee  dededededededef fddZd(de	ddfddZ
defddZdefd d!Zdefd"d#Zd$edefd%d&Z  ZS ))MockDataModulea  A mock LightningDataModule for generating synthetic data for Llama4 models.

    This module creates dummy datasets (train, validation, test) using `MockLlama4Dataset`
    for testing or development purposes without requiring actual data.

    Args:
        seq_length (int): The sequence length for text tokens. Defaults to 2048.
        decoder_seq_length (Optional[int]): The sequence length for the decoder (if applicable). Defaults to None.
        tokenizer (Optional): Tokenizer object.
        image_processor (Optional): Image processor object.
        micro_batch_size (int): Micro batch size per GPU. Defaults to 4.
        global_batch_size (int): Global batch size across all GPUs. Defaults to 8.
        rampup_batch_size (Optional[List[int]]): Ramp-up schedule for batch size. Defaults to None.
        num_train_samples (int): Number of synthetic samples for the training set. Defaults to 10,000,000.
        num_val_samples (int): Number of synthetic samples for the validation set. Defaults to 10,000,000.
        num_test_samples (int): Number of synthetic samples for the test set. Defaults to 10,000,000.
        num_workers (int): Number of worker processes for data loading. Defaults to 8.
        pin_memory (bool): Whether to pin memory for faster data transfer to GPU. Defaults to True.
        persistent_workers (bool): Whether to keep worker processes alive between epochs. Defaults to False.
        packed_sequence (bool): Whether to use packed sequences for efficiency. Defaults to False.
       N      逖 TF
seq_lengthdecoder_seq_length	tokenizerimage_processormicro_batch_sizeglobal_batch_sizerampup_batch_sizenum_train_samplesnum_val_samplesnum_test_samplesnum_workers
pin_memorypersistent_workerspacked_sequencec                    s   t    || _|| _|| _|| _|| _|	| _|
| _|| _	|| _
|| _|| _|du s.|du rQtd ddlm} ddlm} |d}|pI|d| _|pO|j| _t| j| j|||d| _dS )a=  A mock LightningDataModule for generating synthetic data for Llama4 models.

        This module creates dummy datasets (train, validation, test) using `MockLlama4Dataset`
        for testing or development purposes without requiring actual data.

        Args:
            seq_length (int): The sequence length for text tokens. Defaults to 2048.
            decoder_seq_length (Optional[int]): The sequence length for the decoder (if applicable). Defaults to None.
            tokenizer (Optional): Tokenizer object.
            image_processor (Optional): Image processor object.
            micro_batch_size (int): Micro batch size per GPU. Defaults to 4.
            global_batch_size (int): Global batch size across all GPUs. Defaults to 8.
            rampup_batch_size (Optional[List[int]]): Ramp-up schedule for batch size. Defaults to None.
            num_train_samples (int): Number of synthetic samples for the training set. Defaults to 10,000,000.
            num_val_samples (int): Number of synthetic samples for the validation set. Defaults to 10,000,000.
            num_test_samples (int): Number of synthetic samples for the test set. Defaults to 10,000,000.
            num_workers (int): Number of worker processes for data loading. Defaults to 8.
            pin_memory (bool): Whether to pin memory for faster data transfer to GPU. Defaults to True.
            persistent_workers (bool): Whether to keep worker processes alive between epochs. Defaults to False.
            packed_sequence (bool): Whether to use packed sequences for efficiency. Defaults to False.
        NzdProcessor or tokenizer are not provided! Fall back to `'meta-llama/Llama-4-Scout-17B-16E-Instruct'`.r   )AutoProcessor)AutoTokenizerz)meta-llama/Llama-4-Scout-17B-16E-Instruct)seq_lendecoder_seq_lenr   r   r   )super__init__r   r"   r   r   r   r   r   r   r   r   r   r   warningtransformersr   =nemo.collections.common.tokenizers.huggingface.auto_tokenizerr    from_pretrainedr   r   r
   data_sampler)selfr   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    	processor	__class__ b/home/ubuntu/sommelier/.venv/lib/python3.10/site-packages/nemo/collections/vlm/llama4/data/mock.pyr$   3   s8   
&
zMockDataModule.__init__ stagereturnc                 C   s   | j }| jr| jdkr|| j }td| d t| j| jd| j|| jd| _	t| j| jd| j
|| jd| _t| j| jd| j|| jd| _dS )	al  Sets up the mock datasets for the specified stage.

        Initializes `MockLlama4Dataset` instances for train, validation, and test splits.
        Adjusts sequence length if packed sequences are used.

        Args:
            stage (str): The stage for which to set up data ('fit', 'validate', 'test', or '').
                         Defaults to "".
           zPacked sequence is used with mock dataset. Sequence length for each sample is update to `seq_length // self.micro_batch_size = z`!train)r   validtestN)r   r   r   r   r%   MockLlama4Datasetr   r   r   	_train_dsr   _validation_dsr   _test_ds)r*   r1   r   r.   r.   r/   setupx   s@   

zMockDataModule.setupc                 C       t | ds
| d | | jS )z,Returns the DataLoader for the training set.r8   fit)hasattrr;   _create_dataloaderr8   r*   r.   r.   r/   train_dataloader      

zMockDataModule.train_dataloaderc                 C   r<   )z.Returns the DataLoader for the validation set.r9   validate)r>   r;   r?   r9   r@   r.   r.   r/   val_dataloader   rB   zMockDataModule.val_dataloaderc                 C   r<   )z(Returns the DataLoader for the test set.r:   r6   )r>   r;   r?   r:   r@   r.   r.   r/   test_dataloader   rB   zMockDataModule.test_dataloaderdatasetc                 K   s"   t |f| j| j| j|jd|S )a&  Creates a DataLoader for the given dataset.

        Args:
            dataset (Dataset): The dataset to wrap in a DataLoader.
            **kwargs: Additional arguments passed to the DataLoader constructor.

        Returns:
            DataLoader: The configured DataLoader instance.
        )r   r   r   
collate_fn)r   r   r   r   rG   )r*   rF   kwargsr.   r.   r/   r?      s   
z!MockDataModule._create_dataloader)r   NNNr   r   Nr   r   r   r   TFF)r0   )__name__
__module____qualname____doc__intr   r   boolr$   strr;   r   rA   r   rD   rE   r	   r   r?   __classcell__r.   r.   r,   r/   r      sd    
	
E*r   c                       s   e Zd ZdZ					ddeded	ed
edededdf fddZdefddZ	dede
jfddZdedeeejf fddZdeeeejf  deeeej f fddZdd Z  ZS )r7   a9  A mock Dataset implementation for generating synthetic Llama4 data.

    Produces batches containing dummy image tensors and random token sequences,
    mimicking the structure expected by Llama4 models.

    Args:
        tokenizer: Tokenizer object to determine vocabulary size.
        image_processor: Image processor object to determine image dimensions.
        name (str): Name of the dataset split (e.g., "train", "valid", "test").
        num_samples (int): Total number of synthetic samples in this dataset.
        seq_length (int): Sequence length for the generated token sequences.
        seed (int): Random seed for data generation reproducibility. Defaults to 42.
        packed_sequence (bool): Whether the data should be formatted for packed sequences.
                                Defaults to False.
        pixel_shuffle_ratio (float): Ratio used for calculating the image sequence length
                                     after potential pixel shuffling. Defaults to 0.5.
        num_image_embeddings_per_tile (int): Number of embeddings produced per image tile
                                             by the vision encoder (before pixel shuffle).
                                             Defaults to 576.
        num_tiles_per_image (int): Number of tiles the image is split into. Defaults to 1.
    *   F      ?@  r3   namenum_samplesr   seedr   pixel_shuffle_ratior2   Nc                    s   t    || _|| _|| _|| _d| _|j}|d |d | _| _	|| _
|| _|| _|	| _|
| _|| _t|	|
 | | | _tj| jtjd| _tj| jtjd| _dS )a  A mock Dataset implementation for generating synthetic Llama4 data.

        Produces batches containing dummy image tensors and random token sequences,
        mimicking the structure expected by Llama4 models.

        Args:
            tokenizer: Tokenizer object to determine vocabulary size.
            image_processor: Image processor object to determine image dimensions.
            name (str): Name of the dataset split (e.g., "train", "valid", "test").
            num_samples (int): Total number of synthetic samples in this dataset.
            seq_length (int): Sequence length for the generated token sequences.
            seed (int): Random seed for data generation reproducibility. Defaults to 42.
            packed_sequence (bool): Whether the data should be formatted for packed sequences.
                                    Defaults to False.
            pixel_shuffle_ratio (float): Ratio used for calculating the image sequence length
                                         after potential pixel shuffling. Defaults to 0.5.
            num_image_embeddings_per_tile (int): Number of embeddings produced per image tile
                                                 by the vision encoder (before pixel shuffle).
                                                 Defaults to 576.
            num_tiles_per_image (int): Number of tiles the image is split into. Defaults to 1.
        i@ heightwidth)dtypeN)r#   r$   rT   r   r   r   
vocab_sizesizeimage_heightimage_widthlengthrV   r   num_image_embeddings_per_tilenum_tiles_per_imagerW   rM   _img_seq_lentorchonesfloat	loss_maskarangeint64position_ids)r*   r   r   rT   rU   r   rV   r   rW   r`   ra   r\   r,   r.   r/   r$      s&   
"zMockLlama4Dataset.__init__c                 C   s   | j S )z-Returns the number of samples in the dataset.)r_   r@   r.   r.   r/   __len__  s   zMockLlama4Dataset.__len__idxc                 C   s,   t jj| j| d}|j| j| jgt jdS )zGenerates a random sequence of token IDs (unused in current __getitem__).

        Args:
            idx (int): Index of the sample, used for seeding the random generator.

        Returns:
            np.ndarray: An array of random token IDs.
        rV   r\   rZ   )nprandomdefault_rngrV   integersr[   r   rh   )r*   rk   np_genr.   r.   r/   	_get_text  s   	zMockLlama4Dataset._get_textc                 C   s   t jj| j| d}t|j| j| jd gt j	d}d|dd| j
 < | }t|j| jd| j| jgt jd }|dd }|dd }|||| j| jd	S )
a  Generates a single synthetic data sample.

        Creates random tensors for 'media' (images), 'tokens', and 'labels'.
        The 'tokens' sequence includes placeholder IDs where image features
        would normally be inserted.

        Args:
            idx (int): Index of the sample to generate. Used for seeding.

        Returns:
            Dict[str, torch.Tensor]: A dictionary containing:
                - "media": A dummy image tensor [num_tiles, 3, H, W].
                - "tokens": Input token sequence [seq_length].
                - "labels": Target token sequence (shifted tokens) [seq_length].
                - "loss_mask": Mask indicating which tokens contribute to loss [seq_length].
                - "position_ids": Positional IDs for the sequence [seq_length].
        rl   r3   rm   i       N)mediatokenslabelsrf   ri   )rn   ro   rp   rV   rc   
from_numpyrq   r[   r   rh   rb   clonera   r]   r^   float32bfloat16rf   ri   )r*   rk   rr   rx   ry   imagesr.   r.   r/   __getitem__)  s4   

zMockLlama4Dataset.__getitem__batchc              	   C   s   t j|}d|d< |d jdg|d jdd R  |d< | jrsddlm} |d }|jd }| j}t	j
d|d	 | |t	j|jd
}t	j
d|d	 | |t	j|jd
}d}	||||||||	d}
|
|d< dD ]}|| d	d||< qf|S )aY  Collates a batch of samples from the dataset.

        Uses the default PyTorch collate function and then performs specific adjustments:
        - Sets 'attention_mask' to None.
        - Reshapes 'media' tensor.
        - If `packed_sequence` is True, it prepares the batch for packed sequence format
          by calculating cumulative sequence lengths (`cu_seqlens`) and reshaping
          relevant tensors ('tokens', 'labels', 'loss_mask', 'position_ids').

        Args:
            batch (List[Dict[str, torch.Tensor]]): A list of individual samples (dictionaries)
                                                  from `__getitem__`.

        Returns:
            Dict[str, Optional[torch.Tensor]]: The collated batch, ready for model input.
                                               Includes 'packed_seq_params' if packing is enabled.
        Nattention_maskrw   rv   rt   r   )PackedSeqParamsrx   r3   )steprZ   devicethd)cu_seqlens_qcu_seqlens_kvcu_seqlens_q_paddedcu_seqlens_kv_paddedmax_seqlen_qmax_seqlen_kv
qkv_formatpacked_seq_params)rx   ry   rf   ri   )r   
dataloaderdefault_collatereshapeshaper   megatron.core.packed_seq_paramsr   r   rc   rg   int32r   )r*   r   collated_batchr   rx   
batch_sizevalid_seqlen
cu_seqlenscu_seqlens_paddedr   r   keyr.   r.   r/   _collate_fn[  s8   (
	zMockLlama4Dataset._collate_fnc                 C   s
   |  |S )a  Method passed to the DataLoader's `collate_fn` argument.

        Simply calls the internal `_collate_fn` implementation. This structure allows for
        potential future additions like neural type checking within this wrapper method.

        Args:
            batch: A list of samples fetched from the dataset.

        Returns:
            The collated batch dictionary.
        )r   )r*   r   r.   r.   r/   rG     s   
zMockLlama4Dataset.collate_fn)rQ   FrR   rS   r3   )rI   rJ   rK   rL   rO   rM   rN   re   r$   rj   rn   ndarrayrs   r   rc   Tensorr   r   r   r   rG   rP   r.   r.   r,   r/   r7      s6    	:.24r7   )typingr   r   r   lightning.pytorchpytorchplnumpyrn   rc   !lightning.pytorch.utilities.typesr   r   torch.utilsr   torch.utils.datar   r	   nemo.lightning.pytorch.pluginsr
   
nemo.utilsr   LightningDataModuler   r7   r.   r.   r.   r/   <module>   s    -