o
    ih                     @   s  d Z ddlZddlmZ ddlmZmZmZ ddlZddlm	Z	 ddl
mZ ddlmZ dd	lmZmZ dd
lmZmZ ddlmZ ddlmZmZmZmZ ddlmZ ddlmZ ddlm Z m!Z! ddl"m#Z#m$Z$ G dd de	j%Z&G dd de	j%Z'G dd deZ(G dd de Z)G dd de	j%Z*G dd deZ+eG dd  d eZ,ed!d"G d#d$ d$e,Z-eG d%d& d&eZ.ed'd"G d(d) d)e,Z/g d*Z0dS )+zPyTorch Parakeet model.    N)	dataclass)CallableOptionalUnion)nn   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputCausalLMOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)ModelOutputTransformersKwargsauto_docstringcan_return_tuple)check_model_inputs   )%FastSpeech2ConformerConvolutionModule)LlamaAttentioneager_attention_forward   )ParakeetCTCConfigParakeetEncoderConfigc                       sL   e Zd ZU dZejed< d
def fddZe	 dejfdd	Z
  ZS )$ParakeetEncoderRelPositionalEncodingz*Relative positional encoding for Parakeet.inv_freqNconfigc                    sZ   t    |j| _d}d|tjd|jdtjdj|tjd|j   }| j	d|dd	 d S )
Ng     @      ?r   r   dtype)devicer    r   F)
persistent)
super__init__max_position_embeddingstorcharangehidden_sizeint64tofloatregister_buffer)selfr   r!   baser   	__class__ a/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/parakeet/modular_parakeet.pyr$   )   s   
 z-ParakeetEncoderRelPositionalEncoding.__init__hidden_statesc                 C   sF  |j d }|| jkrtd| d| j dtj|d | d|jd}| jd d d d f  |j d dd	|j}|d d d d f  }t
|jjtrW|jjdkrW|jjnd	}tj|d
d4 | |  dd}| }| }	tj||	gdd}
|
jg |
j d d dR  }
W d    n1 sw   Y  |
j	|jdS )Nr   zSequence Length: z= has to be less or equal than config.max_position_embeddings .r!   r   mpscpuF)device_typeenabledr   dimr   )shaper%   
ValueErrorr&   r'   r!   r   r+   expandr*   
isinstancetypestrautocast	transposesincosstackreshaper    )r-   r3   
seq_lengthposition_idsinv_freq_expandedposition_ids_expandedr9   freqsrF   rG   	pos_embedr1   r1   r2   forward7   s2   

. z,ParakeetEncoderRelPositionalEncoding.forwardN)__name__
__module____qualname____doc__r&   Tensor__annotations__r   r$   no_gradrP   __classcell__r1   r1   r/   r2   r   $   s   
 
r   c                       s*   e Zd Zdef fddZdd Z  ZS )ParakeetEncoderFeedForwardr   c                    sR   t    tj|j|j|jd| _t|j	 | _
tj|j|j|jd| _|j| _d S )Nbias)r#   r$   r   Linearr(   intermediate_sizeattention_biaslinear1r   
hidden_act
activationlinear2activation_dropoutr-   r   r/   r1   r2   r$   W   s
   
z#ParakeetEncoderFeedForward.__init__c                 C   s4   |  | |}tjj|| j| jd}| |}|S )Nptraining)rb   r`   r   
functionaldropoutrd   rh   rc   )r-   r3   r1   r1   r2   rP   ^   s   
z"ParakeetEncoderFeedForward.forward)rR   rS   rT   r   r$   rP   rY   r1   r1   r/   r2   rZ   V   s    rZ   c                       s$   e Zd Zddef fddZ  ZS ) ParakeetEncoderConvolutionModuleNr   c                    s   t  || d S rQ   )r#   r$   )r-   r   module_configr/   r1   r2   r$   f   s   z)ParakeetEncoderConvolutionModule.__init__rQ   )rR   rS   rT   r   r$   rY   r1   r1   r/   r2   rk   e   s    rk   c                       sr   e Zd ZdZdedef fddZ	ddejde	ej d	e	ej d
e
e deejejf f
ddZdd Z  ZS )ParakeetEncoderAttentionztMulti-head attention with relative positional encoding. See section 3.3 of https://huggingface.co/papers/1901.02860.r   	layer_idxc                    sf   t  j||d d| _tj|j|j| j dd| _t	t
|j| j| _t	t
|j| j| _d S )N)rn   Fr[   )r#   r$   	is_causalr   r]   r(   num_attention_headshead_dimrelative_k_proj	Parameterr&   zerosbias_ubias_vr-   r   rn   r/   r1   r2   r$   m   s
   z!ParakeetEncoderAttention.__init__Nr3   position_embeddingsattention_maskkwargsreturnc              	   K   s  |j d d }|\}}||d| jf}| ||dd}	| ||dd}
