o
    	۷iu                     @   sF  d dl Z d dlZd dlmZ d dlmZmZmZ d dlZ	d dl
Z
d dlmZ d dlmZ ddlmZ ddlmZ ddlmZ dd	lmZmZ dd
lmZ ddlmZ ddlmZmZmZmZm Z  ddl!m"Z"m#Z# ddl$m%Z% ddl&m'Z'm(Z(m)Z) ddl*m+Z+ e( rddl,m-Z- e).e/Z0ee'ddG dd deZ1G dd dej2Z3G dd dej2Z4G dd deZ5G dd deZ6G dd  d eZ7G d!d" d"ej2Z8G d#d$ d$ej2Z9		%	dVd&ej2d'e
j:d(e
j:d)e
j:d*ee
j: d+ee; d,e;d-ee
j: fd.d/Z<G d0d1 d1ej2Z=G d2d3 d3ej2Z>G d4d5 d5eZ?G d6d7 d7ej2Z@G d8d9 d9ej2ZAG d:d; d;eZBG d<d= d=ej2ZCG d>d? d?ej2ZDe'G d@dA dAe#ZE		 dWdBeFeGeGf dCe;dDeGd*ee
jH dEeGdFe	jIfdGdHZJe ZKe'G dIdJ dJeEZLe'dKdG dLdM dMeEZMdNZNe'dOdG dPdQ dQeEZOe'dRdG dSdT dTeEZPg dUZQdS )X    N)	dataclass)CallableOptionalUnion)CrossEntropyLoss   )ACT2FN)is_deepspeed_zero3_enabled)is_fsdp_managed_module)_prepare_4d_attention_mask#_prepare_4d_attention_mask_for_sdpa)FlashAttentionKwargs)GradientCheckpointingLayer)BaseModelOutputCausalLMOutputModelOutputSequenceClassifierOutputWav2Vec2BaseModelOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)auto_docstringis_torch_flex_attn_availablelogging   )UniSpeechConfig)make_flex_block_causal_maskzh
    Output type of [`UniSpeechForPreTrainingOutput`], with potential hidden states and attentions.
    )custom_introc                   @   s   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeej ed< dZeeej  ed< dZeeej  ed< dS )	UniSpeechForPreTrainingOutputa  
    loss (*optional*, returned when model is in train mode, `torch.FloatTensor` of shape `(1,)`):
        Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
        paper](https://huggingface.co/papers/2006.11477).
    projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
        Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
        projected quantized states.
    projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
        Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
        target vectors for contrastive loss.
    codevector_perplexity (`torch.FloatTensor` of shape `(1,)`):
        The perplexity of the codevector distribution, used to measure the diversity of the codebook.
    Nlossprojected_statesprojected_quantized_statescodevector_perplexityhidden_states
attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r    r!   r"   r#   tupler$    r-   r-   f/home/ubuntu/vllm_env/lib/python3.10/site-packages/transformers/models/unispeech/modeling_unispeech.pyr   :   s   
 r   c                       $   e Zd Z fddZdd Z  ZS )UniSpeechSamePadLayerc                    s*   t    |d dkrd| _d S d| _d S )N   r   r   )super__init__num_pad_remove)selfnum_conv_pos_embeddings	__class__r-   r.   r3   X   s   
 zUniSpeechSamePadLayer.__init__c                 C   s,   | j dkr|d d d d d | j  f }|S Nr   )r4   r5   r#   r-   r-   r.   forward\   s   
zUniSpeechSamePadLayer.forwardr%   r&   r'   r3   r;   __classcell__r-   r-   r7   r.   r0   W   s    r0   c                       r/   ) UniSpeechPositionalConvEmbeddingc                    s$  t    tj|j|j|j|jd |jd| _tjj	}t
tjjdr'tjjj	}t r{dd l}|jj| jjdd || jddd| _W d    n1 sLw   Y  t
| jdrd| jjjj}| jjjj}n| jj}| jj}|j| | |j| | n	|| jddd| _t|j| _t|j | _d S )	Nr1   )kernel_sizepaddinggroupsweight_normr   )modifier_rankweight)namedimparametrizations)r2   r3   nnConv1dhidden_sizer6   num_conv_pos_embedding_groupsconvutilsrB   hasattrrG   r	   	deepspeedzeroGatheredParametersrD   	original0	original1weight_gweight_vregister_external_parameterr0   r@   r   feat_extract_activation
activation)r5   configrB   rO   rT   rU   r7   r-   r.   r3   c   s4   

z)UniSpeechPositionalConvEmbedding.__init__c                 C   s:   | dd}| |}| |}| |}| dd}|S )Nr   r1   )	transposerL   r@   rX   r:   r-   r-   r.   r;      s   


