o
    }oig                  	   @   s  d dl Z d dlZd dlZd dlmZmZ d dlmZ d dlm	Z	 d dl
mZmZmZmZ d dlZd dlmZ d dlmZmZmZmZ d dlmZ d d	lmZ d d
lmZ d dlmZmZ d dl m!Z! d dl"m#Z#m$Z$ d dl%m&Z& d dl'm(Z( d dl)m*Z* G dd dZ+G dd dZ,	d,dede-de-dB defddZ.G dd dZ/G dd de/Z0G d d! d!e/Z1d"d# Z2d$e3e	B e4e3 B de4e3 fd%d&Z5dd'hfd(e6d)e7e3 de6fd*d+Z8dS )-    N)MappingSequence)BytesIO)Path)	GeneratorIterableListLiteral)groupby)AudioSourceMonoCut	RecordingSupervisionSegment)LibsndfileBackend)Cut)resolve_seed)LazyIteratorChainLazyJsonlIterator)	open_best)compute_num_samplesifnone)get_full_path)logging)is_datastore_pathc                   @   s   e Zd ZdZ						d deeB ee B ded	ed
ededee	d B dee
eef  dB ddfddZdeeddf fddZdefddZdd Z	d!dededededB def
ddZ	d!dedededB defddZdS )"LazyNeMoIteratora#  
    ``LazyNeMoIterator`` reads a NeMo (non-tarred) JSON manifest and converts it on the fly to an ``Iterable[Cut]``.
    It's used to create a ``lhotse.CutSet``.

    Currently, it requires the following keys in NeMo manifests:
    - "audio_filepath"
    - "duration"
    - "text" (overridable with ``text_field`` argument)

    Specially supported keys are:
    - [recommended] "sampling_rate" allows us to provide a valid Lhotse ``Recording`` object without checking the audio file
    - "offset" for partial recording reads
    - "lang" is mapped to Lhotse superivsion's language (overridable with ``lang_field`` argument)

    Every other key found in the manifest will be attached to Lhotse Cut and accessible via ``cut.custom[key]``.

    .. caution:: We will perform some I/O (as much as required by soundfile.info) to discover the sampling rate
        of the audio file. If this is not acceptable, convert the manifest to Lhotse format which contains
        sampling rate info. For pure metadata iteration purposes we also provide a ``metadata_only`` flag that
        will create only partially valid Lhotse objects (with metadata related to sampling rate / num samples missing).

    Example::

        >>> cuts = lhotse.CutSet(LazyNeMoIterator("nemo_manifests/train.json"))

    We allow attaching custom metadata to cuts from files other than the manifest via ``extra_fields`` argument.
    In the example below, we'll iterate file "questions.txt" together with the manifest and attach each line
    under ``cut.question`` using the field type ``text_iter``::

        >>> cuts = lhotse.CutSet(LazyNeMoIterator(
        ...     "nemo_manifests/train.json",
        ...     extra_fields=[{"type": "text_iter", "name": "question", "path": "questions.txt"}],
        ... ))

    We also support random sampling of lines with field type ``text_sample``::

        >>> cuts = lhotse.CutSet(LazyNeMoIterator(
        ...     "nemo_manifests/train.json",
        ...     extra_fields=[{"type": "text_sample", "name": "question", "path": "questions.txt"}],
        ... ))
    textlangFtrngNpath
text_field
lang_fieldmetadata_onlyshuffle_shards
shard_seed)
randomizedr   extra_fieldsreturnc           	      C   s|   || _ || _|| _t|}t|dkrt|d | _ntdd |D | j| jd| _|| _|| _	|| _
|| _t| j d S )N   r   c                 s   s    | ]}t |V  qd S Nr   ).0p r,   e/home/ubuntu/.local/lib/python3.10/site-packages/nemo/collections/common/data/lhotse/nemo_adapters.py	<genexpr>d   s    z,LazyNeMoIterator.__init__.<locals>.<genexpr>)shuffle_itersseed)r   r"   r#   expand_sharded_filepathslenr   sourcer   r   r    r!   r%   validate_extra_fields)	selfr   r   r    r!   r"   r#   r%   pathsr,   r,   r-   __init__Q   s   

