o
    }oi                     @   s
  d dl Z d dlZd dlZd dlZd dlZd dlZd dlmZ d dlm	Z	m
Z
mZmZmZmZ d dlZd dlZd dlm  mZ d dlZd dlmZ d dlmZ d dlmZ d dlmZmZ d dlm Z m!Z! d dl"m#  m$  m%  m&  m'Z( d d	l)m*Z* d d
l"m+Z+m,Z,m-Z-m.Z.m/Z/m0Z0m1Z1m2Z2m3Z3m4Z4m5Z5 d dl6m7Z7 dZ8dZ9zd dl:Z:W n e;y   e<d Y nw z
d dl=m>Z> dZ?W n e@eAfy   dZ?Y nw G dd dZBG dd dZCdeeDeeD f de	deEdeEdejFf
ddZGdd ZHdEdeId eId!eEd"eJde
f
d#d$ZKdFd&d'ZLdeIde
fd(d)ZM	dEdeId*eJde
fd+d,ZNdeIde
fd-d.ZOdeIde
fd/d0ZPdeIde
fd1d2ZQdeIde
fd3d4ZRdeIde
fd5d6ZSde
fd7d8ZTd9d: ZUG d;d< d<eZVG d=d> d>eVZWeG d?d@ d@eXZYdGde
fdAdBZZG dCdD dDeZ[dS )H    N)	dataclass)AnyDictListSequenceTupleUnion)	rearrange)
DictConfig)Image)Datasetdefault_collate)CLIPImageProcessorSiglipImageProcessor)image_transform)DEFAULT_BOS_TOKENDEFAULT_EOS_TOKENDEFAULT_IM_END_TOKENDEFAULT_IM_START_TOKENDEFAULT_IMAGE_PATCH_TOKENDEFAULT_IMAGE_TOKENDEFAULT_LABELS_TOKENDEFAULT_PAD_TOKENDEFAULT_VID_END_TOKENDEFAULT_VID_START_TOKENDEFAULT_VIDEO_TOKEN)get_ltor_masks_and_position_ids   z;The package `decord` was not installed in this environment.)IndexedDatasetTFc                   @   s(   e Zd ZdZdd Zdd Zdd ZdS )	TarOrFolderImageLoadera  
    A class for loading images from a tar archive or a regular folder.

    This class provides functionality to open and read images from either a tar archive
    (.tar file) or a standard directory with image files. It builds an index of images
    if the source is a tar archive for efficient access.

    Attributes:
        image_folder (str): The path to the tar archive or image folder.
        tar_index (dict): A dictionary that maps file names to their tarfile member
                          objects if the image source is a tar archive.

    Methods:
        __init__(self, image_folder): Initializes the loader with the specified image folder.
        build_index(self): Builds an index of image file names and their corresponding
                           tarfile member objects for a tar archive.
        open_image(self, file_name): Opens and returns an image by its file name. The image
                                     is returned as an RGB PIL Image object.
    c                 C   s(   || _ i | _| j dr|   d S d S N.tar)image_folder	tar_indexendswithbuild_index)selfr#    r(   f/home/ubuntu/.local/lib/python3.10/site-packages/nemo/collections/multimodal/data/neva/neva_dataset.py__init__Y   s
   zTarOrFolderImageLoader.__init__c                 C   N   t | jd}| D ]}|| j|j< qW d    d S 1 s w   Y  d S Nr)tarfileopenr#   
getmembersr$   namer'   tarmemberr(   r(   r)   r&   _   
   "z"TarOrFolderImageLoader.build_indexc                 C   s   | j dr<t| j d%}| j|}|r*||}t|dW  d    S W d    d S 1 s5w   Y  d S tt	j
| j |dS )Nr"   r-   RGB)r#   r%   r.   r/   r$   getextractfiler   convertospathjoin)r'   	file_namer3   r4   fr(   r(   r)   
open_imaged   s   

z!TarOrFolderImageLoader.open_imageN)__name__
__module____qualname____doc__r*   r&   r?   r(   r(   r(   r)   r    D   s
    r    c                   @   s0   e Zd ZdZdd Zdd Zdd Zdd	 Zd
S )TarOrFolderVideoLoadera  
    A class for loading videos from a tar archive or a regular folder.

    This class provides functionality to open and read videos from either a tar archive
    (.tar file) or a standard directory with video files. It builds an index of videos
    if the source is a tar archive for efficient access.

    Attributes:
        video_folder (str): The path to the tar archive or video folder.
        data_cfg (dict): A dictionary of configuration options for video decoding to frames
        tar_index (dict): A dictionary that maps file names to their tarfile member
                          objects if the video source is a tar archive.

    Methods:
        __init__(self, video_folder): Initializes the loader with the specified video folder.
        build_index(self): Builds an index of image file names and their corresponding
                           tarfile member objects for a tar archive.
        open_video(self, file_name): Opens and returns an video by its file name. The video
                                     is returned as a list of RGB PIL Image objects.
        flatten_frames(self, cap): Converts decord VideoReader video object to list of frame
                                   images based on data config information.
    c                 C   s.   || _ || _i | _| j dr|   d S d S r!   )video_folderdata_cfgr$   r%   r&   )r'   rE   rF   r(   r(   r)   r*      s   zTarOrFolderVideoLoader.__init__c                 C   r+   r,   )r.   r/   rE   r0   r$   r1   r2   r(   r(   r)   r&      r5   z"TarOrFolderVideoLoader.build_indexc                 C   s   | j dr>t| j d'}| j|}|r,||}t|}| 	|W  d    S W d    d S 1 s7w   Y  d S tt
