o
    ei                     @   s  d dl mZ d dlmZ d dlmZ d dlZd dlmZ ddl	m
Z ddlmZ ddlmZmZ dd	lmZ dd
lmZmZmZ ddlmZ ddlmZ ddlmZmZ ddlmZm Z  ddl!m"Z"m#Z# ddl$m%Z% ddl&m'Z'm(Z(m)Z)m*Z*m+Z+ ddl,m-Z-m.Z. ddl/m0Z0 ddl1m2Z2 ddl3m4Z4 ddl5m6Z6m7Z7 ddl8m9Z9 e+:e;Z<ee)ddG dd de'Z=edG dd  d ej>Z?G d!d" d"ej>Z@G d#d$ d$ej>ZAd%d& ZBed'dPd(d)ZCd*ejDd+eEd,ejDfd-d.ZF	/dQd0ej>d1ejDd2ejDd3ejDd4ejDdB d5eGd6eGd7e%e( fd8d9ZHeeCG d:d; d;ej>ZIG d<d= d=eZJe)d>de)G d?d@ d@e#ZKe)G dAdB dBeKZLG dCdD dDej>ZMe)dEdG dFdG dGeKeZNG dHdI dIej>ZOe)G dJdK dKeKZPe)dLdG dMdN dNeKe9ZQg dOZRdS )R    )Callable)	dataclass)OptionalN   )initialization)ACT2FN)CacheDynamicCache)GenerationMixin)use_kernel_forward_from_hubuse_kernel_func_from_hubuse_kernelized_func)create_causal_mask)GradientCheckpointingLayer)BaseModelOutputWithPastCausalLMOutputWithPast)ROPE_INIT_FUNCTIONSdynamic_rope_update)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)ModelOutputTransformersKwargsauto_docstringcan_return_tuplelogging)maybe_autocastmerge_with_config_defaults)is_torchdynamo_compiling)capture_outputs   )	AutoModel   )	CsmConfigCsmDepthDecoderConfig)CsmGenerationMixinz:
    Base class for the model autoregressive outputs.
    )custom_introc                   @   s   e Zd ZU dZdZejdB ed< dZejdB ed< dZ	e
dB ed< dZeejdf dB ed< dZeejdf dB ed< dZejdB ed	< dZejdB ed
< dZe
dB ed< dZeejdf dB ed< dZeejdf dB ed< dZejdB ed< dS )CsmOutputWithPasta	  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss (for next-token prediction).
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).

        Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
        `past_key_values` input) to speed up sequential decoding.
    depth_decoder_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss (for next-token prediction) of the depth decoder model.
    depth_decoder_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the depth decoder (scores for each vocabulary token before SoftMax).
    depth_decoder_past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
    depth_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
        one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

        Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
    depth_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.
    backbone_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss (for next-token prediction) of the backbone model.
    Nlosslogitspast_key_values.hidden_states
attentionsdepth_decoder_lossdepth_decoder_logitsdepth_decoder_past_key_valuesdepth_decoder_hidden_statesdepth_decoder_attentionsbackbone_loss)__name__
__module____qualname____doc__r(   torchFloatTensor__annotations__r)   r*   r   r+   tupler,   r-   r.   r/   r0   r1   r2    r;   r;   b/home/ubuntu/transcripts/venv/lib/python3.10/site-packages/transformers/models/csm/modeling_csm.pyr'   3   s   
 r'   RMSNormc                       sF   e Zd Zddeddf fddZdejdejfdd	Zd
d Z  Z	S )
CsmRMSNormư>epsreturnNc                    s&   t    tt|| _|| _dS )z9
        CsmRMSNorm is equivalent to T5LayerNorm
        N)super__init__nn	Parameterr7   onesweightvariance_epsilon)selfhidden_sizer@   	__class__r;   r<   rC   e   s   