zLazyNeMoIterator.__init__c                 #   s    t | j  fdd| jpdD }| jD ]Y}|ddrqtt|dt| jdd}|d}|d	d }| j	||||d
d d}|j
t|j|jd|j|| j|| jd ||_|D ]}|| qc|V  qd S )Nc                       g | ]}t d  i|qS r0   
ExtraField	from_dictr*   	field_cfgr9   r,   r-   
<listcomp>o       z-LazyNeMoIterator.__iter__.<locals>.<listcomp>r,   _skipmeFaudio_filepath)force_cachedurationoffsetsampling_rate)
audio_pathrE   rD   rF   r   idrecording_idstartrD   r   language)r   r#   r%   r3   getr   strpopr   _create_cutsupervisionsappendr   rI   rJ   rD   r   r    custom	attach_to)r5   r%   datarG   rD   rE   cutextra_fieldr,   r9   r-   __iter__l   s6   





zLazyNeMoIterator.__iter__c                 C   
   t | jS r(   r2   r3   r5   r,   r,   r-   __len__      
zLazyNeMoIterator.__len__c                 C   
   t | |S r(   r   r5   otherr,   r,   r-   __add__   r]   zLazyNeMoIterator.__add__rG   rE   rD   rF   c                 C   s   | j s2| |||}| }|d ur0|j||dd}|j dt|d ddt|d d|_|S t|d}t|d}t|||dg t|t	d	dgd
dg||| t
|| |dd}|S )NTrE   rD   preserve_id-g      Y@06di>          r   dummy typechannelsr3   )rI   sourcesrF   rD   num_samples)rI   rK   rD   channelrQ   	recording)r!   _create_recordingto_cuttruncaterI   roundr   r   r   r   r   )r5   rG   rE   rD   rF   rp   rV   srr,   r,   r-   rP      s0   ,

zLazyNeMoIterator._create_cutc                 C   sL   |d ur!t |r
dnd}t|t|dg|dg|t|||dgdS t|S )Nurlfiler   rj   )rI   rm   rF   rn   rD   channel_ids)r   r   r   r   	from_file)r5   rG   rD   rF   source_typer,   r,   r-   rq      s   
	z"LazyNeMoIterator._create_recording)r   r   FFr   Nr(   )__name__
__module____qualname____doc__rN   r   listboolintr	   dictr7   r   r   rX   r\   rb   floatrP   r   rq   r,   r,   r,   r-   r   &   sd    -
	