j| j |}| 	|S )Nr"   r-   )rE   r%   r.   r/   r$   r7   r8   decordVideoReaderflatten_framesr:   r;   r<   )r'   r=   r3   r4   r>   capr(   r(   r)   
open_video   s   





z!TarOrFolderVideoLoader.open_videoc                    s>  | j d dkr d  }t|dS | j d dkr. t d   }t|dS | j d dkrC d  }t|dS | j d	 dkrdg } D ]}| }t|d}|| qN|S tt | j d	 }tj	dt d
 |t
d} fdd|D }t|| j d	 k r||d  t|| j d	 k s|S )Nsplice_single_framefirstr   r6   middle   lastr   
num_framesr   dtypec                    s$   g | ]}t  |  d qS )r6   )r   	fromarrayasnumpyr9   ).0irJ   r(   r)   
<listcomp>   s   $ z9TarOrFolderVideoLoader.flatten_frames.<locals>.<listcomp>)rF   rU   r   rT   r9   lenappendminnplinspaceint)r'   rJ   frameframes	rgb_frameimgrQ   indicesr(   rX   r)   rI      s.   z%TarOrFolderVideoLoader.flatten_framesN)r@   rA   rB   rC   r*   r&   rK   rI   r(   r(   r(   r)   rD   p   s    rD   texts	tokenizercontext_lengthadd_extra_tokenreturnc           
         s   |dks|dksJ dd}t | tr| g} d} fdd| D }tdd |D }t|| |}tjt||| tjd	}t|D ]!\}}	t|	|| krV|	d
||  }	t	|	||d
t|	f< qB|rj|d }|S )a  
    Returns the tokenized representation of given input string(s). If the list of tokens exceeds the context
    length plus the number of extra tokens, it gets truncated. If it's smaller, it gets padded with zeros.

    Parameters
    ----------
    texts : Union[str, List[str]]
        An input string or a list of input strings to tokenize.
    tokenizer : Any
        A tokenizer to be used for tokenization.
    context_length : int
        The context length to be used for the output tensor.
    add_extra_token : int
        Number of extra tokens to add, should be either 0 or 1.

    Returns
    -------
    torch.LongTensor
        A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length + add_extra_token].
    r   r   z*`add_extra_token` should be either 0 or 1.FTc                       g | ]}  |qS r(   )text_to_ids)rV   trf   r(   r)   rY          ztokenize.<locals>.<listcomp>c                 S   s   g | ]}t |qS r(   )rZ   rV   tokenr(   r(   r)   rY          rR   N)

isinstancestrmaxr\   torchzerosrZ   long	enumeratetensor)
re   rf   rg   rh   texts_is_strtokensmax_lenresultrW   rp   r(   rm   r)   tokenize   s    
r~   c                    s    fdd|D S )a  
    Returns the token id for a given token.

    Parameters
    ----------
    tokenizer : nemo tokenizer
        A tokenizer to be used for tokenization.
    tokens : list
        A list of tokens to get the token id for.

    Returns
    -------
    List
        The token ids.
    c                    rj   r(   )token_to_idro   rm   r(   r)   rY      rn   z"get_tokens_ids.<locals>.<listcomp>r(   )rf   r{   r(   rm   r)   get_tokens_ids   s   r   sourcesmultimodal_cfgcur_token_len	use_plainc                 C   s6  |d }|d }|d }|}|dkrt }n	|dkrt}n| S |s"| S |d }	|d dkr0|d	 }|}
|dkr:|
|	9 }
|d
 rEt| |
 }nt| |
d  }t| | t|  }|dkr |ddr |ddsmtd|d d }|	d}}|dkrt|d d |	}|| }n|}|| }|d dd}t| }t| }t| }|dkr
t| t| }}tt	}}|d
 r|||  | g| }||||  | g7 }d