zCsmRMSNorm.__init__r+   c                 C   sJ   |j }|tj}|djddd}|t|| j  }| j|| S )Nr    T)keepdim)	dtypetor7   float32powmeanrsqrtrH   rG   )rI   r+   input_dtypevariancer;   r;   r<   forwardm   s
   zCsmRMSNorm.forwardc                 C   s   t | jj d| j S )Nz, eps=)r:   rG   shaperH   rI   r;   r;   r<   
extra_reprt   s   zCsmRMSNorm.extra_repr)r?   )
r3   r4   r5   floatrC   r7   TensorrW   rZ   __classcell__r;   r;   rK   r<   r>   c   s    r>   c                       s~   e Zd ZU ejed< ddef fddZe			ddedB de	d de
dB d	ed
ef fddZe edd Z  ZS )CsmRotaryEmbeddinginv_freqNconfigc                    s   t    |j| _|j| _|| _| jjd | _| j}| jdkr$t	| j }|| j|\}| _
| jd|dd | jd| dd d S )N	rope_typedefaultr_   F
persistentoriginal_inv_freq)rB   rC   max_position_embeddingsmax_seq_len_cachedoriginal_max_seq_lenr`   rope_parametersra   compute_default_rope_parametersr   attention_scalingregister_bufferclone)rI   r`   devicerope_init_fnr_   rK   r;   r<   rC   {   s   


zCsmRotaryEmbedding.__init__rn   ztorch.deviceseq_lenrA   ztorch.Tensorc                 C   sZ   | j d }t| ddp| j| j }d}d|tjd|dtjdj|tjd|   }||fS )	a  
        Computes the inverse frequencies according to the original RoPE implementation
        Args:
            config ([`~transformers.PreTrainedConfig`]):
                The model configuration.
            device (`torch.device`):
                The device to use for initialization of the inverse frequencies.
            seq_len (`int`, *optional*):
                The current sequence length. Unused for this type of RoPE.
        Returns:
            Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
            post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
        
