o
    }oiP                     @   s.  d dl mZ d dlmZmZ d dlmZmZmZm	Z	 d dl
Z
d dlZ
d dlmZ d dlm  mZ d dlmZmZmZmZ d dlmZmZmZmZmZmZ d dlmZmZm Z m!Z! d dl"m#Z# d d	l$m%Z% d d
l&m'Z' d dl(m)Z)m*Z* d dl+m,Z, d dl-m.Z.m/Z/m0Z0m1Z1m2Z2m3Z3 d dl4m5Z5 g dZ6G dd de,e'e)Z7G dd de,e'Z8G dd de,e'e*j9Z:G dd de,e'Z;G dd de,e'Z<G dd de,e'Z=G dd de,e'Z>G dd de7e*j9Z?eG d d! d!Z@eG d"d# d#ZAeG d$d% d%ZBeG d&d' d'ZC	 e*De7du re*jEe7e?d( dS dS ))    )OrderedDict)	dataclassfield)ListOptionalSetUnionN)MISSING
DictConfig
ListConfig	OmegaConf)JasperBlockMaskedConv1dParallelBlockSqueezeExciteinit_weightsjasper_activations)AttentivePoolLayerStatsPoolLayer
TDNNModuleTDNNSEModule)adapter_utils)	typecheck)
Exportable)AccessMixinadapter_mixins)NeuralModule)AcousticEncodedRepresentationLengthsType
LogitsTypeLogprobsType
NeuralTypeSpectrogramType)logging)ConvASRDecoderConvASREncoderConvASRDecoderClassificationc                       s   e Zd ZdZdd Zd%ddZedd	 Zed
d Z							d&de	de
de	de	de
dede
dee	 def fddZe dd Zde
fd d!Zed"e
fd#d$Z  ZS )'r%   z
    Convolutional encoder for ASR models. With this class you can implement JasperNet and QuartzNet models.

    Based on these papers:
        https://arxiv.org/pdf/1904.03288.pdf
        https://arxiv.org/pdf/1910.10261.pdf
    c                 K   sV   d}|   D ]\}}t|trd|_|d7 }qtj| fi | td| d d S Nr   F   Turned off  masked convolutions)named_modules
isinstancer   use_maskr   _prepare_for_exportr#   warning)selfkwargsm_countnamem r5   Y/home/ubuntu/.local/lib/python3.10/site-packages/nemo/collections/asr/modules/conv_asr.pyr.   A   s   
z"ConvASREncoder._prepare_for_exportr(       c                 C   sF   t |  j}tj|| j||d}tj|jd f||d}t||gS )s
        Generates input examples for tracing etc.
        Returns:
            A tuple of input examples.
        devicer   )size
fill_valuer:   )	next
parametersr:   torchrandn_feat_infullshapetuple)r0   	max_batchmax_dimr:   input_examplelensr5   r5   r6   rG   K   s   zConvASREncoder.input_examplec                 C   "   t tdt ttdt dS z*Returns definitions of module input ports.BDTrL   )audio_signallengthr   r!   r"   rD   r   r0   r5   r5   r6   input_typesV   
   
