o
    }oi%                 	   @   s  d dl Z d dlZd dlZd dlZd dlZd dlmZ d dlmZ d dl	m
Z
mZmZmZmZ d dlZd dlZd dlZd dlmZ d dlmZ d dlmZ d dlmZ d d	lmZmZmZ d d
lmZm Z m!Z!m"Z" d dl#m$Z$m%Z%m&Z&m'Z'm(Z(m)Z)m*Z*m+Z+m,Z,m-Z-m.Z.m/Z/m0Z0m1Z1 d dl2m3Z3 d dl4m5Z5 z
d dl6m7Z7 dZ8W n e9e:fy   dZ7dZ8Y nw dZ;ej<ej=ej>ej?ddZ@G dd de3ZAG dd deAZBG dd de3ZCG dd de3ZDG dd de3ZEG dd dejFjGjHjIZJdS )    N)defaultdict)Path)CallableDictListOptionalUnion)	rearrange)tqdm)WaveformFeaturizer)AudioSegment)BaseTokenizerEnglishCharsTokenizerEnglishPhonemesTokenizer)BetaBinomialInterpolator beta_binomial_prior_distributiongeneral_paddingget_base_dir)DATA_STR2DATA_CLASSMAIN_DATA_TYPESAlignPriorMatrix	DurationsEnergyLMTokensLogMelP_voicedPitchReferenceAudio	SpeakerIDTTSDataTypeVoiced_maskWithLens)Dataset)logging)
NormalizerTFg&.>)hannhammingblackmanbartlettnonec                9       s  e Zd Z																									dMd	eeeee ee f d
edeee	egee f f de
ee  de
eee	egef f  de
e de
e de
ee  de
eeef  de
e de
e de
eeef  dede
e de
e de
e de
e dede
e de
e dededed e
e d!e
e d"ed#ed$ef8 fd%d&Zed'd( Zd)d* Zd+d, Zd-d. Zd/d0 Zd1d2 Zd3d4 Zd5d6 Zd7d8 Zd9d: Zd;d< Zd=d> Zd?d@ ZdAdB ZdCdD ZdEdF Z dGdH Z!dIdJ Z"dKdL Z#  Z$S )N
TTSDatasetNF   r%   P   r   T   manifest_filepathsample_ratetext_tokenizertokenstext_normalizertext_normalizer_call_kwargstext_tokenizer_pad_idsup_data_typessup_data_pathmax_durationmin_durationignore_filetrimtrim_reftrim_top_dbtrim_frame_lengthtrim_hop_lengthn_fft
win_length
hop_lengthwindown_melslowfreqhighfreqsegment_max_durationpitch_augmentcache_pitch_augmentpad_multiplec           )   
      s  t    | _d _t jtr|j _t jdd _n|du r&t	d|du r.t	d| _ jdu r8dnd _