z(UniSpeechPositionalConvEmbedding.forwardr<   r-   r-   r7   r.   r>   b   s    !r>   c                       &   e Zd Zd fdd	Zdd Z  ZS )UniSpeechNoLayerNormConvLayerr   c                    sj   t    |dkr|j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
t|j | _d S )Nr   r   r?   stridebias)r2   r3   conv_dimin_conv_dimout_conv_dimrH   rI   conv_kernelconv_stride	conv_biasrL   r   rW   rX   r5   rY   layer_idr7   r-   r.   r3      s   
z&UniSpeechNoLayerNormConvLayer.__init__c                 C   s   |  |}| |}|S N)rL   rX   r:   r-   r-   r.   r;      s   

z%UniSpeechNoLayerNormConvLayer.forwardr   r<   r-   r-   r7   r.   r\      s    r\   c                       r[   )UniSpeechLayerNormConvLayerr   c                    s|   t    |dkr|j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
tj| jdd| _t|j | _d S )Nr   r   r]   T)elementwise_affine)r2   r3   r`   ra   rb   rH   rI   rc   rd   re   rL   	LayerNorm
layer_normr   rW   rX   rf   r7   r-   r.   r3      s   
z$UniSpeechLayerNormConvLayer.__init__c                 C   s:   |  |}|dd}| |}|dd}| |}|S )N)rL   rZ   rm   rX   r:   r-   r-   r.   r;      s   


z#UniSpeechLayerNormConvLayer.forwardri   r<   r-   r-   r7   r.   rj      s    rj   c                       r[   )UniSpeechGroupNormConvLayerr   c                    s   t    |dkr|j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
t|j | _tj| j| jdd| _d S )Nr   r   r]   T)
num_groupsnum_channelsaffine)r2   r3   r`   ra   rb   rH   rI   rc   rd   re   rL   r   rW   rX   	GroupNormrm   rf   r7   r-   r.   r3      s   
z$UniSpeechGroupNormConvLayer.__init__c                 C   s"   |  |}| |}| |}|S rh   )rL   rm   rX   r:   r-   r-   r.   r;      s   


z#UniSpeechGroupNormConvLayer.forwardri   r<   r-   r-   r7   r.   rp      s    rp   c                       s0   e Zd ZdZ fddZdd Zdd Z  ZS )UniSpeechFeatureEncoderz.Construct the features from raw audio waveformc                    s   t     jdkr t ddg fddt jd D  }n jdkr2 fddt jD }n	td	 j d
t|| _	d| _
d| _d S )Ngroupr   rg   c                    s   g | ]
}t  |d  dqS )r   rw   )r\   .0irY   r-   r.   
<listcomp>   s    z4UniSpeechFeatureEncoder.__init__.<locals>.<listcomp>r   layerc                    s   g | ]}t  |d qS )rw   )rj   rx   r{   r-   r.   r|      s    z`config.feat_extract_norm` is z), but has to be one of ['group', 'layer']FT)r2   r3   feat_extract_normrp   rangenum_feat_extract_layers
ValueErrorrH   
ModuleListconv_layersgradient_checkpointing_requires_grad)r5   rY   r   r7   r{   r.   r3      s   





z UniSpeechFeatureEncoder.__init__c                 C   s   |   D ]}d|_qd| _d S NF)
parametersrequires_gradr   r5   paramr-   r-   r.   _freeze_parameters   s   
z*UniSpeechFeatureEncoder._freeze_parametersc                 C   s:   |d d d f }| j r| jrd|_| jD ]}||}q|S )NT)r   trainingr   r   )r5   input_valuesr#   
conv_layerr-   r-   r.   r;      s   

zUniSpeechFeatureEncoder.forward)r%   r&   r'   r(   r3   r   r;   r=   r-   r-   r7   r.   ru      s
    ru   c                       r/   )UniSpeechFeatureProjectionc                    sJ   t    tj|jd |jd| _t|jd |j| _	t
