o
    }oi8Y                     @   s  d dl Z d dlZd dlZd dlmZmZ d dlmZmZm	Z	m
Z
 d dlZd dlZd dlZd dlZd dlmZ d dlmZmZ d dlmZ d dlmZ d dlmZ d d	lmZ d d
lmZ d dlmZ d dl m!Z! d dl"m#Z# d dl$m%Z% d dl&m'Z'm(Z( dgZ)G dd de#eZ*dS )    N)ABCabstractmethod)DictListOptionalUnion)Trainer)
DictConfig	OmegaConf)tqdm))inject_dataloader_value_from_model_config)ChannelSelectorType)audio_to_audio_dataset)LhotseAudioToTargetDataset)AudioMetricWrapper)!get_lhotse_dataloader_from_config)ModelPT)PretrainedModelInfo)loggingmodel_utilsAudioToAudioModelc                       s  e Zd ZdZdGdedef fddZdd ZdHd
efddZ	dHd
efddZ
edIded
efddZ fddZ fddZdJdefddZdJddZdIded
efddZdJdefddZdJdefdd Zd!ee fd"d#Zd$eeeef  fd%d&Zd'eeeef  fd(d)Zd*eeeef  fd+d,Zd!ed-d.fd/d0Zed1ejd2ed-ejfd3d4Z e! 	5			dKd6e"e d7ed8ed9ee d:ee# d;ee d-e"e fd<d=Z$e%dLd?d@Z&dAdB Z' fdCdDZ(dEdF Z)  Z*S )Mr   zBase class for audio-to-audio models.

    Args:
        cfg: A DictConfig object with the configuration parameters.
        trainer: A Trainer object to be used for training.
    Ncfgtrainerc                    s   t  j||d |   d S )N)r   r   )super__init___setup_loss)selfr   r   	__class__ `/home/ubuntu/.local/lib/python3.10/site-packages/nemo/collections/audio/models/audio_to_audio.pyr   2   s   zAudioToAudioModel.__init__c                 C   s2   d| j v rt| j j| _dS td d| _dS )zSetup loss for this model.lossz*No loss function is defined in the config.N)_cfgr   from_config_dictr!   r   warningr   r   r   r    r   7   s   


zAudioToAudioModel._setup_lossvaltagc                 C   sd   |dkrt | jtrt| j}|S d}|S |dkr*t | jtr&t| j}|S d}|S td| d)Nr&      testzUnexpected tag .)
isinstance_validation_dlr   len_test_dl
ValueError)r   r'   num_dataloadersr   r   r    _get_num_dataloaders?   s   z&AudioToAudioModel._get_num_dataloadersc              	   C   s  |  |}td|| t| dr(|| jv r(t| j| |kr(td|| dS | jddu r7td dS | jd | }du rKtd| dS d|v rWtd| d	t| dsbt	j
 | _t	j
 | j|< t|D ]O}i }| D ])\}}td
|| t|}|dd}	|dd}
ttj||	|
d||< qvt	j
|}| j| || j td||d| qndS )aZ  Setup metrics for this model for all available dataloaders.

        When using multiple DataLoaders, it is recommended to initialize separate modular
        metric instances for each DataLoader and use them separately.

        Reference:
            - https://torchmetrics.readthedocs.io/en/stable/pages/lightning.html#common-pitfalls
        zFound %d dataloaders for %smetricsz=Found %d metrics for tag %s, not necesary to initialize againNz&No metrics configured in model.metricsz-No metrics configured for %s in model.metricsr!   z[Loss is automatically included in the metrics, it should not be specified in model.metrics.r*   z#Initialize %s for dataloader_idx %schannelmetric_using_batch_averaging)metricr3   r4   z'Setup metrics for %s, dataloader %d: %sz, )r1   r   debughasattrr2   r-   r   getr/   torchnn
ModuleDict
ModuleListrangeitemsr
   to_containerpopr   hydrautilsinstantiateappendtodeviceinfojoin)r   r'   r0   metrics_cfgdataloader_idxmetrics_dataloader_idxnamer   cfg_dictcfg_channelcfg_batch_averagingr   r   r    _setup_metricsI   sJ   







z AudioToAudioModel._setup_metricsr   rJ   c                 C   s   d S Nr   )r   batch	batch_idxrJ   r'   r   r   r    evaluation_step   s   z!AudioToAudioModel.evaluation_stepc                       |  d t  S Nr&   )rP   r   on_validation_startr%   r   r   r    rW         

z%AudioToAudioModel.on_validation_startc                    rU   Nr)   )rP   r   on_test_startr%   r   r   r    rZ      rX   zAudioToAudioModel.on_test_startc                 C   V   |  |||d}t| jjttfr#t| jjdkr#| j| | |S | j| |S )Nr&   r(   )	rT   r+   r   val_dataloaderslisttupler-   validation_step_outputsrD   r   rR   rS   rJ   output_dictr   r   r    validation_step      "z!AudioToAudioModel.validation_stepc                 C   r[   )Nr)   r(   )	rT   r+   r   test_dataloadersr]   r^   r-   test_step_outputsrD   r`   r   r   r    	test_step   rc   zAudioToAudioModel.test_stepc           	         s   t  fdd|D  }  d|i}t| dr> | jv r>| j  |  D ]\}}| }|  ||  d| < q(  d|d|iS )Nc                    s   g | ]	}|  d  qS )_lossr   ).0xr'   r   r    
<listcomp>   s    z@AudioToAudioModel.multi_evaluation_epoch_end.<locals>.<listcomp>rg   r2   _log)r9   stackmeanr7   r2   r>   computereset)	r   outputsrJ   r'   	loss_meantensorboard_logsrL   r5   valuer   rj   r    multi_evaluation_epoch_end   s   z,AudioToAudioModel.multi_evaluation_epoch_endc                 C      |  ||dS rV   rv   r   rr   rJ   r   r   r    multi_validation_epoch_end      z,AudioToAudioModel.multi_validation_epoch_endc                 C   rw   rY   rx   ry   r   r   r    multi_test_epoch_end   r{   z&AudioToAudioModel.multi_test_epoch_endconfigc                 C   s  t | j|dd |ddrt|| j| jt dS |dd}|r%td|ddr/td	d
|v rC|d
 d u rCt	d|  d S t