zConvASREncoder.input_typesc                 C   rI   z+Returns definitions of module output ports.rK   rL   )outputsencoded_lengthsr   r!   r   rD   r   rR   r5   r5   r6   output_types`   rT   zConvASREncoder.output_typesbatchaddTxavier_uniformF
activationfeat_innormalization_moderesidual_modenorm_groups	conv_maskframe_splicing	init_modequantizec              	      s@  t    t|trt|}t|  }t|drd|_|| }|| _	g }g }d| _
d| _t|D ]\}}g }|ddrG|| |}d| _
|dd}|dd}|dd	}|d
|}|dd}|dd}|dd	}|dd}|dd}|dd}|dd	}|t||d fi d|d d|d d|d d|d d|d d|d d|d|d|d
|d|d|d|d |d!|d|d|d"|d|d|d|d|d#|
d$| |d }|  jt|d trt|d d% nt|d 9  _q1|| _tjj| | _|  fd&d' d%| _d S )(NinplaceTFr(   residual_densegroups	separableheadsr\   ra   sese_reduction_ratio   se_context_sizese_interpolation_modenearestkernel_size_factor      ?stride_lastfuture_contextfiltersrepeatkernel_sizekernelstridedilationdropoutresidualnormalizationrb   r^   residual_panesrc   se_context_windowrf   	layer_idxr   c                       t |  dS N)moder   xre   r5   r6   <lambda>       z)ConvASREncoder.__init__.<locals>.<lambda>)super__init__r,   r   r   to_containerr   hasattrrg   rA   dense_residual_subsampling_factor	enumerategetappendr   r   int	_feat_outr?   nn
Sequentialencoderapplymax_audio_length)r0   jasperr^   r_   r`   ra   rb   rc   rd   re   rf   r   encoder_layersr   lcfg	dense_resri   rj   rk   rl   rm   r   rp   rr   rt   ru   	__class__r   r6   r   j   s   





	
(
zConvASREncoder.__init__c                 C   sF   | j |d|jd | |g|f\}}|d u r|d S |d |fS )N   )
seq_lengthr:   r\   )update_max_sequence_lengthr;   r:   r   r0   rO   rP   s_inputr5   r5   r6   forward   s
   zConvASREncoder.forwardr   c                 C   s  t j r t j|gt j|d}t jj|t jjjd | 	 }|| j
kr|dk r.|d }n|dk r6|d }|| _
t|  j}t jd| j
|d}t| d	rR|| _n| jd	|d
d |  D ]#\}}t|trq|j| j
| jd q^t|tr|j| j
| jd q^dS dS )zw
        Find global max audio length across all nodes in distributed training and update the max_audio_length
        )dtyper:   )opi  r   i'  g      ?r   r9   	seq_rangeF)
persistent)r   N)r?   distributedis_initializedtensorfloat32
all_reduceReduceOpMAXr   itemr   r=   r>   r:   aranger   r   register_bufferr+   r,   r   update_masked_lengthr   set_max_len)r0   r   r:   global_max_lenr   r3   r4   r5   r5   r6   r      s.   





z)ConvASREncoder.update_max_sequence_lengthreturnc                 C      | j S N)r   rR   r5   r5   r6   subsampling_factor      z!ConvASREncoder.subsampling_factor)r(   r7   )rZ   r[   r\   Tr(   r]   F)__name__
__module____qualname____doc__r.   rG   propertyrS   rY   strr   boolr   r   r   r   r   r   __classcell__r5   r5   r   r6   r%   8   sN    


	
	
Z
!r%   c                       s   e Zd ZdZdd Zd+ddZedd	 Zed
d Zde	fddZ
ede	fddZedd Zedd Z								d,de	dede	d e	d!ed"ed#ed$ee	 d%ee	 d&ef fd'd(Ze d-d)d*Z  ZS ).ParallelConvASREncoderzs
    Convolutional encoder for ASR models with parallel blocks. CarneliNet can be implemented with this class.
    c                 C   s@   d}|   D ]}t|trd|_|d7 }qtd| d d S r'   )modulesr,   r   r-   r#   r/   )r0   r2   r4   r5   r5   r6   r.      s   
z*ParallelConvASREncoder._prepare_for_exportr(      c                 C   *   t || j|t|  j}t|gS r8   r?   r@   rA   tor=   r>   r:   rD   r0   rE   rF   rG   r5   r5   r6   rG          
z$ParallelConvASREncoder.input_examplec                 C   
   t dgS )zHImplement this method to return a set of input names disabled for exportrP   setrR   r5   r5   r6   disabled_deployment_input_names	     