|}n,|||d   | g}||||  | g|d  7 }||||d   | g7 }d
|}|| | }n|d
 r|| }n||d  }|| | }| D ]v}|d }|d r`||d d v s8J |d d |d |d d< |tjj tjjd  d |d d  |d d< |rt||d d v snJ ||d d< |d dkr|||}||d< q"|D ]}|d |||d< qq"| S ) a  
    Preprocesses multimodal sources based on the provided configuration.

    This function modifies the sources for multimodal data processing. It checks if the data is multimodal and
    adjusts the token lengths accordingly. It also handles the start and end tokens for images and replaces
    image tokens in conversations.

    Parameters:
    - sources (dict): A dictionary containing the multimodal sources to be processed.
    - multimodal_cfg (dict): A configuration dictionary specifying various options for multimodal processing.
      It includes keys like 'is_multimodal', 'use_im_start_end', and 'sep_image_conv_front'.
    - cur_token_len (int): The current length of tokens to be considered for image processing.
    - use_plain (bool, optional): A boolean flag to use plain image token replacement without additional processing.
      Defaults to False.

    Returns:
    - dict: The processed sources dictionary after applying multimodal preprocessing steps.
    is_multimodal
model_type
media_typeimagevideorQ   mm_mlp_adapter_typemlp_downsample   use_im_start_endrO   use_litaFlitaNzLITA config is missinglita_video_archr   temporal_all_resolutionsample_framesvisual_token_formatv1im_vid_start_end r   conversationssep_image_conv_frontvalue: conv_templateinterleaved)r   r   r   r   r   r7   
ValueErrorr\   r   r   r<   replacestripconversation_libdefault_conversationseproles)r   r   r   r   r   r   r   image_token_lendefault_tokenrQ   num_patchesreplace_tokenr   num_temporal_tokensnum_spatial_tokensr   
num_tokensr   media_start	media_endimage_patchimage_start	image_end	vid_startvid_endreplace_token_listsourceconversationupdated_conversationturnr(   r(   r)   preprocess_multimodal   s   







 


r   squarec           
      C   s   t | ts
t | trj|dkr;t|jt|j}}|| }d\}}tt|| |}| j|ddd|idd d }|S |d	kr]d
d }	|	|tdd | j	D }| j|ddd d }|S | j|ddd d }|S |dksrJ d| |}|S )Nkeepi     ptFshortest_edgereturn_tensorsdo_center_cropsizepixel_valuesr   padc                 S   ~   | j \}}||kr| S ||kr't| j||f|}|| d|| d f |S t| j||f|}|| || d df |S Nr   rO   r   r   newmodepastepil_imgbackground_colorwidthheightr}   r(   r(   r)   expand2square}     
z$process_image.<locals>.expand2squarec                 s       | ]	}t |d  V  qdS    Nr_   rV   xr(   r(   r)   	<genexpr>      z process_image.<locals>.<genexpr>r   r   CNeMo image transform with setting `image_aspect_ratio` to `square`.)
rr   r   r   rt   r   r\   r_   
preprocesstuple
image_mean)
	processorr   image_aspect_ratiomax_hwmin_hwaspect_ratior|   min_lenr   r   r(   r(   r)   process_imagep  s0   r   c                 C   sZ  t j }|jd |jd d}g }t| D ]J\}}|d }||d d  |jd kr1|dd }g |_t|D ] \}}	||	d  }
|
|j|d  ksPJ | ||
|	d  q8||  q|	d	}t
|||	d
|d}|  }d}d}t||D ]z\}}||}||dd g|dd  }d}t|D ]T\}}|dkr nK||}t|dkr n>|d  |7  < |dkrt||}t||d }nt||| }t|||d  }t|||| < ||7 }qt||d< q|r|ddddf  }|ddddf  }ntj|ddd}t|dddf< t||dS )a  
    Preprocesses sources for the LLaMA 3 model configuration.

    The function applies prompt templates and tokenizes the conversations according to the LLaMA 2 model specifications.
    It involves special handling of tokens, masking of labels, and adjustments based on configuration settings.

    Parameters:
    - sources (dict): A dictionary of sources containing conversations to be processed.
    - tokenizer: The tokenizer to be used for processing the text.
    - cfg: Configuration settings for preprocessing, including context length and additional tokens.

    Returns:
    - Dict: A dictionary containing tokenized and labeled data suitable for the LLaMA 2 model.
      This includes tokens, labels, and any special processing as defined in the configuration.
    r   r   humangptr   fromNrO   r   rh   rg   re   rf   rg   rh   z/<|start_header_id|>assistant<|end_header_id|>

z*<|start_header_id|>user<|end_header_id|>

r   r   shiftsdimsr{   labels)r   conv_llava_llama_3copyr   rx   messagesappend_messager[   
get_promptr7   r~   clonedetachzipsplitr<   rZ   rk   IGNORE_INDEX
contiguousru   rolldict)r   rf   cfgconvr   r   rW   r   jsentencerolerh   r{   r   r   	round_sepr   targetroundscur_lenrouparts	round_leninstruction_lenr(   r(   r)   preprocess_llama_3  sf   


 

r	  
is_mistralc                 C   s  |rt j }nt j }|jd |jd d}g }t| D ]J\}}|d }||d d  |jd kr9|dd }g |_t|D ] \}	}
||
d  }||j|	d  ksXJ | |||
d  q@||	  q|
d	}t|||
d
|d}td }|ttg}t||\}}}d|||k< d|||k< d|||k< |  }|rd}nd}t||D ]s\}}||j}d}t|D ]\\}}|dkr nS||}t|dkr nF|d  |7  < t|||j }|rt||d d }nt||d d }|dkr
|d8 }n|d7 }t|||| < ||7 }qt||d< q|r>|ddddf  }|ddddf  }ntj|ddd}t|dddf< t||dS )a  
    Preprocesses sources for the LLaMA 2 model configuration.

    The function applies prompt templates and tokenizes the conversations according to the LLaMA 2 model specifications.
    It involves special handling of tokens, masking of labels, and adjustments based on configuration settings.

    Parameters:
    - sources (dict): A dictionary of sources containing conversations to be processed.
    - tokenizer: The tokenizer to be used for processing the text.
    - cfg: Configuration settings for preprocessing, including context length and additional tokens.

    Returns:
    - Dict: A dictionary containing tokenized and labeled data suitable for the LLaMA 2 model.
      This includes tokens, labels, and any special processing as defined in the configuration.
    r   r   r   r   r   NrO   r   rh   rg   r   llama_2z[/INST]z[/INST] r   r   r   r   )r   conv_mistralr   conv_llava_llama_2r   rx   r   r   r[   r   r7   r~   r   r   r   r   r   r   r   r   sep2rZ   rk   r   r   ru   r   r   )r   rf   r   r
  r   r   r   rW   r   r   r   r   rh   r{   image_patch_tokenDEFAULT_TOKENSimg_patch_idbos_ideos_idr   r   r   r  r  r  r  r  r  r  r(   r(   r)   preprocess_llama_2  sz   

	




r  c                    sZ  	 t j }g }t| D ]_\}}|d }d}t|D ]5\}}|d dkr>|jd |d< |d }	||d |	 |d s=d}q|jd |d< ||d |d  q| }
|rf|
d	rf|
d
td	  d }
|	|
 q|
d}t|||
d|d}|  }d d}t||D ]o\}}| } |d
d g fdd|dd
 D  }t|ttt|ksJ d}t|D ]5\}}|dkr n,||}t|dkr nt||d | }t||}t|||| < ||7 }qt||d
< q|r|d
d
d
df  }|d
d
dd
f  }ntj|ddd}t|d
d
df< t||dS )a  
    Preprocess sources for Yi-1.5 34b model configuration.

    The function applies prompt templates and tokenizes the conversations according to the Yi-1.5 34b model specifications.
    It involves special handling of tokens, masking of labels, and adjustments based on configuration settings.

    This template works with the following tokenizer configs:
    - model.tokenizer.library='huggingface'
    - model.tokenizer.type='01-ai/Yi-1.5-34B'
    - model.tokenizer.additional_special_tokens='{additional_special_tokens: ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>"]}'
    At inference time, add end string to stop sampling:
    - inference.end_strings='["<|im_end|>"]'

    Parameters:
    - sources (dict): A dictionary of sources containing conversations to be processed.
    - tokenizer: The tokenizer to be used for processing the text.
    - cfg: Configuration settings for preprocessing, including context length and additional tokens.

    Returns:
    - Dict: A dictionary containing tokenized and labeled data suitable for the LLaMA 2 model.
      This includes tokens, labels, and any special processing as defined in the configuration.
    r   FrO   r   r   r   Tr   z
<|im_end|>N
rh   rg   r   z<|im_start|>user
z<|im_start|>assistant
c                    s   g | ]} | qS r(   r(   r   r  r(   r)   rY     rq   z%preprocess_yi_34b.<locals>.<listcomp>r   r   r   r   )r   conv_yi_34br   rx   r   r   r   r%   rZ   r[   r7   r~   r   r   r   r   r<   summaprk   r   r   ru   r   r   )r   rf   r   r   r   rW   r   strip_end_for_inferencer   r   contextrh   r{   r   r   r   r  r  r  r  r  r  r  r(   r  r)   preprocess_yi_34b`  sn   