(r   c                   @   s   e Zd ZdZ						d!deeB ee B deeB d	ed
ee	d B dedededee
eef  dB ddfddZded  fddZd"ddZedee fddZdeee
ef ddf fddZdeeddf fddZdefddZdd  ZdS )#LazyNeMoTarredIteratora  
    ``LazyNeMoTarredIterator`` reads a NeMo tarred JSON manifest and converts it on the fly to an ``Iterable[Cut]``.
    It's used to create a ``lhotse.CutSet``.

    Currently, it requires the following keys in NeMo manifests:
    - "audio_filepath"
    - "duration"
    - "text" (overridable with text_field argument)
    - "shard_id"

    Specially supported keys are:
    - "lang" is mapped to Lhotse superivsion's language (overridable with ``lang_field`` argument)

    Every other key found in the manifest will be attached to Lhotse Cut and accessible via ``cut.custom[key]``.

    Args ``manifest_path`` and ``tar_paths`` can be either a path/string to a single file, or a string in NeMo format
    that indicates multiple paths (e.g. "[[data/bucket0/tarred_audio_paths.json],[data/bucket1/...]]").
    We discover shard ids from sharded tar and json files by parsing the input specifier/path and
    searching for the following pattern: ``(manifest|audio)[^/]*_(\d+)[^/]*\.(json|tar)``.
    It allows filenames such as ``manifest_0.json``, ``manifest_0_normalized.json``, ``manifest_normalized_0.json``,
    ``manifest_0.jsonl.gz``, etc. (anologusly the same applies to tar files).

    We also support generalized input specifiers that imitate webdataset's pipes (also very similar to Kaldi's pipes).
    These are arbitrary shell commands to be lazily executed which yield manifest or tar audio contents.
    For example, ``tar_paths`` can be set to ``pipe:ais get ais://my-bucket/audio_{0..127}.tar -``
    to indicate that we want to read tarred audio data from shards on an AIStore bucket.
    This can be used for other cloud storage APIs such as S3, GCS, etc.
    The same mechanism applies to ``manifest_path``.

    If your data has been filtered so that the JSON manifests refer to just a subset of recordings,
    set ``skip_missing_manifest_entries` to ``True``.
    This will still read the tar files sequentially (very fast) and discard the audio files that
    are not present in the corresponding manifest.

    The ``shard_seed`` argument is used to seed the RNG shuffling the shards.
    By default, it's ``trng`` which samples a seed number from OS-provided TRNG (see Python ``secrets`` module).
    Seed is resolved lazily so that every dataloading worker may sample a different one.
    Override with an integer value for deterministic behaviour and consult Lhotse documentation for details:
    https://lhotse.readthedocs.io/en/latest/datasets.html#handling-random-seeds

    Example of CutSet with inter-shard shuffling enabled::

        >>> cuts = lhotse.CutSet(LazyNeMoTarredIterator(
        ...     manifest_path=["nemo_manifests/sharded_manifests/manifest_0.json", ...],
        ...     tar_paths=["nemo_manifests/audio_0.tar", ...],
        ...     shuffle_shards=True,
        ... ))

    We allow attaching custom metadata to cuts from files other than the manifest via ``extra_fields`` argument.
    In the example below, we'll iterate file "questions.txt" together with the manifest and attach each line
    under ``cut.question`` using the field type ``text_iter``::

        >>> cuts = lhotse.CutSet(LazyNeMoTarredIterator(
        ...     manifest_path=["nemo_manifests/sharded_manifests/manifest_0.json", ...],
        ...     tar_paths=["nemo_manifests/audio_0.tar", ...],
        ...     extra_fields=[{"type": "text_iter", "name": "question", "path": "questions.txt"}],
        ... ))

    We also support random sampling of lines with field type ``text_sample``::

        >>> cuts = lhotse.CutSet(LazyNeMoTarredIterator(
        ...     manifest_path=["nemo_manifests/sharded_manifests/manifest_0.json", ...],
        ...     tar_paths=["nemo_manifests/audio_0.tar", ...],
        ...     extra_fields=[{"type": "text_sample", "name": "question", "path": "questions.txt"}],
        ... ))
    Fr   r   r   Nmanifest_path	tar_pathsr"   r#   )r   r$   r   r    skip_missing_manifest_entriesr%   r&   c	                 C   sx  || _ |  t|| _t| jdkr-td| jd  d t| jd | _td| j| _	n@t
d}	g }
| jD ]!}|	|}|d usNJ d|	j d| d	|
t|d q7d
d t|
| jD | _	t| j	  | _t|| _t
d}g }
| jD ]!}||}|d usJ d|j d| d	|
t|d q|tt|
| j| _|| _|| _|| _|| _|| _|   d S )Nr'   zYou are using Lhotse dataloading for tarred audio with a non-sharded manifest.
                            This will incur significant memory overhead and slow-down training. To prevent this error message
                            please shard file 'r   z' using 'scripts/speech_recognition/convert_to_tarred_audio_dataset.py'
                            WITHOUT '--no_shard_manifest'shard_idzmanifest[^/]*_(\d+)[^/]*\.jsonzQCannot determine shard_id from manifest input specified: we searched with regex 'z' in input ''c                 S   s   i | ]	\}}|t |qS r,   r)   )r*   sidr+   r,   r,   r-   
<dictcomp>0  s    z3LazyNeMoTarredIterator.__init__.<locals>.<dictcomp>zaudio[^/]*_(\d+)[^/]*\.tarzLCannot determine shard_id from tar input specifier: we searched with regex ')r   r1   r6   r2   r   warningr   r3   r
   shard_id_to_manifestrecompilesearchpatternrR   r   groupzipr   valuesr   r   shard_id_to_tar_pathr"   r#   r   r    r%   	_validate)r5   r   r   r"   r#   r   r    r   r%   json_pattern	shard_idsr+   mtar_patternr,   r,   r-   r7     sZ   









zLazyNeMoTarredIterator.__init__c                    s4   t  jdkr
 gS  fddt j j D S )zEConvert this iterator to a list of separate iterators for each shard.r'   c              
      s*   g | ]\}}t ||d  j j jdqS )F)r   r   r"   r#   r   r    )r   r#   r   r    )r*   r   tarpathr[   r,   r-   r?   M  s    	z4LazyNeMoTarredIterator.to_shards.<locals>.<listcomp>)r2   r6   r   r   r   r[   r,   r[   r-   	to_shardsF  s
   
	z LazyNeMoTarredIterator.to_shardsc              
   C   sX   t | j}t | j}||ks%J d| j d| j dt| dt| d	t| j d S )Nz8Mismatch between shard IDs. Details:
* JSON manifest(s) z
* Tar files: z%
* JSON manifest(s) indicate(s) IDs: z 
* Tar path(s) indicate(s) IDs: 
)setr   r   r6   r   sortedr4   r%   )r5   shard_ids_tarsshard_ids_manifestr,   r,   r-   r   Y  s   


z LazyNeMoTarredIterator._validatec                 C   s   t | j S r(   )r   r   keysr[   r,   r,   r-   r   e  s   z LazyNeMoTarredIterator.shard_idsc           	      c   s    t jt|ddddG}|D ];}z||j }|| }|||fV  W q tyJ } z| jr6W Y d }~qtd| d| d|j d|d }~ww W d    d S 1 sVw   Y  d S )	Nrb)modezr|*)fileobjr   z)Mismatched entry between JSON manifest ('z') and tar file ('z+'). Cannot locate JSON entry for tar file 'r   )	tarfileopenr   nameextractfilereadKeyErrorr   RuntimeError)	r5   tar_pathshard_manifestr   tartar_inforU   	raw_audioer,   r,   r-   _iter_sequentiali  s*   
"z'LazyNeMoTarredIterator._iter_sequentialc                 #   s   | j }t| j| jrt| fdd| jpdD }t	d |D ]}t
| jdkr5| j| n| jd }dtdtf fd	d
}t|| j| }| j| }z| |||D ]\}}	}
tt|	}t|
jtdtt|j|	dgt|j|j|jd}g }t||
j  dd dD ]J}|!ddrqt"||!dd|!dd}|j#$t%|j&|j'd|j|!| j(|!| j)d t*||_+||_,||_-|D ]}|.| q|$| q~~	|E d H  qZW q' t/j0y   t12d| Y q'w d S )Nc                    r8   r9   r:   r=   r9   r,   r-   r?     r@   z3LazyNeMoTarredIterator.__iter__.<locals>.<listcomp>r,   z-^(?P<stem>.+)(?P<sub>-sub\d+)(?P<ext>\.\w+)?$r'   r   dr&   c                    s8     | d  } }d ur|dt|dd S |S )NrB   stemextri   )matchr   r   )r   kr   )offset_patternr,   r-   basename  s
   z1LazyNeMoTarredIterator.__iter__.<locals>.basenamememoryrj   )rI   rm   rF   rn   rD   c                 S   s   | d S )NrB   r,   )r   r,   r,   r-   <lambda>  s    z1LazyNeMoTarredIterator.__iter__.<locals>.<lambda>)keyrA   FrE   rg   rD   )rE   rD   rH   zOSkipping tar file due to read errors (unstable storage or bad file?): tar_path=)3r   r   r#   r"   randomRandomshuffler%   r   r   r2   r6   r   rN   r
   r   r   r   	soundfileinfor   r   r   r   r   rangerl   r   
samplerateframesrD   r   r   rM   'make_cut_with_subset_inmemory_recordingrQ   rR   r   rI   rJ   r   r    _to_custom_attr_dictrS   manifest_origin
tar_originrT   r   	ReadErrorr   r   )r5   r   r%   r   r   r   r   r   rU   r   r   metarp   cuts_for_recordingrV   rW   r,   )r   r0   r-   rX   y  sn   

"




%zLazyNeMoTarredIterator.__iter__c                 C   rY   r(   rZ   r[   r,   r,   r-   r\     r]   zLazyNeMoTarredIterator.__len__c                 C   r^   r(   r_   r`   r,   r,   r-   rb     r]   zLazyNeMoTarredIterator.__add__)Fr   r   r   FN)r&   N)r{   r|   r}   r~   rN   r   r   r   r   r	   r   r7   r   r   r   propertyr   r   tuplebytesr   r   rX   r\   rb   r,   r,   r,   r-   r      sF    G
	