| _ jdu rGd _ntsMtdt jtrW jjn j _|dura|ni  _t|trl|g}| _g  _g }d} jD ]} tt|  d	}!td
|  d t|!D ]}"t|"}#|#d |#d d|#v r|#d ndd|#v r|#d ndd|#v r|#d ndd}$d|#v r|#d |$d< n"d|#v r|#d |$d< n|#d }% jdur j|%fi  j}%|%|$d<  j
r |$d |$d< ||$  jtj |#d |d   |$d du rtd d}|dur$||#d 7 }qW d   n	1 s0w   Y  qytdt!| d |durRtd|d dd t"#||||
| _$t%dd  j$D  _&| _'t( j'd _)| _*|dur{|nt+j, _-|dur|nd  _.|dur|nd! _/|dur|nd" _0| _1| _2| _3| _4| _5| _6| _7| _8|p j4 _9| _: j:pǈ j4d#  _;t<j=t>j?j@ j' j4 j5 j6 j7d$t<jAd%Bd _CztD j8 W n tEy   tFd& j8 d'tGtDH  dw  fd(d) _I|	durt|	jJddd* |	 _Kg  _L|durZ|D ] }&ztM|& }'W n tEy>   tFd&|& d+w  jL|' q&d,|v sQd-|v rZd.|vrZt	d/tN jL _O jLD ]}(t d0|(jP d1i | qc| _QdS )2a  Dataset which can be used for training spectrogram generators and end-to-end TTS models.
        It loads main data types (audio, text) and specified supplementary data types (log mel, durations, align prior matrix, pitch, energy, speaker id).
        Some supplementary data types will be computed on the fly and saved in the sup_data_path if they did not exist before.
        Saved folder can be changed for some supplementary data types (see keyword args section).
        Arguments for supplementary data should be also specified in this class, and they will be used from kwargs (see keyword args section).
        Args:
            manifest_filepath (Union[str, Path, List[str], List[Path]]): Path(s) to the .json manifests containing information on the
                dataset. Each line in the .json file should be valid json. Note: the .json file itself is not valid
                json. Each line should contain the following:
                    "audio_filepath": <PATH_TO_WAV>,
                    "text": <THE_TRANSCRIPT>,
                    "normalized_text": <NORMALIZED_TRANSCRIPT> (Optional),
                    "mel_filepath": <PATH_TO_LOG_MEL_PT> (Optional),
                    "duration": <Duration of audio clip in seconds> (Optional),
            sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to.
            text_tokenizer (Optional[Union[BaseTokenizer, Callable[[str], List[int]]]]): BaseTokenizer or callable which represents text tokenizer.
            tokens (Optional[List[str]]): Tokens from text_tokenizer. Should be specified if text_tokenizer is not BaseTokenizer.
            text_normalizer (Optional[Union[Normalizer, Callable[[str], str]]]): Normalizer or callable which represents text normalizer.
            text_normalizer_call_kwargs (Optional[Dict]): Additional arguments for text_normalizer function.
            text_tokenizer_pad_id (Optional[int]): Index of padding. Should be specified if text_tokenizer is not BaseTokenizer.
            sup_data_types (Optional[List[str]]): List of supplementary data types.
            sup_data_path (Optional[Union[Path, str]]): A folder that contains or will contain supplementary data (e.g. pitch).
            max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            ignore_file (Optional[Union[str, Path]]): The location of a pickle-saved list of audio paths
                that will be pruned prior to training. Defaults to None which does not prune.
            trim (bool): Whether to apply `librosa.effects.trim` to trim leading and trailing silence from an audio
                signal. Defaults to False.
            trim_ref (Optional[float]): the reference amplitude. By default, it uses `np.max` and compares to the peak
                amplitude in the signal.
            trim_top_db (Optional[int]): the threshold (in decibels) below reference to consider as silence.
                Defaults to 60.
            trim_frame_length (Optional[int]): the number of samples per analysis frame. Defaults to 2048.
            trim_hop_length (Optional[int]): the number of samples between analysis frames. Defaults to 512.
            n_fft (int): The number of fft samples. Defaults to 1024
            win_length (Optional[int]): The length of the stft windows. Defaults to None which uses n_fft.
            hop_length (Optional[int]): The hope length between fft computations. Defaults to None which uses n_fft//4.
            window (str): One of 'hann', 'hamming', 'blackman','bartlett', 'none'. Which corresponds to the
                equivalent torch window function.
            n_mels (int): The number of mel filters. Defaults to 80.
            lowfreq (int): The lowfreq input to the mel filter calculation. Defaults to 0.
            highfreq (Optional[int]): The highfreq input to the mel filter calculation. Defaults to None.
        Keyword Args:
            log_mel_folder (Optional[Union[Path, str]]): The folder that contains or will contain log mel spectrograms.
            pitch_folder (Optional[Union[Path, str]]): The folder that contains or will contain pitch.
            voiced_mask_folder (Optional[Union[Path, str]]): The folder that contains or will contain voiced mask of the pitch
            p_voiced_folder (Optional[Union[Path, str]]): The folder that contains or will contain p_voiced(probability) of the pitch
            energy_folder (Optional[Union[Path, str]]): The folder that contains or will contain energy.
            durs_file (Optional[str]): String path to pickled durations location.
            durs_type (Optional[str]): Type of durations. Currently, supported only "aligner-based".
            use_beta_binomial_interpolator (Optional[bool]): Whether to use beta-binomial interpolator for calculating alignment prior matrix. Defaults to False.
            pitch_fmin (Optional[float]): The fmin input to librosa.pyin. Defaults to librosa.note_to_hz('C2').
            pitch_fmax (Optional[float]): The fmax input to librosa.pyin. Defaults to librosa.note_to_hz('C7').
            pitch_mean (Optional[float]): The mean that we use to normalize the pitch.
            pitch_std (Optional[float]): The std that we use to normalize the pitch.
            segment_max_duration (Optional[float]): If audio length is greater than segment_max_duration, take a random segment of segment_max_duration (Used for SV task in SSLDisentangler)
            pitch_augment (bool): Whether to apply pitch-shift transform and return a pitch-shifted audio. If set as False, audio_shifted will be None (used in SSLDisentangler)
            cache_pitch_augment (bool): Whether to cache pitch augmented audio or not. Defaults to False (used in SSLDisentangler)
            pad_multiple (int): If audio length is not divisible by pad_multiple, pad the audio with zeros to make it divisible by pad_multiple (used in SSLDisentangler)
            pitch_norm (Optional[bool]): Whether to normalize pitch or not. If True, requires providing either
                pitch_stats_path or (pitch_mean and pitch_std).
            pitch_stats_path (Optional[Path, str]): Path to file containing speaker level pitch statistics.
            reference_audio_type (Optional[str]): Criterion for the selection of reference audios for the GlobalStyleToken submodule. Currently, supported values are "ground-truth" (reference audio = ground truth audio, like in the original GST paper) and "same-speaker" (reference audio = random audio from the same speaker). Defaults to "same-speaker".
        Nphoneme_probabilityzNtext_tokenizer_pad_id must be specified if text_tokenizer is not BaseTokenizerz?tokens must be specified if text_tokenizer is not BaseTokenizerTFz`nemo_text_processing` is not installed, see https://github.com/NVIDIA/NeMo-text-processing for details. If you wish to continue without text normalization, please remove the text_normalizer part in your TTS yaml file.r   rLoading dataset from .audio_filepathtextmel_filepathdurationspeaker)rN   original_textrP   rQ   
speaker_idnormalized_texttext_normalizedtext_tokens   QNot all audio files have duration information. Duration logging will be disabled.Loaded dataset with  files.Dataset contains   .2f hours.c                 S      g | ]}|d  qS rN    .0itemrb   rb   U/home/ubuntu/.local/lib/python3.10/site-packages/nemo/collections/tts/data/dataset.py
<listcomp>      z'TTSDataset.__init__.<locals>.<listcomp>r/   <   i   i      )srr?   rC   fminfmaxdtypez'Current implementation doesn't support z  window. Please choose one from c              	      s<   t j|  j j jr jddt jddS d ddS )NF)periodicT)inputr?   rA   r@   rB   return_complex)torchstftr?   hop_lenr@   tofloat)xself	window_fnrb   rf   <lambda>7  s    z%TTSDataset.__init__.<locals>.<lambda>)parentsexist_okz type.voiced_maskp_voicedpitchzPlease add 'pitch' to sup_data_types in YAML because 'pitch' is required when using either 'voiced_mask' or 'p_voiced' or both.add_rb   )Rsuper__init__r0   rJ   
isinstancer   padr4   getattr
ValueError
cache_textr2   text_normalizer_callPYNINI_AVAILABLEImportErrorr$   	normalizer3   strr.   lengthsopenr   
expanduserr#   infor
   jsonloadsappendospathgetsizelenr*   filter_filesdatar   base_data_dirr/   r   
