o
    }oiUU                     @   s\  d dl Z d dlZd dlmZ d dlmZ d dlmZ d dlm	Z	 d dl
mZmZmZmZ d dlZd dlZd dlmZ d dlmZ d d	lmZ d d
lmZ d dlmZ d dlmZmZmZ d dl m!Z!m"Z" d dl#m$Z$ d dl%m&Z&m'Z' d dl(m)Z) d dl*m+Z+ 	 G dd dZ,	 eG dd de,eZ-eG dd dZ.e'e-de-fddZ/	 eG dd de,eZ0eG dd dZ1e'e0de0fdd Z2	 eG d!d" d"e,eZ3e'e3de3fd#d$Z4eG d%d& d&Z5	 eG d'd( d(Z6eG d)d* d*Z7eG d+d, d,e,eZ8de8d-ed. d/e9fd0d1Z:e'e8de8fd2d3Z;eG d4d5 d5Z<G d6d7 d7Z=dS )8    N)deque)	dataclass)groupby)Path)IteratorLiteralOptionalUnion)	Recording)CustomFieldMixin)Cut)resolve_seed)
load_jsonl)AudioTarWriterJsonlShardWriterTarIterator)Pathlikeis_valid_url)expand_sharded_filepaths)apply_prompt_format_fnregistered_prompt_format_fn)get_full_path)TokenizerWrapperc                   @   s`   e Zd Zdd ZededB fddZededB fddZededB fd	d
ZdddZ	dS )Formattablec                 C   s   d | _ d | _d | _d | _d S N)	input_idscontext_ids
answer_idsmaskself r!   e/home/ubuntu/.local/lib/python3.10/site-packages/nemo/collections/common/data/lhotse/text_adapters.py__init__,   s   
zFormattable.__init__returnNc                 C      | j d u rd S | j jd S Nr   )r   shaper   r!   r!   r"   input_length2      
zFormattable.input_lengthc                 C   r%   r&   )r   r'   r   r!   r!   r"   output_length8   r)   zFormattable.output_lengthc                 C   r%   r&   )r   r'   r   r!   r!   r"   total_length>   r)   zFormattable.total_lengthc                 C   s:   t | |}|d | _|d | _|d| _|d| _| S )Nr   r   r   r   )r   r   r   getr   r   )r    promptansr!   r!   r"   apply_prompt_formatD   s   


zFormattable.apply_prompt_format)r$   r   )
__name__
__module____qualname__r#   propertyintr(   r*   r+   r/   r!   r!   r!   r"   r   +   s    r   c                   @   sZ   e Zd ZU dZeed< dZedB ed< dZee	j
 ed< dZeed< dedd fd	d
ZdS )TextExamplezN
    Represents a single text example. Useful e.g. for language modeling.
    textNlanguagetokenscustom	tokenizerr$   c                 C   s   t || j| j| _| S r   )npasarrayr6   r7   r8   r    r:   r!   r!   r"   tokenize]   s   zTextExample.tokenize)r0   r1   r2   __doc__str__annotations__r7   r8   r   r;   ndarrayr9   dictr   r>   r!   r!   r!   r"   r5   R   s   
 r5   c                   @   t   e Zd ZU dZeeee f ed< dZe	dB ed< dZ
eed< dZeeed f ed	< d
d Zdee fddZdS )LhotseTextAdapterzj
    ``LhotseTextAdapter`` is used to read a text file and wrap
    each line into a ``TextExample``.
    pathsNr7   Fshuffle_shardstrngrH   
randomized
shard_seedc                 C      t | j| _d S r   r   rF   r   r!   r!   r"   __post_init__n      zLhotseTextAdapter.__post_init__r$   c              	   c   sx    | j }| jrt| j}t|| |D ]#}t|}|D ]
}t|| j	dV  qW d    n1 s4w   Y  qd S Nr7   )
rF   rG   r   rK   randomRandomshuffleopenr5   r7   )r    rF   seedpathfliner!   r!   r"   __iter__q   s   

zLhotseTextAdapter.__iter__)r0   r1   r2   r?   r	   r   listrA   r7   r@   rG   boolrK   r4   r   rN   r   r5   rZ   r!   r!   r!   r"   rE   b   s   
 rE   examplec                 C   s   | |jd| jidgS )Nmessageroleslots)encode_dialogOUTPUT_ROLEr6   r]   r-   r!   r!   r"   %default_text_example_prompt_format_fn|   s   re   c                   @   sP   e Zd ZU dZeed< eed< dZedB ed< dZeed< de	dd fd	d
Z
dS )SourceTargetTextExamplez
    Represents a pair of text examples. Useful e.g. for sequence-to-sequence tasks.
    Supports a ``question`` field, used as the prompt for LLM.
    sourcetargetNquestionr9   r:   r$   c                 C   s8   | j || _ | j|| _| jd ur| j|| _| S r   )rg   r>   rh   ri   r=   r!   r!   r"   r>      s
   
z SourceTargetTextExample.tokenize)r0   r1   r2   r?   r5   rA   ri   r9   rC   r   r>   r!   r!   r!   r"   rf      s   
 rf   c                   @   s   e Zd ZU dZeeee f ed< eeee f ed< dZe	dB ed< dZ
e	dB ed< dZeed< dZe	ed< d	Zeed
< dZeeed f ed< dd Zdee fddZdS )LhotseTextPairAdapterae  
    ``LhotseTextAdapter`` is used to read a tuple of N text files
    (e.g., a pair of files with translations in different languages)
    and wrap them in a ``TextExample`` object to enable dataloading
    with Lhotse together with training examples in audio modality.

    Provide ``questions_path`` to enable randomly sampling lines with questions.
    source_pathstarget_pathsNsource_languagetarget_languagequestions_pathquestions_languageFrG   rH   rI   rK   c                 C   s   d}t | jttfrt | jttfsJ |n+t | jtr#t | jts'J |t| jt| jksBJ dt| j dt| j dt| j| _t| j| _d S )Nz>Both source and target must be a single path or lists of pathszSource (z) and target (z0) path lists must have the same number of items.)
isinstancerk   r@   r   rl   r[   lenr   )r    
ASSERT_MSGr!   r!   r"   rN      s    
z#LhotseTextPairAdapter.__post_init__r$   c                 c   s<   t | j}t|}tt| j| j}| jr|	| d }| j
d ur?t| j
}dd |D }W d    n1 s:w   Y  |D ]Z\}}t|J}t|6}	t||	D ](\}
}tt|
 | jdt| | jd|d urwt||| jdnd dV  qTW d    n1 sw   Y  W d    n1 sw   Y  qAd S )Nc                 S      g | ]}|  qS r!   )strip).0qr!   r!   r"   
<listcomp>       z2LhotseTextPairAdapter.__iter__.<locals>.<listcomp>rQ   )rg   rh   ri   )r   rK   rR   rS   r[   ziprk   rl   rG   rT   ro   rU   rf   r5   ru   rm   rn   choicerp   )r    rV   rngrF   	questionsrX   source_pathtarget_pathfsftlsltr!   r!   r"   rZ      s6   




 zLhotseTextPairAdapter.__iter__)r0   r1   r2   r?   r	   r   r[   rA   rm   r@   rn   ro   rp   rG   r\   rK   r4   r   rN   r   rf   rZ   r!   r!   r!   r"   rj      s   
 	rj   c                 C   sR   | j d ur| j j d| jj }n| jj}|dd|id|jd| jjidgS )N userr^   r_   )ri   r6   rg   rb   rc   rh   )r]   r-   ctxr!   r!   r"    default_src_tgt_prompt_format_fn   s   
r   c                   @   sB   e Zd ZU eed< dZedB ed< dZedB ed< dZeed< dS )NeMoSFTExampledataNr7   metadatar9   )	r0   r1   r2   rC   rA   r7   r@   r   r9   r!   r!   r!   r"   r      s
   
 r   c                    s@   d| j v r| j d rtd    fdd| j d D S )NsystemzDefault prompt format for NeMoSFTExample doesn't support 'system' prompt. Please specialize the prompt_format_fn for PromptFormatter of type c                    s0   g | ]}|d  dkrdn j d|d idqS )fromUserr   r^   valuer_   )rc   rv   turnr-   r!   r"   rx      s    "z0default_sft_prompt_format_fn.<locals>.<listcomp>conversations)r   RuntimeErrorrb   rd   r!   r   r"   default_sft_prompt_format_fn   s   
r   c                   @   rD   )NeMoSFTJsonlAdaptera  
    ``NeMoSFTJsonlAdapter`` is used to read a NeMo LM SFT Chat JSONL file and yield objects of type
    ``NeMoSFTExample`` that can be sampled with Lhotse.

    We expect the following schema (contained in a single line per example)::

        {
            "conversations": [
                {
                    "value": str,
                    "from": "User" | "Assistant",
                    "canonical_form": str,
                    "label": str | null
                },
                ...
            ],
            "mask": "User" | "Assistant",
            "system": str,
            "dataset": str,
            "category": str,
        }
    rF   Nr7   FrG   rH   rI   rK   c                 C   rL   r   rM   r   r!   r!   r"   rN   "  rO   z!NeMoSFTJsonlAdapter.__post_init__r$   c                 c   sT    | j }| jrt| j}t|| |D ]}t|D ]
}t|| j	dV  qqd S rP   )
rF   rG   r   rK   rR   rS   rT   r   r   r7   )r    rF   rV   rW   r   r!   r!   r"   rZ   %  s   
zNeMoSFTJsonlAdapter.__iter__)r0   r1   r2   r?   r	   r   r[   rA   r7   r@   rG   r\   rK   r4   r   rN   r   r   rZ   r!   r!   r!   r"   r     s   
 r   c                   @   s&   e Zd ZU eed< eed< dd ZdS )TextTurnr   r`   c                 C   s   d| j  | jdS )Nr6   )typer   r   )r`   titler   r   r!   r!   r"   to_dict9  s   zTextTurn.to_dictN)r0   r1   r2   r@   rA   r   r!   r!   r!   r"   r   4  s   
 r   c                   @   s.   e Zd ZU eed< eed< eed< dd ZdS )	AudioTurncutr`   audio_locator_tagc                 C   sH   | j jr| j jjd jdvsJ dd| j | j j| j jjd jdS )Nr   >   sharmemoryznCannot serialize AudioTurn to dict because it doesn't reference an audio file (the audio is stored in memory).audio)r   r   durationr   )	r   has_recording	recordingsourcesr   r`   r   r   rg   r   r!   r!   r"   r   C  s   zAudioTurn.to_dictN)r0   r1   r2   r   rA   r@   r   r!   r!   r!   r"   r   =  s
   
 r   c                   @   s   e Zd ZU eed< eeeB  ed< dZe	ed< dZ