.

r  c                 C   s~  t j }|jd |jd d}g }t| D ]J\}}|d }||d d  |jd kr1|dd }g |_t|D ] \}}	||	d  }
|
|j|d  ksPJ | ||
|	d  q8||  q|	d	}t
|||	d
|d}td }|ttg}t||\}}}d|||k< d|||k< d|||k< |  }|j|jd  d }t||D ]c\}}||j}d}t|D ]L\}}|dkr nC||}t|dkr n6|d  |7  < t|||j }t||d d }|dkr|d8 }|d8 }t|||| < ||7 }qt||d< q|r)|ddddf  }|ddddf  }ntj|ddd}t|dddf< t||dS )a  
    Preprocesses sources for the Vicuna V1 model configuration.

    Similar to `preprocess_llama_2`, this function applies prompt templates and performs tokenization, but it is tailored
    for the Vicuna V1 model. It includes specific handling for token translations, label masking, and tokenizer configuration.

    Parameters:
    - sources (dict): A dictionary of sources containing conversations to be processed.
    - tokenizer: The tokenizer to be used for processing the text.
    - cfg: Configuration settings for preprocessing, which may include context length and additional tokens.

    Returns:
    - Dict: A dictionary containing the processed data, including tokens and labels, formatted for the Vicuna V1 model.
    r   r   r   r   r   NrO   r   rh   rg   r   r  r   r   r   r   r   )r   conv_vicuna_v1r   r   rx   r   r   r[   r   r7   r~   r   r   r   r   r   r   r   r   r   r  rZ   rk   r   r   ru   r   r   )r   rf   r   r   r   r   rW   r   r   r   r   rh   r{   r  r  r  r  r  r   r   r   r  r  r  r  r  r  r  r(   r(   r)   preprocess_v1  sn   

	


r  c                 C   s$  g }| D ]	}| |d  q|d}t|||d|d}|d }t| }t| }	t| }
||	|
tg}t||\}}}}d|||k< | 	 }t
|||k< t
|||k< t
||dk< t
|||k< |r||ddddf  }|ddd	df  }ntj|ddd
}t
|dddf< t||dS )zDtokenize the interleaved prompt and mask the text part of the promptr   rh   rg   r   r   r   Nr   r   r   r   )r[   r7   r~   r   r   r   r   r   r   r   r   r   ru   r   r   )r   rf   r   r   r   rh   r{   r   r  image_start_tokenimage_end_tokenr  r  img_start_id
img_end_idpad_idr   r(   r(   r)   preprocess_interleaved_prompt!  s>   
r$  c              
   C   s  	 t j }g }| D ]m}g |_|d|j|_d}t|d D ]E\}}|d dkrS|jd |d< d|vr8d|d< t|d  d	 |d
  }	|	|d |	 |d
 sRd}q|jd |d< |	|d |d
  q|
 }