featurizerr:   npmaxr;   r<   r=   r>   rF   rG   rH   r?   rC   rD   rE   rB   r@   rA   rv   rt   tensorlibrosafiltersmelrx   	unsqueezefbWINDOW_FN_SUPPORTEDKeyErrorNotImplementedErrorlistkeysru   mkdirr6   r5   r   setsup_data_types_setnamerI   ))r{   r.   r/   r0   r1   r2   r3   r4   r5   r6   r7   r8   r9   r:   r;   r<   r=   r>   r?   r@   rA   rB   rC   rD   rE   rF   rG   rH   rI   kwargsr   total_durationmanifest_fileflinere   	file_inforO   d_as_strsup_data_type	data_type	__class__rz   rf   r   R   s   
d







 
'






 
zTTSDataset.__init__c                 C   sD  |r+t d| d tt| d}tt|}W d    n1 s&w   Y  g }|d ur3dnd }d}	| D ]@}
|
d }|d ur^|rK|
d |k sS|r^|
d |kr^||
d 7 }|	d7 }	q9|rt||v rt|	d7 }	||
d 7 }|| q9|	|
 q9t d|	 d	t
| d
 |d urt d|d dd|| d dd |S )NzUsing z to prune dataset.rbr   rN   rQ   r-   zPruned z files. Final dataset contains z filesr]   r^   z hours. Final dataset contains r_   )r#   r   r   r   r   r   pickleloadremover   r   )r   r9   r8   r7   r   r   wavs_to_ignorefiltered_datapruned_durationpruned_itemsre   
audio_pathrb   rb   rf   r   \  s@   
zTTSDataset.filter_filesc                 K   X   | dd | _| jd u rt| jtj | _nt| jtr"t| j| _| jjddd d S )Nlog_mel_folderTr   r~   )	popr   r   r6   r   r   r   r   r   r{   r   rb   rb   rf   add_log_mel     
zTTSDataset.add_log_melc                 K   sh   | d}| d}t|}g | _dd | jD D ]}|| }|dkr+| j| qt| dd S )N	durs_file	durs_typec                 S   s   g | ]	}t |d  jqS ra   )r   stem)rd   drb   rb   rf   rg     s    z,TTSDataset.add_durations.<locals>.<listcomp>zaligner-basedzP duration type is not supported. Only aligner-based is supported at this moment.)r   rt   r   dursr   r   r   )r{   r   r   r   audio_stem2durstagr   rb   rb   rf   add_durations  s   


zTTSDataset.add_durationsc                 K   sH   | dd| _| jsd|v r| jstd d| _| jr"t | _d S d S )Nuse_beta_binomial_interpolatorFzphoneme_probability is not None, but use_beta_binomial_interpolator=False, we set use_beta_binomial_interpolator=True manually to use phoneme_probability.T)r   r   r   r#   warningr   beta_binomial_interpolatorr   rb   rb   rf   add_align_prior_matrix  s   z!TTSDataset.add_align_prior_matrixc                 K   sf  | dd | _| jd u rt| jtj | _nt| jtr"t| j| _| jjddd | dt	
d| _| dt	
d| _| dd | _| d	d | _| d
d| _| dd }| jr| jd u | jd u kssJ d| j d| j d| jd u |d u ksJ d| j d| j d| |d urtt|ddd}t|| _W d    d S 1 sw   Y  d S d S )Npitch_folderTr   
pitch_fminC2
pitch_fmaxC7
pitch_mean	pitch_std
pitch_normFpitch_stats_pathz*Found only 1 of (pitch_mean, pitch_std): (z, )zYpitch_norm requires exactly 1 of (pitch_mean, pitch_std) or pitch_stats_path. Provided: (z) and rK   zutf-8)encoding)r   r   r   r6   r   r   r   r   r   r   
note_to_hzr   r   r   r   r   r   r   r   pitch_stats)r{   r   r   pitch_frb   rb   rf   	add_pitch  s@   
"zTTSDataset.add_pitchc                 K   >   | dd | _| jd u rt| jtj | _| jjddd d S )Nvoiced_mask_folderTr   )r   r   r   r6   r    r   r   r   rb   rb   rf   add_voiced_mask     
zTTSDataset.add_voiced_maskc                 K   r   )Np_voiced_folderTr   )r   r   r   r6   r   r   r   r   rb   rb   rf   add_p_voiced  r   zTTSDataset.add_p_voicedc                 K   r   )Nenergy_folderTr   )	r   r   r   r6   r   r   r   r   r   r   rb   rb   rf   
add_energy  r   zTTSDataset.add_energyc                 K   s   d S Nrb   r   rb   rb   rf   add_speaker_id  s   zTTSDataset.add_speaker_idc                    s   | dd}|dkr4t jv sJ dttt jD ]\}}|d  | q fdd _d S |dkr?dd  _d S t	d	| d
)Nreference_audio_typezsame-speakerz(Please add speaker_id in sup_data_types.rT   c                    s    j tt| d   S )NrT   )r   randomchoicetuplesampler{   speaker_to_index_maprb   rf   r}     s    z0TTSDataset.add_reference_audio.<locals>.<lambda>zground-truthc                 S   s   | S r   rb   r   rb   rb   rf   r}     s    zReference audio type "z" is not supported.)
r   r   r5   r   r   	enumerater   addget_reference_for_sampler   )r{   r   r   ir   rb   r   rf   add_reference_audio  s   zTTSDataset.add_reference_audioc                 C   sz   t jj|jjdd) | |}|jt jt jfv rt 	|}t 
|ddt }W d    |S 1 s6w   Y  |S )NFenabledrX   )rt   ampautocastdevicetyperu   rp   cfloatcdoubleview_as_realsqrtpowsumEPSILON)r{   audiospecrb   rb   rf   get_spec  s   


zTTSDataset.get_specc                 C   sx   t jj|jjdd( | |}t | j|j	|}t 
t j|t |j	jd}W d    |S 1 s5w   Y  |S )NFr   )min)rt   r   r  r  r  r  matmulr   rw   rp   logclampfinfotiny)r{   r  r  r   log_melrb   rb   rf   get_log_mel  s   
 
zTTSDataset.get_log_melc           	      C   s   t | j| d }| r| jrt|}|S tjdd}tjdd}t	||g}t
jj|||d}t|}| jrDt|| |S )Nz_pitch_shift.ptr   r-   rk   )rl   n_steps)r   r6   existsrH   rt   r   r   r   uniformr   r   effectspitch_shiftr   save)	r{   r  rl   rel_audio_path_as_text_idaudio_shifted_pathaudio_shiftedchoice1choice2	shift_valrb   rb   rf   r  
  s   

zTTSDataset.pitch_shiftc                 C   sN   | j dkr%|jd | j  dkr%t|tj| j |jd | j   tjdg}|S )Nr-   r   ro   )rI   shapert   catzerosrx   )r{   wavrb   rb   rf   _pad_wav_to_multiple  s   
$zTTSDataset._pad_wav_to_multiplec           *      C   s  | j | }t|d | jd}t|dd}| jd urdd|v rd|d | jkrdt| j| j	 }t
j|d | j	|| jd}d }| jrFJ t|j}| jdkrV| |}|t|jd  }}	nQ| jj|d | j| j| j| j| jd	}| jdkr| |}d }| jr| |   | j	|}| | ksJ d
 | | |t|jd  }}	d|v rt|d  }