|j| _d S )Nro   eps)r2   r3   rH   rl   r`   layer_norm_epsrm   LinearrJ   
projectionDropoutfeat_proj_dropoutdropoutr5   rY   r7   r-   r.   r3     s   
z#UniSpeechFeatureProjection.__init__c                 C   s&   |  |}| |}| |}||fS rh   )rm   r   r   )r5   r#   norm_hidden_statesr-   r-   r.   r;     s   


z"UniSpeechFeatureProjection.forwardr<   r-   r-   r7   r.   r      s    r           modulequerykeyvalueattention_maskscalingr   	head_maskc                 K   s   |d u r| dd }t||dd| }	|d ur|	| }	tjj|	dd}	|d ur5|	|dddd }	tjj|	|| j	d}	t|	|}
|
dd
 }
|
|	fS )Nro         r1   r   rF   r   )pr   )sizer)   matmulrZ   rH   
functionalsoftmaxviewr   r   
contiguous)r   r   r   r   r   r   r   r   kwargsattn_weightsattn_outputr-   r-   r.   eager_attention_forward  s   r   c                       s   e Zd ZdZ					ddededed	ed
ededee f fddZ					dde
jdee
j dee
j dee
j dee dee dee
jee
j eee
j  f fddZ  ZS )UniSpeechAttentionz=Multi-headed attention from 'Attention Is All You Need' paperr   FTN	embed_dim	num_headsr   
is_decoderr_   	is_causalrY   c                    s   t    || _|| _|| _|| | _|| _| j| | jkr*td| j d| d| jd | _|| _	|| _
tj|||d| _tj|||d| _tj|||d| _tj|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).r   )r_   )r2   r3   r   r   r   head_dimrY   r   r   r   r   rH   r   k_projv_projq_projout_proj)r5   r   r   r   r   r_   r   rY   r7   r-   r.   r3   0  s&   



zUniSpeechAttention.__init__r#   key_value_statesr   layer_head_maskoutput_attentionsr   returnc                 K   s  |du}|j dd \}}	|r|j d n|	}
||	d| jf}||
d| jf}| |j| dd}|r4|n|}| |j| dd}| |j| dd}t}| jj	dkr\t
| jj	 }|| ||||f| jshdn| j| j||d|\}}|||	d }| |}||dfS )z#Input shape: Batch x Time x ChannelNro   r   r1   eagerr   )r   r   r   r   )shaper   r   r   rZ   r   r   r   rY   _attn_implementationr   r   r   r   reshaper   r   )r5   r#   r   r   r   r   r   is_cross_attentionbsztgt_lensrc_lenq_input_shapekv_input_shapequery_statescurrent_states
key_statesvalue_statesattention_interfacer   r   r-   r-   r.   r;   O  s:   



zUniSpeechAttention.forward)r   FTFN)NNNF)r%   r&   r'   r(   intfloatboolr   r   r3   r)   Tensorr   r   r,   r;   r=   r-   r-   r7   r.   r   -  sR    "	
r   c                       r/   )UniSpeechFeedForwardc                    sp   t    t|j| _t|j|j| _	t
|jtr"t|j | _n|j| _t|j|j| _t|j| _d S rh   )r2   r3   rH   r   activation_dropoutintermediate_dropoutr   rJ   intermediate_sizeintermediate_dense
isinstance
hidden_actstrr   intermediate_act_fnoutput_densehidden_dropoutoutput_dropoutr   r7   r-   r.   r3     s   
zUniSpeechFeedForward.__init__c                 C   s6   |  |}| |}| |}| |}| |}|S rh   )r   r   r   r   r   r:   r-   r-   r.   r;     s   




zUniSpeechFeedForward.forwardr<   r-   r-   r7   r.   r     s    r   c                       s&   e Zd Z fddZdddZ  ZS )UniSpeechEncoderLayerc                    sh   t    t|j|j|jd|d| _t|j	| _
tj|j|jd| _t|| _tj|j|jd| _d S )NFr   r   r   r   rY   r   )r2   r3   r   rJ   num_attention_headsattention_dropout	attentionrH   r   r   r   rl   r   rm   r   feed_forwardfinal_layer_normr   r7   r-   r.   r3     s   

zUniSpeechEncoderLayer.__init__NFc                 C   sf   |}| j |||d\}}}| |}|| }| |}|| | }| |}|f}|r1||f7 }|S Nr   r   )r   r   rm   r   r   r5   r#   r   r   attn_residualr   _outputsr-   r-   r.   r;     s   



zUniSpeechEncoderLayer.forwardr   r<   r-   r-   r7   r.   r     s    r   c                       sj   e Zd Z fddZ				ddejdeej ded	ed
ef
ddZ	de
ejdf dejfddZ  ZS )UniSpeechEncoderc                    f   t     | _t | _tj j jd| _	t
 j| _t fddt jD | _d| _d S )Nr   c                       g | ]}t  qS r-   )r   ry   r   r{   r-   r.   r|         z-UniSpeechEncoder.__init__.<locals>.<listcomp>Fr2   r3   rY   r>   pos_conv_embedrH   rl   rJ   r   rm   r   r   r   r   r   num_hidden_layerslayersr   r   r7   r{   r.   r3     s   

 
zUniSpeechEncoder.__init__NFTr#   r   r   output_hidden_statesreturn_dictc                 C   s*  |rdnd }|r
dnd }|d ur"| ddd|jd }d|| < | ||}| |}	||	 }| |}| |}t pAt| }
| j	D ]3}|rN||f }t
g }| jo[|| jjk }|r`|
rk||||d}|d }|rod}|rx||d f }qE|r||f }|stdd	 |||fD S t|||d
S )Nr-   ro   r   r1   r   r   NNc                 s       | ]	}|d ur|V  qd S rh   r-   ry   vr-   r-   r.   	<genexpr>       z+UniSpeechEncoder.forward.<locals>.<genexpr>last_hidden_stater#   r$   )	unsqueezerepeatr   _update_full_maskr   rm   r   r	   r
   r   r)   randr   rY   	layerdropr,   r   r5   r#   r   r   r   r   all_hidden_statesall_self_attentionsexpand_attention_maskposition_embeddingssynced_gpusr}   dropout_probabilityskip_the_layerlayer_outputsr-   r-   r.   r;     sL   