eed< ededB fddZededB fd	d
ZededB fddZedefddZedefddZdd Zdee fddZdS )NeMoMultimodalConversationidturnsNtoken_equivalent_durationr9   r$   c                 C   (   | j d u rd S t| d}| j jd | S )Ncontextr   )r   _compute_num_audio_tokensr'   r    extrar!   r!   r"   r(   W     

z'NeMoMultimodalConversation.input_lengthc                 C   r   )Nanswerr   )r   r   r'   r   r!   r!   r"   r*   ^  r   z(NeMoMultimodalConversation.output_lengthc                 C   r   )Nallr   )r   r   r'   r   r!   r!   r"   r+   e  r   z'NeMoMultimodalConversation.total_lengthc                 C      t dd | jD S )Nc                 s       | ]}t |tV  qd S r   )rq   r   rv   tr!   r!   r"   	<genexpr>n      z=NeMoMultimodalConversation.has_audio_turns.<locals>.<genexpr>anyr   r   r!   r!   r"   has_audio_turnsl     z*NeMoMultimodalConversation.has_audio_turnsc                 C   r   )Nc                 s   r   r   )rq   r   r   r!   r!   r"   r   r  r   z<NeMoMultimodalConversation.has_text_turns.<locals>.<genexpr>r   r   r!   r!   r"   has_text_turnsp  r   z)NeMoMultimodalConversation.has_text_turnsc                 C   s   | j dd | jD dS )Nc                 S   rt   r!   )r   r   r!   r!   r"   rx   u  ry   z6NeMoMultimodalConversation.to_dict.<locals>.<listcomp>)r   r   r   r   r   r!   r!   r"   r   t  s   z"NeMoMultimodalConversation.to_dictc                 C   s   dd | j D S )Nc                 S   s   g | ]
}t |tr|jqS r!   )rq   r   r   r   r!   r!   r"   rx   x      z8NeMoMultimodalConversation.list_cuts.<locals>.<listcomp>)r   r   r!   r!   r"   	list_cutsw  rO   z$NeMoMultimodalConversation.list_cuts)r0   r1   r2   r@   rA   r[   r   r   r   floatr9   rC   r3   r4   r(   r*   r+   r\   r   r   r   r   r   r!   r!   r!   r"   r   P  s"   
 r   mode)r   r   r   r$   c                    s    j sdS  jd usJ d|dkr jd d }n|dkr& jdd  }n|dkr. j}ntd| t fdd	|D S )