| ||dd}t}| jj	dkrDt
| jj	 }|	| jd| jjd| j }|	| jd| jjd| j }| |}||d| jj| j}||dddd }| |}|dd |f }|| j }|d ur|| td}|| f||
||| jsd	n| j| jd
|\}}|jg |dR   }| |}||fS )Nr5   r   r   eagerr   r   .z-inf        )querykeyvaluery   rj   scaling)r>   rq   q_projviewrE   k_projv_projr   r   _attn_implementationr   ru   rp   rv   rr   permute
_rel_shiftr   masked_fill_logical_notr+   rh   attention_dropoutrI   
contiguouso_proj)r-   r3   rx   ry   rz   input_shape
batch_sizerJ   hidden_shapequery_states
key_statesvalue_statesattention_interfacequery_states_with_bias_uquery_states_with_bias_vrelative_key_states	matrix_bdattn_outputattn_weightsr1   r1   r2   rP   w   sL   




z ParakeetEncoderAttention.forwardc                 C   sX   |j \}}}}tjj|dd}|||d|}|ddddddf ||||}|S )ztRelative position shift for Shaw et al. style attention. See appendix B of https://huggingface.co/papers/1901.02860.)r   r   )padr5   Nr   )r>   r   ri   r   r   )r-   attention_scoresr   	num_headsquery_lengthposition_lengthr1   r1   r2   r      s
   &z#ParakeetEncoderAttention._rel_shiftrQ   )rR   rS   rT   rU   r   intr$   r&   rV   r   r   r   tuplerP   r   rY   r1   r1   r/   r2   rm   j   s     
9rm   c                       sP   e Zd Zdef fddZdejdejfddZ	dd	ejd
ejfddZ
  ZS ) ParakeetEncoderSubsamplingConv2Dr   c                    s  t    |j| _|j| _|j| _| jd d | _t	t
|j| _t | _| jtjd| j| j| j| jd | jt  t| jd D ]-}| jtj| j| j| j| j| j| jd | jtj| j| jdd | jt  qH|j| j| j  }tj|j| |jdd| _d S )Nr   r   )kernel_sizestridepadding)r   r   r   groupsr   Tr[   )r#   r$   subsampling_conv_kernel_sizer   subsampling_conv_strider   subsampling_conv_channelschannelsr   r   mathlog2subsampling_factor
num_layersr   
ModuleListlayersappendConv2dReLUrangenum_mel_binsr]   r(   linear)r-   r   i
out_lengthr/   r1   r2   r$      s4   

z)ParakeetEncoderSubsamplingConv2D.__init__input_lengths
conv_layerc                 C   sV   t |dr)|jdkr)|j}|jd }|jd }||d  |d  | | d }|S |S )Nr   )r   r   r   r   )hasattrr   r   r   )r-   r   r   r   r   r   output_lengthsr1   r1   r2   _get_output_length   s   

 z3ParakeetEncoderSubsamplingConv2D._get_output_lengthNinput_featuresry   c                 C   s   | d}|d ur|dnd }| jD ]9}||}t|tjrL|d urL| ||}|jd }tj	||j
d|d d d f k }||d d d d d d f 9 }q|dd|jd |jd d}| |}|S )Nr   r5   r   r6   r   )	unsqueezesumr   rA   r   r   r   r>   r&   r'   r!   rE   rI   r   )r-   r   ry   r3   current_lengthslayercurrent_seq_lengthchannel_maskr1   r1   r2   rP      s   


"
z(ParakeetEncoderSubsamplingConv2D.forwardrQ   )rR   rS   rT   r   r$   r&   rV   r   r   r   rP   rY   r1   r1   r/   r2   r      s    # r   c                       sd   e Zd Zddedee f fddZ		ddejdeej deej d	e	e
 d
ejf
ddZ  ZS )ParakeetEncoderBlockNr   rn   c                    s   t    d| _t|| _t||| _t|| _t|| _	t
|j| _t
|j| _t
|j| _t
|j| _t
|j| _d S NF)r#   r$   gradient_checkpointingrZ   feed_forward1rm   	self_attnrk   convfeed_forward2r   	LayerNormr(   norm_feed_forward1norm_self_att	norm_convnorm_feed_forward2norm_outrw   r/   r1   r2   r$      s   



