o
    	۷iQ                     @   s  d dl Z d dlZd dlmZmZmZ d dlZd dlZd dlm	Z	 d dl
mZ ddlmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZ ddlmZmZ ddlmZ ddlm Z m!Z! ddl"m#Z# e!$e%Z&G dd deZ'G dd deZ(G dd deZ)G dd de	j*Z+G dd de	j*Z,G dd de	j*Z-G dd de	j*Z.			dEd e	j*d!ej/d"ej/d#ej/d$eej/ d%ee0 d&e0d'eej/ fd(d)Z1G d*d+ d+e	j*Z2G d,d- d-e	j*Z3G d.d/ d/eZ4G d0d1 d1e	j*Z5e G d2d3 d3eZ6		 dFd4e7e8e8f d5e0d6e8d$eej9 d7e8d8ej:fd9d:Z;e G d;d< d<e6Z<dZ=e d=d>G d?d@ d@e6Z>e dAd>G dBdC dCe6Z?g dDZ@dS )G    N)CallableOptionalUnion)nn)CrossEntropyLoss   )ACT2FN)is_deepspeed_zero3_enabled)is_fsdp_managed_module)FlashAttentionKwargs)GradientCheckpointingLayer)BaseModelOutputCausalLMOutputSequenceClassifierOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)auto_docstringlogging   )	SEWConfigc                       &   e Zd Zd fdd	Zdd Z  ZS )SEWNoLayerNormConvLayerr   c                    sj   t    |dkr|j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
t|j | _d S )Nr   r   kernel_sizestridebias)super__init__conv_dimin_conv_dimout_conv_dimr   Conv1dconv_kernelconv_stride	conv_biasconvr   feat_extract_activation
activationselfconfiglayer_id	__class__ Z/home/ubuntu/vllm_env/lib/python3.10/site-packages/transformers/models/sew/modeling_sew.pyr   /   s   
z SEWNoLayerNormConvLayer.__init__c                 C   s   |  |}| |}|S N)r&   r(   r*   hidden_statesr/   r/   r0   forward=   s   

zSEWNoLayerNormConvLayer.forwardr   __name__
__module____qualname__r   r4   __classcell__r/   r/   r-   r0   r   .   s    r   c                       r   )SEWLayerNormConvLayerr   c                    s|   t    |dkr|j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
tj| jdd| _t|j | _d S )Nr   r   r   T)elementwise_affine)r   r   r   r    r!   r   r"   r#   r$   r%   r&   	LayerNorm
layer_normr   r'   r(   r)   r-   r/   r0   r   D   s   
zSEWLayerNormConvLayer.__init__c                 C   s:   |  |}|dd}| |}|dd}| |}|S )N)r&   	transposer>   r(   r2   r/   r/   r0   r4   S   s   


zSEWLayerNormConvLayer.forwardr5   r6   r/   r/   r-   r0   r;   C   s    r;   c                       r   )SEWGroupNormConvLayerr   c                    s   t    |dkr|j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
t|j | _tj| j| jdd| _d S )Nr   r   r   T)
num_groupsnum_channelsaffine)r   r   r   r    r!   r   r"   r#   r$   r%   r&   r   r'   r(   	GroupNormr>   r)   r-   r/   r0   r   _   s   
zSEWGroupNormConvLayer.__init__c                 C   "   |  |}| |}| |}|S r1   )r&   r>   r(   r2   r/   r/   r0   r4   o   s   


zSEWGroupNormConvLayer.forwardr5   r6   r/   r/   r-   r0   rB   ^   s    rB   c                       $   e Zd Z fddZdd Z  ZS )SEWPositionalConvEmbeddingc                    s(  t    tj|j|j|j|jd |j|jd| _tj	j
}ttj	jdr)tj	jj
}t r}dd l}|jj| jjdd || jddd| _W d    n1 sNw   Y  t| jdrf| jjjj}| jjjj}n| jj}| jj}|j| | |j| | n	|| jddd| _t|j| _t|j | _d S )	N   )r   paddinggroupsr   weight_normr   modifier_rankweight)namedimparametrizations)r   r   r   r"   hidden_sizenum_conv_pos_embeddingsnum_conv_pos_embedding_groupssqueeze_factorr&   utilsrM   hasattrrS   r	   	deepspeedzeroGatheredParametersrP   	original0	original1weight_gweight_vregister_external_parameterSEWSamePadLayerrK   r   r'   r(   )r*   r+   rM   rZ   r_   r`   r-   r/   r0   r   w   s6   
	