Nr   a_  Cannot compute the length of a NeMoMultimodalConversation: token_equivalent_duration must be set in order to estimate the number of tokens equivalent to audio turns. Did you forget to set token_equivalent_duration option in your dataloading config? Tip: generally it should be set to frame_shift * total_subsampling_factor of your audio encoder model.r   r   r   z4invalid mode for number of audio token computation: c                    s.   g | ]}t |trt|jj j d  qS )   )rq   r   mathceilr   r   r   r   r]   r!   r"   rx     s    z-_compute_num_audio_tokens.<locals>.<listcomp>)r   r   r   r   sum)r]   r   r   r!   r   r"   r   {  s"   
r   c                 C   s4   t dd | jD dd d}dd |D }||S )Nc                 S   s.   g | ]}|j d t|tr|jn|jidqS )r^   r_   )r`   rq   r   r   r   r   r!   r!   r"   rx     s    zDdefault_multimodal_conversation_prompt_format_fn.<locals>.<listcomp>c                 S   s   | d S )Nr`   r!   )r   r!   r!   r"   <lambda>  s    zBdefault_multimodal_conversation_prompt_format_fn.<locals>.<lambda>)keyc                 S   s.   g | ]\}}|d d dd |D idqS )r^   r   c                 s   s    | ]	}|d  d V  qdS )ra   r^   Nr!   r   r!   r!   r"   r     s    zNdefault_multimodal_conversation_prompt_format_fn.<locals>.<listcomp>.<genexpr>r_   )join)rv   r`   turn_grpr!   r!   r"   rx     s    )r   r   rb   )r]   r-   r   r!   r!   r"   0default_multimodal_conversation_prompt_format_fn  s   

r   c                   @   s   e Zd ZU dZeee B ed< eed< dZeee B ed< dZe	ed< dZ
eed< d	Zeeed
 f ed< dd Zdee fddZdd Zdd ZdS )&NeMoMultimodalConversationJsonlAdaptera  
    ``NeMoMultimodalConversationJsonlAdapter`` is used to read a NeMo multimodal conversation JSONL
    and yield objects of type ``NeMoMultimodalConversation`` that can be sampled with Lhotse.

    We expect the following schema (contained in a single line per example)::

        {
            "id": str,
            "conversations": [
                {
                    "value": str,  # text message or path to audio
                    "from": "User" | "Assistant",
                    "type": "text" | "audio",
                    "duration": float,  # only for audio
                },
                ...
            ],
        }
    manifest_filepathr   Ntarred_audio_filepathsr   FrG   rH   rI   rK   c                 C   s\   t | j| _| jd ur*t | j| _t| jt| jks,J t| j dt| j d S d S )Nz != )r   r   r   rr   r   r!   r!   r"   rN     s   

z4NeMoMultimodalConversationJsonlAdapter.__post_init__r$   c                 c   s0    | j d ur|  E d H  d S |  E d H  d S r   )r   	_iter_tar_iter_jsonlr   r!   r!   r"   rZ     s   
z/NeMoMultimodalConversationJsonlAdapter.__iter__c                 #   s    t tjj}jrtj}t|	| |D ]q\}}t
t|}t|D ]b}dd |d D }g  |D ];}t|\}	}
t|
}
|	 }|
|d ks[J d|d  d|
|j|d ksoJ d|d  d	|j | q9t  t|d
  fdd|d D dV  q*qd S )Nc                 S   s   g | ]
}|d  dkr|qS )r   r   r!   r   r!   r!   r"   rx     r   zDNeMoMultimodalConversationJsonlAdapter._iter_tar.<locals>.<listcomp>r   r   z9Mismatch between JSONL and tar. JSONL defines audio path=z. but we got the following from tar audio_path=r   z=Mismatch between JSONL and tar. JSONL defines audio duration=z0 but we got the following from tar cut.duration=r   c                    sL   g | ]"}|d  dkrt |d |d  dnt  |d  jdqS r   r6   r   r   )r   r`   )r   r`   r   )r   lowerr   popleftr   r   cutsr    r!   r"   rx     s&    r   )r[   rz   r   r   rG   r   rK   rR   rS   rT   iterr   r   nextr@   to_cutr   appendr   r   )r    rF   rV   
jsonl_pathtar_pathtarr   audio_turnsr   r   
audio_pathr   r!   r   r"   r     s<   

z0NeMoMultimodalConversationJsonlAdapter._iter_tarc                 #   sn    j }jrtj}t|| |D ] t D ]}t|d  fdd|d D j	dV  qqd S )Nr   c                    s\   g | ]*}|d  dkrt |d |d  dnttt|d   |d  jdqS r   )r   r   r   r
   	from_filer   r   r   r   rW   r    r!   r"   rx     s&    zFNeMoMultimodalConversationJsonlAdapter._iter_jsonl.<locals>.<listcomp>r   )r   r   r   )
r   rG   r   rK   rR   rS   rT   r   r   r   )r    rF   rV   r   r!   r   r"   r     s    

z2NeMoMultimodalConversationJsonlAdapter._iter_jsonl)r0   r1   r2   r?   r@   r[   rA   r   r   r   rG   r\   rK   r	   r4   r   rN   r   r   rZ   r   r   r!   r!   r!   r"   r     s   
 -r   c                   @   s^   e Zd ZddedefddZdefddZd	d
 Zdd Z	dd Z
dd Zdd Zdd ZdS )#NeMoMultimodalConversationTarWriterd   
output_dir
shard_sizec                 C   s    || _ || _|   |   d S r   )r   r   _reset_setup_writers)r    r   r   r!   r!   r"   r#   #  s   z,NeMoMultimodalConversationTarWriter.__init__r]   c                 C   s   |    | }|d D ]}|d dkr t|d dj|d< q| j| | D ]}|js7J d| | j	|j
j| |j|j
 q+|  jd7  _d S )Nr   r   r   r   z.flaczTCannot serialize multimodal conversation with cuts that have no recordings. We got: r   )_maybe_increment_shardr   r   with_suffixnamemanifest_writerwriter   r   
tar_writerr   r   
load_audiosampling_rate	item_cntr)r    r]   
serializedr   r   r!   r!   r"   r   )  s    z)NeMoMultimodalConversationTarWriter.writec                 C   s   | j   | j  d S r   )r   closer   r   r!   r!   r"   r   7  s   
z)NeMoMultimodalConversationTarWriter.closec                 C   s    |    | j  | j  | S r   )r   r   	__enter__r   r   r!   r!   r"   r   ;  s   

z-NeMoMultimodalConversationTarWriter.__enter__c                 O   s   |    d S r   )r   )r    argskwargsr!   r!   r"   __exit__A  s   z,NeMoMultimodalConversationTarWriter.__exit__c                 C   sB   | j dkr| j | j dkrd| _ |  jd7  _|   d S d S d S )Nr   r   )r   r   	shard_idxr   r   r!   r!   r"   r   D  s
   z:NeMoMultimodalConversationTarWriter._maybe_increment_shardc                 C   s   d| _ d| _d S r&   )r   r   r   r!   r!   r"   r   J  s   
z*NeMoMultimodalConversationTarWriter._resetc                 C   s\   t | jst| jjdd t| j d| j dd d| _t| j d| j dd d| _d S )NT)exist_okz
/manifest_z.jsonl)r   z/audio_z.tar)	r   r   r   mkdirr   r   r   r   r   r   r!   r!   r"   r   N  s   
"z2NeMoMultimodalConversationTarWriter._setup_writersN)r   )r0   r1   r2   r@   r4   r#   r   r   r   r   r   r   r   r   r!   r!   r!   r"   r   "  s    r   )>r   rR   collectionsr   dataclassesr   	itertoolsr   pathlibr   typingr   r   r   r	   numpyr;   torchlhotser
   lhotse.customr   
lhotse.cutr   lhotse.dataset.dataloadingr   lhotse.serializationr   lhotse.sharr   r   r   lhotse.utilsr   r   1nemo.collections.common.data.lhotse.nemo_adaptersr   &nemo.collections.common.data.prompt_fnr   r   4nemo.collections.common.parts.preprocessing.manifestr   6nemo.collections.common.tokenizers.aggregate_tokenizerr   r   r5   rE   re   rf   rj   r   r   r   r   r   r   r   r4   r   r   r   r   r!   r!   r!   r"   <module>   sp   "7**w