|rr|
dd	 }
||
 q
|d}t|||d|d}|  }|j|jd  d	 }tt d}t||D ]\}}||j}|j|dd g}tdt|dD ]}||j|||d   qd}t|D ]I\}}|dkr n@||}t|dkr n3||d }|r| nd}t||d | | }t|||j }t|||| < ||7 }qt||d< q|rG|ddddf  }|ddddf  }ntj|ddd}t|dddf< t||dS )!  
    Preprocess a given set of conversational sources using nvgpt conversation template

    This function processes conversations by first ensuring the conversation starts with a 'human' role, then tokenizes the conversations, applies specific token replacements, and finally masks labels for training purposes.

    Parameters:
    - sources: A dictionary containing conversational data. Expected format is a dict of conversations, where each conversation is a list of messages, and each message is a dict with 'from' (role) and 'value' (message text).
    - tokenizer: A tokenizer from the Hugging Face Transformers library used for tokenizing the conversations.
    - cfg: Configuration settings which include 'add_extra_token' (bool) to determine if an extra token should be added to the tokenized output, and 'context_length' for specifying the tokenization context length.

    Returns:
    - Dict: A dictionary containing two keys:
        - 'tokens': A tensor of tokenized conversation data.
        - 'labels': A tensor of labels for the conversation data, used for training models. Labels are masked based on the conversation structure.

    Note:
    - The function includes specific token replacements (e.g., DEFAULT_IMAGE_PATCH_TOKEN, <s>, </s>) and masking techniques for labels.
    - It is designed to work with conversational data where messages alternate between a 'human' and a 'gpt' role.
    - The function asserts that each message in a conversation alternates between the defined roles and skips messages not starting with the 'human' role.
    systemFr   rO   r   r   labelzjquality:4,toxicity:0,humor:0,creativity:0,helpfulness:4,correctness:4,coherence:4,complexity:4,verbosity:4r  r   Tr   