tt!|
 }n| "|d }t| }
tt!| }d\}}t#| j$v r0|d }|d urt|% rt&|}n| j'| d }|% rt&|}n| (|}t)|| |*d}t|jd  }d }t+| j$v r=| j,| }d }t-| j$v re| (|jd }| j.r]t/| 0||1 }nt/t2||}g }t3 }t4t5t6t7gD ]7\}}|| j$v rt8| |j9 d}|| d }|% r|:|j9t&|;  qq|<||j9|f qqt!|dkrt=j>| | j?| j@| jA| j	dd}|D ]\}}}|:|t/|| ;  t)|B|| q|Bdd }|Bdd }|Bdd }|Bdd }|d urjtt!| }| jCrj| jDd ur| jEd ur| jD}| jE} n>| jFrWd|v r:t|d | jFv r:| jFt|d  }!nd| jFv rF| jFd }!ntGd| d|!d }|!d } ntGd||8 }d||| k< ||  }d\}"}#tH| j$v r| jI| d }$|$% rt&|$; }"n| J|}%tjKjL|%*ddd; }"t)|"|$ tt!|" }#d }&tM| j$v rt|d  }&d\}'}(tN| j$v r| O|})| jj|)d | j| j| j| j| jd	}'t|'jd  }(||	|
||||||||"|#|&||||'|(fS ) NrN    /_rQ   	target_sr
n_segmentsr:   r-   r   )r:   r;   r<   r=   r>   z{} != {}rW   rU   )NNrP   .ptrX   _folder        )rm   rn   frame_lengthrl   fill_nar   pitch_lengthr   r   rT   defaultzCould not find pitch stats for rM   r   r   z+Missing statistics for pitch normalization.)axis)Pr   r   relative_tor   with_suffixr   replacerF   intr/   r   segment_from_filer:   rG   rt   r   samplesrI   r'  r#  longr   processr;   r<   r=   r>   r  cpudetachnumpysizeformatr   r0   r   r   r  r   r   r  r  squeezer   r   r   r   
from_numpyr   re   r   localsr   r   r    r   r   r   __setitem__rx   r   r   pyinr   r   r@   getr   r   r   r   r   r   r   r  linalgnormr   r   r   )*r{   indexr   rel_audio_pathr  r-  featuresr  r  audio_lengthrO   text_length	tokenizedr  log_mel_lengthmel_path	durationsalign_prior_matrixmel_lennon_exist_voiced_indexmy_varr   voiced_itemvoiced_foldervoiced_filepathvoiced_tuplevoiced_namer   r3  r   r   sample_pitch_meansample_pitch_stdr   energyenergy_lengthenergy_pathr  rT   reference_audioreference_audio_length	referencerb   rb   rf   __getitem__"  s*  





	










zTTSDataset.__getitem__c                 C   
   t | jS r   r   r   r{   rb   rb   rf   __len__     
zTTSDataset.__len__c                 C   sV   g }t | j D ]}|||j  t|tr&t|tr&|||j d  qt|S )N_lens)r   r5   r   r   
issubclassr   r!   r   )r{   	data_dictresultr   rb   rb   rf   	join_data  s   zTTSDataset.join_datac           3      C   s  t | \}}}}}}}}}}	}
}}}}}}}t| }t| }t| jv r+t|nd }t| jv r;tdd |D nd }t| jv rHt|	 nd }t| jv rUt| nd }t| jv rbt| nd }t| jv rtt	
|d d jj}t| jv rt	t|tdd |D tdd |D ng }g g g g g g g g g g g f\}}}}}}
}}}}}t|D ]\}}|\} }!}"}#}$}%}&}'}(})}*}+},}-}.}/}0}1t| |! |} ||  t|"|# || jd}"||" |/d urt|/|! |}/||/ t| jv r|t|$|%||d t| jv r|t|&t|&| t| jv r-|'||d |'jd d |'jd f< t| jv r>|t|(|) | t| jv rO|t|-|) | t| jv r`|t|.|) | t| jv rq|
t|*|+ | t| jv r|||, t| jv r|t|0|1 | qi d	t	|d
t	|dt	|dt	|dt| jv rt	|nd dt| jv rt	|nd dt| jv rt	|nd dt| jv r|nd dt| jv rt	|nd dt| jv rt	|	nd dt| jv rt	|
nd dt| jv rt	|nd dt| jv r t	|nd dt| jv r.t	|nd dt| jv r<t	|nd d|/d urIt	|nd dt| jv rWt	|nd dt| jv ret	|nd i}2|2S )Nc                 S   s   g | ]}t |qS rb   )r   rd   r   rb   rb   rf   rg     rh   z1TTSDataset.general_collate_fn.<locals>.<listcomp>r   rk   c                 S      g | ]}|j d  qS r   r#  rd   prior_irb   rb   rf   rg         c                 S   rq  )r-   rs  rt  rb   rb   rf   rg     rv  )	pad_valuer-   r  
audio_lensrO   	text_lensr  log_mel_lensrS  rT  r   
pitch_lensr_  energy_lensrT   r   r   r  rb  reference_audio_lens)zipr   re   r   r   r   r   r   r   rt   r  rp   r  r   r%  r   r   r   r   r4   r#  r    r   r   stack)3r{   batchr*  audio_lengthstokens_lengthslog_mel_lengthsdurations_listalign_prior_matrices_listpitchespitches_lengthsenergiesenergies_lengthsvoiced_masks	p_voicedsreference_audio_lengthsmax_audio_lenmax_tokens_lenmax_log_mel_lenmax_durations_lenmax_pitches_lenmax_energies_lenmax_reference_audio_lenlog_mel_padalign_prior_matricesaudiosr1   log_melsspeaker_idsaudios_shiftedreference_audiosr   sample_tupler  	audio_lentoken	token_lenr  log_mel_lenrS  rT  r   r3  r_  r`  rT   r   r   r  rb  reference_audios_lengthrm  rb   rb   rf   general_collate_fn  s.   




 




	
zTTSDataset.general_collate_fnc                 C   s   |  |}| |}|S r   )r  ro  )r{   r  rm  joined_datarb   rb   rf   _collate_fn  s   

zTTSDataset._collate_fn)NNNNNNNNNFNNNNr+   NNr%   r,   r   NNFTr-   )%__name__
__module____qualname__r   r   r   r   r9  r   r   r   r$   r   rx   boolr   staticmethodr   r   r   r   r   r   r   r   r   r   r  r  r  r'  re  ri  ro  r  r  __classcell__rb   rb   r   rf   r*   Q   s    

	
  