j|d}t|drR|j}nt|jd dra|jd j}n	|jd jd j}tjjj||d ||dd|d |dd|dddS )Nsample_rate)key
use_lhotseF)global_rank
world_sizedataset	is_concatzConcat not implemented	is_tarredTarred datasets not supportedmanifest_filepathzJCould not load dataset as `manifest_filepath` was None. Provided config : r}   
collate_fnr   
batch_size	drop_lastshufflenum_workers
pin_memory)r   r   r   r   r   r   r   )r   r   r8   r   r   r   r   NotImplementedErrorr   r$   r   get_audio_to_target_datasetr7   r   datasetsr9   rB   data
DataLoader)r   r}   r   r   r   r   r   r    _setup_dataloader_from_config   s8   



z/AudioToAudioModel._setup_dataloader_from_configtrain_data_configc                 C   sL   d|vrd|d< | j d|d | j|d| _d|v r"|d r$tddS dS )	aR  
        Sets up the training data loader via a Dict-like object.

        Args:
            train_data_config: A config that contains the information regarding construction
                of a training dataset.

        Supported Datasets:
            -   :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset`
        r   Ttraindataset_namer}   r   r   r   N)_update_dataset_configr   	_train_dlr   )r   r   r   r   r    setup_training_data   s   z%AudioToAudioModel.setup_training_dataval_data_configc                 C   0   d|vrd|d< | j d|d | j|d| _dS )aT  
        Sets up the validation data loader via a Dict-like object.

        Args:
            val_data_config: A config that contains the information regarding construction
                of a validation dataset.

        Supported Datasets:
            -   :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset`
        r   F
validationr   r   N)r   r   r,   )r   r   r   r   r    setup_validation_data      z'AudioToAudioModel.setup_validation_datatest_data_configc                 C   r   )aI  
        Sets up the test data loader via a Dict-like object.

        Args:
            test_data_config: A config that contains the information regarding construction
                of a test dataset.

        Supported Datasets:
            -   :class:`~nemo.collections.asr.data.audio_to_audio.AudioToTargetDataset`
        r   Fr)   r   r   N)r   r   r.   )r   r   r   r   r    setup_test_data  r   z!AudioToAudioModel.setup_test_datareturnztorch.utils.data.DataLoaderc                 C   s^   |d | j |d |dddd|d d|dt|d t d d	d

}| jt|d}|S )aI  Prepare a dataloader for processing files.

        Args:
            config: A python dictionary which contains the following keys:
                manifest_filepath: path to a manifest file
                input_key: key with audio filepaths in the manifest
                input_channel_selector: Optional, used to select a subset of channels from input audio files
                batch_size: batch size for the dataloader
                num_workers: number of workers for the dataloader

        Returns:
            A pytorch DataLoader for the given manifest filepath.
        r   	input_keyinput_channel_selectorNr   Fr   r(   T)
r   r~   r   r   
target_keytarget_channel_selectorr   r   r   r   r   )r~   r8   minos	cpu_countr   r	   )r   r}   	dl_configtemporary_dataloaderr   r   r    _setup_process_dataloader  s   
z+AudioToAudioModel._setup_process_dataloaderinputbatch_lengthc                 C   s.   |  d}|| }d|f}tjj| |ddS )zTrim or pad the output to match the batch length.

        Args:
            input: tensor with shape (B, C, T)
            batch_length: int

        Returns:
            Tensor with shape (B, C, T), where T matches the
            batch length.
        r   constant)sizer9   r:   
functionalpad)r   r   input_length
pad_lengthr   r   r   r    match_batch_length<  s   
z$AudioToAudioModel.match_batch_lengthr(   paths2audio_files
output_dirr   r   r   	input_dirc              	   C   s  |du s
t |dkri S |du rt|t d }g }| j}t|  j}	z|   | 	  t
 }