zUniSpeechEncoder.forwardinputs_embedsc                 C      |d ur>| j jdkrd|v r|}|S d }|S | j jdkr$t||j}|S | j jdkr8t|tjr6t|dd}|S t||j}|S Nflash_attention_2r   sdpaflex_attentionF)r   	rY   r   r   dtyper   r)   r   r   r   r5   r   r
  r-   r-   r.   r        z"UniSpeechEncoder._update_full_maskNFFT)r%   r&   r'   r3   r)   tensorr   r   r   r;   r   r   r=   r-   r-   r7   r.   r     s,    
<r   c                       s,   e Zd Z fddZdejfddZ  ZS )UniSpeechAttnAdapterLayerc                    sZ   t    |j| _|j| _t| j| _t	| j| j| _
t | _t	| j| j| _dS )z
        Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
        up training throughput.
        N)r2   r3   adapter_attn_dim	input_dimrJ   
hidden_dimrH   rl   normr   linear_1ReLUact_fnlinear_2r   r7   r-   r.   r3     s   

z"UniSpeechAttnAdapterLayer.__init__r#   c                 C   s,   |  |}| |}| |}| |}|S rh   )r  r  r  r  r:   r-   r-   r.   r;   -  s
   



z!UniSpeechAttnAdapterLayer.forward)r%   r&   r'   r3   r)   r*   r;   r=   r-   r-   r7   r.   r    s    r  c                       s@   e Zd Z fddZ		d
dejdeej defdd	Z  Z	S )$UniSpeechEncoderLayerStableLayerNormc                    s   t    t|j|j|jd|d| _t|j	| _
tj|j|jd| _t|| _tj|j|jd| _t|dd d urAt|| _d S d | _d S )NFr   r   r  )r2   r3   r   rJ   r   r   r   rH   r   r   r   rl   r   rm   r   r   r   getattrr  adapter_layerr   r7   r-   r.   r3   8  s   


z-UniSpeechEncoderLayerStableLayerNorm.__init__NFr#   r   r   c                 C   sz   |}|  |}| j|||d\}}}| |}|| }|| | | }| jd ur1|| | }|f}|r;||f7 }|S r   )rm   r   r   r   r   r!  r   r-   r-   r.   r;   K  s   



z,UniSpeechEncoderLayerStableLayerNorm.forwardr   )
r%   r&   r'   r3   r)   r   r   r   r;   r=   r-   r-   r7   r.   r  7  s    r  c                       sL   e Zd Z fddZ				dddZdeejdf d	ejfd
dZ  Z	S )UniSpeechEncoderStableLayerNormc                    r   )Nr   c                    r   r-   )r  r   r{   r-   r.   r|   m  r   z<UniSpeechEncoderStableLayerNorm.__init__.<locals>.<listcomp>Fr   r   r7   r{   r.   r3   f  s   


z(UniSpeechEncoderStableLayerNorm.__init__NFTc                 C   s*  |rdnd }|r
dnd }|d ur"| ddd|jd }d|| < | ||}| |}	||	 }| |}t p<t| }
| jD ]3}|rI||f }t	
g }| joV|| jjk }|r[|
rf||||d}|d }|rjd}|rs||d f }q@| |}|r||f }|stdd	 |||fD S t|||d
S )Nr-   ro   r   r1   r   r   r   c                 s   r   rh   r-   r   r-   r-   r.   r     r   z:UniSpeechEncoderStableLayerNorm.forward.<locals>.<genexpr>r   )r   r   r   r   r   r   r	   r
   r   r)   r   r   rY   r   rm   r,   r   r  r-   r-   r.   r;   q  sL   







z'UniSpeechEncoderStableLayerNorm.forwardr   r
  c                 C   r  r  r  r  r-   r-   r.   r     r  z1UniSpeechEncoderStableLayerNorm._update_full_maskr  )