z6ParallelConvASREncoder.disabled_deployment_input_namesc                 C   r   )zIImplement this method to return a set of output names disabled for exportrW   r   rR   r5   r5   r6    disabled_deployment_output_names  r   z7ParallelConvASREncoder.disabled_deployment_output_names	save_pathc                 C      d S r   r5   )r0   r   r5   r5   r6   save_to  s   zParallelConvASREncoder.save_torestore_pathc                 C   r   r   r5   )clsr   r5   r5   r6   restore_from  s   z#ParallelConvASREncoder.restore_fromc                 C   rI   rJ   rQ   rR   r5   r5   r6   rS     rT   z"ParallelConvASREncoder.input_typesc                 C   rI   rU   rX   rR   r5   r5   r6   rY   $  rT   z#ParallelConvASREncoder.output_typesrZ   r[   r\   Tr]   NFr^   r_   r`   ra   rb   rc   rd   re   aggregation_moderf   c                    s>  t    t|trt|}t|  }|| }|| _g }g }d| _|D ]}g }|	ddr8|
| |}d| _|	dd}|	dd}|	dd}|	d	|}|	d
d}|	dd}|	dd}|	dd}|	dd}|	dd}|	dd}
|	dd}|	dd}g }|d D ][}|
t||d fi d|d d|gd|d d|d d|d d|d d|d|d|d	|d |d!|d"|d#|d$|d
|d|d%|d|d|d|d&| qt|dkr|
|d'  n|
t||
||||d d( |d }q$|| _tjj| | _|  fd)d* d S )+NFrh   Tri   r(   rj   rk   r\   ra   rl   rm   rn   ro   rp   rq   rr   rs   rt   r   sumblock_dropoutg        parallel_residual_modery   rv   rw   rx   rz   r{   r|   r}   r~   rb   r^   r   rc   r   rf   r   )r   block_dropout_probra   
in_filtersout_filtersc                    r   r   r   r   r   r5   r6   r     r   z1ParallelConvASREncoder.__init__.<locals>.<lambda>)r   r   r,   r   r   r   r   rA   r   r   r   r   lenr   r   r?   r   r   r   r   )r0   r   r^   r_   r`   ra   rb   rc   rd   re   r   rf   r   r   r   r   ri   rj   rk   rl   rm   r   rp   rr   rt   r   r   parallel_blocksrx   r   r   r6   r   .  s   




	


zParallelConvASREncoder.__init__c                 C   s0   |  |g|f\}}|d u r|d S |d |fS )Nr\   )r   r   r5   r5   r6   r     s   zParallelConvASREncoder.forwardr(   r   )rZ   r[   r\   Tr(   r]   NFr   )r   r   r   r   r.   rG   r   r   r   r   r   classmethodr   rS   rY   r   r   r   r   r   r   r   r5   r5   r   r6   r      s\    
	


	
	
`r   c                       s   e Zd ZdZedd Zedd Zd fd	d
	Ze dd Z	dddZ
dd Zdedef fddZdefddZedd Zedd Z  ZS )r$   zSimple ASR Decoder for use with CTC-based models such as JasperNet and QuartzNet

    Based on these papers:
       https://arxiv.org/pdf/1904.03288.pdf
       https://arxiv.org/pdf/1910.10261.pdf
       https://arxiv.org/pdf/2005.04290.pdf
    c                 C      t dtdt iS Nencoder_outputrK   r   r!   r   rR   r5   r5   r6   rS        zConvASRDecoder.input_typesc                 C   r   )NlogprobsrL   rN   rM   )r   r!   r    rR   r5   r5   r6   rY     r   zConvASRDecoder.output_typesr]   NTc                    s   t    |d u r|dk rtd|dkr"t|}td| d |d ur;|t|kr8td| dt| || _|| _|rD|d n|| _t	j
t	j
j| j| jddd	| _|  fd
d tjg}| | d| _d S )Nr   zWNeither of the vocabulary and num_classes are set! At least one of them need to be set.zDnum_classes of ConvASRDecoder is set to the size of the vocabulary: .z}If vocabulary is specified, it's length should be equal to the num_classes.                         Instead got: num_classes=z and len(vocabulary)=r(   Trx   biasc                    r   r   r   r   r   r5   r6   r     r   z)ConvASRDecoder.__init__.<locals>.<lambda>rs   )r   r   
ValueErrorr   r#   info_ConvASRDecoder__vocabularyrA   _num_classesr?   r   r   Conv1ddecoder_layersr   r   LINEAR_ADAPTER_CLASSPATHset_accepted_adapter_typestemperature)r0   r_   num_classesre   
vocabulary	add_blankaccepted_adaptersr   r   r6   r     s2   