z#SEWPositionalConvEmbedding.__init__c                 C   rG   r1   )r&   rK   r(   r2   r/   r/   r0   r4      s   


z"SEWPositionalConvEmbedding.forwardr6   r/   r/   r-   r0   rI   v   s    "rI   c                       rH   )rb   c                    s*   t    |d dkrd| _d S d| _d S )NrJ   r   r   )r   r   num_pad_remove)r*   rU   r-   r/   r0   r      s   
 zSEWSamePadLayer.__init__c                 C   s,   | j dkr|d d d d d | j  f }|S Nr   )rc   r2   r/   r/   r0   r4      s   
zSEWSamePadLayer.forwardr6   r/   r/   r-   r0   rb      s    rb   c                       rH   )SEWUpsamplingc                    s:   t    t|j|j|j | _t|j | _	|j| _d S r1   )
r   r   r   LinearrT   rW   
projectionr   r'   r(   r*   r+   r-   r/   r0   r      s   
zSEWUpsampling.__init__c                 C   sd   |  |}| |}| jdkr0| \}}}|| j }|| j }|||| j|}||||}|S )Nr   )rg   r(   rW   sizereshape)r*   r3   bszsrc_lensrc_embed_dimtgt_lentgt_embed_dimr/   r/   r0   r4      s   




zSEWUpsampling.forwardr6   r/   r/   r-   r0   re      s    re   c                       s0   e Zd ZdZ fddZdd Zdd Z  ZS )SEWFeatureEncoderz.Construct the features from raw audio waveformc                    s   t     jdkr t ddg fddt jd D  }n jdkr2 fddt jD }n	td	 j d
t|| _	d| _
d| _d S )Ngroupr   r,   c                    s   g | ]
}t  |d  dqS )r   rr   )r   .0ir+   r/   r0   
<listcomp>   s    z.SEWFeatureEncoder.__init__.<locals>.<listcomp>r   layerc                    s   g | ]}t  |d qS )rr   )r;   rs   rv   r/   r0   rw      s    z`config.feat_extract_norm` is z), but has to be one of ['group', 'layer']FT)r   r   feat_extract_normrB   rangenum_feat_extract_layers
ValueErrorr   
ModuleListconv_layersgradient_checkpointing_requires_grad)r*   r+   r~   r-   rv   r0   r      s   




zSEWFeatureEncoder.__init__c                 C   s   |   D ]}d|_qd| _d S NF)
parametersrequires_gradr   r*   paramr/   r/   r0   _freeze_parameters   s   
z$SEWFeatureEncoder._freeze_parametersc                 C   s:   |d d d f }| j r| jrd|_| jD ]}||}q|S )NT)r   trainingr   r~   )r*   input_valuesr3   
conv_layerr/   r/   r0   r4      s   

zSEWFeatureEncoder.forward)r7   r8   r9   __doc__r   r   r4   r:   r/   r/   r-   r0   rp      s
    rp           modulequerykeyvalueattention_maskscalingdropout	head_maskc                 K   s   |d u r| dd }t||dd| }	|d ur|	| }	tjj|	dd}	|d ur5|	|dddd }	tjj|	|| j	d}	t|	|}
|
dd
 }