r%   r&   r'   r3   r;   r   r)   r   r   r=   r-   r-   r7   r.   r"  e  s    
>r"  c                       s4   e Zd ZdZ fddZedd Zdd Z  ZS )UniSpeechGumbelVectorQuantizerz
    Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
    GUMBEL-SOFTMAX](https://huggingface.co/papers/1611.01144) for more information.
    c                    s   t    |j| _|j| _|j| j dkr"td|j d| j dt	t
d| j| j |j| j | _t|jd | j| j | _d| _d S )Nr   z`config.codevector_dim z5 must be divisible by `config.num_codevector_groups` z for concatenationr   ro   r1   )r2   r3   num_codevector_groupsrq   num_codevectors_per_groupnum_varscodevector_dimr   rH   	Parameterr)   r*   codevectorsr   r`   weight_projtemperaturer   r7   r-   r.   r3     s   


z'UniSpeechGumbelVectorQuantizer.__init__c                 C   s8   | j dd}ttj|t|d  dd  }|S )Nr   r   gHz>ro   )meanr)   expsumlog)probsmarginal_probs
perplexityr-   r-   r.   _compute_perplexity  s   (z2UniSpeechGumbelVectorQuantizer._compute_perplexityc                 C   s  |j \}}}| |}||| | j d}| jr?tjj| | j	dd
|}tj||| | jd dd}| |}n$|jdd}|j|j  d|ddd}||| | jd}| |}||| d}|d| j }	|	|| | j| jd}
|
d||d}
|
|fS )Nro   T)tauhardr   r         ?rn   )r   r*  r   rq   r   rH   r   gumbel_softmaxr   r+  type_asr)   r   r3  argmax	new_zerosscatter_r   r)  r&  r.  )r5   r#   
batch_sizesequence_lengthrJ   codevector_probscodevector_soft_distr2  codevector_idxcodevectors_per_groupr)  r-   r-   r.   r;     s0   

z&UniSpeechGumbelVectorQuantizer.forward)	r%   r&   r'   r(   r3   staticmethodr3  r;   r=   r-   r-   r7   r.   r#    s    
r#  c                   @   sb   e Zd ZU eed< dZdZdZdZdZ	dZ
dd Zdeejef fdd	Zd
edejfddZdS )UniSpeechPreTrainedModelrY   	unispeechr   Tc              	   C   s  t |tr|jjjjddd |jjj  tj	
|j dS t |trItj	j|jjddtd|jjd |jj   d tj	|jjd dS t |trqtd|jj }tj	j
|jj| |d tj	j
|jj| |d dS t |tjr|jjjd| jjd |jdur|jj  dS dS t |tjtjfr|jj  |jjd dS t |tjrtj	|j |jdurt|j|j|jd   }tj	j
|j| |d dS dS dS )	zInitialize the weightsr   r   )r,  stdr   r1   )abNr6  )r   r#  r*  rD   datanormal_r_   zero_rH   inituniform_r)  r>   rL   mathsqrtr?   in_channels	constant_r   r   in_featuresr   rY   initializer_rangerl   rt   fill_rI   kaiming_normal_rA   )r5   r   kr-   r-   r.   _init_weights  s<   

 


z&UniSpeechPreTrainedModel._init_weightsinput_lengthsc                 C   s4   dd }t | jj| jjD ]
\}}||||}q|S )zH
        Computes the output length of the convolutional layers
        c                 S   s   t j| | |ddd S )Nfloor)rounding_moder   )r)   div)input_lengthr?   r^   r-   r-   r.   _conv_out_length<  s   zSUniSpeechPreTrainedModel._get_feat_extract_output_lengths.<locals>._conv_out_length)ziprY   rc   rd   )r5   rW  r\  r?   r^   r-   r-   r.    _get_feat_extract_output_lengths7  s   z9UniSpeechPreTrainedModel._get_feat_extract_output_lengthsfeature_vector_lengthr   c                 C   s   |j ddd d df }| |tj}|jd }tj||f|j|jd}d|tj	|jd |jd|d f< |
dg d
dg }|S )Nro   r   r   )r  devicer   )r`  )cumsumr^  tor)   longr   zerosr  r`  arangeflipr   )r5   r_  r   non_padded_lengthsoutput_lengthsr<  r-   r-   r.   "_get_feature_vector_attention_maskF  s   
"z;UniSpeechPreTrainedModel._get_feature_vector_attention_maskN)r%   r&   r'   r   r+   base_model_prefixmain_input_namesupports_gradient_checkpointing_supports_flash_attn_supports_sdpa_supports_flex_attnrV  r   r)   
LongTensorr   r^  ri  r-   r-   r-   r.   rC    s   
 !rC  r   	mask_probmask_length	min_masksr   c                    s  | \}dk rt dkrt d d dtjd   fdd}|dur:| d	 n
fd
dt|D }tj	|ft
d}g }	|}
|
dkrZ|S |D ];}||}tjjt|d  |dd}t|dkr}d }n|d }t|tj|
| tjd| g}|	| q\t|	}	t|	dddddf ||
f}	|	||
 }	tddddf }t|||