zConvASRDecoder.__init__c                 C   sz   |   r|dd}| |}|dd}| jdkr-tjjj| |dd| j ddS tjjj| |ddddS )Nr(   r   rs   r\   dim)	is_adapter_available	transposeforward_enabled_adaptersr   r?   r   
functionallog_softmaxr   )r0   r   r5   r5   r6   r     s   

 zConvASRDecoder.forwardr(   r   c                 C   r   r   r   r   r5   r5   r6   rG     r   zConvASRDecoder.input_examplec                 K   ^   d}|   D ]}t|jdkrd|_|d7 }q|dkr$td| d tj| fi | d S Nr   r   Fr(   r)   r*   r   typer   r-   r#   r/   r   r.   r0   r1   r2   r4   r5   r5   r6   r.        z"ConvASRDecoder._prepare_for_exportr3   cfgc                    s   |  |}t j||d d S )N)r3   r	  )_update_adapter_cfg_input_dimr   add_adapter)r0   r3   r	  r   r5   r6   r    s   
zConvASRDecoder.add_adapterc                 C   s   t j| || jd}|S N)
module_dim)r   update_adapter_cfg_input_dimrA   )r0   r	  r5   r5   r6   r
       z,ConvASRDecoder._update_adapter_cfg_input_dimc                 C   r   r   )r   rR   r5   r5   r6   r     r   zConvASRDecoder.vocabularyc                 C   r   r   r   rR   r5   r5   r6   num_classes_with_blank  r   z%ConvASRDecoder.num_classes_with_blank)r]   NTr   )r   r   r   r   r   rS   rY   r   r   r   rG   r.   r   r
   r  r
  r   r  r   r5   r5   r   r6   r$     s"    

 

	
r$   c                       sh   e Zd ZdZedd Zedd Z							
	d fdd	Ze dd Z	dddZ
dd Z  ZS )ConvASRDecoderReconstructionz<ASR Decoder for reconstructing masked regions of spectrogramc                 C   r   r   r   rR   r5   r5   r6   rS      r   z(ConvASRDecoderReconstruction.input_typesc                 C   s.   | j rtdtdt iS tdtdt iS )Noutr   )apply_softmaxr   r!   r    r   rR   r5   r5   r6   rY     s   z)ConvASRDecoderReconstruction.output_typesr      r]   reluTFc                    s  t    || dkr|dk s|d dkrtdt|  }|| _|| _|| _tj| j| jdddg| _	t
|D ]X}| j	| |	r]| j	tj| j| j|d|d d d dd| jd n| j	tj| j| j|d|d d d| jd	 | j	tj| j| jddd | j	tj| jd
dd q8t
|D ]8}| j	| | j	tj| j| j|d| j|d d | j	tj| j| jddd | j	tj| jd
dd q| j	| | j	tj| j| jddd tj| j	 | _	|
| _|  fdd d S )Nr      r   zVKernel size in this decoder needs to be >= 3 and odd when using at least 1 conv layer.r(   Tr   )rz   paddingoutput_paddingr   ri   )rz   r  r   ri   gMbP?g?)epsmomentum)r   ri   r  c                    r   r   r   r   r   r5   r6   r   V  r   z7ConvASRDecoderReconstruction.__init__.<locals>.<lambda>)r   r   r   r   r_   feat_outfeat_hiddenr   r   r   ranger   ConvTranspose1dBatchNorm1dr   r  r   )r0   r_   r  r  stride_layersnon_stride_layersrx   re   r^   stride_transposer  ir   r   r6   r     sn   
 


z%ConvASRDecoderReconstruction.__init__c                 C   s.   |  |dd}| jrtjjj|dd}|S )Nr\   r   )r   r   r  r?   r   r  r  )r0   r   r  r5   r5   r6   r   X  s   z$ConvASRDecoderReconstruction.forwardr(   r   c                 C   r   r   r   r   r5   r5   r6   rG   _  r   z*ConvASRDecoderReconstruction.input_examplec                 K   r  r  r  r  r5   r5   r6   r.   h  r  z0ConvASRDecoderReconstruction._prepare_for_export)r   r   r  r]   r  TFr   )r   r   r   r   r   rS   rY   r   r   r   rG   r.   r   r5   r5   r   r6   r    s$    