|
|	fS )Nr@         rJ   r   rR   r   )pr   )ri   torchmatmulrA   r   
functionalsoftmaxviewr   r   
contiguous)r   r   r   r   r   r   r   r   kwargsattn_weightsattn_outputr/   r/   r0   eager_attention_forward   s   r   c                       s   e Zd ZdZ					ddededed	ed
ededee f fddZ					dde
jdee
j dee
j dee
j dee dee dee
jee
j eee
j  f fddZ  ZS )SEWAttentionz=Multi-headed attention from 'Attention Is All You Need' paperr   FTN	embed_dim	num_headsr   
is_decoderr   	is_causalr+   c                    s   t    || _|| _|| _|| | _|| _| j| | jkr*td| j d| d| jd | _|| _	|| _
tj|||d| _tj|||d| _tj|||d| _tj|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).r   )r   )r   r   r   r   r   head_dimr+   r|   r   r   r   r   rf   k_projv_projq_projout_proj)r*   r   r   r   r   r   r   r+   r-   r/   r0   r   	  s&   



zSEWAttention.__init__r3   key_value_statesr   layer_head_maskoutput_attentionsr   returnc                 K   s  |du}|j dd \}}	|r|j d n|	}
||	d| jf}||
d| jf}| |j| dd}|r4|n|}| |j| dd}| |j| dd}t}| jj	dkr\t
| jj	 }|| ||||f| jshdn| j| j||d|\}}|||	d }| |}||dfS )z#Input shape: Batch x Time x ChannelNr@   r   rJ   eagerr   )r   r   r   r   )shaper   r   r   rA   r   r   r   r+   _attn_implementationr   r   r   r   rj   r   r   )r*   r3   r   r   r   r   r   is_cross_attentionrk   rn   rl   q_input_shapekv_input_shapequery_statescurrent_states
key_statesvalue_statesattention_interfacer   r   r/   r/   r0   r4   (  s:   



zSEWAttention.forward)r   FTFN)NNNF)r7   r8   r9   r   intfloatboolr   r   r   r   Tensorr   r   tupler4   r:   r/   r/   r-   r0   r     sR    "	
r   c                       rH   )SEWFeedForwardc                    sp   t    t|j| _t|j|j| _	t
|jtr"t|j | _n|j| _t|j|j| _t|j| _d S r1   )r   r   r   Dropoutactivation_dropoutintermediate_dropoutrf   rT   intermediate_sizeintermediate_dense
isinstance
hidden_actstrr   intermediate_act_fnoutput_densehidden_dropoutoutput_dropoutrh   r-   r/   r0   r   _  s   
zSEWFeedForward.__init__c                 C   s6   |  |}| |}| |}| |}| |}|S r1   )r   r   r   r   r   r2   r/   r/   r0   r4   l  s   




zSEWFeedForward.forwardr6   r/   r/   r-   r0   r   ^  s    r   c                       s&   e Zd Z fddZdddZ  ZS )SEWEncoderLayerc                    sh   t    t|j|j|jd|d| _t|j	| _
tj|j|jd| _t|| _tj|j|jd| _d S )NF)r   r   r   r   r+   eps)r   r   r   rT   num_attention_headsattention_dropout	attentionr   r   r   r   r=   layer_norm_epsr>   r   feed_forwardfinal_layer_normrh   r-   r/   r0   r   w  s   

zSEWEncoderLayer.__init__NFc                 C   sf   |}| j |||d\}}}| |}|| }| |}|| | }| |}|f}|r1||f7 }|S )Nr   r   )r   r   r>   r   r   )r*   r3   r   r   attn_residualr   _outputsr/   r/   r0   r4     s   



zSEWEncoderLayer.forwardr   r6   r/   r/   r-   r0   r   v  s    r   c                       s.   e Zd Z fddZ				dddZ  ZS )	
SEWEncoderc                    s   t     | _t | _t j j| _tj	 j
 jd| _t j| _t fddt jD | _t | _d| _d S )Nr   c                    s   g | ]}t  qS r/   )r   rt   r   rv   r/   r0   rw     s    z'SEWEncoder.__init__.<locals>.<listcomp>F)r   r   r+   rI   pos_conv_embedr   	AvgPool1drW   poolr=   rT   r   r>   r   r   r   r}   rz   num_hidden_layerslayersre   upsampler   rh   r-   rv   r0   r     s   

 