7
 Er   rg   rp   rE   rD   r&   c                 C   s   |   }|dkr|du st|| j dk r|S z
|j||dd}W n ty> } ztd| d| d|  d	| |d}~ww t }t j||	 |j
d
d |d t| j| j
|j|jtd| j| dgd}|  S )a  
    This method is built specifically to optimize CPU memory usage during dataloading
    when reading tarfiles containing very long recordings (1h+).
    Normally each cut would hold a reference to the long in-memory recording and load
    the necessary subset of audio (there wouldn't be a separate copy of the long recording for each cut).
    This is fairly efficient already, but we don't actually need to hold the unused full recording in memory.
    Instead, we re-create each cut so that it only holds a reference to the subset of recording necessary.
    This allows us to discard unused data which would otherwise be held in memory as part of sampling buffering.
    rg   Ng?Trc   z'Lhotse cut.truncate failed with offset=z, duration=z, recording=z: wav)rF   formatr   r   rj   )rI   rF   rn   rD   rm   )rr   absrD   rs   	Exceptionr   r   r   
save_audio
load_audiorF   seekr   rI   rn   r   rx   getvalue)rp   rE   rD   rV   r   
audiobytesnew_recordingr,   r,   r-   r     s:   "
r   c                       sn   e Zd ZdZi Zdd Z fddZededd fdd	Z	e
d
edefddZe
dee fddZ  ZS )r;   Nc                 C   s   t  r(   )NotImplementedError)r5   rV   r,   r,   r-   rT     s   zExtraField.attach_toc                    s.   | j tjvr| tj| j< t jdi | d S )Nr,   )r{   r;   SUPPORTED_TYPESTYPEsuper__init_subclass__)clskwargs	__class__r,   r-   r     s   zExtraField.__init_subclass__rU   r&   c                 C   sF   | d t jv sJ d| d  t j| d  di dd |  D S )Nrk   zUnknown transform type: c                 S   s   i | ]\}}|d kr||qS )rk   r,   r*   r   vr,   r,   r-   r     r@   z(ExtraField.from_dict.<locals>.<dictcomp>r,   )r;   r   items)rU   r,   r,   r-   r<     s    &zExtraField.from_dict
field_typec                 C   s
   || j v S r(   )r   )r   r   r,   r,   r-   is_supported     