&
"
 D
 r*   c                       s@   e Zd Z fddZdd Zdd Z fddZd	d
 Z  ZS )MixerTTSXDatasetc                    s   t  jdi | d S )Nrb   )r   r   r   r   rb   rf   r     s   zMixerTTSXDataset.__init__c                 C   s   ddl m} |d| _| jd| _| jd}i | _t| jD ]5\}}|d }t	| j
ts7t	| j
ts7J | j
|}| jj|dd}| j
jrQ|g| |g }|| j|< q!d S )	Nr   )AlbertTokenizerzalbert-base-v2z<pad>u   ▁rU   F)add_special_tokens)transformersr  from_pretrainedlm_model_tokenizer_convert_token_to_idlm_padding_valueid2lm_tokensr   r   r   r0   r   r   text_preprocessing_funcencodepad_with_space)r{   r  space_valuer   r   rU   preprocess_text_as_tts_inputlm_tokens_as_idsrb   rb   rf   _albert  s    zMixerTTSXDataset._albertc                 K   s,   | d}|dkr|   d S t| d)Nlm_modelalbertzD lm model is not supported. Only albert is supported at this moment.)r   r  r   )r{   r   r  rb   rb   rf   add_lm_tokens  s   
zMixerTTSXDataset.add_lm_tokensc                    sr   t  |\}}}}}}}}	}
}}}}}}}d }t| jv r't| j|  }||||||||	|
|||||||fS r   )r   re  r   r   rt   r   r  r<  )r{   rK  r  rN  rO   rO  r  rQ  rS  rT  r   r3  r_  r`  rT   r   r   r*  	lm_tokensr   rb   rf   re    sL   

zMixerTTSXDataset.__getitem__c                 C   s   t t| }| t t|d d  }|d }t| jv rHtjt|tdd |D f| j	d}t
|D ]\}}|||d |jd f< q3||tj< | |}|S )N   c                 S   rq  rr  rs  )rd   r  rb   rb   rf   rg     rv  z0MixerTTSXDataset._collate_fn.<locals>.<listcomp>)
fill_valuer   )r   r~  r  r   r   rt   fullr   r   r  r   r#  r   ro  )r{   r  rm  lm_tokens_listr  r   lm_tokens_ir  rb   rb   rf   r    s   


zMixerTTSXDataset._collate_fn)	r  r  r  r   r  r  re  r  r  rb   rb   r   rf   r    s    
,r  c                       s   e Zd Z							ddeeeee ee f dedee dee	 dee	 deeeef  d	ee
 d
e
dee f fddZdd Zdd Zdd Z  ZS )VocoderDatasetNFr.   r/   r-  r7   r8   r9   r:   load_precomputed_melrA   c
              	      s  t    |r|	du rtd|du rtdt|tr|g}|| _g }
d}| jD ]q}tt| d^}t	
d| d t|D ]H}t|}d|vrV|rVtd	| |d
 d|v ra|d ndd|v rj|d ndd}|
| |d du rt	
d d}|dur||d 7 }qBW d   n1 sw   Y  q)t	
dt|
 d |durt	
d|d dd t|
||||| _tdd | jD | _|| _t|d| _|| _|| _|	| _|| _dS )a	  Dataset which can be used for training and fine-tuning vocoder with pre-computed mel-spectrograms.
        Args:
            manifest_filepath (Union[str, Path, List[str], List[Path]]): Path(s) to the .json manifests containing
            information on the dataset. Each line in the .json file should be valid json. Note: the .json file itself
            is not valid json. Each line should contain the following:
                "audio_filepath": <PATH_TO_WAV>,
                "duration": <Duration of audio clip in seconds> (Optional),
                "mel_filepath": <PATH_TO_LOG_MEL> (Optional, can be in .npy (numpy.save) or .pt (torch.save) format)
            sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to.
            n_segments (int): The length of audio in samples to load. For example, given a sample rate of 16kHz, and
                n_segments=16000, a random 1-second section of audio from the clip will be loaded. The section will
                be randomly sampled everytime the audio is batched. Can be set to None to load the entire audio.
                Must be specified if load_precomputed_mel is True.
            max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            ignore_file (Optional[Union[str, Path]]): The location of a pickle-saved list of audio paths
                that will be pruned prior to training. Defaults to None which does not prune.
            trim (bool): Whether to apply librosa.effects.trim to the audio file. Defaults to False.
            load_precomputed_mel (bool): Whether to load precomputed mel (useful for fine-tuning).
                Note: Requires "mel_filepath" to be set in the manifest file.
            hop_length (Optional[int]): The hope length between fft computations. Must be specified if load_precomputed_mel is True.
        Nz>hop_length must be specified when load_precomputed_mel is Truez>n_segments must be specified when load_precomputed_mel is Truer   rK   rL   rM   rP   zmel_filepath is missing in rN   rQ   )rN   rP   rQ   rY   rZ   r[   r\   r]   r^   r_   c                 S   r`   ra   rb   rc   rb   rb   rf   rg   P  rh   z+VocoderDataset.__init__.<locals>.<listcomp>ri   )r   r   r   r   r   r.   r   r   r   r#   r   r
   r   r   r   r   r*   r   r   r   r   r  r   r   r/   r-  rA   r:   )r{   r.   r/   r-  r7   r8   r9   r:   r  rA   r   r   r   r   r   re   r   r   rb   rf   r     sZ   
&




zVocoderDataset.__init__c                 C   s   | j rtjjj|S dd |D }tjt|t|tj	d}t
|D ]\}}|| dd|d d|d  q#|tj|tjdfS )Nc                 S   s   g | ]\}}|qS rb   rb   )rd   r*  r  rb   rb   rf   rg   ^  rh   z.VocoderDataset._collate_fn.<locals>.<listcomp>ro   r   )r  rt   utilsr   
dataloaderdefault_collater%  r   r   rx   r   narrowrA  copy_r   r<  )r{   r  r  audio_signalr   r   rb   rb   rf   r  Z  s   (zVocoderDataset._collate_fnc           	      C   s  | j | }| js4tj|d | j| jd ur| jnd| jd}t|j	}|t|j
d  }}||fS | jj|d | jd}|t|j
d  }}t|d jdkr`tt|d }nt|d }t| j| j }t|| jkrtd|j
d | d }|d d ||| f }||| j || | j  }ntjj|d||j
d  f}tjj|d| jt| f}|t||fS )	NrN   r   r+  r   )r:   rP   z.npyr-   )r   r  r   r:  r/   r-  r:   rt   r   r;  r#  r<  r   r=  r   suffixrD  r   r   mathceilrA   r   r   randintnn
functionalr   )	r{   rK  r   rM  r  rN  r   framesstartrb   rb   rf   re  f  s0   
zVocoderDataset.__getitem__c                 C   rf  r   rg  rh  rb   rb   rf   ri    rj  zVocoderDataset.__len__)NNNNFFN)r  r  r  r   r   r   r   r9  r   rx   r  r   r  re  ri  r  rb   rb   r   rf   r    s<    	
_"r  c                   @   s:   e Zd Zdeeef fddZdd Zdd Zdd	 Z	d
S )!PairedRealFakeSpectrogramsDatasetr.   c                 C   s   t |}t | 1}td|  g | _|D ]}t| }d|v s'J d|v s-J | j| qW d    n1 s>w   Y  tdt	|  d d S )Nz(Loading paired spectrogram dataset from rP   mel_gt_filepathzManifest describes z spectrogram pairs)
r   r   r#   r   manifestr   r   stripr   r   )r{   r.   r   r   entryrb   rb   rf   r     s   	z*PairedRealFakeSpectrogramsDataset.__init__c                 C   rf  r   )r   r  rh  rb   rb   rf   ri    rj  z)PairedRealFakeSpectrogramsDataset.__len__c                 C   s>   | j | }t|d }t|d }t|jt|jfS )NrP   r  )r  r   r   rt   rD  T)r{   rK  r  	pred_spec	true_specrb   rb   rf   re    s   
z-PairedRealFakeSpectrogramsDataset.__getitem__c                 C   sb   t | \}}dd |D }tjjjj|dd}tjjjj|dd}t|}t|dt|d|fS )Nc                 S   rq  rr  rs  )rd   r  rb   rb   rf   rg     rv  zAPairedRealFakeSpectrogramsDataset._collate_fn.<locals>.<listcomp>T)batch_firstzb l c -> b c l)r~  rt   r  r  rnnpad_sequence
LongTensorr	   )r{   r  
pred_specs
true_specsr   rb   rb   rf   r    s   
z-PairedRealFakeSpectrogramsDataset._collate_fnN)
r  r  r  r   r   r   r   ri  re  r  rb   rb   rb   rf   r    s    

r  c                    @   s   e Zd Z												d$deeeee ee f dededee d	ee	 d
ee	 deeeef  dee
 dee
 dee	 dee	 dee deeeef  deeeef  dee fddZdd Zdd Zdd Zdd Zdd Zd d! Zd"d# ZdS )%FastPitchSSLDatasetr+   NF
per_sampler.   r/   ssl_content_emb_typerI   r7   r8   r9   r:   pitch_conditioningr   r   pitch_normalizationsup_data_dirspeaker_stats_pitch_fpspeaker_conditioning_typec              	   C   s  |dv sJ t |tr|g}|| _g }d}| jD ]u}tt| db}td| d t|D ]L}t	
|}d|vr@d|d< |d d|v rK|d nd	d|v rT|d ndd
|v r]|d
 ndd}|| |d d	u rstd d	}|d	ur}||d 7 }q1W d	   n1 sw   Y  qtdt| d |d	urtd|d dd t|||||| _tdd | jD | _t|d| _|| _|| _|| _|| _|
| _|| _|	| _|| _|| _|d	u rtj| jd}|| _ | jdkr>i | _!|d	u rtj|d}tj"|sJ d#|t|d}t	$|}|D ]}|| | j!t%|< qW d	   d	S 1 s7w   Y  d	S d	S )a  Dataset used for training FastPitchModel_SSL model.
        Requires supplementary data created using scripts/ssl_tts/make_supdata.py
        Args:
            manifest_filepath (Union[str, Path, List[str], List[Path]]): Path(s) to the .json manifests containing
            information on the dataset. Each line in the .json file should be valid json. Note: the .json file itself
            is not valid json. Each line should contain the following:
                "audio_filepath": <PATH_TO_WAV>,
                "speaker" : <SPEAKER NUM>
                "duration": <Duration of audio clip in seconds> (Optional)
            sample_rate (int): The sample rate of the audio. Or the sample rate that we will resample all files to.
            ssl_content_emb_type (str): One of ["probs", "embedding", "log_probs", "embedding_and_probs"].
                Indicated which output to use as content embedding.
            max_duration (Optional[float]): Max duration of audio clips in seconds. All samples exceeding this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            min_duration (Optional[float]): Min duration of audio clips in seconds. All samples lower than this will be
                pruned prior to training. Note: Requires "duration" to be set in the manifest file. It does not load
                audio to compute duration. Defaults to None which does not prune.
            ignore_file (Optional[Union[str, Path]]): The location of a pickle-saved list of audio paths
                that will be pruned prior to training. Defaults to None which does not prune.
            trim (bool): Whether to apply `librosa.effects.trim` to trim leading and trailing silence from an audio
                signal. Defaults to False.
            pitch_conditioning (bool): Whether to load pitch contour or not
            pitch_mean (Optional[float]): If using global normalization, normalize using these statistics.
                Also used if speaker stats are not available for the given speaker
            pitch_std (Optional[float]): If using global normalization, normalize using these statistics.
                Also used if speaker stats are not available for the given speaker
            pitch_normalization (str): Can be one of ['speaker_wise', 'global', 'none']. Indicates the kind of pitch normalization.
            sup_data_dir (Optional[Union[str, Path]]): Data directory containing pre-computed embeddings/statistics. If set as
            speaker_stats_pitch_fp (Optional[Union[str, Path]]): Path to the json containing speaker pitch stats.
                If set as None, tries to lookup for a default filename (speaker_pitch_stats.json) in sup_data_dir.
                Needed if we use pitch_normalization is "speaker_wise"
            speaker_conditioning_type (Optional[str]): Can be one of ["per_sample", "mean", "interpolate"]. Defaults to "per_sample"
                per_sample: Speaker embedding computed from the same utterance
                mean: Speaker embedding for all utterances of a given speaker is the same and equal to the mean speaker embedding.
                interpolate: Interpolate b/w per_sample and mean speaker embedding.
        )probs	embedding	log_probsembedding_and_probsr   rK   rL   rM   rR   rN   rQ   N
dataset_id)rN   rQ   rR   r  rY   rZ   r[   r\   r]   r^   r_   c                 S   r`   ra   rb   rc   rb   rb   rf   rg     rh   z0FastPitchSSLDataset.__init__.<locals>.<listcomp>ri   sup_dataspeaker_wisezspeaker_pitch_stats.jsonzkspeaker_stats_pitch_fp {} does not exist. Make sure to run scripts/ssl_tts/make_supdata.py before training.)&r   r   r.   r   r   r   r#   r   r
   r   r   r   r   r*   r   r   r   r   r   r   r/   r:   rI   r  r   r   r  r  r  r   r   joinr  speaker_statsr  rB  r   r9  )r{   r.   r/   r  rI   r7   r8   r9   r:   r  r   r   r  r  r  r  r   r   r   r   r   re   r   speaker_stats_rawkeyrb   rb   rf   r     s   7





$zFastPitchSSLDataset.__init__c                 C   s   t j|| jd| jd}|j}t|t|jd  }}|jd | j	 dkrHt
|tj| j	|jd | j	  tjdg}t|jd  }||fS )Nr   r+  r   ro   )r   r:  r/   r:   r;  rt   r   r#  r<  rI   r$  r%  rx   )r{   rN   rM  audio_samplesr  rN  rb   rb   rf   _get_wav_from_filepath2  s    $z*FastPitchSSLDataset._get_wav_from_filepathc                 C   s   | j  d| d}d| d}d| d}tj| j|}tj| j|}tj| j|}tj|r9t|}ntd| dtj|rMt|}	ntd| dtj|rat|}
ntd| dt	|j
d	  }||	||
fS )
N_content_embedding_r.  speaker_embedding_duration_embedding_zContent embedding file R does not exist. Make sure to run scripts/ssl_tts/make_supdata.py before training.zSpeaker embedding file zDuration file r-   )r  r   r   r  r  r  rt   r   r   r   r#  r<  )r{   wav_text_idcontent_emb_fnspeaker_emb_fnduration_fncontent_emb_fpspeaker_emb_fpduration_fpcontent_embeddingspeaker_embeddingrQ   encoded_lenrb   rb   rf   get_ssl_featuresF  s.   


