o
    }oil                     @   s  d dl Z d dlZd dlZd dlZd dlmZ d dlmZ d dlm	Z	m
Z
mZmZmZ d dlZd dlZd dlmZ d dlmZmZmZ d dlmZ d dlmZmZ d d	lmZmZ d d
lmZ d dl m!Z! d dl"m#Z# d dl$m%Z% d dl&m'Z' d dl(m)Z) eG dd dZ*eG dd dZ+de,de,de
e,e	f fddZ-dIdee* de	de+fddZ.dee,edf fddZ/		 	!	"	#dJd$e
e,e	f d%e0d&ee0 d'e1d(e0d)e0d*e0fd+d,Z2dKd.ee
 d%e0d&e0dB d*e0fd/d0Z3d1ed2e0defd3d4Z4G d5d6 d6ej5Z6G d7d8 d8ej7Z8G d9d: d:ej9j:j;Z<	dLd;e
e,e	f d<e	de	de6fd=d>Z=	dLd;e
e,e	f d?e0d@e0d<e	de	de#fdAdBZ>dIde	fdCdDZ?	dIde	fdEdFZ@d?e0d@e0fdGdHZAdS )M    N)	dataclass)isclose)AnyDictListOptionalUnion)AudioSamples)
DictConfig
ListConfig	open_dict)Tensor)audio_to_textaudio_to_text_dataset)WhiteNoisePerturbationprocess_augmentations)AudioSegment)read_manifest)ConcatDataset)get_full_path)Serialization)loggingc                   @      e Zd ZU dZedB ed< dZeedf ed< dZ	eedf ed< dZ
eedf ed< dZeedf ed< dZeedf ed< dZeedf ed< dS )	AudioNoiseItemN	sample_idaudio	audio_lennoise	noise_lennoisy_audionoisy_audio_len)__name__
__module____qualname__r   str__annotations__r   r   r   r   r   r   r   r     r&   r&   Y/home/ubuntu/.local/lib/python3.10/site-packages/nemo/collections/asr/data/ssl_dataset.pyr   '      
 r   c                   @   r   )	AudioNoiseBatchNr   r   r   r   r   r   r    )r!   r"   r#   r   listr%   r   r   r   r   r   r   r   r    r&   r&   r&   r'   r)   2   r(   r)   linemanifest_filereturnc                 C   s   t | }d|v r|d|d< nd|v r|d|d< ntd| t|d |d|d< d|vr6d|d< d|d	< t|d |d |d	 |d
d|dd|dd|dd|ddd}|S )z
    Specialized function to parse the manifest file by ignoring text,
    such that nemo dataset can save time on tokenizing text.
    audio_filename
audio_fileaudio_filepathz:No 'audio_filename' or 'audio_filepath' in manifest item: )r/   r,   durationN textoffsetspeakerorig_sample_ratetoken_labelslang)r/   r1   r3   r4   r5   orig_srr7   r8   )jsonloadspopKeyErrorr   dictget)r+   r,   itemr&   r&   r'   _parse_manifest_item=   s*   






rA   batchbatch_augmentorc                 C   s  dd | D }dd | D }t | }dd | D }dd | D }dd | D }dd | D }g }	g }
g }t|D ]h\}}|d}||k rUd|| f}tjj||}|	| || }|d}||k rud|| f}tjj||}|
|d |  || }|d}||k rd|| f}tjj||}||d |  q:t	|	
 }t	| }t	|

 }t	| }t	|
 }t	| }t||||||d	}|d ur||}|S )