t
t
j t }tj|d}t|ddd }|D ]}|tj|dd	}|t|d
  qOW d   n1 sow   Y  |d|t|t ||d}tj|st| | |}d}t|ddD ]}|d }|d }|jdkr|d}| j||	||	d\}}t |!dD ]W}|durtjj"|| |d}ntj#|| }tj||}tjtj$|sttj$| ||ddd|| f % & }t'||j(| j)d |d7 }|*| q~~qW d   n	1 s)w   Y  W | j+|d |du r>| ,  t
|
 |S | j+|d |du rT| ,  t
|
 w )a`  
        Takes paths to audio files and returns a list of paths to processed
        audios.

        Args:
            paths2audio_files: paths to audio files to be processed
            output_dir: directory to save the processed files
            batch_size: (int) batch size to use during inference.
            num_workers: Number of workers for the dataloader
            input_channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio.
                            If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`.
            input_dir: Optional, directory that contains the input files. If provided, the output directory will mirror the input directory structure.

        Returns:
            Paths to processed audio signals.
        Nr   r(   zmanifest.jsonwzutf-8)encoding)path)input_filepathduration
r   )r   r   r   r   r   
Processing)desc   )input_signalr   )startfloat)modeT)-r-   r   r   r   trainingnext
parametersrF   evalfreezer   get_verbosityset_verbosityWARNINGtempfileTemporaryDirectoryr   rH   openlibrosaget_durationwritejsondumpsisdirmakedirsr   r   ndim	unsqueezeforwardrE   r=   r   relpathbasenamedirnamecpunumpysfTr~   rD   r   unfreeze)r   r   r   r   r   r   r   paths2processed_filesr   rF   logging_leveltmpdirtemporary_manifest_filepathfp
audio_fileentryr}   r   file_idx
test_batchr   r   processed_batchrl   example_idxfilepath_relativeoutput_fileoutput_signalr   r   r    processN  s~   
	




"C


zAudioToAudioModel.processList[PretrainedModelInfo]c                 C   s   t | }|S )z
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.

        Returns:
            List of available pre-trained models.
        )r   &resolve_subclass_pretrained_model_info)clslist_of_modelsr   r   r    list_available_models  s   
	z'AudioToAudioModel.list_available_modelsc                 C   s2   d| _ d| jv r| jd r| jd | _ dS dS dS )aT  
        Utility method that must be explicitly called by the subclass in order to support optional optimization flags.
        This method is the only valid place to access self.cfg prior to DDP training occurs.

        The subclass may chose not to support this method, therefore all variables here must be checked via hasattr()
        Fskip_nan_gradN)_skip_nan_gradr"   r%   r   r   r    setup_optimization_flags  s   z*AudioToAudioModel.setup_optimization_flagsc                    s   t    t| drf| jrht|  j}tjdg|tj	d}| 
 D ]"\}}|jdurDt|j p:t|j  }|sD|d } nq"tj rUtjj|tjjjd |dk rjtd | jdd	 dS dS dS dS )
zH
        zero-out the gradients which any of them is NAN or INF
        r   r(   )rF   dtypeNr   )opzCdetected inf or nan values in gradients! Setting gradients to zero.F)set_to_none)r   on_after_backwardr7   r   r   r   rF   r9   tensorfloat32named_parametersgradisnananyisinfdistributedis_initialized
all_reduceReduceOpMINr   r$   	zero_grad)r   rF   valid_gradients
param_nameparamis_not_nan_or_infr   r   r    r     s$   

"

z#AudioToAudioModel.on_after_backwardc                 C   s   | j dd| _| jsg S g }ddlm} t| jtr| j}n| jg}t|D ]\}}|	|||| j
j| jj| jj| j| jddd q'|S )zW
        Create an callback to add audio/spectrogram into tensorboard & wandb.
        
log_configNr   ) SpeechEnhancementLoggingCallbackmax_utts)data_loaderdata_loader_idxloggerslog_tensorboard	log_wandbr~   r  )r   r8   r  ,nemo.collections.audio.parts.utils.callbacksr  r+   r,   r   	enumeraterD   r   r  r  r  r~   )r   log_callbacksr  data_loadersr  r  r   r   r    configure_callbacks  s*   z%AudioToAudioModel.configure_callbacksrQ   )r&   )r   r&   )r   )r(   NNN)r   r   )+__name__
__module____qualname____doc__r	   r   r   r   strr1   rP   r   intrT   rW   rZ   rb   rf   rv   rz   r|   r   r   r   r   r   r   r   r   staticmethodr9   Tensorr   no_gradr   r   r   classmethodr   r   r   r  __classcell__r   r   r   r    r   *   s^    
<
+z)+r   r   r   abcr   r   typingr   r   r   r   rA   r   	soundfiler   r9   lightning.pytorchr   	omegaconfr	   r
   r   /nemo.collections.asr.data.audio_to_text_datasetr   0nemo.collections.asr.parts.preprocessing.segmentr   nemo.collections.audio.datar   1nemo.collections.audio.data.audio_to_audio_lhotser   $nemo.collections.audio.metrics.audior   #nemo.collections.common.data.lhotser   nemo.core.classesr   nemo.core.classes.commonr   
nemo.utilsr   r   __all__r   r   r   r   r    <module>   s.   