z$FastPitchSSLDataset.get_ssl_featuresc                 C   B   d| d}t j| j|}t j|rt|S td| d)Npitch_contour_r.  zPitch contour file r  r   r   r  r  r  rt   r   r   )r{   r  pitch_contour_fnpitch_contour_fprb   rb   rf   get_pitch_contourg     

z%FastPitchSSLDataset.get_pitch_contourc                 C   r  )N	mel_spec_r.  zMel spectrogram file r  r  )r{   r  mel_spec_fnmel_spec_fprb   rb   rf   get_mel_spectrogramq  r  z'FastPitchSSLDataset.get_mel_spectrogramc                 C   s  t t}|D ]}|D ]}|| ||  q
qtdd |d D }tdd |d D }tdd |d D }g }|d D ]}	tjjj|	d	||	d	 fd	d
}
||
 q>g }|d D ]}tjjj|d	||d fd	d
}|| q]g }|d D ]}tjjj|d	||d	 fd	d
}|| q|g }|d D ]}tjjj|d	||d fd	d
}|| qg }|d D ]}tjjj|d	||d	 fdd
}|| q||d< ||d< ||d< ||d< ||d< |D ]}t	|| ||< q|S )z
        Collate function for FastPitchModel_SSL.
        Pads the tensors in the batch with zeros to match length of the longest sequence in the batch.
        Used in fastpitch_ssl.py
        c                 S      g | ]}|  qS rb   re   )rd   