M

	r  c                	       st   e Zd ZdZdddZedd Zedd	 Z	
		ddedede	e
 def fddZdd Zedd Z  ZS )r&   zSimple ASR Decoder for use with classification models such as JasperNet and QuartzNet

    Based on these papers:
       https://arxiv.org/pdf/2005.04290.pdf
    r(   r   c                 C   r   r   r   r   r5   r5   r6   rG   z  r   z*ConvASRDecoderClassification.input_examplec                 C   r   r   r   rR   r5   r5   r6   rS     r   z(ConvASRDecoderClassification.input_typesc                 C   r   )NlogitsrL   rM   )r   r!   r   rR   r5   r5   r6   rY     r   z)ConvASRDecoderClassification.output_typesr]   Tavgr_   r   re   return_logitsc                    s   t    || _|| _|| _|dkrtjd| _n|dkr&tj	d| _nt
dtjtjj| j| jdd| _|  fdd d S )	Nr(  r(   maxz?Pooling type chosen is not valid. Must be either `avg` or `max`Tr   c                    r   r   r   r   r   r5   r6   r     r   z7ConvASRDecoderClassification.__init__.<locals>.<lambda>)r   r   rA   _return_logitsr   r?   r   AdaptiveAvgPool1dpoolingAdaptiveMaxPool1dr   r   Linearr   r   )r0   r_   r   re   r)  pooling_typer   r   r6   r     s   
 z%ConvASRDecoderClassification.__init__c                 K   sF   |  \}}}| |||}| |}| jr|S tjjj|ddS )Nr\   r   )	r;   r.  viewr   r,  r?   r   r  softmax)r0   r   r1   rZ   in_channels	timestepsr&  r5   r5   r6   r     s   
z$ConvASRDecoderClassification.forwardc                 C   r   r   r  rR   r5   r5   r6   r     r   z(ConvASRDecoderClassification.num_classesr   )r]   Tr(  )r   r   r   r   rG   r   rS   rY   r   r   r   r   r   r   r   r   r5   r5   r   r6   r&   s  s,    
	

r&   c                       sb   e Zd ZdZedd Zedd Z		dded	ed
ededede	f fddZ
dddZ  ZS )ECAPAEncodera  
    Modified ECAPA Encoder layer without Res2Net module for faster training and inference which achieves
    better numbers on speaker diarization tasks
    Reference: ECAPA-TDNN Embeddings for Speaker Diarization (https://arxiv.org/pdf/2104.01466.pdf)

    input:
        feat_in: input feature shape (mel spec feature shape)
        filters: list of filter shapes for SE_TDNN modules
        kernel_sizes: list of kernel shapes for SE_TDNN modules
        dilations: list of dilations for group conv se layer
        scale: scale value to group wider conv channels (deafult:8)

    output:
        outputs : encoded output
        output_length: masked output lengths
    c                 C   rI   rJ   rQ   rR   r5   r5   r6   rS     rT   zECAPAEncoder.input_typesc                 C   rI   rU   rX   rR   r5   r5   r6   rY     rT   zECAPAEncoder.output_typesrn   r]   r_   rv   kernel_sizes	dilationsscalere   c                    s   t    t | _| jt||d |d |d d tt|d D ]}| jt	|| ||d  |d||d  ||d  d q$t|d |d |d |d | _
|  fdd	 d S )
Nr   )rx   r{   r   r(      )group_scalese_channelsrx   r{   r\   c                    r   r   r   r   r   r5   r6   r     r   z'ECAPAEncoder.__init__.<locals>.<lambda>)r   r   r   
ModuleListlayersr   r   r  r   r   feature_aggr   )r0   r_   rv   r7  r8  r9  re   r$  r   r   r6   r     s    
	
$


 
zECAPAEncoder.__init__Nc                 C   sR   |}g }| j D ]}|||d}|| qtj|dd  dd}| |}||fS )N)rP   r(   r   )r>  r   r?   catr?  )r0   rO   rP   r   rV   layerr5   r5   r6   r     s   

zECAPAEncoder.forward)rn   r]   r   )r   r   r   r   r   rS   rY   r   listr   r   r   r   r5   r5   r   r6   r6    s,    
	
r6  c                       s   e Zd ZdZdddZedd Zedd	 Z		
			ddedede	e