zSEWEncoder.__init__NFTc              	   C   s  |rdnd }|r
dnd }|d ur| ddd|jd }| jjdkr5d|| < |d ur2d|v r2|nd }nfd|| < | d}	|	| jj }
|jd | jj }tj	d||
j
ddd|
jd d}||
ddk  }d	|d d d d d d f j|jd
 }|t|jj }||jd d|jd |jd }|jd }|dd}| |}| |}t|d|d}|dd |f |dd |f  }|dd}| |}| |}t pt| }| jD ]7}|r||f }tg }| jo|| jjk }|r|r||||d}|d }|rd}|r||d f }q|r%||f }| |}|jd |k rBtj |ddd||jd  f}|sQt!dd |||fD S t"|||dS )Nr/   r@   r   rJ   flash_attention_2r   r   device      ?dtype.r   NNc                 s   s    | ]	}|d ur|V  qd S r1   r/   )rt   vr/   r/   r0   	<genexpr>  s    z%SEWEncoder.forward.<locals>.<genexpr>last_hidden_stater3   
attentions)#	unsqueezerepeatr   r+   r   longsumrW   r   aranger   r   expandtor   finfominrA   r   r   ri   r>   r   r	   r
   r   randr   	layerdropr   r   r   padr   r   )r*   r3   r   r   output_hidden_statesreturn_dictall_hidden_statesall_self_attentionsexpand_attention_maskinput_lengthsoutput_lengthsmax_encoder_lengthattention_idsn_input_timestepsposition_embeddingspooled_hidden_states
min_lengthsynced_gpusrx   dropout_probabilityskip_the_layerlayer_outputsr/   r/   r0   r4     st   

&


 






 zSEWEncoder.forward)NFFTr6   r/   r/   r-   r0   r     s    r   c                   @   sb   e Zd ZU eed< dZdZdZdZdZ	dZ
dd Zdeejef fd	d
ZdedejfddZdS )SEWPreTrainedModelr+   sewr   TFc              	   C   s  t |tr)tjj|jjddtd|jj	d |jj
   d tj|jjd nt |tjr;|jjjd| jjd n}t |tjtjfrR|jj  |jjd nft |tjrt rddl}t|drt|d	r|jj|j|jgdd
 tj|jj W d   n1 sw   Y  n*|jj|jdd
 tj|jj W d   n1 sw   Y  ntj|jj t |tjtjfr|jdur|jj  dS dS dS )zInitialize the weightsr   rJ   r   )meanstdr   r   Nr`   r_   rN   )r   rI   r   initnormal_r&   rP   mathsqrtr   in_channels	constant_r   rf   datar+   initializer_ranger=   rF   zero_fill_r"   r	   rZ   rY   r[   r\   r`   r_   kaiming_normal_)r*   r   rZ   r/   r/   r0   _init_weights
  s8   
 z SEWPreTrainedModel._init_weightsr   c                 C   s4   dd }t | jj| jjD ]
\}}||||}q|S )zH
        Computes the output length of the convolutional layers
        c                 S   s   t j| | |ddd S )Nfloor)rounding_moder   )r   div)input_lengthr   r   r/   r/   r0   _conv_out_length/  s   zMSEWPreTrainedModel._get_feat_extract_output_lengths.<locals>._conv_out_length)zipr+   r#   r$   )r*   r   r  r   r   r/   r/   r0    _get_feat_extract_output_lengths*  s   z3SEWPreTrainedModel._get_feat_extract_output_lengthsfeature_vector_lengthr   c                 C   s~   |  |dtj}|jd }tj||f|j|jd}d|tj	|jd |jd|d f< |