<extra_id_1>rh   rg   r   zquality:.*
N   r   r   r   r   ) r   
conv_nvgptr   r   r7   r&  rx   r   r   r   r   rstripr[   r~   r   r   r   recompiler   r   r<   rangerZ   searchgrouprk   r   r   ru   r   r   )r   rf   r   r   r   r   r  rW   r   r   r  rh   r{   r   r   labels_str_regexpr   r  r  	re_roundsconv_idxr  r  r  match
labels_strr  r  r(   r(   r)   preprocess_nvgptO  s~   

 

r6  c              
   C   s  	 t j }g }| D ]w}g |_|d|j|_d}t|d D ]F\}}|d dkrT|jd |d< d|v rAt|d  d |d	  }	n|d	 }	|	|d |	 |d	 sSd
}q|jd |d< |	|d |d	  q|
 }
|r||
dr||
dtd  d }
||
 q
|d}t|||d|d}|  }|j|jd  d }t||D ]\}}||j}|j|dd g}tdt|dD ]}||j|||d   qd}t|D ]U\}}|dkr nL||}t|dkr n?ttt d|d }t||d | |r|d d|  nd }t|||j }t|||| < ||7 }qt||d< q|rU|ddddf  }|ddddf  }ntj|ddd}t|dddf< t||dS )r%  r&  Fr   rO   r   r   r'  r  r   Tr   r(  Nrh   rg   r   r)  r   z.*?\nr   r   r   ) r   conv_nv_dpor   r   r7   r&  rx   r   r   r   r   r%   rZ   r[   r~   r   r   r   r   r   r<   r.  r,  r/  escaperk   endr   r   ru   r   r   )r   rf   r   r   r   r   r  rW   r   r   r  rh   r{   r   r   r   r  r  r2  r3  r  r  r  labels_matchr  r  r(   r(   r)   preprocess_nv_dpo  s~   


 
.
r;  c                 C   s  g }| D ]!}|d }t |dksJ |d d |d d  d }|| q|d}t|||d|d	}|  }t|| D ]\}	}|d }t ||d d }
t|	d
|
< qA|rv|d
d
d
df 	 }|d
d
dd
f 	 }nt
j|ddd}t|d
d
df< t||dS )aV  
    Preprocesses plain text sources (no template) for tokenization and label generation.

    This function concatenates conversations with an end signal, tokenizes them, and prepares labels for training.
    It handles sources with a specific structure (expecting two elements in 'conversations') and includes the
    option to add an extra token as specified in the configuration. The function also applies masking to the labels.

    Parameters:
    - sources: A list of source dictionaries. Each source dictionary should have a key 'conversations'
      containing a list of conversation parts.
    - tokenizer: The tokenizer to be used for converting text to tokens.
    - cfg: Configuration dictionary which may include 'context_length' and 'add_extra_token' settings.

    Returns:
    - Dict: A dictionary containing tokenized data and corresponding labels. This includes 'tokens' which are the
      tokenized representations of the conversations, and 'labels' which are used for training the model. The labels
      have specific indices masked with IGNORE_INDEX as per the preprocessing logic.
    r   rO   r   r   r   r  rh   rg   r   Nr   r   r   )rZ   r[   r7   r~   r   r   r   rk   r   r   ru   r   r   )r   rf   r   r   r   r   rh   r{   r   r  tokenized_lenr(   r(   r)   preprocess_plain1  s6   
r=  c                 C   s   | j dv rt|| j| jS | j dkrt|| j| jS | j dkr't|| j| jS | j dkr4t|| j| jS | j dkrAt|| j| jS | j dkrPt|| j| jddS | j d	kr]t|| j| jS | j d
krjt	|| j| jS | j dkrwt
|| j| jS td| j  d)N)nvgpt
nv_steerlmnv_dpor   r  llama_3mistralT)r
  yi_34bplainr   zConversation template `z` is not supported in Neva now.)r   r6  rf   r   r;  r  r  r	  r  r=  r$  r   )r'   r   r(   r(   r)   preprocess_conversationsk  s&   








rE  c                       sN   e Zd ZdZdededef fddZdd Zd	eee	j
f fd
dZ  ZS )LazySupervisedDataset#Dataset for supervised fine-tuning.	data_pathr   rF   c                    s   t t|   |d ur&t|d}t|}W d    n1 s w   Y  ng }td || _|| _	|| _
|d | _|d | _|d | _|d | _| jrRt| jnd | _| jrat| j|| _d S d | _d S )Nr-   z%Formatting inputs...Skip in lazy moder   r#   rE   image_processor)superrF  r*   r/   jsonloadloggingwarningrf   list_data_dictr   r   r#   rE   r   r    image_loaderrD   video_loader)r'   rH  rf   r   rF   filerO  	__class__r(   r)   r*     s"   




"zLazySupervisedDataset.__init__c                 C   s
   t | jS N)rZ   rO  r'   r(   r(   r)   __len__  s   