eef  dedededef fddZ		d ddZe d!ddZ  ZS )"SpeakerDecodera  
    Speaker Decoder creates the final neural layers that maps from the outputs
    of Jasper Encoder to the embedding layer followed by speaker based softmax loss.

    Args:
        feat_in (int): Number of channels being input to this module
        num_classes (int): Number of unique speakers in dataset
        emb_sizes (list) : shapes of intermediate embedding layers (we consider speaker embbeddings
            from 1st of this layers). Defaults to [1024,1024]
        pool_mode (str) : Pooling strategy type. options are 'xvector','tap', 'attention'
            Defaults to 'xvector (mean and variance)'
            tap (temporal average pooling: just mean)
            attention (attention based pooling)
        init_mode (str): Describes how neural network parameters are
            initialized. Options are ['xavier_uniform', 'xavier_normal',
            'kaiming_uniform','kaiming_normal'].
            Defaults to "xavier_uniform".
    r(   r   c                 C   r   r   )r?   r@   input_feat_inr   r=   r>   r:   rD   r   r5   r5   r6   rG     r   zSpeakerDecoder.input_examplec                 C   s"   t tdt tdt dddS )NrK   )rL   T)optional)r   rP   )r   r!   r   r   rR   r5   r5   r6   rS     s
   
zSpeakerDecoder.input_typesc                 C   s   t tdt tdt dS )Nr'  )r&  embs)r   r!   r   r   rR   r5   r5   r6   rY   '  s
   

zSpeakerDecoder.output_typesxvectorFr:  r]   r_   r   	emb_sizes	pool_modeangularattention_channelsre   c                    s0  t    || _d| _| jrdnd}t|tu r|gn|}|| _| | _| jdks/| jdkr:t	|| jd| _
d}	n| jdkrHt||d	| _
d
}	| j
jg}
|D ]	}|
t| qOg }t|
d d |
dd  D ]\}}| j||d|	d}|| qht|| _tj|
d | j|d| _|  fdd d S )Nr   FTrG  tap)r_   rI  linear	attention)inp_filtersrK  convr\   r(   )
learn_meanaffine_typer+  c                    r   r   r   r   r   r5   r6   r   V  r   z)SpeakerDecoder.__init__.<locals>.<lambda>)r   r   rJ  emb_idr  r   r   lowerrI  r   _poolingr   r_   r   zipaffine_layerr   r=  
emb_layersr0  finalr   )r0   r_   r   rH  rI  rJ  rK  re   r   rR  shapesr;   rX  shape_in	shape_outrA  r   r   r6   r   0  s.   




"zSpeakerDecoder.__init__TrP  c                 C   sZ   |dkrt t j|dddt j||dd}|S t t ||t j||ddt  }|S )NrP  T)affinetrack_running_statsr(   )rx   )r   r   r   r   r0  ReLU)r0   	inp_shape	out_shaperQ  rR  rA  r5   r5   r6   rW  X  s   
zSpeakerDecoder.affine_layerNc           	      C   s   |  ||}g }| jD ]}|||d | j |}}|| q|d}| jrA| j D ]
}tj	|ddd}q.tj	|ddd}| |}||d dfS )Nr\   r   r(   )pr   )
rU  rX  rS  r   squeezerJ  rY  r>   F	normalize)	r0   r   rP   poolrF  rA  embWr  r5   r5   r6   r   n  s   


zSpeakerDecoder.forwardr   )r   rG  Fr:  r]   )TrP  r   )r   r   r   r   rG   r   rS   rY   r   r   r   rB  r   r   r   rW  r   r   r   r5   r5   r   r6   rC    s@    
	

,
rC  c                       s~   e Zd ZdedefddZdefddZddee d
efddZ	de
e fddZdefddZdee f fddZ  ZS )ConvASREncoderAdapterr3   r	  c                 C   s6   | j D ]}| ||}||   ||| qd S r   )r   r
  r   get_accepted_adapter_typesr  )r0   r3   r	  jasper_blockr5   r5   r6   r    s
   
z!ConvASREncoderAdapter.add_adapterr   c                 C   s   t dd | jD S )Nc                 S   s   g | ]}|  qS r5   )r   ).0rk  r5   r5   r6   
<listcomp>  s    z>ConvASREncoderAdapter.is_adapter_available.<locals>.<listcomp>)anyr   rR   r5   r5   r6   r     s   z*ConvASREncoderAdapter.is_adapter_availableNTenabledc                 C   s   | j D ]	}|j||d qd S )N)r3   ro  )r   set_enabled_adapters)r0   r3   ro  rk  r5   r5   r6   rp    s   
z*ConvASREncoderAdapter.set_enabled_adaptersc                 C   s2   t g }| jD ]	}||  qtt|}|S r   )r   r   updateget_enabled_adapterssortedrB  )r0   namesrk  r5   r5   r6   rr    s
   
