o
    i5                     @   s   d Z ddlZddlZddlZddlZddlmZ ddlmZm	Z	 ddl
Z
ddlZddlmZ ddlmZ ddlmZ ddlmZ dd	lmZ G d
d deZG dd deZdd ZdS )zEncoder definition.    N)Path)OptionalTuple)FileLock)check_argument_types)
AbsEncoder)make_pad_mask)	LayerNormc                )       s   e Zd ZdZ												
						
		d/dedededededededededededededededededed ed!ef( fd"d#Zd$efd%d&Z		'd0d(e
jd)e
jd*e
jd$ee
je
jee
j f fd+d,Zd-d. Z  ZS )1FairseqHubertEncodera6  FairSeq Hubert encoder module, used for loading pretrained weight and finetuning

    Args:
        input_size: input dim
        hubert_url: url to Hubert pretrained model
        hubert_dir_path: directory to download the Wav2Vec2.0 pretrained model.
        output_size: dimension of attention
        normalize_before: whether to use layer_norm before the first block
        freeze_finetune_updates: steps that freeze all layers except output layer
            before tuning the whole model (nessasary to prevent overfit).
        dropout_rate: dropout rate
        activation_dropout: dropout rate in activation function
        attention_dropout: dropout rate in attention
    Hubert specific Args:
        Please refer to:
        https://github.com/pytorch/fairseq/blob/master/fairseq/models/hubert/hubert.py
    ./   Fr           皙?
         ?staticT@         ?
input_size
hubert_urlhubert_dir_pathoutput_sizenormalize_beforefreeze_finetune_updatesdropout_rateactivation_dropoutattention_dropoutmask_length	mask_probmask_selection
mask_other
apply_maskmask_channel_lengthmask_channel_probmask_channel_othermask_channel_selection	layerdropfeature_grad_multc           !         s  t  sJ t   || _zdd l}ddlm} W n ty/ } z
td td |d }~ww |||	|
||||||||||d}|dkr|| _	t
j| j	t
dd}td	d
 |D rvzdd | D }W n tyu } z|d }~ww tjd| j	dd d d}t|}|jddd}t|| _W d    n1 sw   Y  td| jd | jd d| jd }|j}| jd d }t|| _n&t||| _	|jj| j	g|dd\}| _} |d }| jj j!}t|" | _|| _#t$||sz|j%j&}W n ty } ztd |d }~ww || _'|| _(| j(r%t)|| _*|r:||kr:t
j+,t
j+-||| _.nd | _.|| _/| 0dt
1dg d S )Nr   HubertModel)Error: FairSeq is not properly installed.BPlease install FairSeq: cd ${MAIN_ROOT}/tools && make fairseq.done)dropoutr   r   r   r   r   r    r"   r#   r%   r$   encoder_layerdropr'   dataespnetcpu)map_locationc                 s   s    | ]}d |v V  qdS )zencoder.encoderN ).0kr2   r2   V/home/ubuntu/.local/lib/python3.10/site-packages/espnet2/asr/encoder/hubert_encoder.py	<genexpr>i   s    z0FairseqHubertEncoder.__init__.<locals>.<genexpr>c                 S   s&   i | ]\}}d |vr| dd|qS )label_embs_concatzencoder.encoder. )replace)r3   r4   vr2   r2   r5   
<dictcomp>k   s
    z1FairseqHubertEncoder.__init__.<locals>.<dictcomp>/zconfig.yamlrzutf-8)encodingr   hubert_dict)r   r@   encoder_confr   F)arg_overridesstrictzQError: pretrained models should be within: 'HubertModel, Hubertctc' classes, etc.num_updatesr2   )2r   super__init__r!   fairseqfairseq.models.hubert.hubertr)   	Exceptionprinthubert_model_pathtorchloaddeviceallitemsospathjoinsplitr   openyaml	safe_loadpretrained_cfgFairseqHubertPretrainEncoderencodercopydeepcopypretrained_paramsdownload_hubertcheckpoint_utilsload_model_ensemble_and_taskmodelencoder_embed_dim
state_dict_output_size
isinstancehubert_encoderhubert_modelencodersr   r	   
after_normnn
SequentialLinearoutput_layerr   register_buffer
LongTensor)!selfr   r   r   r   r   r   r   r   r   r   r   r   r    r!   r"   r#   r$   r%   r&   r'   rG   r)   erB   sstateconfig_filefra   dmodelstask	__class__r2   r5   rF   /   s   




zFairseqHubertEncoder.__init__returnc                 C      | j S Nrd   rp   r2   r2   r5   r         z FairseqHubertEncoder.output_sizeNxs_padilensprev_statesc                 C   s  t ||j}| j| jk}| j| jkr|  jd7  _n|r3| j| jd kr3|  jd7  _td n|  jd7  _|s@t nt	
  | j||| joN| jddd}W d   n1 s]w   Y  |d }|d }~| jdd}| jdur|| |}| jr| |}||dfS )	zForward Hubert ASR Encoder.

        Args:
            xs_pad: input tensor (B, L, D)
            ilens: input length (B)
            prev_states: Not to be used now.
        Returns:
            position embedded tensor and mask
           z$Start fine-tuning hubert parameters!TN)padding_maskmaskfeatures_onlyrm   xr   )dim)r   torN   r   rD   logginginforL   no_grad
contextlibnullcontextrh   r!   trainingsumrm   r   ri   )rp   r   r   r   masksftenc_outputsolensr2   r2   r5   forward   s4   
	



