o
    i[                     @   s  d dl Z d dlZd dlmZmZmZ d dlZd dlmZ ddlm	Z	 ddl
mZ ddlmZ ddlmZ dd	lmZmZmZ dd
lmZmZ ddlmZ ddlmZmZmZmZ ddlmZ ddl m!Z!m"Z" ddl#m$Z$m%Z% e&e'Z(			d-dej)dej*dej*dej*deej* dee+ de+deej* fddZ,G dd dej)Z-G dd  d eZ.eG d!d" d"eZ/ed#d$G d%d& d&e/Z0G d'd( d(ej)Z1ed)d$G d*d+ d+e/eZ2g d,Z3dS ).    N)CallableOptionalUnion)nn   )ACT2FN)Cache)GenerationMixin)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPastCausalLMOutputWithPast)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)TransformersKwargsauto_docstringcan_return_tuplelogging)check_model_inputs   )	AutoModelAutoModelForCausalLM   )VoxtralConfigVoxtralEncoderConfig        modulequerykeyvalueattention_maskscalingdropout	head_maskc                 K   s   |d u r| dd }t||dd| }	|d ur5|jdkr5|	|d d d d d d d |jd f  }	tjj|	dd}	|d urK|	|	dddd }	tjj
|	|| jd	}	t|	|}
|
dd }
|
|	fS )
N      r   r      )dimr   ptraining)sizetorchmatmul	transposendimshaper   
functionalsoftmaxviewr#   r,   
contiguous)r   r   r   r    r!   r"   r#   r$   kwargsattn_weightsattn_output r:   `/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/voxtral/modeling_voxtral.pyeager_attention_forward-   s   *r<   c                       s   e Zd ZdZ						ddededed	ed
ededee dee f fddZ	de
jdedefddZ			dde
jdee
j dee
j dedee
jee
j eee
j  f f
ddZ  ZS )VoxtralAttentionz=Multi-headed attention from 'Attention Is All You Need' paperr   FTN	embed_dim	num_headsr#   
is_decoderbias	is_causal	layer_idxconfigc	           	         s   t    || _|| _|| _|| | _|| _| j| | jkr*td| j d| d| jd | _|| _	|| _
|d u rG|rGtd| jj d || _tj||dd| _tj|||d| _tj|||d| _tj|||d| _d S )	Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).r&   zInstantiating a decoder z without passing `layer_idx` is not recommended and will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` when creating this class.FrA   )super__init__r>   r?   r#   head_dimrD   
ValueErrorr"   r@   rB   loggerwarning_once	__class____name__rC   r   Lineark_projv_projq_projout_proj)	selfr>   r?   r#   r@   rA   rB   rC   rD   rL   r:   r;   rG   N   s0   


zVoxtralAttention.__init__tensorseq_lenbszc                 C   s    | ||| j| jdd S )Nr   r   )r5   r?   rH   r0   r6   )rS   rU   rV   rW   r:   r:   r;   _shapev   s    zVoxtralAttention._shapehidden_statesr!   layer_head_maskoutput_attentionsreturnc                 K   s   |  \}}}| | || j ||}	| | |d|}
| | |d|}t}| jjdkr6t	| jj }|| |	|
||f| j
sBdn| jd||d|\}}|||d }| |}||fS )z#Input shape: Batch x Time x Channelr%   eagerr         ?)r#   r"   r[   r$   )r-   rX   rQ   r"   rO   rP   r<   rD   _attn_implementationr   r,   r#   reshaper6   rR   )rS   rY   r!   rZ   r[   r7   rW   tgt_len_query_states
key_statesvalue_statesattention_interfacer9   r8   r:   r:   r;   forwardy   s0   



zVoxtralAttention.forward)r   FTFNN)NNF)rM   
__module____qualname____doc__intfloatboolr   r   rG   r.   TensorrX   tuplerg   __classcell__r:   r:   rT   r;   r=   K   sP    	(r=   c                       sL   e Zd Zdef fddZ	ddejdejdejded	ejf
d
dZ  Z	S )VoxtralEncoderLayerrD   c                    s   t    |j| _t| j|j|j|d| _t	| j| _
|j| _t|j | _|j| _t| j|j| _t|j| j| _t	| j| _d S )N)r>   r?   r#   rD   )rF   rG   d_modelr>   r=   encoder_attention_headsattention_dropout	self_attnr   	LayerNormself_attn_layer_normr#   r   activation_functionactivation_fnactivation_dropoutrN   encoder_ffn_dimfc1fc2final_layer_normrS   rD   rT   r:   r;   rG      s   
zVoxtralEncoderLayer.__init__FrY   r!   rZ   r[   r\   c                 C   s   |}|  |}| j||||d\}}tjj|| j| jd}|| }|}| |}| | |}tjj|| j	| jd}| 
|}tjj|| j| jd}|| }|jtjkrgt|jjd }tj|| |d}||fS )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )rY   r!   rZ   r[   r*   i  )minmax)rw   ru   r   r3   r#   r,   r~   ry   r|   rz   r}   dtyper.   float16finfor   clamp)rS   rY   r!   rZ   r[   residualr8   clamp_valuer:   r:   r;   rg      s*   



zVoxtralEncoderLayer.forward)F)
rM   rh   ri   r   rG   r.   rn   rm   rg   rp   r:   r:   rT   r;   rq      s    rq   c                   @   sF   e Zd ZU eed< dZdZdZdZdZ	dZ
dZdZdZdZdd ZdS )VoxtralPreTrainedModelrD   modelTNpast_key_valuesc                 C   s   t | jdr
| jjn| jjj}t|tjtjfr0|jj	j
d|d |jd ur.|jj	  d S d S t|tjrE|jj	d |jj	  d S t|tjrd|jj	j
d|d |jd urf|jj	|j   d S d S d S )Ninitializer_ranger   )meanstdr^   )hasattrrD   r   audio_config
isinstancer   rN   Conv1dweightdatanormal_rA   zero_rv   fill_	Embeddingpadding_idx)rS   r   r   r:   r:   r;   _init_weights   s$   



z$VoxtralPreTrainedModel._init_weights)rM   rh   ri   r   __annotations__base_model_prefixsupports_gradient_checkpointing_no_split_modules_skip_keys_device_placement_supports_flash_attn_supports_sdpa_supports_flex_attn_supports_cache_class_supports_attention_backend_can_compile_fullgraphr   r:   r:   r:   r;   r      s   
 r   z:
    The Voxtral encoder, which is a Whisper encoder.
    )custom_introc                       s   e Zd ZU dZeed< dZdgZee	dZ
def fddZdd	 Zd
ejfddZdejfddZe	ddee fddZdejfddZ  ZS )VoxtralEncoderz
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`VoxtralEncoderLayer`].

    Args:
        config: VoxtralEncoderConfig
    rD   input_featuresrq   )
attentionsrY   c                    s   t     j| _ j| _ j} j| _ j| _ j	| _	 j
r%t|nd| _tj| j|ddd| _tj||dddd| _t| j	|| _| jd t fdd	t jD | _t j| _tjddd
| _d| _|   d S )Nr^   r   r   )kernel_sizepaddingr   )r   strider   Fc                    s   g | ]}t  qS r:   )rq   ).0rb   rD   r:   r;   
<listcomp>0  s    z+VoxtralEncoder.__init__.<locals>.<listcomp>)r   )rF   rG   r#   encoder_layerdrop	layerdroprr   num_mel_binspad_token_idr   max_source_positionsscale_embeddingmathsqrtembed_scaler   r   conv1conv2r   embed_positionsrequires_grad_
ModuleListrangeencoder_layerslayersrv   
layer_norm	AvgPool1d
avg_poolergradient_checkpointing	post_init)rS   rD   r>   rT   r   r;   rG     s"    zVoxtralEncoder.__init__c                 C   s   |   D ]}d|_qd| _d S )NF)
parametersrequires_grad_requires_grad)rS   paramr:   r:   r;   _freeze_parameters9  s   
z!VoxtralEncoder._freeze_parametersr\   c                 C   s   | j S Nr   rS   r:   r:   r;   get_input_embeddings>  s   z#VoxtralEncoder.get_input_embeddingsr    c                 C   s
   || _ d S r   r   rS   r    r:   r:   r;   set_input_embeddingsA     
z#VoxtralEncoder.set_input_embeddingsNr7   c                 K   s  | j j| jjd  | jjd  }|jd |kr(td| d|jd  d| d|j| jjj	| jjj
d}tj| |}tj| |}|ddd	}| jj}|| |j	}tjj|| j| jd
}t| jD ]\}}	|	||dd}
|
d }qj| |}t|dS )a  
        Args:
            input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
                Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
                obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
                `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
                `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
                and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
            attention_mask (`torch.Tensor`)`, *optional*):
                Voxtral does not support masking of the `input_features`, this argument is preserved for compatibility,
                but it is not used. By default the silence in the input log mel spectrogram are ignored.
        r   r%   z:Qwen2Audio expects the mel input features to be of length z, but found z-. Make sure to pad the input mel features to .)r   devicer   r   r*   N)r!   rZ   )last_hidden_state)rD   r   r   r   r   r2   rI   tor   r   r   r   r3   gelupermuter   r#   r,   	enumerater   r   r   )rS   r   r!   r7   expected_seq_lengthinputs_embeds	embed_posrY   idxencoder_layerlayer_outputsr:   r:   r;   rg   D  s.    