f||
 }|	| }	|	 d krd |	|	d k< t||	dd	 |S )an  
    Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
    ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
    CPU as part of the preprocessing during training.

    Args:
        shape: The shape for which to compute masks. This should be of a tuple of size 2 where
               the first element is the batch size and the second element is the length of the axis to span.
        mask_prob:  The percentage of the whole axis (between 0 and 1) which will be masked. The number of
                    independently generated mask spans of length `mask_length` is computed by
                    `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
                    actual percentage will be smaller.
        mask_length: size of the mask
        min_masks: minimum number of masked spans
        attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
                        each batch dimension.
    r   z&`mask_length` has to be bigger than 0.zO`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: z and `sequence_length`: `c                    sX   t |     }t|}| kr }| d  |k r*t| d  d}|S )z;Given input length, compute how many spans should be maskedr   r   )r   max)r[  num_masked_spanepsilonrr  rq  rs  r=  r-   r.   compute_num_masked_span|  s   
z6_compute_mask_indices.<locals>.compute_num_masked_spanNro   c                    s   g | ]} qS r-   r-   r   )r=  r-   r.   r|     s    z)_compute_mask_indices.<locals>.<listcomp>r  r   F)replace)r   nprandomr   itemdetachr.  tolistr   rd  r   choicere  lenconcatenateonesint32appendarraybroadcast_tor   ru  put_along_axis)r   rq  rr  r   rs  r<  ry  rW  spec_aug_maskspec_aug_mask_idxsmax_num_masked_spanr[  rv  spec_aug_mask_idxdummy_mask_idxoffsetsr-   rw  r.   _compute_mask_indicesV  s\   

r  c                       s   e Zd Zdef fddZ		ddejdeej deej fdd	Z	e
					dd
eej deej deej dee dee dee deeef fddZ  ZS )UniSpeechModelrY   c                    sz   t  | || _t|| _t|| _|jdks|jdkr)t	
t|j | _|jr2t|| _nt|| _|   d S )Nr   )r2   r3   rY   ru   feature_extractorr   feature_projectionmask_time_probmask_feature_probrH   r(  r)   r   rJ   rL  masked_spec_embeddo_stable_layer_normr"  encoderr   	post_initr   r7   r-   r.   r3     s   


zUniSpeechModel.__init__Nr#   mask_time_indicesr   c                 C   s  t | jdds	|S | \}}}|dur| j|j||< n-| jjdkrK| jrKt||f| jj| jj	|| jj
d}tj||jtjd}| j|j||< | jjdkr| jrt||f| jj| jj| jjd}tj||jtjd}|dddf d|d}d||< |S )	z
        Masks extracted features along time axis and/or along feature axis according to
        [SpecAugment](https://huggingface.co/papers/1904.08779).
        apply_spec_augmentTNr   )rq  rr  r   rs  )r`  r  )rq  rr  rs  ro   )r   rY   r   r  rb  r  r  r   r  mask_time_lengthmask_time_min_masksr)   r  r`  r   r  mask_feature_lengthmask_feature_min_masksexpand)r5   r#   r  r   r<  r=  rJ   mask_feature_indicesr-   r-   r.   _mask_hidden_states  s4   z"UniSpeechModel._mask_hidden_statesr   r   r   r   r   c           
      C   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}| |}|dd}|dur6| |jd |}| |\}}| j	|||d}| j
|||||d}	|	d }|s_||f|	dd  S t|||	j|	jdS )a/  
        mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
            masked extracted features in *config.proj_codevector_dim* space.
        Nr   r1   )r  r   r   r   r   r   r   )r   extract_featuresr#   r$   )rY   r   r   use_return_dictr  rZ   ri  r   r  r  r  UniSpeechBaseModelOutputr#   r$   )
r5   r   r   r  r   r   r   r  r#   encoder_outputsr-   r-   r.   r;     s8   
zUniSpeechModel.forwardr   NNNNN)r%   r&   r'   r   r3   r)   r*   r   rp  r  r   r   r   r   r,   r  r;   r=   r-   r-   r7   r.   r    s@    
.
r  zZ
    UniSpeech Model with a vector-quantization module and ctc loss for pre-training.
    c                       s   e Zd Zdef fddZdefddZdd Zd	d
 Ze		dde
jde
jde
jdefddZe				ddee
j dee
j dee dee dee deeef fddZ  ZS )UniSpeechForPreTrainingrY   c                    s~   t  | t|| _t|j| _t|| _	t
|j|j| _t
|j|j| _t
|j|j| _t|j| _|   d S rh   )r2   r3   r  rD  rH   r   feat_quantizer_dropoutdropout_featuresr#  	quantizerr   r'  proj_codevector_dim	project_qrJ   project_hidnum_ctc_classesctc_projfinal_dropoutr   r  r   r7   r-   r.   r3   M  s   

z UniSpeechForPreTraining.__init__r+  c                 C   s   || j _dS )zb
        Set the Gumbel softmax temperature to a given value. Only necessary for training
        N)r  r+  )r5   r+  r-   r-   r.   set_gumbel_temperature\  s   z.UniSpeechForPreTraining.set_gumbel_temperaturec                 C      t dt |   dS z
        Calling this function will disable the gradient computation for the feature encoder so that its parameters will
        not be updated during training.
        The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. Please use the equivalent `freeze_feature_encoder` method instead.NwarningswarnFutureWarningfreeze_feature_encoderr5   r-   r-   r.   freeze_feature_extractorb  
   z0UniSpeechForPreTraining.freeze_feature_extractorc                 C      | j j  dS 
        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
        not be updated during training.
        NrD  r  r   r  r-   r-   r.   r  n     z.UniSpeechForPreTraining.freeze_feature_encoderr   target_featuresnegative_featurespredicted_featuresc                 C   s@   t j| |gdd} t j| |  dd}|| }|| }|S )z
        Compute logits for contrastive loss based using cosine similarity as the distance measure between
        `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
        r   r   ro   )r)   catcosine_similarityr   r8  )r  r  r  r+  logitsr-   r-   r.   compute_contrastive_logitsu  s
   