Nc                 S      g | ]}|j qS r&   )r   .0xr&   r&   r'   
<listcomp>c       z+_audio_noise_collate_fn.<locals>.<listcomp>c                 S   rD   r&   )r   rE   r&   r&   r'   rH   d   rI   c                 S   rD   r&   )r   rE   r&   r&   r'   rH   g   rI   c                 S   rD   r&   )r   rE   r&   r&   r'   rH   h   rI   c                 S   rD   r&   )r   rE   r&   r&   r'   rH   j   rI   c                 S   rD   r&   )r    rE   r&   r&   r'   rH   k   rI   r   )r   r   r   r   r   r    )maxr@   	enumeratesizetorchnn
functionalpadappendstackfloatlongr)   )rB   rC   audiosaudio_lengthsmax_audio_lennoisesnoise_lengthsnoisy_audiosnoisy_audio_lengthsaudio_signal_listnoise_signal_listnoisy_audio_signal_listir   r   rP   r   r   r   r    audio_signalnoise_signalnoisy_audio_signaloutputr&   r&   r'   _audio_noise_collate_fnb   sZ   



	rd   noise_manifestc                 C   sp   | du rg S t | tr| d} g }| D ]!}t|}tt|D ]}t|| d ||| d< q || q|S )zG
    load noise manifest from a single or a list of manifest files
    N,r0   )
isinstancer$   splitr   rangelenr   extend)re   
noise_datamanifest	curr_datar_   r&   r&   r'   load_noise_manifest   s   

ro   Td   samplesample_raterW   
pad_to_maxmin_white_noise_dbmax_white_noise_db	max_trialc              	   C   sz  |du rdn|| }|  dd}|  dd}	|durM|durM||krMd}
|
|k rLtjd|| }	tj| d |	||d}t|jdkrDn|
d7 }
|
|k s(ntj| d |	||d}t|jdkrwt	d	|  d
|	 d| d t
||d| tj|jtjd}t|d }|dur|r|d|k rd||d f}tjj||}||fS |d| }t| }||fS )a  
    Load noise audio from the manifest item, and apply white noise if the loaded noise audio is empty.
    Args:
        sample: a sample from the noise manifest
        sample_rate: target sample rate to resample the noise audio
        max_audio_len: the maximum audio length to load
        pad_to_max: whether to pad the audio to max_audio_len
        min_white_noise_db: the minimum white noise level in dB
        max_white_noise_db: the maximum white noise level in dB
        max_trial: the maximum number of trials to load noise audio before giving up
    Returns:
        noise: the loaded noise audio
        noise_len: the length of the loaded noise audio
    Nr1   r4   g        r   r0   )r/   r4   r1   	target_sr   zLoaded noise audio is empty: z, with sampled offset=z, duration=z. Adding white noise.)	min_level	max_level)dtype)r?   nprandomuniformr   	from_filesumsamplesr   warningr   perturbrM   tensorrS   rL   rT   rN   rO   rP   )rs   rt   rW   ru   rv   rw   rx   max_durr1   r4   cntaudio_segmentr   r   rP   r&   r&   r'   load_noise_audio   sN   r      rl   c           	   
   C   s  d}t | }t | }||k rt| dkrz| tjt|  }t	|||\}}W ||fS  t
yt } z7td| d| d |d7 }||krjtd| d t | t | fW  Y d}~S W Y d}~nd}~ww ||k rt| dks||fS )	a  
    Randomly sample noise audio from the noise manifest.
    Args:
        noise_data: the noise manifest data
        sample_rate: target sample rate to resample the noise audio
        max_audio_len: the maximum audio length to load
        max_trial: the maximum number of trials to load noise audio before giving up
    Returns:
        noise_audio: the sampled noise audio
        noise_len: the length of the sampled noise audio
    r   z&Error loading noise audio with config z and exception: z, retrying.rz   z!Failed to load noise audio after z  attempts, returning zero noise.N)rM   zerosrS   r   rT   rj   r~   r   randintr   	Exceptionr   r   )	rl   rt   rW   rx   r   noise_audior   noise_sampleer&   r&   r'   sample_noise   s(   (r   r   min_lenc                 C   s   ddg}|  d|k rS|dkr-|  ddkr-tt||  d }| |d| } | S |dks8|  ddkrItjj| d||  d f} | S t	d| d| | S )a  
    Pad audio to min_len with the specified mode
    Args:
        audio: the input audio tensor
        min_len: the minimum length to pad to
        pad_audio_mode: the padding mode, either 'repeat' or 'zero'
    Returns:
        audio: the padded audio tensor
    repeatzeror   NUnsupported pad_audio_mode: z, must be one of )