zExtraField.is_supportedc                 C   rY   r(   )r   r   )r   r,   r,   r-   supported_types  r   zExtraField.supported_types)r{   r|   r}   r   r   rT   r   staticmethodr   r<   classmethodrN   r   r   r   r   __classcell__r,   r,   r   r-   r;     s    r;   c                   @   s4   e Zd ZdZddedefddZdd Zd	d
 ZdS )TextIteratorExtraField	text_iterNr   r   c                 C   s   || _ || _d | _d S r(   )r   r   iteratorr5   r   r   r0   r,   r,   r-   r7     s   
zTextIteratorExtraField.__init__c                 C   s*   | j d u rtttjt| j| _ d S d S r(   )r   itermaprN   stripr   r   r[   r,   r,   r-   _maybe_init  s   
z"TextIteratorExtraField._maybe_initc              	   C   sR   |    zt| j}W n ty   td| j d| j dw t|| j| |S )NzNot enough lines in file z to attach to cuts under field .)r   nextr   StopIterationr   r   r   setattrr5   rV   attached_valuer,   r,   r-   rT     s   z TextIteratorExtraField.attach_tor(   )r{   r|   r}   r   rN   r7   r   rT   r,   r,   r,   r-   r     s
    r   c                   @   s:   e Zd ZdZdededeeB fddZdd Zd	d
 ZdS )TextSampleExtraFieldtext_sampler   r   r0   c                 C   s"   || _ || _|| _d | _d | _d S r(   )r   r   r0   