z2UniSpeechForPreTraining.compute_contrastive_logitsNr   r   r   r   r   r   c                 C   sJ  |dur|n| j j}| j|||||d}|d }| |d }| |\}	}
| |	| jjj}	| 	|	}	t
|d|d| j j}|dd}t
| |j}|dd}|d}||d|	| d }| |}| |}d}|s|dur|||	|
f|dd  S ||	|
f|dd  S t|||	|
|j|jdS )	a  
        Example:

        ```python
        >>> import torch
        >>> from transformers import AutoFeatureExtractor, UniSpeechForPreTraining

        >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-large-1500h-cv")
        >>> model = UniSpeechForPreTraining.from_pretrained("microsoft/unispeech-large-1500h-cv")
        >>> # TODO: Add full pretraining example
        ```Nr  r   r   ro   r   r1   )r   r    r!   r"   r#   r$   )rY   r  rD  r  r  r  rb  rD   r  r  r)   emptyr   rS  replace_probrZ   	bernoullir   r`  r   masked_fillr   r  r   r#   r$   )r5   r   r   r   r   r   r   transformer_featuresr  quantized_featuresr"   prob_replace_matrixsampled_replace_matrixr  r   r-   r-   r.   r;     sL   




zUniSpeechForPreTraining.forward)r   )NNNN)r%   r&   r'   r   r3   r   r  r  r  rB  r)   r*   r  r   r   r   r   r   r,   r   r;   r=   r-   r-   r7   r.   r  G  sD    
r  r1   zq
    UniSpeech Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
    c                       s   e Zd Zddee f fddZdd Zdd Zd	d
 Zdd Z	e
					ddeej deej dee dee dee deej deeef fddZ  ZS )UniSpeechForCTCNtarget_langc                    s~   t  | t|| _t|j| _|| _|j	du r#t
d| j dt|dr.|jr.|jn|j}t||j	| _|   dS )a3  
        target_lang (`str`, *optional*):
            Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
            adapter.<lang>.bin. Only relevant when using an instance of [`UniSpeechForCTC`] with adapters. Uses 'eng' by
            default.
        NzYou are trying to instantiate z with a configuration that does not define the vocabulary size of the language model head. Please instantiate the model as follows: `UniSpeechForCTC.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of your model's configuration.add_adapter)r2   r3   r  rD  rH   r   r  r   r  
vocab_sizer   r8   rN   r  output_hidden_sizerJ   r   lm_headr  )r5   rY   r  r  r7   r-   r.   r3     s   

zUniSpeechForCTC.__init__c                 C   sv   | j }|durt| jdddu rtd| d|du r,t| jdddur,td dS |dur9| j|dd dS dS )a'  
        This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
        passing `target_lang=...` to `from_pretrained(...)`.

        This method is **not** supposed to be called by the user and is prone to be changed in the future.
        Nr  zCannot pass `target_lang`: z- if `config.adapter_attn_dim` is not defined.z)By default `target_lang` is set to 'eng'.T)
force_load)r  r   rY   r   loggerinfoload_adapter)r5   r  r-   r-   r.   tie_weights  s   zUniSpeechForCTC.tie_weightsc                 C   r  )r  r  Nr  r  r-   r-   r.   r    r  z(UniSpeechForCTC.freeze_feature_extractorc                 C   r  r  r  r  r-   r-   r.   r    r  z&UniSpeechForCTC.freeze_feature_encoderc                 C      | j  D ]}d|_qdS z
        Calling this function will disable the gradient computation for the base model so that its parameters will not
        be updated during training. Only the classification head will be updated.
        FNrD  r   r   r   r-   r-   r.   freeze_base_model     z!UniSpeechForCTC.freeze_base_modelr   r   r   r   r   labelsr   c              
   C   s|  |dur|n| j j}|dur| | j jkrtd| j j | j|||||d}|d }| |}| |}	d}
|dur|durC|ntj	|tj
d}| |dtj
}|dk}|d}||}tjj|	dtjddd}tjjjd	d
 tjj||||| j j| j j| j jd}
W d   n1 sw   Y  |s|	f|td  }|
dur|
f| S |S t|
|	|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
            Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
            the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
            All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
            config.vocab_size - 1]`.
        Nz$Label values must be <= vocab_size: r  r   rz  ro   )rF   r  r   F)enabled)blank	reductionzero_infinityr   r  r#   r$   )rY   r  ru  r  r   rD  r   r  r)   	ones_likerc  r^  r.  rb  masked_selectrH   r   log_softmaxfloat32rZ   backendscudnnflagsctc_losspad_token_idctc_loss_reductionctc_zero_infinity_HIDDEN_STATES_START_POSITIONr   r#   r$   )r5   r   r   r   r   r   r  r   r#   r  r   rW  labels_masktarget_lengthsflattened_targets	log_probsoutputr-   r-   r.   r;   '  sN   