z*ConvASREncoderAdapter.get_enabled_adaptersblockc                 C   s   t j| ||jd}|S r  )r   r  planes)r0   ru  r	  r5   r5   r6   r
    r  z3ConvASREncoderAdapter._update_adapter_cfg_input_dimc                    s0   t   }t|dkr| tjg |  }|S )Nr   )r   rj  r   r   r   r   )r0   typesr   r5   r6   rj    s   
z0ConvASREncoderAdapter.get_accepted_adapter_types)NT)r   r   r   r   dictr  r   r   r   rp  r   rr  r   r
  r   r  rj  r   r5   r5   r   r6   ri    s    ri  c                   @   s   e Zd ZU eZeed< eZeed< eZe	e ed< eZ
e	e ed< eZe	e ed< eZeed< eZeed< dZeed	< d
Zeed< dZeed< dZeed< d
Zeed< d
Zeed< dZeed< dZeed< dZeed< dZeed< d
Zeed< dS )JasperEncoderConfigrv   rw   ry   rz   r{   r|   r}   r(   ri   Frj   r\   rk   r[   ra   rh   rl   rn   rm   ro   rq   rp   rs   rr   rt   N)r   r   r   r	   rv   r   __annotations__rw   ry   r   rz   r{   r|   floatr}   r   ri   rj   rk   ra   r   rh   rl   rm   ro   rp   rr   rt   r5   r5   r5   r6   ry    s&   
 ry  c                   @   s   e Zd ZU dZeed< eedZe	e
e  ed< eZeed< eZeed< dZeed< d	Zeed
< dZeed< dZeed< dZeed< dZe	e ed< dS )ConvASREncoderConfigz+nemo.collections.asr.modules.ConvASREncoder_target_default_factoryr   r^   r_   rZ   r`   r[   ra   r\   rb   Trc   r(   rd   r]   re   N)r   r   r   r}  r   rz  r   rB  r   r   r   ry  r	   r^   r_   r   r`   ra   rb   rc   r   rd   re   r5   r5   r5   r6   r|    s   
 r|  c                   @   s\   e Zd ZU dZeed< eZeed< eZ	eed< dZ
ee ed< eedZeee  ed< d	S )
ConvASRDecoderConfigz+nemo.collections.asr.modules.ConvASRDecoderr}  r_   r   r]   re   r~  r   N)r   r   r   r}  r   rz  r	   r_   r   r   re   r   r   rB  r   r   r5   r5   r5   r6   r    s   
 r  c                   @   sZ   e Zd ZU dZeed< eZeed< eZ	eed< dZ
ee ed< dZeed< d	Zeed
< dS )"ConvASRDecoderClassificationConfigz9nemo.collections.asr.modules.ConvASRDecoderClassificationr}  r_   r   r]   re   Tr)  r(  r1  N)r   r   r   r}  r   rz  r	   r_   r   r   re   r   r)  r   r1  r5   r5   r5   r6   r    s   
 r  )
base_classadapter_class)Fcollectionsr   dataclassesr   r   typingr   r   r   r   r?   torch.distributedtorch.nnr   torch.nn.functionalr  rd  	omegaconfr	   r
   r   r   ,nemo.collections.asr.parts.submodules.jasperr   r   r   r   r   r   4nemo.collections.asr.parts.submodules.tdnn_attentionr   r   r   r    nemo.collections.asr.parts.utilsr   nemo.core.classes.commonr   nemo.core.classes.exportabler   nemo.core.classes.mixinsr   r   nemo.core.classes.moduler   nemo.core.neural_typesr   r   r   r    r!   r"   
nemo.utilsr#   __all__r%   r   AdapterModuleMixinr$   r  r&   r6  rC  ri  ry  r|  r  r  get_registered_adapterregister_adapterr5   r5   r5   r6   <module>   sR      < %fv@N ,	