zParakeetEncoderBlock.__init__r3   ry   rx   rz   r{   c                 K   s   |}|  | |}|d|  }| |}| jd|||d|\}}|| }| j| ||d}	||	 }| | |}
|d|
  }| |}|S )Ng      ?)r3   ry   rx   )ry   r1   )	r   r   r   r   r   r   r   r   r   )r-   r3   ry   rx   rz   residualnormalized_hidden_statesr   _conv_output
ff2_outputr1   r1   r2   rP     s$   


zParakeetEncoderBlock.forwardrQ   NN)rR   rS   rT   r   r   r   r$   r&   rV   r   r   rP   rY   r1   r1   r/   r2   r      s    r   c                       s   e Zd ZU eed< dZdZdZdgZdZ	dZ
dZdZdZdZeedZ fdd	Zd
ejfddZddejdee fddZ  ZS )ParakeetPreTrainedModelr   modelr   Tr   F)r3   
attentionsc                    sj   t  | t| jdr| jj}n	t| j dd}t|tr3|j	j
jd|d |jj
jd|d d S d S )Ninitializer_rangeg{Gz?r}   )meanstd)r#   _init_weightsr   r   r   getattrget_text_configrA   rm   ru   datanormal_rv   )r-   moduler   r/   r1   r2   r   B  s   

z%ParakeetPreTrainedModel._init_weightsr   c           
      C   s   t | jtr
| jjn| j}|j}|j}tt|j	}|d d d }|| }|}t
|D ]}	t|jtjd| |d }t|}q-|jtjdS )Nr   r   r   r   )rA   r   r   encoder_configr   r   r   r   r   r   r   r&   divr*   r+   floor)
r-   r   r   r   r   r   all_paddingsadd_padlengthsr   r1   r1   r2   _get_subsampling_output_lengthP  s   z6ParakeetPreTrainedModel._get_subsampling_output_lengthNry   target_lengthc                 C   sH   |  |d}|dur|n| }tj||jd|dddf k }|S )z
        Convert the input attention mask to its subsampled form. `target_length` sets the desired output length, useful
        when the attention mask length differs from `sum(-1).max()` (i.e., when the longest sequence in the batch is padded)
        r5   Nr6   )r   r   maxr&   r'   r!   )r-   ry   r   r   
max_lengthr1   r1   r2   _get_output_attention_maska  s    z2ParakeetPreTrainedModel._get_output_attention_maskrQ   )rR   rS   rT   r   rW   base_model_prefixmain_input_namesupports_gradient_checkpointing_no_split_modules_supports_flat_attention_mask_supports_sdpa_supports_flex_attn_supports_flash_attn_can_compile_fullgraph_supports_attention_backendr   rm   _can_record_outputsr   r&   rV   r   r   r   r   rY   r1   r1   r/   r2   r   -  s$   
 "r   z{
    The Parakeet Encoder model, based on the [Fast Conformer architecture](https://huggingface.co/papers/2305.05084).
    )custom_introc                       sf   e Zd ZU eed< dZdef fddZeee		dde
jdee
j dee d	efd
dZ  ZS )ParakeetEncoderr   encoderc                    s   t     | _d| _ j| _ j| _ j| _ jr!t	 j
nd| _t | _t | _t fddt jD | _|   d S )NFr   c                    s   g | ]}t  |qS r1   )r   ).0rn   r   r1   r2   
<listcomp>  s    z,ParakeetEncoder.__init__.<locals>.<listcomp>)r#   r$   r   r   rj   dropout_positions	layerdropscale_inputr   sqrtr(   input_scaler   subsamplingr   encode_positionsr   r   r   num_hidden_layersr   	post_initre   r/   r   r2   r$   v  s   

zParakeetEncoder.__init__Nr   ry   rz   r{   c           	      K   s   |  ||}|| j }| |}tjj|| j| jd}tjj|| j| jd}|durN| j||j	d d}|
dd|j	d d}||dd@ }|
d}| jD ] }d}| jrdtg }|| jk rdd}|sq||f||d	|}qQt|d
S )a  
        Example:

        ```python
        >>> from transformers import AutoProcessor, ParakeetEncoder
        >>> from datasets import load_dataset, Audio

        >>> model_id = "nvidia/parakeet-ctc-1.1b"
        >>> processor = AutoProcessor.from_pretrained(model_id)
        >>> encoder = ParakeetEncoder.from_pretrained(model_id)

        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))

        >>> inputs = processor(ds[0]["audio"]["array"])
        >>> encoder_outputs = encoder(**inputs)

        >>> print(encoder_outputs.last_hidden_state.shape)
        ```
        rf   Nr   r   r5   r   FT)ry   rx   )last_hidden_state)r  r  r  r   ri   rj   rh   r  r   r>   r   r@   rE   r   r&   randr  r
   )	r-   r   ry   rz   r3   rx   encoder_layerto_dropdropout_probabilityr1   r1   r2   rP     s:   