rope_thetahead_dimNg      ?r   r    rO   rn   rO   )	ri   getattrrJ   num_attention_headsr7   arangeint64rP   r[   )r`   rn   rp   basedimattention_factorr_   r;   r;   r<   rj      s   
&z2CsmRotaryEmbedding.compute_default_rope_parametersc           
      C   s   | j d d d d f  |jd dd|j}|d d d d d f  }t|jjtr6|jjdkr6|jjnd}t	|dd+ | |  
dd}tj||fdd	}| | j }| | j }	W d    n1 slw   Y  |j|jd
|	j|jd
fS )Nr   rM   r"   mpscpuF)device_typeenabledr    rz   rs   )r_   r[   expandrX   rP   rn   
isinstancetypestrr   	transposer7   catcosrk   sinrO   )
rI   xposition_idsinv_freq_expandedposition_ids_expandedr~   freqsembr   r   r;   r;   r<   rW      s   0&zCsmRotaryEmbedding.forwardN)NNN)r3   r4   r5   r7   r\   r9   r#   rC   staticmethodr   intr:   r[   rj   no_gradr   rW   r]   r;   r;   rK   r<   r^   x   s&   
 

r^   c                       $   e Zd Z fddZdd Z  ZS )CsmMLPc                    sx   t    || _|j| _|j| _tj| j| j|jd| _tj| j| j|jd| _	tj| j| j|jd| _
t|j | _d S )Nbias)rB   rC   r`   rJ   intermediate_sizerD   Linearmlp_bias	gate_projup_proj	down_projr   
hidden_actact_fnrI   r`   rK   r;   r<   rC      s   
zCsmMLP.__init__c                 C   s$   |  | | || | }|S r   )r   r   r   r   )rI   r   r   r;   r;   r<   rW      s    zCsmMLP.forwardr3   r4   r5   rC   rW   r]   r;   r;   rK   r<   r      s    
r   c                 C   sH   | dd| j d d f }| d| j d d df }tj| |fddS )z*Rotates half the hidden dims of the input..NrM   r    r   )rX   r7   r   )r   x1x2r;   r;   r<   rotate_half   s   r   rotary_pos_embc                 C   sD   | |}| |}| | t| |  }|| t||  }||fS )a  Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    )	unsqueezer   )qkr   r   unsqueeze_dimq_embedk_embedr;   r;   r<   apply_rotary_pos_emb   s
   

r   r+   n_reprA   c                 C   s^   | j \}}}}|dkr| S | dddddddddf |||||} | ||| ||S )z
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    r"   N)rX   r   reshape)r+   r   batchnum_key_value_headsslenrr   r;   r;   r<   	repeat_kv   s
   0r           modulequerykeyvalueattention_maskscalingdropoutkwargsc                 K   s   t || j}t || j}	t||dd| }
|d ur |
| }
tjj|
dtjd	|j
}
tjj|
|| jd}
t|
|	}|dd }||
fS )Nr    r   rM   )rz   rO   )ptrainingr"   )r   num_key_value_groupsr7   matmulr   rD   
functionalsoftmaxrQ   rP   rO   r   r   
contiguous)r   r   r   r   r   r   r   r   
key_statesvalue_statesattn_weightsattn_outputr;   r;   r<   eager_attention_forward   s   
r   c                       s   e Zd ZdZdedef fddZ				ddejde	ejejf dB d	ejdB d
e
dB dejdB dee de	ejejf fddZ  ZS )CsmAttentionz=Multi-headed attention from 'Attention Is All You Need' paperr`   	layer_idxc                    s   t    || _|| _t|d|j|j | _|j|j | _	| jd | _
|j| _d| _tj|j|j| j |jd| _tj|j|j| j |jd| _tj|j|j| j |jd| _tj|j| j |j|jd| _d S )Nrr   g      Tr   )rB   rC   r`   r   ru   rJ   rv   rr   r   r   r   attention_dropout	is_causalrD   r   attention_biasq_projk_projv_projo_projrI   r`   r   rK   r;   r<   rC     s(   
zCsmAttention.__init__Nr+   position_embeddingsr   r*   cache_positionr   rA   c                 K   s  |j d d }g |d| jR }| ||dd}	| ||dd}
| ||dd}|\}}t|	|
||\}	}
|d urW|||d}||
|| j	|\}
}t
| jjt}|| |	|
||f| jskdn| j| jd|\}}|jg |dR   }| |}||fS )NrM   r"   r    )r   r   r   r   )r   r   )rX   rr   r   viewr   r   r   r   updater   r   get_interfacer`   _attn_implementationr   r   r   r   r   r   r   )rI   r+   r   r   r*   r   r   input_shapehidden_shapequery_statesr   r   r   r   cache_kwargsattention_interfacer   r   r;   r;   r<   rW   *  s8   	

zCsmAttention.forwardNNNN)r3   r4   r5   r6   r#   r   rC   r7   r\   r:   r   
LongTensorr   r   rW   r]   r;   r;   rK   r<   r     s,    r   c                       s   e Zd Zdedef fddZ						ddejdejdB d	ejdB d
e	dB de
dB dejdB deejejf dB dee dejfddZ  ZS )CsmDecoderLayerr`   r   c                    sR   t    |j| _t||d| _t|| _t|j|jd| _	t|j|jd| _
d S )N)r`   r   r@   )rB   rC   rJ   r   	self_attnr   mlpr>   rms_norm_epsinput_layernormpost_attention_layernormr   rK   r;   r<   rC   W  s   

zCsmDecoderLayer.__init__NFr+   r   r   r*   	use_cacher   r   r   rA   c              
   K   s^   |}	|  |}| jd|||||||d|\}}
|	| }|}	| |}| |}|	| }|S )N)r+   r   r   r*   r   r   r   r;   )r   r   r   r   )rI   r+   r   r   r*   r   r   r   r   residual_r;   r;   r<   rW   a  s&   