zLazySupervisedDataset.__len__ri   c                    s  t |tjr
t|}j| }t |tr|g}t|dks!J dd|d v rt j| d ts>j| d gj| d< g }j| d D ]$}j|}|d u r\t	
d| d tj|jd }|| qGtg }|rt|}jd }|d jd | }|d jd	 | }	jd
 dkr|d	 dkr|d7 }|	d	 dkr|	d7 }	||	 }
tt|j|
jdkd}nd|d v rt j| d ts݈j| d gj| d< g }j| d D ]}j|}|d u rt	
d| d t jts	t jtrhjd dkr<t|jt|j}}|| }d\}}tt|| |}jj|ddd|idd }n=jd dkr]dd   fdd|D }jj|ddd }njj|ddd }njd dkstJ d|}|| q|}|rt|}jd }|d jd | }|d jd  | }	jd
 dkr|d	 dkr|d7 }|	d	 dkr|	d7 }	||	 }
tt|j|
jdkd}n
tg }t|}t|}t |trt |d! d |d" d d#}jd$ rVt jtrjj!d% jj!d& g}njd' }|jd t"k r<t"|jd  }tj#|d(|d |d ftj$d)}tj%||fdd*}jd+ dkrJ||d< |S jd+ dkrV||d< |S ),Nr   z&Don't know why it is wrapped to a listr   r   zImage z could not be found!r   	patch_dimrO   r   r   rD  )r   r   zVideo r   r   r   Fr   r   r   r   c                 S   r   r   r   r   r(   r(   r)   r     r   z8LazySupervisedDataset.__getitem__.<locals>.expand2squarec                    s(   g | ]} |t d d jjD qS )c                 s   r   r   r   r   r(   r(   r)   r     r   z?LazySupervisedDataset.__getitem__.<locals>.<listcomp>.<genexpr>)r   r   r   )rV   r`   r   r'   r(   r)   rY     s    z5LazySupervisedDataset.__getitem__.<locals>.<listcomp>r   r   r   r   r{   r   r   r   r   r   	crop_sizer)  rR   dimr   )&rr   r]   integerr_   rO  rZ   listrP  r?   rM  rN  r   r   r   r[   ru   ry   stackshaper   r   deepcopyr   rQ  rK   r   r   rt   r   r\   r   rE  r   r[  MAX_NUM_IMAGESrv   floatcat)r'   rW   r   images
image_filer   media_tensorsrX  height_num_patcheswidth_num_patchesr   videos
video_filera   r   r   r   r|   r   r   	data_dictr[  padding_sizezero_paddingr(   rY  r)   __getitem__  s   











 z!LazySupervisedDataset.__getitem__)r@   rA   rB   rC   rs   r   r*   rW  r   ru   Tensorrp  __classcell__r(   r(   rS  r)   rF    s
     rF  c                       s.   e Zd ZdZdededef fddZ  ZS )NevaDatasetrG  rH  r   rF   c                    s4  | drtt| |||| d S | drtt| d ||| td |d dkr|d }t|dD ]Y}t|}g |d< |d D ]A}t	
d	|d
 }	|	D ])}
|
ddd }tj||}tj|sstd|  qQ|d | qQt	dt|d
 |d
< qE| j| q6d S d S td| d)Nz.jsonz.jsonlz)Loading image inputs from SteerLM Datasetr   r   r#   r-   r   z<img src="([^"]+)"r   r   /r   zImage not found: z<img src="([^"]+)">zFormatting of z is not supported in Neva.)r%   rJ  rs  r*   rM  rN  r/   rK  loadsr,  finditerr0  r   r:   r;   r<   isfiler[   subr   rO  r   )r'   rH  rf   r   rF   r#   linerecordr   matchesr4  
image_name
image_pathrS  r(   r)   r*   ,  s0   



zNevaDataset.__init__)r@   rA   rB   rC   rs   r   r*   rr  r(   r(   rS  r)   rs  )  s    "rs  c                   @   sD   e Zd ZU dZeed< ejed< dee	 de	e
ejf fddZdS )	 DataCollatorForSupervisedDatasetz,Collate examples for supervised fine-tuning.	model_cfgrf   	instancesri   c                 C   s   d|d v }t dd |D }|d d d d }|D ]A}||d jd  }t|d d|fdd|d< t|d	 d|fdd
|d	< |r\|d d
 |kr\t|d t|gfd|d< q|rt dd |D }t dd |D }|D ]C}||d jd  }t|d d|fd||d< |d }	||	jd  }
tj|
g|	jdd  R |	j|	j	d}tj|	|fdd|d< qst
|}| j}| j}|d }|d	 }|jdd}|dkr|d}n|dkr|d}ntd| |r>|d }g }|D ])}|g  tdt|d D ]}||d  ||  }|d
 tt| qqt|}tj| tj|j	d}tj| tj|j	d}nt||j|jddddd\}}}d||d