zParakeetEncoder.forwardrQ   )rR   rS   rT   r   rW   r   r$   r   r   r   r&   rV   r   r   r   r
   rP   rY   r1   r1   r/   r2   r   m  s"   
 r   c                   @   sf   e Zd ZU dZejed< dZee	ej
  ed< dZee	e	ej
   ed< dZee	e	ej
   ed< dS )ParakeetGenerateOutputal  
    Outputs of Parakeet models.

    Args:
        sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
        logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
            Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
            `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
    	sequencesNlogitsr   r3   )rR   rS   rT   rU   r&   
LongTensorrW   r  r   r   FloatTensorr   r3   r1   r1   r1   r2   r    s   
 
r  zS
    Parakeet Encoder with a Connectionist Temporal Classification (CTC) head.
    c                       s   e Zd ZU eed< def fddZee		ddej	de
ej	 de
ej	 dee d	ef
d
dZe 		ddej	de
ej	 dedee d	eeejf f
ddZ  ZS )ParakeetForCTCr   c                    s<   t  | t|j| _tj|jj|jdd| _	| 
  d S )Nr   r   )r#   r$   r   r   r   r   Conv1dr(   
vocab_sizectc_headr	  re   r/   r1   r2   r$     s   zParakeetForCTC.__init__Nr   ry   labelsrz   r{   c              
   K   s  | j d||d|}|j}| |dddd}d}|dur|dur'|ntj|tjd}| |d}	|| j	j
k}
|
d}||
}tjj|dtjddd}tjjjd	d
 tjj|||	|| j	j
| j	j| j	jd}W d   n1 s{w   Y  t|||j|jdS )a  
        Example:

        ```python
        >>> from transformers import AutoProcessor, ParakeetForCTC
        >>> from datasets import load_dataset, Audio

        >>> model_id = "nvidia/parakeet-ctc-1.1b"
        >>> processor = AutoProcessor.from_pretrained(model_id)
        >>> model = ParakeetForCTC.from_pretrained(model_id)

        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))

        >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
        >>> outputs = model(**inputs)

        >>> print(outputs.loss)
        ```r   ry   r   r   Nr   r5   )r<   r    r   F)r:   )blank	reductionzero_infinity)lossr  r3   r   r1   )r   r  r  rE   r&   	ones_likelongr   r   r   pad_token_idmasked_selectr   ri   log_softmaxfloat32backendscudnnflagsctc_lossctc_loss_reductionctc_zero_infinityr   r3   r   )r-   r   ry   r  rz   encoder_outputsr3   r  r  r   labels_masktarget_lengthsflattened_targets	log_probsr1   r1   r2   rP     sD   

zParakeetForCTC.forwardFreturn_dict_in_generatec                 K   st   d|d< | j d
||d|}|jjdd}|dur+| j||jd d}| jj|| < |r8t||j|j|j	d	S |S )a3  
        Example:

        ```python
        >>> from transformers import AutoProcessor, ParakeetForCTC
        >>> from datasets import load_dataset, Audio

        >>> model_id = "nvidia/parakeet-ctc-1.1b"
        >>> processor = AutoProcessor.from_pretrained(model_id)
        >>> model = ParakeetForCTC.from_pretrained(model_id)

        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))

        >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
        >>> predicted_ids = model.generate(**inputs)
        >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)

        >>> print(transcription)
        ```
        Treturn_dictr  r5   r;   Nr   r
  )r  r  r   r3   r1   )
rP   r  argmaxr   r>   r   r!  r  r   r3   )r-   r   ry   r0  rz   outputsr  r1   r1   r2   generate=  s&   zParakeetForCTC.generater   r   )rR   rS   rT   r   rW   r$   r   r   r&   rV   r   r   r   r   rP   rX   boolr   r  r  r4  rY   r1   r1   r/   r2   r    s@   
 Gr  )r  r   r   )1rU   r   dataclassesr   typingr   r   r   r&   r   activationsr   modeling_layersr	   modeling_outputsr
   r   modeling_utilsr   r   processing_utilsr   utilsr   r   r   r   utils.genericr   4fastspeech2_conformer.modeling_fastspeech2_conformerr   llama.modeling_llamar   r   configuration_parakeetr   r   Moduler   rZ   rk   rm   r   r   r   r   r  r  __all__r1   r1   r1   r2   <module>   sH   2OE/?W 