rL   intr~   ceilr   rM   rN   rO   rP   
ValueError)r   r   pad_audio_modeallowed_modenum_repeatsr&   r&   r'   	pad_audio  s   
r   c                	       sp   e Zd Zedd Z				ddedB dedB ded	ef fd
dZde	fddZ
dee	 defddZ  ZS )AudioNoiseDatasetc                 C      d S Nr&   selfr&   r&   r'   output_types)     zAudioNoiseDataset.output_typesN      ?r   re   rC   min_audio_len_secsr   c                    s>   t  jddtd| || _|| _t|| _|| _|| _d S )Nr   bos_idmanifest_parse_funcr&   	super__init__rA   re   rC   ro   rl   r   r   r   re   rC   r   r   kwargs	__class__r&   r'   r   .  s   	

zAudioNoiseDataset.__init__r-   c           
   	   C   s   | j j| }|j}|d u rd}| jj|j||j| j|j| j	d}|
ddkr0td| d t| j| jj }t||| j}t|jd  }t| j| jj| \}}tt||||||| |d}	|	S )Nr   )r4   r1   trimr9   channel_selectorzLoaded audio has zero length: .r   r   r   r   r   r   r    )manifest_processor
collectionr4   
featurizerprocessr/   r1   r   r9   r   rL   r   r   r   r   rt   r   r   rM   r   shaperT   r   rl   r@   r   r$   )
r   indexrs   r4   r   r   r   r   r   r@   r&   r&   r'   __getitem__>  s8   	zAudioNoiseDataset.__getitem__rB   c                 C      t || jS r   rd   rC   r   rB   r&   r&   r'   _collate_fn`     zAudioNoiseDataset._collate_fnNNr   r   )r!   r"   r#   propertyr   r$   r   rS   r   r   r   r   r)   r   __classcell__r&   r&   r   r'   r   (  s$    
"r   c                	       s|   e Zd Zedd Z				ddedB dedB ded	ef fd
dZdd Z	de
de
fddZdee defddZ  ZS )TarredAudioNoiseDatasetc                 C   r   r   r&   r   r&   r&   r'   r   e  r   z$TarredAudioNoiseDataset.output_typesNr   r   re   rC   r   r   c                    s>   t  jddtd| || _|| _t|| _|| _|| _dS )a  
        Args:
            noise_manifest: the noise manifest file
            batch_augmentor: the batch augmentor
            min_audio_len_secs: the minimum audio length in seconds, audios shorter than this will be padded
            pad_audio_mode: the padding mode for audios shorter than min_audio_len_secs, either 'repeat' or 'zero'
            **kwargs: other arguments for TarredAudioToCharDataset

        r   r   Nr&   r   r   r   r&   r'   r   j  s   

z TarredAudioNoiseDataset.__init__c              
   C   s$  |\}}}t jt j|\}}| jjj| | }| jj| }|j}	|	du r)d}	zt	|}
| j
j|
|	|j| j|jd}|
  W n tyY } ztd| d| dd}~ww t| j| j
j }t||| j}t|jd  }t| j| j
j| \}}tt||||||| |d}|S )z\Builds the training sample by combining the data from the WebDataset with the manifest info.Nr   )r4   r1   r   r9   zError reading audio sample: z, with exception: r   r   ) ospathsplitextbasenamer   r   mappingr4   ioBytesIOr   r   r1   r   r9   closer   RuntimeErrorr   r   rt   r   r   rM   r   r   rT   r   rl   r@   r   r$   )r   tupaudio_bytesr.   	offset_idfile_id_manifest_idxmanifest_entryr4   audio_filestreamr   r   r   r   r   r   r@   r&   r&   r'   _build_sample  sF   

	z%TarredAudioNoiseDataset._build_sampler   r-   c                 C   s   t | j| jj }|d|k rK| jdkr,t t||d }||d | }|S | jdkrBt	j
j|d||d f}|S td| j d|S )Nr   r   r   r   z#, must be one of ['repeat', 'zero'])r   r   r   rt   rL   r   r~   r   r   rM   rN   rO   rP   r   )r   r   r   r   r&   r&   r'   
_pad_audio  s   

z"TarredAudioNoiseDataset._pad_audiorB   c                 C   r   r   r   r   r&   r&   r'   r     r   z#TarredAudioNoiseDataset._collate_fnr   )r!   r"   r#   r   r   r$   r   rS   r   r   r   r   r   r   r)   r   r   r&   r&   r   r'   r   d  s&    
+r   c                       s4   e Zd ZddedB def fddZdd Z  ZS )	LhotseAudioNoiseDatasetNre   batch_augmentor_cfgc                    s>   t    |rt|}nd }|| _t|| _tdd| _d S )NT)fault_tolerant)	r   r   r   from_config_dictrC   ro   rl   r	   
load_audio)r   re   r   rC   r   r&   r'   r     s   

z LhotseAudioNoiseDataset.__init__c                    sN    \ fddD  fddttD }t|jS )Nc                    s   g | ]}t  j|j|jqS r&   )r   rl   sampling_ratenum_samples)rF   cutr   r&   r'   rH     s    z7LhotseAudioNoiseDataset.__getitem__.<locals>.<listcomp>c                    sX   g | ](}t t| j|  | | d  | d | | d    | dqS )r   rz   r   )r   r$   id)rF   r_   )
audio_lensrU   cutssampled_noisesr&   r'   rH     s    


)r   ri   rj   rd   rC   )r   r   itemsr&   )r   rU   r   r   r   r'   r     s   

z#LhotseAudioNoiseDataset.__getitem__NN)r!   r"   r#   r$   r
   r   r   r   r&   r&   r   r'   r     s    r   config	augmentorc                 C   sb   t | dd || d | dd | d | dd|| dd | dd | d	d| d
d d}|S )Nre   manifest_filepathlabelsrt   
int_valuesFmax_durationmin_durationtrim_silencer   )re   rC   r   r   rt   r   r   r   r   r   r   )r   r?   )r   r   rC   datasetr&   r&   r'   get_audio_noise_dataset  s   






r   global_rank
world_sizec           
      C   s   | d }g }t |dkr!t|d ts!td|  | d d }|D ]}t| }||d< t||d}	||	 q#t	|| 
dd| 
dd	| 
d
d| 
dd | 
dd| 
dd ||d	}	|	S )Nr   rz   r   z%removing an extra nesting level from )r   r   concat_sampling_techniquetemperatureconcat_sampling_temperature   concat_sampling_scaleconcat_sampling_probabilitiesconcat_shuffleTconcat_sampling_seedsampling_techniquesampling_temperaturesampling_scalesampling_probabilitiesshuffleseedr   r   )rj   rg   r$   r   infocopydeepcopyr   rQ   r   r?   )
r   r   r   r   rC   manifest_filepathsdatasetsr   confr   r&   r&   r'   get_concat_audio_noise_dataset  s,   






r  c              	      s  | d }| d }g t |}t |}| dd }|r2t|D ]\}	}
t|
tr-|
dkr1tdq t|t|krItdt| dt| dtt||D ]\}\}}t|d	kr`|d }t|d	krj|d }d
|v rtd|v rtdnd}t	
d| d| d|  td&i d| dd d|d|d|d| dd d| d d| ddd|d|d| dd d| dd d| ddd| ddd |d!|d"| |r fd#d$t|| D  qP  qPt j| |d%S )'Ntarred_audio_filepathsr   bucketing_weightsr   z(bucket weights must be positive integerszmanifest_filepaths (length=z%) and tarred_audio_filepaths (length=z*) need to have the same number of buckets.rz   _OP__CL_TFz%Loading TarredAudioNoiseDataset from z and z, shard=re   rC   audio_tar_filepathsr   rt   r   r   	shuffle_nr   r   r   r   shard_strategytarred_shard_strategyscattershard_manifestsr   r   c                    s   g | ]}  qS r&   )rQ   )rF   r   r   r  r&   r'   rH   E  s    z2get_tarred_audio_noise_dataset.<locals>.<listcomp>)r  	ds_configrankr&   )r   convert_to_config_listr?   rK   rg   r   r   rj   zipr   r  r   ri   rQ   get_chain_dataset)r   r  r   r   r   rC   r  r  r	  idxweightdataset_idxtarred_audio_filepathr   is_sharded_manifestr&   r  r'   get_tarred_audio_noise_dataset  s~   

	
r  c                 C   s   | d }| d }g }t t||D ]"\}	\}
}t| }||d< |
|d< t||||||d}|| qt|| dd| dd| dd	| d
d | dd| dd ||d	}|S )Nr  r   r   r  r   r   r   rC   r   r   r   r   r   rz   r   r   Tr   r   )rK   r  r  r  r  rQ   r   r?   )r   r  r   r   r   rC   r  r  r  r  r  r   r  r   r&   r&   r'   %get_concat_tarred_audio_noise_datasetL  s<   