k< d||d
k< d||d
k< |d u rit|dkrtt|d}n
|dkr~t|d}||||||d}|r||d< |S )N
cu_seqlensr   c                 s       | ]
}|d  j d V  qdS )r{   r   Nra  rV   instancer(   r(   r)   r   W      z<DataCollatorForSupervisedDataset.__call__.<locals>.<genexpr>r   @   r{   constantr   r   c                 s   r  )r  r   Nr  r  r(   r(   r)   r   a  r  c                 s   r  )r   r   Nr  r  r(   r(   r)   r   b  r  r   )rS   devicer\  r   r   zUnsupported media type eod_mask_lossF)data	eod_tokenr  reset_attention_maskreset_position_idsg        zb T c h w -> b T 1 c h wzb T F c h w -> b T F c h w)r{   r   attention_mask	loss_maskposition_idsmedia)rt   ra  Fr   ru   re  	IntTensorrv   rS   r  r   rf   r  r  r7   r   r[   r.  rZ   extendr_  
LongTensoronesr   rd  rw   r   r  NotImplementedErrorr	   )r'   r  packed_sequencer|   r  pad_len
max_len_cumax_len_image
pad_len_cur   num_pad
pad_tensorbatchrf   r  r{   r   r   r  r  r  	cu_seqlenindseqlenr  r  r(   r(   r)   __call__U  s    (





z)DataCollatorForSupervisedDataset.__call__N)r@   rA   rB   rC   r
   __annotations__transformersPreTrainedTokenizerr   r   rs   ru   rq  r  r(   r(   r(   r)   r~  N  s
   
 
$r~  c           
      C   sX  |j }|j}d}t|ddrd}|jdd}|dur|n|j}t| |td"i d|jd	|j	d
|j
d
dd|ddd|jjjd|d|ddd|ddd|jdt|jddd|d|d|jd|ddd|dddt|jdddt|jdi d|jddt|dd|dd|dddd }	t|	|	d!S )#z5Make dataset and collator for supervised fine-tuning.r   no_seqlen_plus_one_input_tokensFr   r[  r   r   Nr   r   r   r>  r   rX  r#   rE   r   r   rI  rh   rg   r   r   rQ   r   r   r   r   linearrL   sep_token_between_frames)rL   rQ   r  )rf   rH  r   rF   )train_dataseteval_datasetr(   )r  mm_cfggetattrvision_encoderr7   rH  rs  r   r   r   llmrX  r   encoder_seq_length)
rf   rI  r  each_file_from_pathrF   r  rh   r[  rH  r  r(   r(   r)   make_supervised_data_module  sl   
	



r  c                   @   sH   e Zd Zddedeeef fddZdd Zdeee	j
f fd	d
ZdS )NevaPackedSeqDatatsetr  rH  r[  c                 C   s   t || _|| _d S rU  )r   dsr[  )r'   rH  r[  r(   r(   r)   r*     s   

zNevaPackedSeqDatatset.__init__c                 C   s   t | jjd S )Nr   )rZ   r  document_indicesrV  r(   r(   r)   rW    s   zNevaPackedSeqDatatset.__len__ri   c                 C   sl   | j j| }t| j | t| j |d  t| j |d  t| j |d  jddg| jR  d}|S )Nr   rO   r)  r   )r  r{   r   r   )r  r  ru   r  r  FloatTensorreshaper[  )r'   rW   	doc_startr  r(   r(   r)   rp    s   $z!NevaPackedSeqDatatset.__getitem__N)r  )r@   rA   rB   rs   r   r_   r*   rW  r   ru   rq  rp  r(   r(   r(   r)   r    s    r  )F)r   rU  )\r   rK  rM  r:   r,  r.   dataclassesr   typingr   r   r   r   r   r   numpyr]   ru   torch.nn.functionalnn
functionalr  r  einopsr	   	omegaconfr
   PILr   torch.utils.datar   r   r   r   2nemo.collections.multimodal.data.neva.conversationcollections
multimodalr  nevar   r   Anemo.collections.multimodal.data.clip.augmentations.augmentationsr   r   r   r   r   r   r   r   r   r   r   r   2nemo.collections.nlp.modules.common.megatron.utilsr   rc  r   rG   	ExceptionrN  &megatron.core.datasets.indexed_datasetr   HAVE_MEGATRON_COREImportErrorModuleNotFoundErrorr    rD   rs   r_   r  r~   r   r   boolr   r   r	  r  r  r  r$  r6  r;  r=  rE  rF  rs  objectr~  r  r  r(   r(   r(   r)   <module>   s    $4,M
/ 
q$
`
p
b
_
.
p
r
: (%X*