zFairseqHubertEncoder.forwardc                 C   s    | j j| jdd td d S )NF)rC   z,Pretrained Hubert model parameters reloaded!)rh   load_state_dictr]   r   r   r   r2   r2   r5   reload_pretrained_parameters   s   z1FairseqHubertEncoder.reload_pretrained_parameters)r   r   r   Fr   r   r   r   r   r   r   r   Tr   r   r   r   r   r   r}   )__name__
__module____qualname____doc__intstrboolfloatrF   r   rL   Tensorr   r   r   r   __classcell__r2   r2   ry   r5   r
      s    	
 
3r
   c                       s   e Zd ZdZ														d*d
ededededededededededededef fddZdd Z	defddZ
	d+dejd ejd!ejd"ejd#ejdeejejeej f fd$d%Zd&d' Zd(d) Z  ZS ),rY   a  FairSeq Hubert pretrain encoder module, only used for pretraining stage

    Args:
        input_size: input dim
        output_size: dimension of attention
        linear_units: dimension of feedforward layers
        attention_heads: the number of heads of multi head attention
        num_blocks: the number of encoder blocks
        dropout_rate: dropout rate
        attention_dropout_rate: dropout rate in attention
        hubert_dict: target dictionary for Hubert pretraining
        label_rate: label frame rate. -1 for sequence label
        sample_rate: target sample rate.
        use_amp: whether to use automatic mixed precision
        normalize_before: whether to use layer_norm before the first block
    r         r   
./dict.txtd   F>  r   r   linear_unitsattention_heads
num_blocksr   attention_dropout_rateactivation_dropout_rater@   
label_ratecheckpoint_activationssample_rateuse_ampc              
      s:  t  sJ t   || _|| _zddlm} ddlm} ddlm	} ddlm
} W n ty@ } z
td td |d }~ww ||||||||
|d	}i ||}| | _| D ]\}}t| j|rlt| j|| q[| }|
|d	}| D ]\}}t||rt||| qy| }| ||	 || j|| j| _d S )
Nr   )
Dictionary)HubertConfigr(   )HubertPretrainingConfigr*   r+   )	rb   encoder_ffn_embed_dimencoder_attention_headsencoder_layers	final_dimr,   r   r   r   )r   r   )r   rE   rF   rd   r   fairseq.data.dictionaryr   rH   r   r)   r   rI   rJ   cfgrP   hasattrsetattr_build_dictionarydictionariesrZ   )rp   r   r   r   r   r   r   r   r   r@   r   r   r   r   kwargsr   r   r)   r   rq   cfg_overideskeyvaluehubert_task_cfghubert_task_cfg_overidesrv   ry   r2   r5   rF     sT   


z%FairseqHubertPretrainEncoder.__init__c                 C   sV   t j| r t|dg  t|dg  t|di  ||  n|d |g| _d S )Nsymbolscountindices0)rQ   rR   existsr   add_from_file
add_symbolr   )rp   
dictionaryhubert_dict_pathr2   r2   r5   r   A  s   
z.FairseqHubertPretrainEncoder._build_dictionaryr{   c                 C   r|   r}   r~   r   r2   r2   r5   r   L  r   z(FairseqHubertPretrainEncoder.output_sizeNr   r   ys_padys_pad_lengthr   c                 C   sJ   |    t||j}|dddt|f }| j||d|gdd}|S )zForward Hubert Pretrain Encoder.

        Args:
            xs_pad: input tensor (B, L, D)
            ilens: input length (B)
            prev_states: Not to be used now.
        Returns:
            position embedded tensor and mask
        NTF)r   r   target_listr   )cast_mask_embr   r   rN   minrZ   )rp   r   r   r   r   r   r   r   r2   r2   r5   r   O  s   z$FairseqHubertPretrainEncoder.forwardc                 C   s<   | j r| jjjtjjkrtj| jj	 | j_d S d S d S r}   )
r   rZ   mask_embdtyperL   cuda
HalfTensorrj   	Parameterhalfr   r2   r2   r5   r   l  s   z*FairseqHubertPretrainEncoder.cast_mask_embc                 C   s@   t jt | jj | j_t	
d| jjj d| j  d S )Nz4Hubert mask embedding re-initiallized!,             z,             )rL   rj   r   r   r   rb   uniform_rZ   r   r   r   r   r   r   r2   r2   r5   r   p  s   z9FairseqHubertPretrainEncoder.reload_pretrained_parameters)r   r   r   r   r   r   r   r   r   r   Fr   Fr}   )r   r   r   r   r   r   r   r   rF   r   r   rL   r   r   r   r   r   r   r   r2   r2   ry   r5   rY      sx    	
@	
rY   c                 C   s   t j|dd | dd }t j||}t|d 0 t j|s2tj	| | t
d|  nt
d| d W d    |S W d    |S 1 sNw   Y  |S )	NT)exist_okr<   r=   z.lockzHubert model downloaded zHubert model z already exists.)rQ   makedirsrT   rR   rS   r   r   rL   hubdownload_url_to_filer   r   )	model_urldir_path
model_name
model_pathr2   r2   r5   r^   {  s   

r^   )r   r   r[   r   rQ   pathlibr   typingr   r   rL   rV   filelockr   	typeguardr   espnet2.asr.encoder.abs_encoderr   &espnet.nets.pytorch_backend.nets_utilsr   2espnet.nets.pytorch_backend.transformer.layer_normr	   r
   rY   r^   r2   r2   r2   r5   <module>   s&   	 T 