zVoxtralEncoder.forwardinput_lengthsc                 C   s(   |d d d }|d d d }||fS )zs
        Computes the output length of the convolutional layers and the output length of the audio encoder
        r   r   r:   )rS   r   output_lengthsr:   r:   r;    _get_feat_extract_output_lengthsu  s   z/VoxtralEncoder._get_feat_extract_output_lengthsr   )rM   rh   ri   rj   r   r   main_input_namer   r=   rq   _can_record_outputsrG   r   r   Moduler   r   r   r   r   rg   r.   
LongTensorr   rp   r:   r:   rT   r;   r     s$   
 	0r   c                       s*   e Zd Zdef fddZdd Z  ZS )VoxtralMultiModalProjectorrD   c                    sN   t    tj|jj|jjdd| _t	|j
 | _tj|jj|jjdd| _d S )NFrE   )rF   rG   r   rN   r   intermediate_sizetext_confighidden_sizelinear_1r   projector_hidden_actactlinear_2r   rT   r:   r;   rG     s   
z#VoxtralMultiModalProjector.__init__c                 C   s"   |  |}| |}| |}|S r   )r   r   r   )rS   audio_featuresrY   r:   r:   r;   rg     s   


z"VoxtralMultiModalProjector.forward)rM   rh   ri   r   rG   rg   rp   r:   r:   rT   r;   r   ~  s    r   zs
    The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a LLama language model.
    c                       s4  e Zd ZdgZddiZddgdgfiZdgZ fddZd	d
 Zdd Z	dd Z