zCsmDecoderLayer.forward)NNNFNN)r3   r4   r5   r#   r   rC   r7   r\   r   r   boolr:   r   r   rW   r]   r;   r;   rK   r<   r   V  s6    	
r   z[
    The bare Csm Model outputting raw hidden-states without any specific head on top.
    c                       s`   e Zd ZU eed< dZdZdZdgZdgZ	dZ
dZdZdZeedZe  fdd	Z  ZS )
CsmPreTrainedModelr`   model)audiotextTr   r*   )r+   r,   c                    sz   t  | t|tr$|j}t|d D ]}tj|jd| j	j
d qd S t|tr;t|jt| j	j| j	j  d S d S )Nr"   r   )rS   std)rB   _init_weightsr   CsmCodebooksHeadnum_codebooksrangeinitnormal_rG   r`   initializer_rangeCsmBackboneModelEmbeddingscopy_audio_tokens_offsetsr7   rw   
vocab_size)rI   r   r   irK   r;   r<   r     s   

$z CsmPreTrainedModel._init_weights)r3   r4   r5   r#   r9   base_model_prefixinput_modalitiessupports_gradient_checkpointing_no_split_modules_skip_keys_device_placement_supports_flash_attn_supports_sdpa_can_compile_fullgraph_supports_attention_backendr   r   _can_record_outputsr7   r   r   r]   r;   r;   rK   r<   r     s    
 r   c                       s   e Zd ZU eed<  fddZeee								dde	j
dB de	jdB de	jdB de	j
dB d	edB d
e	jdB dedB de	j
dB dee deeB fddZ  ZS )CsmDepthDecoderModelr`   c                    s   t     j| _ j| _t j j  j| _	t
 fddt jD | _t j jd| _t d| _d| _tj j jdd| _|   d S )Nc                       g | ]}t  |qS r;   r   .0r   r`   r;   r<   
<listcomp>      z1CsmDepthDecoderModel.__init__.<locals>.<listcomp>r   r	  Fr   )rB   rC   pad_token_idpadding_idxr   rD   	Embeddingr   backbone_hidden_sizeembed_tokens
ModuleListr   num_hidden_layerslayersr>   rJ   r   normr^   
rotary_embgradient_checkpointingr   inputs_embeds_projector	post_initr   rK   r	  r<   rC     s   zCsmDepthDecoderModel.__init__N	input_idsbackbone_last_hidden_stater   r   r*   inputs_embedsr   r   r   rA   c	              
   K   s  |durt  std d}|du |duA rtd|r&|du r&t| jd}|du rV|dur2| nd}
|dur=|jd n|jd }|durI|jn|j}t	j
|
|
| |d}|du rt	j|d dd}|| j }| || }|d dk}|dur||dddf< n
t  s|rtd	 | |}t| j|||||d
}|}|d}| j||d}| jd| jj D ]}||f||||||d|	}q| |}t||r|dS ddS )aJ  
        backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
            The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
            is provided in the `input_ids` argument.
        NzCustom `position_ids` were provided but will be ignored. CSM depth decoder automatically determines position_ids from `cache_position` and as it requires them to be identical across the batch, the provided position_ids will be ignored.z;You must specify exactly one of input_ids or inputs_embeds.r	  r   r"   rn   )minzvWhen the first codebook token is provided, `backbone_last_hidden_state` should also be provided for correct inference.r`   r  r   r   r*   r   r   )r   r   r*   r   r   r   last_hidden_stater*   )r   loggerwarning_once