_audio_lenrb   rb   rf   rg     rh   z6FastPitchSSLDataset.pad_collate_fn.<locals>.<listcomp>r  c                 S   r  rb   r  )rd   _mel_lenrb   rb   rf   rg     rh   rU  c                 S   r  rb   r  )rd   _encoded_lenrb   rb   rf   rg     rh   r  r  r   )valuemel_spectrogramr-   pitch_contourr  rQ   r0  )
r   r   r   r   rt   r  r  r   rA  r  )r{   r  final_batchrowr  r  max_mel_lenmax_encoded_lenaudios_paddedr  audio_paddedmels_paddedr   
mel_paddedpitch_contours_paddedr  pitch_contour_paddedcontent_embeddings_paddedencodedencoded_paddeddurations_paddedrQ   duration_paddedrb   rb   rf   pad_collate_fn{  sL   """"z"FastPitchSSLDataset.pad_collate_fnc                 C   s  | j | }t|d | jd}t|dd}t|d 	 }t|d 	 }| 
|d \}}d }	| jr?| |}	| |\}
}}}| jdkre|d | jv s]J d|d | j|d  }n:| jd	kr|d | jv szJ d|d | j|d  }|}tjd
d}|d|  ||  }tj|dd}|| }d }d }| |}t|jd 	 }|	d ur7| jdv r| j| j}}| jdkr| j|d  d }| j|d  d }t|st|s|d
ks|d
krtd|d  | j}| j}n| jdkr
| j}| j}|	| }	d|	|	| k< |	| }	|	jtjkr7td|d  td t |jd }	|||
|||	|||||d}|S )NrN   r(  r)  r*  rR   r  meanz{} not in speaker embinterpolater   r-   rX   )p)r  globalr  r   r   z*NaN found in pitch mean/std for speaker {}r+  r0  zinvalid pitch contour for {}zSetting pitch contour to 0)r  r  r  r  r  r  rR   r  rU  r  rQ   )!r   r   r6  r   r7  r   r8  rt   r   r<  r  r  r
  r  r  mean_speaker_embeddingsrB  r   r   r  rJ  r  r#  r  r   r   r  isnanr#   r   rp   float32r%  )r{   rK  r   rL  r  rR   r  r  rN  r  r  r  r  rQ   e1e2interpolate_factorl2_normr  rU  r(  stdre   rb   rb   rf   re    sv   


 
 


$
zFastPitchSSLDataset.__getitem__c                 C   rf  r   rg  rh  rb   rb   rf   ri    rj  zFastPitchSSLDataset.__len__)r+   NNNFFNNNNNr  )r  r  r  r   r   r   r   r9  r   rx   r  r   r  r  r
  r  r'  re  ri  rb   rb   rb   rf   r    sh    	

 !

5Ir  c                       sV   e Zd ZdZd fdd	Zdd Zdd	 ZdddZdd Zde	ddfddZ
  ZS )DistributedBucketSamplera  
    Maintain similar input lengths in a batch.
    Length groups are specified by boundaries.
    Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.

    It removes samples which are not included in the boundaries.
    Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
    NTc                    sV   t  j||||d |j| _|| _|| _|  \| _| _t| j| _	| j	| j
 | _d S )N)num_replicasrankshuffle)r   r   r   
batch_size
boundaries_create_bucketsbucketsnum_samples_per_bucketr	  
total_sizer5  num_samples)r{   datasetr8  r9  r5  r6  r7  r   rb   rf   r     s   z!DistributedBucketSampler.__init__c           	      C   s   dd t t| jd D }t t| jD ]}| j| }| |}|dkr,|| | qt t|d ddD ]}t|| dkrN|| | j|d  q7g }| j| j }t t|D ]}t|| }|||  | }|||  q]||fS )Nc                 S   s   g | ]}g qS rb   rb   )rd   r*  rb   rb   rf   rg     s    z<DistributedBucketSampler._create_buckets.<locals>.<listcomp>r-   r   r   )	ranger   r9  r   _bisectr   r   r5  r8  )	r{   r;  r   length
idx_bucketr<  total_batch_size
len_bucketremrb   rb   rf   r:    s&   


z(DistributedBucketSampler._create_bucketsc                    s  t  }|| j g }| jr$| jD ]|t jt|d	  qn| jD ]|t
tt q'g  tt| jD ]W}| j| t}|| }| j| }|| }||||   |d ||   }|| jd | j }tt|| j D ]}fdd||| j |d | j  D }	 |	 qxq>| jrt jt |d	 }
 fdd|
D   | _t| j| j | jksJ t| jS )N)	generatorc                       g | ]} | qS rb   rb   )rd   idx)bucketrb   rf   rg   B  rh   z5DistributedBucketSampler.__iter__.<locals>.<listcomp>r-   c                    rH  rb   rb   rp  )batchesrb   rf   rg   G  rh   )rt   	Generatormanual_seedepochr7  r;  r   randpermr   tolistr   r@  r<  r6  r5  r8  rK  r>  iter)r{   gindicesr   rE  
ids_bucketnum_samples_bucketrF  jr  	batch_idsrb   )rK  rJ  rf   __iter__&  s8   



 *
z!DistributedBucketSampler.__iter__r   c                 C   s   |d u rt | jd }||kr>|| d }| j| |k r'|| j|d  kr'|S || j| kr5| |||S | ||d |S dS )Nr-   rX   r   )r   r9  rA  )r{   ry   lohimidrb   rb   rf   rA  M  s    z DistributedBucketSampler._bisectc                 C   s   | j | j S r   )r>  r8  rh  rb   rb   rf   ri  \  s   z DistributedBucketSampler.__len__rN  returnc                 C   s
   || _ dS )a(  
        Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
        use a different random ordering for each epoch. Otherwise, the next iteration of this
        sampler will yield the same ordering.
        Args:
            epoch (int): Epoch number.
        N)rN  )r{   rN  rb   rb   rf   	set_epoch_  s   
z"DistributedBucketSampler.set_epoch)NNT)r   N)r  r  r  __doc__r   r:  rX  rA  ri  r9  r]  r  rb   rb   r   rf   r4    s    	

'r4  )Kr   r  r   r   r   collectionsr   pathlibr   typingr   r   r   r   r   r   r@  r   rt   einopsr	   r
   1nemo.collections.asr.parts.preprocessing.featuresr   0nemo.collections.asr.parts.preprocessing.segmentr   @nemo.collections.common.tokenizers.text_to_speech.tts_tokenizersr   r   r   2nemo.collections.tts.parts.utils.tts_dataset_utilsr   r   r   r   )nemo.collections.tts.torch.tts_data_typesr   r   r   r   r   r   r   r   r   r   r   r   r    r!   nemo.core.classesr"   
nemo.utilsr#   1nemo_text_processing.text_normalization.normalizer$   r   r   ModuleNotFoundErrorr
  hann_windowhamming_windowblackman_windowbartlett_windowr   r*   r  r  r  r  r  r   distributedDistributedSamplerr4  rb   rb   rb   rf   <module>   sb   @	      Kd %  N