dd Zdd Zdd ZdejfddZdejfddZee										d+deej deej deej deej dee d eej d!eej d"ee d#eej d$eeejf d%ee d&efd'd(Z fd)d*Z  Z S ),VoxtralForConditionalGenerationzlm_head.weightlm_headcolwise_reprY   logitsr   c                    sH   t  | |jj| _t|j| _t|j| _	t
|| _|   d S r   )rF   rG   r   
vocab_sizer   from_configr   audio_towerr   language_modelr   multi_modal_projectorr   r   rT   r:   r;   rG     s   

z(VoxtralForConditionalGeneration.__init__c                 C   
   | j  S r   )r   r   r   r:   r:   r;   r     r   z4VoxtralForConditionalGeneration.get_input_embeddingsc                 C      | j | d S r   )r   r   r   r:   r:   r;   r        z4VoxtralForConditionalGeneration.set_input_embeddingsc                 C   r   r   )r   get_output_embeddingsr   r:   r:   r;   r     r   z5VoxtralForConditionalGeneration.get_output_embeddingsc                 C   r   r   )r   set_output_embeddings)rS   new_embeddingsr:   r:   r;   r     r   z5VoxtralForConditionalGeneration.set_output_embeddingsc                 C   r   r   )r   set_decoder)rS   decoderr:   r:   r;   r     r   z+VoxtralForConditionalGeneration.set_decoderc                 C   r   r   )r   get_decoderr   r:   r:   r;   r     r   z+VoxtralForConditionalGeneration.get_decoderr   c                 C   s0   |  |}|j}|d| jjj}| |}|S )a  
        This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector.
        Args:
            input_features (`torch.FloatTensor`):
                Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
                obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
                `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
                `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
                and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]

        Returns:
            `torch.FloatTensor`:
                The audio embeddings.
        r%   )r   r   r`   rD   r   r   r   )rS   r   audio_outputsaudio_hidden_statesaudio_embedsr:   r:   r;   get_audio_features  s
   

z2VoxtralForConditionalGeneration.get_audio_featuresc                 C   s   t dt | |S )NzUThe method `get_audio_embeds` is deprecated. Please use `get_audio_features` instead.)warningswarnFutureWarningr  )rS   r   r:   r:   r;   get_audio_embeds  s   
z0VoxtralForConditionalGeneration.get_audio_embedsNr   	input_idsr!   position_idsr   r   labels	use_cachecache_positionlogits_to_keepr7   r\   c                 K   s   |du r
|   |}|dur.|dur.| |}|| jjkd}|||j||j}| jd|||||||	|
d|}|S )aj  
        Example:

        ```python
        >>> from transformers import VoxtralForConditionalGeneration, AutoProcessor
        >>> import torch

        >>> device = "cuda" if torch.cuda.is_available() else "cpu"
        >>> repo_id = "mistralai/Voxtral-Mini-3B-2507"

        >>> processor = AutoProcessor.from_pretrained(repo_id)
        >>> model = VoxtralForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map=device)

        >>> conversation = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "audio",
                        "url": "https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/dude_where_is_my_car.wav",
                    },
                    {"type": "text", "text": "What can you tell me about this audio?"},
                ],
            }
        ]

        >>> inputs = processor.apply_chat_template(conversation)
        >>> inputs = inputs.to(device, dtype=torch.bfloat16)

        >>> outputs = model.generate(**inputs, max_new_tokens=30)
        >>> processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
        ["This audio is a humorous conversation between two friends, likely in English, where one of them is trying to figure out what the other's tattoo says."]
        ```Nr%   )r!   r	  r   r   r
  r  r  r  r:   )	r   r  rD   audio_token_id	unsqueezemasked_scatterr   r   r   )rS   r  r   r!   r	  r   r   r
  r  r  r  r7   r  audio_token_maskoutputsr:   r:   r;   rg     s*   1
	z'VoxtralForConditionalGeneration.forwardc                    sH   | dd }|d}t j|i |}|d ur"|d dkr"||d< |S )Nr   r  r   )popgetrF   prepare_inputs_for_generation)rS   argsr7   r   r  model_inputsrT   r:   r;   r    s   
z=VoxtralForConditionalGeneration.prepare_inputs_for_generation)
NNNNNNNNNr   )!rM   rh   ri   _tied_weights_keys_tp_plan_pp_plan_keep_in_fp32_modules_strictrG   r   r   r   r   r   r   r.   FloatTensorr  r  r   r   r   r   rn   r   rm   r   rk   r   r   r   rg   r  rp   r:   r:   rT   r;   r     sh    
	
Hr   )r   r   r   )Nr   N)4r   r  typingr   r   r   r.   r   activationsr   cache_utilsr   
generationr	   modeling_layersr
   modeling_outputsr   r   r   modeling_utilsr   r   processing_utilsr   utilsr   r   r   r   utils.genericr   autor   r   configuration_voxtralr   r   
get_loggerrM   rJ   r   rn   rl   r<   r=   rq   r   r   r   r   __all__r:   r:   r:   r;   <module>   sh   
	
Z?#q 