r  c           	      C   s  d| v rt | d ||d}nd }d| v rt| d }nd }| dd}|r| dd d u r:td|   d| d< | d d	krd
| vrrtd|   t|  dt| d  gt| d  | d
< W d    n1 slw   Y  ntt	| d
 dddst
d|  | d }| ddrd| v r| d d u sd| v r| d d u rtd|   d S |r| dd| d  nd}|rt| |||||d}|S t| |||||d}|S d| v r| d d u rtd|   d S |rt| ||||d}|S t| ||d}|S )Nr   )r   r   rC   	is_concatFr   zhConcat dataset requires `concat_sampling_technique` but it was not provided, using round-robin. Config: zround-robinr   r   z]Concat dataset requires `concat_sampling_probabilities` list, using uniform weights. Config: rz   r   gư>)abs_tolzN`concat_sampling_probabilities` need to sum to 1 with 1e-6 tolerance. Config: r   	is_tarredr  znCould not load dataset as `manifest_filepath` was None or `tarred_audio_filepaths` is None. Provided config : r     
batch_sizer   r  zJCould not load dataset as `manifest_filepath` was None. Provided config : )r   r   r   r   rC   )r   r   rC   )r   r   r   r?   r   r   r   rj   r   r   r   r  r  r  r   )	r   r   r   r   rC   r   r   r  r   r&   r&   r'   #get_audio_noise_dataset_from_configp  s   
	r%  r   )NTrp   rq   rr   )Nr   r   )Br  r   r:   r   dataclassesr   mathr   typingr   r   r   r   r   numpyr~   rM   lhotse.datasetr	   	omegaconfr
   r   r   r   nemo.collections.asr.datar   r   0nemo.collections.asr.parts.preprocessing.perturbr   r   0nemo.collections.asr.parts.preprocessing.segmentr   /nemo.collections.asr.parts.utils.manifest_utilsr   $nemo.collections.common.data.datasetr   4nemo.collections.common.parts.preprocessing.manifestr   nemo.core.classesr   
nemo.utilsr   r   r)   r$   rA   rd   ro   r   boolr   r   r   AudioToCharDatasetr   TarredAudioToCharDatasetr   utilsdataDatasetr   r   r  r  r  r%  r&   r&   r&   r'   <module>   s   

%9

$G<["



!9
$