zUniSpeechForCTC.forwardrh   r  )r%   r&   r'   r   r   r3   r  r  r  r  r   r)   r   r   r   r,   r   r;   r=   r-   r-   r7   r.   r    s6    
r  z
    UniSpeech Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
    SUPERB Keyword Spotting.
    c                       s   e Zd Z fddZdd Zdd Zdd Ze										dd
ee	j
 dee	j
 dee dee dee dee	j
 deeef fddZ  ZS )"UniSpeechForSequenceClassificationc                    s   t  | t|dr|jrtdt|| _|jd }|jr*t	
t|| | _t	|j|j| _t	|j|j| _|   d S )Nr  z`Sequence classification does not support the use of UniSpeech adapters (config.add_adapter=True)r   )r2   r3   rN   r  r   r  rD  r   use_weighted_layer_sumrH   r(  r)   r  layer_weightsr   rJ   classifier_proj_size	projector
num_labels
classifierr  )r5   rY   
num_layersr7   r-   r.   r3   v  s   

z+UniSpeechForSequenceClassification.__init__c                 C   r  r  r  r  r-   r-   r.   r    r  z;UniSpeechForSequenceClassification.freeze_feature_extractorc                 C   r  r  r  r  r-   r-   r.   r    r  z9UniSpeechForSequenceClassification.freeze_feature_encoderc                 C   r  r  r  r   r-   r-   r.   r    r  z4UniSpeechForSequenceClassification.freeze_base_modelNr   r   r   r   r   r  r   c                 C   sz  |dur|n| j j}| j jrdn|}| j|||||d}| j jrB|t }tj|dd}tjj	| j
dd}	||	ddd jdd}n|d }| |}|du rV|jdd}
n+| |jd |}|ddd|jd }d	|| < |jdd|jdddd }
| |
}d}|durt }||d| j j|d}|s|f|td  }|dur|f| S |S t|||j|jd
S )a  
        input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
            into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
            (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
            To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
            into a tensor of type `torch.FloatTensor`. See [`UniSpeechProcessor.__call__`] for details.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        NTr  r   r   ro   r   r1   r   r  )rY   r  r  rD  r  r)   stackrH   r   r   r  r   r.  r  r,  ri  r   r   r   r   r   r  r   r#   r$   )r5   r   r   r   r   r   r  r   r#   norm_weightspooled_outputpadding_maskexpand_padding_maskr  r   loss_fctr  r-   r-   r.   r;     sH   

 
z*UniSpeechForSequenceClassification.forwardr  )r%   r&   r'   r3   r  r  r  r   r   r)   r   r   r   r,   r   r;   r=   r-   r-   r7   r.   r  o  s4    
r  )r  r  r  r  rC  )Nr   Nr9   )RrM  r  dataclassesr   typingr   r   r   numpyr|  r)   torch.nnrH   r   activationsr   integrations.deepspeedr	   integrations.fsdpr
   modeling_attn_mask_utilsr   r   modeling_flash_attention_utilsr   modeling_layersr   modeling_outputsr   r   r   r   r   modeling_utilsr   r   processing_utilsr   rM   r   r   r   configuration_unispeechr   integrations.flex_attentionr   
get_loggerr%   r  r   Moduler0   r>   r\   rj   rp   ru   r   r   r   r   r   r   r   r   r  r  r"  r#  rC  r,   r   rp  ndarrayr  r  r  r  r  r  r  __all__r-   r-   r-   r.   <module>   s   
-)
X$].aFM

wv  s