ValueErrorr	   r`   get_seq_lengthrX   rn   r7   rw   clampr   r  warningr  r   r   r  r  r  r  r   )rI   r  r  r   r   r*   r  r   r   r   past_seen_tokensinputs_seq_lengthrn   codebook_idxsoffsetinput_ids_are_first_codebookcausal_maskr+   r   decoder_layerr;   r;   r<   rW     sr   


	

zCsmDepthDecoderModel.forward)NNNNNNNN)r3   r4   r5   r$   r9   rC   r   r   r   r7   r   r8   r\   r   r   r   r   r:   r   rW   r]   r;   r;   rK   r<   r    sF   
 	
r  c                       s&   e Zd Z fddZdddZ  ZS )r   c                    s0   t    || _tt| jd ||| _d S )Nr"   )rB   rC   r   rD   rE   r7   emptyrG   )rI   rJ   r   r   rK   r;   r<   rC     s   
 zCsmCodebooksHead.__init__Nc                    sf   |d u rj d }| jt|  n	|d }| j|   fddt j d D tjddS )Nr"   c              	      s2   g | ]}t jd d |d d f  | jqS r   )rD   r   linearT)r  codebook_idxcodebook_weightr+   r;   r<   r
  !  s    $z,CsmCodebooksHead.forward.<locals>.<listcomp>r   r   )rX   rG   r7   rw   r   stack)rI   r+   r   
seq_lengthr*  r;   r3  r<   rW     s   

zCsmCodebooksHead.forwardr   r   r;   r;   rK   r<   r     s    r   a$  
    The CsmDepthDecoder Model transformer, with a [`CsmCodebooksHead`] on top,
    which can be seen a position-specific language modeling head, allowing to use a different linear layer for each codebook
    (e.g. position 0 is the first codebook and uses the first codebook head, etc.)
    c                       s  e Zd ZdZdZdZ fddZee										dde	j
dB de	jdB de	jdB de	j
dB d	edB d
e	jdB de	j
dB dedB de	j
dB dee	jB dee deeB fddZ				dde	j
d	edB de	j
dB d
e	jdB de	j
dB f
 fddZ  ZS )CsmDepthDecoderForCausalLMNc                    s>   t  | t|| _|j| _t|j|j|j| _| 	  d S r   )
rB   rC   r  r   r   r   rJ   r   codebooks_headr  r   rK   r;   r<   rC   6  s
   
z#CsmDepthDecoderForCausalLM.__init__r   r  r  r   r   r*   r  labelsr   r   logits_to_keepr   rA   c                 K   s   | j d||||||||	d|}|d }t|
tr+|
dkr$tdd}n	t|
 d}n|
}| |dd|ddf |	durA|	| nd}| }d}|durg|dddf  }| jd|d| jj|d|}t	|||j
|j|jdS )	a  
        backbone_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, backbone_hidden_size)`, *optional*):
            The last hidden state of the backbone model. Such input is required when the first codebook token (the one generated by the backbone model)
            is provided in the `input_ids` argument.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        )r  r  r   r   r*   r  r   r   r   r"   N.)r)   r9  r   shift_labels)r(   r)   r*   r+   r,   r;   )r   r   r   slicer8  r   loss_functionr`   r   r   r*   r+   r,   )rI   r  r  r   r   r*   r  r9  r   r   r:  r   outputsr+   slice_indicesr)   r(   r;  r;   r;   r<   rW   ?  sJ   	
&z"CsmDepthDecoderForCausalLM.forwardc           	         sH   t  j|||||fi |}|d d dk}|s|d |d |S )Nr   r   r  r   )rB   prepare_inputs_for_generationpop)	rI   r  r*   r   r  r   r   model_inputsis_first_generation_steprK   r;   r<   r@    s   	


z8CsmDepthDecoderForCausalLM.prepare_inputs_for_generation)
NNNNNNNNNr   r   )r3   r4   r5   _tied_weights_keys_tp_plan_pp_planrC   r   r   r7   r   r8   r\   r   r   r   r   r   r:   r   rW   r@  r]   r;   r;   rK   r<   r7  *  sr    		
Er7  c                       r   )r   c                    sD   t    t|j|j |j| _| jdt	
|j|j dd d S )Nr   Frc   )rB   rC   rD   r  r   r   rJ   embed_audio_tokensrl   r7   rw   r   rK   r;   r<   rC     s
   

z#CsmBackboneModelEmbeddings.__init__c                 C   s    |  || j }|jdd}|S )Nr    r   )rG  r   sum)rI   r  r  r;   r;   r<   rW     s   z"CsmBackboneModelEmbeddings.forwardr   r;   r;   rK   r<   r     s    r   c                       s   e Zd Z fddZeee							ddejdB dej	dB dejdB de
dB dejdB d	ejdB d
edB dee defddZ  ZS )CsmBackboneModelc                    sv   t     j| _ j| _t | _t fddt	 j
D | _t j jd| _t d| _d| _|   d S )Nc                    r  r;   r  r  r	  r;   r<   r
    r  z-CsmBackboneModel.__init__.<locals>.<listcomp>r   r	  F)rB   rC   r  r  r   r   r  rD   r  r   r  r  r>   rJ   r   r  r^   r  r  r  r   rK   r	  r<   rC     s   