dgd
dg }|S )Nr@   r   )r   r   r   r   )r   r   r   r   r   r   zerosr   r   r   flipcumsumr   )r*   r!  r   r   
batch_sizer/   r/   r0   "_get_feature_vector_attention_mask9  s   
"z5SEWPreTrainedModel._get_feature_vector_attention_maskN)r7   r8   r9   r   __annotations__base_model_prefixmain_input_namesupports_gradient_checkpointing_supports_flash_attn_supports_sdpa_supports_flex_attnr  r   r   
LongTensorr   r   r&  r/   r/   r/   r0   r
     s   
  r
  r   	mask_probmask_length	min_masksr   c                    s  | \}dk rt dkrt d d dtjd   fdd}|dur:| d	 n
fd
dt|D }tj	|ft
d}g }	|}
|
dkrZ|S |D ];}||}tjjt|d  |dd}t|dkr}d }n|d }t|tj|
| tjd| g}|	| q\t|	}	t|	dddddf ||
f}	|	||
 }	tddddf }t|||
f||
 }|	| }	|	 d krd |	|	d k< t||	dd	 |S )an  
    Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
    ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
    CPU as part of the preprocessing during training.

    Args:
        shape: The shape for which to compute masks. This should be of a tuple of size 2 where
               the first element is the batch size and the second element is the length of the axis to span.
        mask_prob:  The percentage of the whole axis (between 0 and 1) which will be masked. The number of
                    independently generated mask spans of length `mask_length` is computed by
                    `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
                    actual percentage will be smaller.
        mask_length: size of the mask
        min_masks: minimum number of masked spans
        attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
                        each batch dimension.
    r   z&`mask_length` has to be bigger than 0.zO`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: z and `sequence_length`: `c                    sX   t |     }t|}| kr }| d  |k r*t| d  d}|S )z;Given input length, compute how many spans should be maskedr   r   )r   max)r  num_masked_spanepsilonr0  r/  r1  sequence_lengthr/   r0   compute_num_masked_spanl  s   
z6_compute_mask_indices.<locals>.compute_num_masked_spanNr@   c                    s   g | ]} qS r/   r/   r   )r7  r/   r0   rw     s    z)_compute_mask_indices.<locals>.<listcomp>r   r   F)replace)r|   nprandomr   itemdetachr   tolistrz   r"  r   choicer   lenconcatenateonesint32appendarraybroadcast_torj   r3  put_along_axis)r   r/  r0  r   r1  r%  r8  r   spec_aug_maskspec_aug_mask_idxsmax_num_masked_spanr  r4  spec_aug_mask_idxdummy_mask_idxoffsetsr/   r5  r0   _compute_mask_indicesF  s\   

rN  c                       s   e Zd Zdef fddZ		ddejdeej deej fdd	Z	e
					dd
eej deej deej dee dee dee deeef fddZ  ZS )SEWModelr+   c                    s   t  | || _t|| _tj|jd |jd| _	|jd |j
k| _| jr1t|jd |j
| _t|j| _|jdksB|jdkrNtt|j
 | _t|| _|   d S )Nr@   r   r   )r   r   r+   rp   feature_extractorr   r=   r   r   r>   rT   project_featuresrf   feature_projectionr   feat_proj_dropoutfeature_dropoutmask_time_probmask_feature_prob	Parameterr   r   uniform_masked_spec_embedr   encoder	post_initrh   r-   r/   r0   r     s   

zSEWModel.__init__Nr3   mask_time_indicesr   c                 C   s  t | jdds	|S | \}}}|dur| j|j||< n-| jjdkrK| jrKt||f| jj| jj	|| jj
d}tj||jtjd}| j|j||< | jjdkr| jrt||f| jj| jj| jjd}tj||jtjd}|dddf d|d}d||< |S )	z
        Masks extracted features along time axis and/or along feature axis according to
        [SpecAugment](https://huggingface.co/papers/1904.08779).
        apply_spec_augmentTNr   )r/  r0  r   r1  )r   r   )r/  r0  r1  r@   )getattrr+   ri   rY  r   r   rU  r   rN  mask_time_lengthmask_time_min_masksr   tensorr   r   rV  mask_feature_lengthmask_feature_min_masksr   )r*   r3   r\  r   r%  r7  rT   mask_feature_indicesr/   r/   r0   _mask_hidden_states  s4   zSEWModel._mask_hidden_statesr   r   r   r   r   c           
      C   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}| |}|dd}| |}| jr6| |}| 	|}|durH| 
|jd |}| j||d}| j|||||d}	|	d }|sh|f|	dd  S t||	j|	jdS )a/  
        mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
            masked extracted features in *config.proj_codevector_dim* space.
        Nr   rJ   )r\  r   r   r   r   r   r   )r+   r   r   use_return_dictrP  rA   r>   rQ  rR  rT  r&  r   re  rZ  r   r3   r   )
r*   r   r   r\  r   r   r   extract_featuresr3   encoder_outputsr/   r/   r0   r4     s8   



zSEWModel.forwardr   NNNNN)r7   r8   r9   r   r   r   FloatTensorr   r.  re  r   r   r   r   r   r   r4   r:   r/   r/   r-   r0   rO    s@    
.
rO  zk
    SEW Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
    )custom_introc                       s   e Zd Zddee f fddZdd Zdd Zd	d
 Zdd Z	e
					ddeej deej dee dee dee deej deeef fddZ  ZS )	SEWForCTCNtarget_langc                    s~   t  | t|| _t|j| _|| _|j	du r#t
d| j dt|dr.|jr.|jn|j}t||j	| _|   dS )a-  
        target_lang (`str`, *optional*):
            Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
            adapter.<lang>.bin. Only relevant when using an instance of [`SEWForCTC`] with adapters. Uses 'eng' by
            default.
        NzYou are trying to instantiate z with a configuration that does not define the vocabulary size of the language model head. Please instantiate the model as follows: `SEWForCTC.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of your model's configuration.add_adapter)r   r   rO  r  r   r   final_dropoutr   rn  