populationrngr   r,   r,   r-   r7   %  s
   
zTextSampleExtraField.__init__c                 C   s<   | j d u rtttjt| j| _ tt	| j
| _d S d S r(   )r   r   r   rN   r   r   r   r   r   r   r0   r   r[   r,   r,   r-   r   ,  s   
z TextSampleExtraField._maybe_initc                 C   s(   |    | j| j}t|| j| |S r(   )r   r   choicer   r   r   r   r,   r,   r-   rT   1  s   zTextSampleExtraField.attach_toN)	r{   r|   r}   r   rN   r   r7   r   rT   r,   r,   r,   r-   r   "  s
    r   c                 C   s   | d u rd S t | tsJ d| | D ]8}t |ts%J d|d| |d}t|s>J dt  d|d| d|v sLJ d|d| qd S )	NzZThe argument provided to 'extra_fields' must be a list of dicts. We received extra_fields=z>Each item in 'extra_fields' must be a dict. We received field=z in extra_fields=rk   zZEach item in 'extra_fields' must contain a 'type' field with one of the supported values (z). We got field_type=r   zwEach item in 'extra_fields' must contain a 'name' field so that the field is available under cut.<name>.We found field=)
isinstancer   r   rM   r;   r   r   )r%   fieldr   r,   r,   r-   r4   8  s>   

r4   r6   c                 C   s.   ddl m} t| trt| } || ddddS )Nr   )r1   	replicater'   )shard_strategy
world_sizeglobal_rank)'nemo.collections.asr.data.audio_to_textr1   r   r   rN   )r6   _expand_sharded_filepathsr,   r,   r-   r1   N  s   
r1   rB   r   _excluded_fieldsc                    s    fdd|   D S )Nc                    s   i | ]\}}| vr||qS r,   r,   r   r  r,   r-   r   Y  r@   z(_to_custom_attr_dict.<locals>.<dictcomp>)r   )r   r  r,   r	  r-   r   X  s   r   )rg   N)9r   r   r   collections.abcr   r   ior   pathlibr   typingr   r   r   r	   r   cytoolzr
   lhotser   r   r   r   lhotse.audio.backendr   
lhotse.cutr   lhotse.dataset.dataloadingr   lhotse.lazyr   r   lhotse.serializationr   lhotse.utilsr   r   4nemo.collections.common.parts.preprocessing.manifestr   
nemo.utilsr   nemo.utils.data_utilsr   r   r   r   r   r;   r   r   r4   rN   r   r1   r   r   r   r,   r,   r,   r-   <module>   sN    & |
-"&