zCsmBackboneModel.__init__Nr  r   r   r*   r  r   r   r   rA   c              
   K   s   |du |duA rt d|du r| |}|r!|du r!t| jd}|du r<|dur-| nd}	tj|jd |jd|	 }|du rE|	d}t
| j|||||d}
|}| j||d}| jd| jj D ]}||f|
|||||d	|}qb| |}t||d
S )a&  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
            1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
            requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.

            2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        Nz:You must specify exactly one of input_ids or inputs_embedsr	  r   r"   r  r  r  )r   r   r   r*   r   r   r   )r$  r  r	   r`   r%  r7   rw   rX   rn   r   r   r  r  r  r  r   )rI   r  r   r   r*   r  r   r   r   r(  r-  r+   r   r.  r;   r;   r<   rW     sP   

	
zCsmBackboneModel.forward)NNNNNNN)r3   r4   r5   rC   r   r   r   r7   r   r\   r   r8   r   r   r   r   rW   r]   r;   r;   rK   r<   rI    s>    	
rI  z
    The Csm model consists of two llama-like auto-regressive transformer models: a backbone model that predicts the first codebook token and a depth decoder that predicts the other codebook tokens.
    c                       s~  e Zd ZddiZ fddZdd Zdd Ze fd	d
Z fddZ					d"de
jdB de
jdB de
jdB de
jdB de
jdB f
ddZ				d"de
jdedB de
jdB de
jdB de
jdB f
 fddZee											d#de