vocab_sizer|   r.   rY   ro  output_hidden_sizerT   rf   lm_headr[  )r*   r+   rn  rr  r-   r/   r0   r   A  s   

zSEWForCTC.__init__c                 C   sv   | j }|durt| jdddu rtd| d|du r,t| jdddur,td dS |dur9| j|dd dS dS )a'  
        This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
        passing `target_lang=...` to `from_pretrained(...)`.

        This method is **not** supposed to be called by the user and is prone to be changed in the future.
        Nadapter_attn_dimzCannot pass `target_lang`: z- if `config.adapter_attn_dim` is not defined.z)By default `target_lang` is set to 'eng'.T)
force_load)rn  r^  r+   r|   loggerinfoload_adapter)r*   rn  r/   r/   r0   tie_weights^  s   zSEWForCTC.tie_weightsc                 C      t dt |   dS )
        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
        not be updated during training.
        The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. Please use the equivalent `freeze_feature_encoder` method instead.NwarningswarnFutureWarningfreeze_feature_encoderr*   r/   r/   r0   freeze_feature_extractors  
   z"SEWForCTC.freeze_feature_extractorc                 C      | j j  dS r{  Nr  rP  r   r  r/   r/   r0   r       z SEWForCTC.freeze_feature_encoderc                 C      | j  D ]}d|_qdS z
        Calling this function will disable the gradient computation for the base model so that its parameters will not
        be updated during training. Only the classification head will be updated.
        FNr  r   r   r   r/   r/   r0   freeze_base_model     zSEWForCTC.freeze_base_modelr   r   r   r   r   labelsr   c              
   C   s|  |dur|n| j j}|dur| | j jkrtd| j j | j|||||d}|d }| |}| |}	d}
|dur|durC|ntj	|tj
d}| |dtj
}|dk}|d}||}tjj|	dtjddd}tjjjd	d
 tjj||||| j j| j j| j jd}