jdB de
jdB de
jdB de
jdB de
jdB dedB de
jdB de
jdB dedB de
jdB dee
jB dee deeB fd d!Z  ZS )$CsmForConditionalGenerationz5backbone_model.embed_tokens.embed_audio_tokens.weightz'depth_decoder.model.embed_tokens.weightc                    sp   t  | |j| _tj|j|jdd| _t|j|j| _	t
|| _t|j| _t|j| _|   d S )NFr   )rB   rC   r   rD   r   rJ   lm_headr  text_vocab_sizeembed_text_tokensrI  _from_configbackbone_modelr7  depth_decoder_configdepth_decoderr!   from_configcodec_configcodec_modelr  r   rK   r;   r<   rC     s   z$CsmForConditionalGeneration.__init__c                 C   s   | j jS r   rO  r  rY   r;   r;   r<   get_input_embeddings  s   z0CsmForConditionalGeneration.get_input_embeddingsc                 C   s   || j _d S r   rU  )rI   r   r;   r;   r<   set_input_embeddings  s   z0CsmForConditionalGeneration.set_input_embeddingsc                    s   | ddrt j|i |\}}n	t j|i |}d t  fddt|j D }t|jjddi| |D ]
}t	|j |  q?d|v rR||fS |S )Noutput_loading_infoFdepth_decoder_c                    s(   i | ]\}}|  r|d  |qS r   )
startswith)r  attrr   prefix
prefix_lenr;   r<   
<dictcomp>(  s    z?CsmForConditionalGeneration.from_pretrained.<locals>.<dictcomp>_from_model_config)
getrB   from_pretrainedlenvarsgeneration_configitemsrQ  r   delattr)clsargsr   r   loading_infodepth_decoder_attrsr[  rK   r\  r<   rb    s   z+CsmForConditionalGeneration.from_pretrainedc                    sV   d}| j j }|dd  | D ]\}}t| j|| | qt j|i | d S )NrY  transformers_version)rQ  re  to_diff_dictrA  rf  setattrrB   save_pretrained)rI   ri  r   r]  rk  r[  r   rK   r;   r<   ro  9  s   z+CsmForConditionalGeneration.save_pretrainedNr  input_valuesinput_values_cutoffsr9  rA   c                    sF  |  |}|durtj|d}||dk  }||dk }tj| |jd	t
|d}||dk }t j g }t||D ]?\}	}
|
|
dk }
t|
jd d D ]+}|
| }|
|d  }|	d||f }| j|d}|jdd}||d  qUqBtdd	 |D  t fd
d|D }| j|}W d   n1 sw   Y  | jj}||k}| j|}|| ||< tjdd| jjf|jtjd| jj }| j|d}|| jj k}|!|" d||< |dur|d!dd| jj}|| ||< |||< |dkj#dd}d||d |d ddf< |}||dS )a  
        Merges the input_ids and input_values to produce a single inputs_embeds tensor:
        1 - Infers the codec model on the input_values to retrieve codebook token.
        2 - Embeds codebook tokens and places them at the correct positions in the inputs_embeds tensor.
        3 - If labels are provided, expands them to match codebook dimensions and position the target codebook tokens in the inputs_embeds tensor.

        Args:
            input_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`):
                The input ids to embed.
            input_values (`torch.Tensor` of shape `(batch_size, channels, audio_sequence_length)`):
                The audio input values to embed.
            input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`):
                The cutoffs of the audio input values relative to its batch index, padded with -1 when no audio.
        Nr"   r   r   r  rM   r"   .c                 s   s    | ]}|j d  V  qdS )r   N)rX   r  elr;   r;   r<   	<genexpr>s  s    zQCsmForConditionalGeneration._merge_input_ids_with_input_values.<locals>.<genexpr>c                    s,   g | ]}t j|d d d  |jd   fqS )r   )rD   r   padrX   rs  max_audio_framesr;   r<   r
  u  s   , zRCsmForConditionalGeneration._merge_input_ids_with_input_values.<locals>.<listcomp>rt   iTas_tuple)r  r9  )$rM  rD   r   rv  diffr7   rw   maxrn   r   rc  r   r   zipr   rX   rT  encodeaudio_codesr   appendr5  get_audio_codes_maskr`   audio_token_idrO  r  rF   r   longcodebook_eos_token_idsqueezeaudio_eos_token_idrepeatrH  nonzero)rI   r  rp  rq  r9  r  audio_lengthsinput_values_maskaudio_tokens_listbatch_input_valuesbatch_input_values_cutoffsr   	start_idxend_idxaudio_batchcodec_outputscodebook_idsbatched_audio_token_idsaudio_codes_maskr  audio_token_maskaudio_embedsaudio_eos_frame_idsaudio_eos_embedsaudio_eos_token_masklabels_expanded depth_decoder_ignore_frames_idxsr;   rw  r<   "_merge_input_ids_with_input_valuesC  s\   




z>CsmForConditionalGeneration._merge_input_ids_with_input_valuesr*   r   r  r   c           	         s   t  jd	|||||d|}|d ur>|jdkr>|dd u r>| j||d|d|dd}||d |d d d |S )
N)r  r*   r   r  r   r    r  rp  rq  r9  )r  rp  rq  r9  )r  r9  r  r;   )rB   r@  ndimra  r  r   )	rI   r  r*   r   r  r   r   rB  merged_inputsrK   r;   r<   r@    s(   	 	z9CsmForConditionalGeneration.prepare_inputs_for_generationr   r   r   r:  r   c                 K   s  |dur|j dkr| ||||}|d }|d }d}| jd||||||	|
d|}|d }t|tr:t| dn|}| |dd|ddf }d}d}d}d}|dur|dddddf }| jd||| jj	d|}|ddddddf d	kj
d
d }|| dd| jjd f }tjj|ddd}|jdd}||d |d d ddf }|| }| jd|||	d|d|}|j}|| }t|||||j|j|j|dur|jnd|dur|jnd|dur|jnd|dur|jdS ddS )a  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks) or (batch_size, sequence_length)`):
            1. (batch_size, sequence_length): corresponds to the input sequence prepared with the processor from the text prompt. Such input
            requires `input_values` to be provided so that audio can be encoded in codebook tokens and then merged with the text tokens.

            2. (batch_size, sequence_length, num_codebooks): codebook tokens generated during the autoregressive decoding. Such input is not meant to be used by end users.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        input_values_cutoffs (`torch.Tensor` of shape `(batch_size, max_num_audio)`, *optional*):
            Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input.
            If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences
            where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2,
            the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]].
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[config.audio_token_id, -100, -101]`.
            Requires targeted `input_values` to be provided as audio tokens will be inferred from it using the `codec_model`.
            - `config.audio_token_id` indicates an audio frames (considering sequence length elements as frames)
            - `-100` will be ignored in the loss computation
            - `-101` indicates the audio frame will be used only for the backbone model (using the first codebook token as labels)

            Such labels can be prepared using `output_labels=True` when calling [`CsmProcessor`].
        logits_to_keep (`int` or `torch.Tensor`, *optional*):
            Kept for compatibility. Does not support another value than:
            1. `0`, which is equivalent to keeping all logits, used in the training regime
            2. `1`, which is equivalent to keeping only the last logit, used in the generation regime

        Example:

        ```python
        >>> import torch
        >>> from transformers import CsmForConditionalGeneration, AutoProcessor
        >>> from datasets import load_dataset, Audio

        >>> model_id = "sesame/csm-1b"
        >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"

        >>> processor = AutoProcessor.from_pretrained(model_id)

        >>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
        >>> # ensure the audio is 24kHz
        >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))

        >>> conversation = []
        >>> # prepare a conversation with text and corresponding audio
        >>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
        ...     conversation.append(
        ...         {
        ...             "role": f"{speaker_id}",
        ...             "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
        ...         }
        ...     )

        >>> inputs = processor.apply_chat_template(
        ...     conversation,
        ...     tokenize=True,
        ...     return_dict=True,
        ...     output_labels=True,
        ... ).to(torch_device)

        >>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
        >>> output = model(**inputs)
        >>> output.loss.backward()
        ```Nr    r  r9  )r  r   r   r*   r  r   r   r   )r)   r9  r   r"   r{  rM   r   .rr  )r   Try  )r  r  r   return_dictr9  )r(   r2   r-   r)   r*   r+   r,   r.   r/   r0   r1   r;   )r  r  rO  r   r   r<  rK  r=  r`   r   allr   rD   r   rv  r  rQ  r(   r'   r*   r+   r,   r)   )rI   r  rp  r   rq  r   r*   r  r9  r   r   r:  r   r  backbone_outputsbackbone_hidden_statesr?  backbone_logitsr(   r2   r-   depth_decoder_outputsbackbone_labels
train_maskdepth_decoder_input_ids
train_idxsbackbone_last_hidden_statesdepth_decoder_labelsr;   r;   r<   rW     s   S
(	z#CsmForConditionalGeneration.forwardr   )NNNNNNNNNNr   )r3   r4   r5   rD  rC   rV  rW  classmethodrb  ro  r7   r\   r  r   r   r8   r@  r   r   r   r   r   r   r:   r'   rW   r]   r;   r;   rK   r<   rJ    s    

U	
rJ  )r   rI  r  r7  rJ  )r"   )r   )Scollections.abcr   dataclassesr   typingr   r7   torch.nnrD    r   r   activationsr   cache_utilsr   r	   
generationr
   integrationsr   r   r   masking_utilsr   modeling_layersr   modeling_outputsr   r   modeling_rope_utilsr   r   modeling_utilsr   r   processing_utilsr   utilsr   r   r   r   r   utils.genericr   r   utils.import_utilsr   utils.output_capturingr   autor!   configuration_csmr#   r$   generation_csmr%   
get_loggerr3   r"  r'   Moduler>   r^   r   r   r   r\   r   r   r[   r   r   r   r   r  r   r7  r   rI  rJ  __all__r;   r;   r;   r<   <module>   s   
*A
F-ki[  M