W d   n1 sw   Y  |s|	f|td  }|
dur|
f| S |S t|
|	|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
            Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
            the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
            All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
            config.vocab_size - 1]`.
        Nz$Label values must be <= vocab_size: rf  r   r   r@   )rR   r   r   F)enabled)blank	reductionzero_infinitylosslogitsr3   r   )r+   rg  r3  rq  r|   r  r   rs  r   	ones_liker   r   r   r   masked_selectr   r   log_softmaxfloat32rA   backendscudnnflagsctc_losspad_token_idctc_loss_reductionctc_zero_infinity_HIDDEN_STATES_START_POSITIONr   r3   r   )r*   r   r   r   r   r   r  r   r3   r  r  r   labels_masktarget_lengthsflattened_targets	log_probsoutputr/   r/   r0   r4     sN   



zSEWForCTC.forwardr1   rj  )r7   r8   r9   r   r   r   ry  r  r  r  r   r   r   r   r   r   r   r4   r:   r/   r/   r-   r0   rm  ;  s6    
rm  z
    SEW Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
    SUPERB Keyword Spotting.
    c                       s   e Zd Z fddZdd Zdd Zdd Ze										dd
ee	j
 dee	j
 dee dee dee dee	j
 deeef fddZ  ZS )SEWForSequenceClassificationc                    s   t  | t|dr|jrtdt|| _|jd }|jr*t	
t|| | _t	|j|j| _t	|j|j| _|   d S )Nro  zZSequence classification does not support the use of SEW adapters (config.add_adapter=True)r   )r   r   rY   ro  r|   rO  r  r   use_weighted_layer_sumr   rW  r   rB  layer_weightsrf   rT   classifier_proj_size	projector
num_labels
classifierr[  )r*   r+   
num_layersr-   r/   r0   r     s   

z%SEWForSequenceClassification.__init__c                 C   rz  )z
        Calling this function will disable the gradient computation for the feature encoder so that its parameters will
        not be updated during training.
        r|  Nr}  r  r/   r/   r0   r    r  z5SEWForSequenceClassification.freeze_feature_extractorc                 C   r  r  r  r  r/   r/   r0   r    r  z3SEWForSequenceClassification.freeze_feature_encoderc                 C   r  r  r  r   r/   r/   r0   r    r  z.SEWForSequenceClassification.freeze_base_modelNr   r   r   r   r   r  r   c                 C   sz  |dur|n| j j}| j jrdn|}| j|||||d}| j jrB|t }tj|dd}tjj	| j
dd}	||	ddd jdd}n|d }| |}|du rV|jdd}
n+| |jd |}|ddd|jd }d	|| < |jdd|jdddd }
| |
}d}|durt }||d| j j|d}|s|f|td  }|dur|f| S |S t|||j|jd
S )a  
        input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
            into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
            (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
            To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
            into a tensor of type `torch.FloatTensor`. See [`SEWProcessor.__call__`] for details.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        NTrf  r   r   r@   r   rJ   r   r  )r+   rg  r  r  r  r   stackr   r   r   r  r   r   r  r  r&  r   r   r   r  r   r  r   r3   r   )r*   r   r   r   r   r   r  r   r3   norm_weightspooled_outputpadding_maskexpand_padding_maskr  r  loss_fctr  r/   r/   r0   r4   	  sH   

 
z$SEWForSequenceClassification.forwardrj  )r7   r8   r9   r   r  r  r  r   r   r   r   r   r   r   r   r4   r:   r/   r/   r-   r0   r    s4    
r  )rm  r  rO  r
  )Nr   Nrd   )Ar  r~  typingr   r   r   numpyr:  r   r   torch.nnr   activationsr   integrations.deepspeedr	   integrations.fsdpr
   modeling_flash_attention_utilsr   modeling_layersr   modeling_outputsr   r   r   modeling_utilsr   r   processing_utilsr   rX   r   r   configuration_sewr   
get_loggerr7   rv  r   r;   rB   ModulerI   rb   re   rp   r   r   r   r   r   r   r   r
  r   r   r.  ndarrayrN  rO  r  rm  r  __all__r/   r/   r/   r0   <module>   s   
+,
X$fI

wz s