o
    i                     @   sd  d Z ddlZddlZddlmZmZ ddlZddlmZ ddlm	Z	 ddl
mZ ddlmZ dd	lmZmZmZ dd
lmZ ddlmZ ddlmZ ddlmZmZmZ ddlmZ ddlmZm Z  ddl!m"Z"m#Z#m$Z$m%Z%m&Z& ddl'm(Z( ddl)m*Z* e# rddl+m,Z, ddl-m.Z. e&/e0Z1dZ2zddl3m4Z4 dZ2e15d W n e6y   Y n e7y   e18d Y nw G dd dej9Z:e2se4Z:G dd dej9Z;G dd  d ej9Z<G d!d" d"ej9Z=G d#d$ d$ej9Z>G d%d& d&ej9Z?G d'd( d(ej9Z@G d)d* d*eZAe"G d+d, d,eZBG d-d. d.eBZCG d/d0 d0ej9ZDe"d1d2G d3d4 d4eBeZEd4d,gZFdS )5zPyTorch Pop2Piano model.    N)OptionalUnion)nn)CrossEntropyLoss)GenerationConfig   )ACT2FN)CacheDynamicCacheEncoderDecoderCache)GenerationMixin)AttentionMaskConverter)GradientCheckpointingLayer)BaseModelOutput)BaseModelOutputWithPastAndCrossAttentionsSeq2SeqLMOutput)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)auto_docstringis_torch_flex_attn_availableis_torch_fx_proxyis_torchdynamo_compilinglogging)deprecate_kwarg   )Pop2PianoConfig)	BlockMask)make_flex_block_causal_maskT)FusedRMSNormFzVDiscovered apex.normalization.FusedRMSNorm - will use it instead of Pop2PianoLayerNormzIDiscovered apex but it failed to load, falling back to Pop2PianoLayerNormc                       s&   e Zd Zd fdd	Zdd Z  ZS )Pop2PianoLayerNormư>c                    s&   t    tt|| _|| _dS )zj
        Construct a layernorm module in the Pop2Piano style. No bias and no subtraction of mean.
        N)super__init__r   	Parametertorchonesweightvariance_epsilon)selfhidden_sizeeps	__class__ m/home/ubuntu/veenaModal/venv/lib/python3.10/site-packages/transformers/models/pop2piano/modeling_pop2piano.pyr#   B   s   

zPop2PianoLayerNorm.__init__c                 C   s\   | tjdjddd}|t|| j  }| jjtj	tj
fv r)| | jj}| j| S )N   T)keepdim)tor%   float32powmeanrsqrtr(   r'   dtypefloat16bfloat16)r)   hidden_statesvariancer.   r.   r/   forwardJ   s
   
zPop2PianoLayerNorm.forward)r!   )__name__
__module____qualname__r#   r=   __classcell__r.   r.   r,   r/   r    A   s    r    c                       *   e Zd Zdef fddZdd Z  ZS )Pop2PianoDenseActDenseconfigc                    sT   t    tj|j|jdd| _tj|j|jdd| _t|j	| _
t|j | _d S NFbias)r"   r#   r   Lineard_modeld_ffwiwoDropoutdropout_ratedropoutr   dense_act_fnactr)   rD   r,   r.   r/   r#   `   s
   
zPop2PianoDenseActDense.__init__c                 C   sl   |  |}| |}| |}t| jjtjr/|j| jjjkr/| jjjtj	kr/|
| jjj}| |}|S N)rK   rQ   rO   
isinstancerL   r'   r%   Tensorr8   int8r3   )r)   r;   r.   r.   r/   r=   g   s   



zPop2PianoDenseActDense.forwardr>   r?   r@   r   r#   r=   rA   r.   r.   r,   r/   rC   _   s    rC   c                       rB   )Pop2PianoDenseGatedActDenserD   c                    sj   t    tj|j|jdd| _tj|j|jdd| _tj|j|jdd| _t	|j
| _t|j | _d S rE   )r"   r#   r   rH   rI   rJ   wi_0wi_1rL   rM   rN   rO   r   rP   rQ   rR   r,   r.   r/   r#   w   s   
z$Pop2PianoDenseGatedActDense.__init__c                 C   sz   |  | |}| |}|| }| |}t| jjtjr6|j	| jjj	kr6| jjj	tj
kr6|| jjj	}| |}|S rS   )rQ   rY   rZ   rO   rT   rL   r'   r%   rU   r8   rV   r3   )r)   r;   hidden_geluhidden_linearr.   r.   r/   r=      s   


z#Pop2PianoDenseGatedActDense.forwardrW   r.   r.   r,   r/   rX   v   s    rX   c                       rB   )Pop2PianoLayerFFrD   c                    sJ   t    |jrt|| _nt|| _t|j|jd| _	t
|j| _d S )Nr+   )r"   r#   is_gated_actrX   DenseReluDenserC   r    rI   layer_norm_epsilon
layer_normr   rM   rN   rO   rR   r,   r.   r/   r#      s   

zPop2PianoLayerFF.__init__c                 C   s&   |  |}| |}|| | }|S rS   )rb   r`   rO   )r)   r;   forwarded_statesr.   r.   r/   r=      s   

zPop2PianoLayerFF.forwardrW   r.   r.   r,   r/   r]      s    
r]   c                       sz   e Zd Z		ddedee f fddZdd ZedddZ	dddZ
edddd									dddZ  ZS )Pop2PianoAttentionFNrD   	layer_idxc                    s  t    |j| _|| _|j| _|j| _|j| _|j| _|j	| _
|j| _| j
| j | _|| _|d u r@| jr@td| jj d tj| j| jdd| _tj| j| jdd| _tj| j| jdd| _tj| j| jdd| _| jrxt| j| j
| _t | _d| _d S )NzInstantiating a decoder z without passing `layer_idx` is not recommended and will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` when creating this class.FrF   )r"   r#   
is_decoderhas_relative_attention_biasrelative_attention_num_bucketsrelative_attention_max_distancerI   d_kvkey_value_proj_dim	num_headsn_headsrN   rO   	inner_dimre   loggerwarning_oncer-   r>   r   rH   qkvo	Embeddingrelative_attention_biassetpruned_headsgradient_checkpointingr)   rD   rg   re   r,   r.   r/   r#      s.   

zPop2PianoAttention.__init__c                 C   s   t |dkrd S t|| j| j| j\}}t| j|| _t| j|| _t| j|| _t| j	|dd| _	| jt | | _| j| j | _
| j|| _d S )Nr   r   dim)lenr   rm   rk   rx   r   rq   rr   rs   rt   rn   union)r)   headsindexr.   r.   r/   prune_heads   s   zPop2PianoAttention.prune_headsT       c                 C   s   d}|r|d }|| dk tj| 7 }t| } n
t| t|  } |d }| |k }|t|  | t||  ||   tj }t|t	||d }|t
|| |7 }|S )a  
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on

        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
            max_distance: an integer

        Returns:
            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
        r   r0   r   )r3   r%   longabsmin
zeros_likelogfloatmath	full_likewhere)relative_positionbidirectionalnum_bucketsmax_distancerelative_buckets	max_exactis_smallrelative_position_if_larger.   r.   r/   _relative_position_bucket   s*   z,Pop2PianoAttention._relative_position_bucketc           
      C   s   |du r	| j jj}|du rtj|tj|ddddf }n|dddf |}tj|tj|ddddf }|| }| j|| j | j	| j
d}|  |}	|	g dd}	|	S )z%Compute binned relative position biasN)r8   device)r   r   r   )r0   r   r   r   )rv   r'   r   r%   aranger   r3   r   rf   rh   ri   permute	unsqueeze)
r)   query_length
key_lengthr   cache_positioncontext_positionmemory_positionr   relative_position_bucketvaluesr.   r.   r/   compute_bias  s    
 
zPop2PianoAttention.compute_biaspast_key_valuepast_key_values4.58new_nameversionc                 C   s  |j dd \}}|du}| |}||d| j| jdd}d}t|tr8|j	| j
}|r4|j}n|j}n|}|r>|n|}|rW|durW|rW|j| j
 j}|j| j
 j}nJ| |}| |}||d| j| jdd}||d| j| jdd}|dur|s|
nd}
|||| j
d|
i\}}|rt|trd|j| j
< t||dd}|du r|j d	 }|dur|n|
d d }| jstjd| j||f|j|jd
}| jr| jrd|_n| j|||j|
d}|dddd| dddf }|dur|ddddddd|j d	 f }|| }| jr2t|j d }d|t| j< |dd|  f }n|}||7 }t!j"j#|$ dd%|}t!j"j&|| j&| jd}|durY|| }t||}|dd' }||d| j(}| )|}||f}|	r||f }|S )z
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
        Nr0   r1   r   Fr   Tr   )r   r8   )r   r   r   r{   )ptraining)*shaperq   viewrm   rk   	transposerT   r   
is_updatedgetre   cross_attention_cacheself_attention_cachelayerskeysr   rr   rs   updater%   matmulrg   zerosr   r8   ry   r   requires_gradr   rx   r&   listboolr   
functionalsoftmaxr   type_asrO   
contiguousrn   rt   )r)   r;   maskkey_value_statesposition_biasr   layer_head_maskr   	use_cacheoutput_attentionsr   
batch_size
seq_lengthis_cross_attentionquery_statesr   curr_past_key_valuecurrent_states
key_statesvalue_statesscoresr   real_seq_lengthcausal_maskposition_bias_maskedattn_weightsattn_outputoutputsr.   r.   r/   r=     s|   






"
&


zPop2PianoAttention.forwardFN)Tr   r   )NN)	NNNNNNFFN)r>   r?   r@   r   r   intr#   r   staticmethodr   r   r   r=   rA   r.   r.   r,   r/   rd      s.    #
/rd   c                       sN   e Zd Zddee f fddZedddd								dd
dZ  ZS )Pop2PianoLayerSelfAttentionFNre   c                    s>   t    t|||d| _t|j|jd| _t	|j
| _d S )Nrg   re   r^   )r"   r#   rd   SelfAttentionr    rI   ra   rb   r   rM   rN   rO   rz   r,   r.   r/   r#     s   
z$Pop2PianoLayerSelfAttention.__init__r   r   r   r   c	              
   C   sL   |  |}	| j|	|||||||d}
|| |
d  }|f|
dd   }|S )N)r   r   r   r   r   r   r   r   r   )rb   r   rO   )r)   r;   attention_maskr   r   r   r   r   r   normed_hidden_statesattention_outputr   r.   r.   r/   r=     s   

z#Pop2PianoLayerSelfAttention.forwardr   )NNNNFFN	r>   r?   r@   r   r   r#   r   r=   rA   r.   r.   r,   r/   r     s    r   c                       sP   e Zd Zddee f fddZedddd										dd
dZ  ZS )Pop2PianoLayerCrossAttentionNre   c                    s>   t    t|d|d| _t|j|jd| _t	|j
| _d S )NFr   r^   )r"   r#   rd   EncDecAttentionr    rI   ra   rb   r   rM   rN   rO   )r)   rD   re   r,   r.   r/   r#     s   
z%Pop2PianoLayerCrossAttention.__init__r   r   r   r   Fc                 C   sP   |  |}| j|||||||||	|
d
}|| |d  }|f|dd   }|S )N)	r   r   r   r   r   r   r   r   r   r   r   )rb   r   rO   )r)   r;   r   r   r   r   r   r   r   r   r   r   r   layer_outputr   r.   r.   r/   r=     s    
z$Pop2PianoLayerCrossAttention.forwardrS   )NNNNFNFNr   r.   r.   r,   r/   r     s    r   c                       sX   e Zd Zddee f fddZedddd												
	dddZ  ZS )Pop2PianoBlockFNre   c                    s`   t    |j| _t | _| jt|||d | jr&| jt||d | jt	| d S )Nr   )re   )
r"   r#   rf   r   
ModuleListlayerappendr   r   r]   rz   r,   r.   r/   r#     s   

zPop2PianoBlock.__init__r   r   r   r   Tc                 C   s  | j d |||||	|
||d}|d }|dd  }|jtjkr@tt| t|jjd t|jj}tj	|| |d}| j
oF|d u}|r| j d ||||||	|d d |
|d	}|d }|jtjkrtt| t|jjd t|jj}tj	|| |d}||dd   }| j d |}|jtjkrtt| t|jjd t|jj}tj	|| |d}|f}|| S )Nr   )r   r   r   r   r   r   r   r   i  )r   maxr1   )r   r   r   r   r   r   r   r   )r   r8   r%   r9   r   isinfanyfinfor   clamprf   )r)   r;   r   r   encoder_hidden_statesencoder_attention_maskencoder_decoder_position_biasr   cross_attn_layer_head_maskr   r   r   return_dictr   self_attention_outputsattention_outputsclamp_valuedo_cross_attentioncross_attention_outputsr   r.   r.   r/   r=     sh   

zPop2PianoBlock.forwardr   )NNNNNNNNFFTNr   r.   r.   r,   r/   r     s     r   c                   @   sB   e Zd ZU eed< dZdZdZdZdgZ	dgZ
dd Zd	d
 ZdS )Pop2PianoPreTrainedModelrD   transformerFTr   rL   c                 C   s  | j j}t|tr|jj|d  dS t|tr'|jjjj	d|d d dS t|t
rS|jjjj	d|d d t|drO| j jsQ|jjjj	d|d d dS dS dS t|tr|jjjj	d|| j jd  d t|jdr{|jjdur{|jjj  |jjjj	d|| j jd  d t|jdr|jjdur|jjj  dS dS dS t|tr|jjjj	d|| j jd  d t|jdr|jjdur|jjj  |jjjj	d|| j jd  d t|jdr|jjdur|jjj  |jjjj	d|| j jd  d t|jdr|jjdur|jjj  dS dS dS t|tr}| j j}| j j}| j j}|jjjj	d||| d  d |jjjj	d||d  d |jjjj	d||d  d |jjjj	d||| d  d |j r|j!jjj	d||d  d dS dS dS )zInitialize the weights      ?        )r6   stdlm_head      rG   N)"rD   initializer_factorrT   r    r'   datafill_Pop2PianoConcatEmbeddingToMel	embeddingnormal_!Pop2PianoForConditionalGenerationsharedhasattrtie_word_embeddingsr   rC   rK   rI   rG   zero_rL   rJ   rX   rY   rZ   rd   rj   rl   rq   rr   rs   rt   rg   rv   )r)   modulefactorrI   rk   rm   r.   r.   r/   _init_weightsN  sR   



        
z&Pop2PianoPreTrainedModel._init_weightsc                 C   s   | j j}| j j}|d u rtdt|r1t|jd d d |}tj||dd df gdd}n|	|j}|dd df 
 |ddd f< ||d< |d u rStd||d	k| |S )
Nzoself.model.config.decoder_start_token_id has to be defined. In Pop2Piano it is usually set to the pad_token_id.r1   )r   .r{   r   ).r   z1self.model.config.pad_token_id has to be defined.)rD   decoder_start_token_idpad_token_id
ValueErrorr   r%   fullr   cat	new_zerosclonemasked_fill_)r)   	input_idsr	  r
  shifted_input_idsr.   r.   r/   _shift_right|  s      z%Pop2PianoPreTrainedModel._shift_rightN)r>   r?   r@   r   __annotations__base_model_prefixis_parallelizablesupports_gradient_checkpointing_can_compile_fullgraph_no_split_modules_keep_in_fp32_modulesr  r  r.   r.   r.   r/   r   C  s   
 .r   c                       s   e Zd Zd fdd	Zdd Z													dddZ	dd	eejd
f dejdejde	de
f
ddZed	ejdededejdejdefddZ  ZS )Pop2PianoStackNc                    sx   t    || _ j| _t fddt jD | _t	 j
 jd| _t j| _|   d| _d | _d| _d S )Nc                    s"   g | ]}t  t|d k|dqS )r   r   )r   r   ).0irD   r.   r/   
<listcomp>  s    z+Pop2PianoStack.__init__.<locals>.<listcomp>r^   F)r"   r#   embed_tokensrf   r   r   range
num_layersblockr    rI   ra   final_layer_normrM   rN   rO   	post_initmodel_parallel
device_mapry   )r)   rD   r   r,   r  r/   r#     s   

zPop2PianoStack.__init__c                 C   s
   || _ d S rS   )r   r)   new_embeddingsr.   r.   r/   set_input_embeddings     
z#Pop2PianoStack.set_input_embeddingsc           %      C   s"  |	d ur|	n| j j}	|
d ur|
n| j j}
|d ur|n| j j}|d ur$|n| j j}|d urB|d urB| jr5dnd}td| d| d|d urS| }|d|d }n|d ur`| d d }n| jrednd}td| d| d	| j	r| j
r|	rtd
 d}	|d u r| jd u rtd| |}|\}}|	du r| jstd|  d| jr|	r|d u r| j jrtt| j dt| j d}nt| j d}n| jsd }|d ur| nd}|d u rtj||| |jd}|d u rt s|| }tj|||jd}| j jr| |||t|tr|jn||
}n|d d d d d d f }|j|jd}d| t|jj }| jrW|d urW| \}}}||f}|d u rQtj||jd}| |}nd }| || j j }| || j j }|rndnd }|
rudnd }|
r| jrdnd }d }d }| !|}t"| j#D ]T\} }!||  }"||  }#|r||f }|!|||||||"|#||	|
|d}$|$d }|$d }| jr|d ur|$|
rdnd }|
r||$d f }| jr||$d f }q| $|}| !|}|r||f }|st%dd |||||fD S t&|||||dS )Ndecoder_ zYou cannot specify both zinput_ids and zinputs_embeds at the same timer1   zYou have to specify either zinput_ids or inputs_embedszZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fz<You have to initialize the model with valid token embeddingsTz)`use_cache` can only be set to `True` if z is used as a decoderr  r   r   )r8   r   r.   )r   r   r   r   r   r   r   r   r0      c                 s   s    | ]	}|d ur|V  qd S rS   r.   )r  rs   r.   r.   r/   	<genexpr>R  s    z)Pop2PianoStack.forward.<locals>.<genexpr>)last_hidden_stater   r;   
attentionscross_attentions)'rD   r   r   output_hidden_statesuse_return_dictrf   r  sizer   ry   r   ro   rp   r   is_encoder_decoderr   r
   get_seq_lengthr%   r   r   r   r&   _update_causal_maskrT   r   r3   r8   r   r   invert_attention_maskget_head_maskr"  rO   	enumerater#  r$  tupler   )%r)   r  r   r   r   r.  	head_maskcross_attn_head_maskr   r   r   r5  r   r   err_msg_prefixinput_shaper   r   past_key_values_lengthmask_seq_lengthr   encoder_batch_sizeencoder_sequence_length_encoder_hidden_shapeencoder_extended_attention_maskall_hidden_statesall_attentionsall_cross_attentionsr   r   r;   r  layer_moduler   r   layer_outputsr.   r.   r/   r=     s   










zPop2PianoStack.forwardFr   r   input_tensorr   r   r   c                 C   s:  | j jdkr|d ur|dk r|S d S | j jdkr&t|tjr$t|}|S |d ur.| nd}|d ur7|jnd}| j jdkrO|sO|sOt	j
|||| jdrOd S |j}|jd }	|r^| }
nt|tjri|jd	 n||	 d }
| j||	|
|||jd d
}| j jdkr|d ur|jjdv r|st|j}t	||}|S )Nflash_attention_2r   flex_attentionr   Fsdpa)r.  rC  is_trainingr   r1   )sequence_lengthtarget_lengthr8   r   r   )cudaxpunpu)rD   _attn_implementationr   rT   r%   rU   r   r9  is_compileabler   _ignore_causal_mask_sdpar   r8   r   get_max_cache_shape5_prepare_4d_causal_attention_mask_with_cache_positionr   typer   r   _unmask_unattended)r)   r   rO  r   r   r   past_seen_tokensusing_compilable_cacher8   rT  rU  r   	min_dtyper.   r.   r/   r:  f  sT   




z"Pop2PianoStack._update_causal_maskrT  rU  r8   r   c                 K   sD  | dur|   dkr| }|S t|j}tj||f|||jd}|dkr+tj|dd}|tj||jd|ddk9 }|ddddddf 	|ddd}| dur|
 }| jd }	|ddddddd|	f | ddddddf |j }
|
dk}
|ddddddd|	f |
||ddddddd|	f< |S )	aM  
        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

        Args:
            attention_mask (`torch.Tensor`):
                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
                `(batch_size, 1, query_length, key_value_length)`.
            sequence_length (`int`):
                The sequence length being processed.
            target_length (`int`):
                The target length: when generating with static cache, the mask should be as long as the static cache,
                to account for the 0 padding, the part of the cache that is not filled yet.
            dtype (`torch.dtype`):
                The dtype to use for the 4D attention mask.
            cache_position (`torch.Tensor`):
                Indices depicting the position of the input sequence tokens in the sequence.
            batch_size (`torch.Tensor`):
                Batch size.
        Nr0  )
fill_valuer8   r   r   )diagonalr/  r1   r   )r|   r%   r   r   r  r   triur   reshapeexpandr  r   r3   masked_fill)r   rT  rU  r8   r   r   kwargsr   rb  mask_lengthpadding_maskr.   r.   r/   r]    s,    $
6  zDPop2PianoStack._prepare_4d_causal_attention_mask_with_cache_positionrS   )NNNNNNNNNNNNN)F)r>   r?   r@   r#   r*  r=   r   r%   rU   r	   r   r:  r   r   r8   r]  rA   r.   r.   r,   r/   r    sX    
 :
Dr  c                       s(   e Zd ZdZ fddZdd Z  ZS )r   z'Embedding Matrix for `composer` tokens.c                    s"   t    tj|j|jd| _d S )N)num_embeddingsembedding_dim)r"   r#   r   ru   composer_vocab_sizerI   r   rR   r,   r.   r/   r#     s   
z&Pop2PianoConcatEmbeddingToMel.__init__c                 C   s.   || }|  |d}tj||gdd}|S )Nr   r{   )r   r   r%   r  )r)   featureindex_valueembedding_offsetindex_shiftedcomposer_embeddingr.  r.   r.   r/   r=     s   z%Pop2PianoConcatEmbeddingToMel.forward)r>   r?   r@   __doc__r#   r=   rA   r.   r.   r,   r/   r     s    r   zA
    Pop2Piano Model with a `language modeling` head on top.
    )custom_introc                *       s  e Zd Zg dZdef fddZdd Zdd Zd	d
 Z	d*de	j
dededee	j
 fddZe																		d+dee	j dee	j
 dee	j dee	j dee	j
 dee	j
 dee	j deeee	j   dee dee	j
 dee	j
 dee	j
 dee	j dee dee dee d ee d!ee	j d"eee	j
 ef f&d#d$Ze	 		%	d, fd&d'	Zde	jfd(d)Z  ZS )-r   )zencoder.embed_tokens.weightzdecoder.embed_tokens.weightzlm_head.weightrD   c                    s   t  | || _|j| _t|j|j| _t	|| _
t|}d|_d|_d|_t|| j| _t|}d|_d|_|j|_t|| j| _tj|j|jdd| _|   d S )NFTrF   )r"   r#   rD   rI   	model_dimr   ru   
vocab_sizer  r   mel_conditionercopydeepcopyrf   r   tie_encoder_decoderr  encodernum_decoder_layersr"  decoderrH   r   r%  )r)   rD   encoder_configdecoder_configr,   r.   r/   r#     s"   


z*Pop2PianoForConditionalGeneration.__init__c                 C      | j S rS   )r  r)   r.   r.   r/   get_input_embeddings     z6Pop2PianoForConditionalGeneration.get_input_embeddingsc                 C   s"   || _ | j| | j| d S rS   )r  r|  r*  r~  r(  r.   r.   r/   r*    s   z6Pop2PianoForConditionalGeneration.set_input_embeddingsc                 C   r  rS   )r|  r  r.   r.   r/   get_encoder  r  z-Pop2PianoForConditionalGeneration.get_encoderNinput_featurescomposergeneration_configr   c                 C   s   |j }||vrtdt|  d| || }tj|| jd}||jd }t	|
 }| j|||d}|dur_d||dddf   < tj|dddf dd	|gd	d
}||fS |dfS )a  
        This method is used to concatenate mel conditioner tokens at the front of the input_features in order to
        control the type of MIDI token generated by the model.

        Args:
            input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                input features extracted from the feature extractor.
            composer (`str`):
                composer token which determines the type of MIDI tokens to be generated.
            generation_config (`~generation.GenerationConfig`):
                The generation is used to get the composer-feature_token pair.
            attention_mask (``, *optional*):
                For batched generation `input_features` are padded to have the same shape across all examples.
                `attention_mask` helps to determine which areas were padded and which were not.
                - 1 for tokens that are **not padded**,
                - 0 for tokens that are **padded**.
        zPlease choose a composer from z. Composer received - r/  r   )ro  rp  rq  Nr   r1   r   )axis)composer_to_feature_tokenr  r   r   r%   tensorr   repeatr   r   r   rx  r   concatenater   )r)   r  r  r  r   r  composer_valuerq  r.   r.   r/   get_mel_conditioner_outputs  s&   &z=Pop2PianoForConditionalGeneration.get_mel_conditioner_outputsr  decoder_input_idsdecoder_attention_maskr?  decoder_head_maskr@  encoder_outputsr   r.  decoder_inputs_embedslabelsr   r   r5  r   r   returnc                 C   s  |dur|n| j j}|dur|n| j j}|
dur |dur td|dur*|
du r*|}
|du r;| j|||
||||d}n$|r_t|ts_t|d t|dkrP|d ndt|dkr[|d ndd}|d }|durt|du rt|du rt| |}| j	||||	|||||||||d}|d }| j j
r|| jd	  }| |}d}|durtd
d}||d|d|d}|s|f|dd  | }|dur|f| S |S t|||j|j|j|j|j|j|jd	S )a2
  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Pop2Piano is a model with relative position embeddings
            so you should be able to pad the inputs on both the right and the left. Indices can be obtained using
            [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for detail.
            [What are input IDs?](../glossary#input-ids) To know more on how to prepare `input_ids` for pretraining
            take a look a [Pop2Piano Training](./Pop2Piano#training).
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
            [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
            [What are decoder input IDs?](../glossary#decoder-input-ids) Pop2Piano uses the `pad_token_id` as the
            starting token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last
            `decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
            1]`:
            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
            `[0, 1]`:
            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
            labels in `[0, ..., config.vocab_size]`
        NzSBoth `inputs_embeds` and `input_features` received! Please provide only one of them)r  r   r.  r?  r   r5  r   r   r   r0   )r2  r;   r3  )r  r   r.  r   r   r   r?  r@  r   r   r5  r   r   r   r  )ignore_indexr1   )	losslogitsr   decoder_hidden_statesdecoder_attentionsr4  encoder_last_hidden_stater   encoder_attentions)rD   r   r6  r  r|  rT   r   r}   r  r~  r  rv  r   r   r   r7  r   r   r;   r3  r4  r2  )r)   r  r   r  r  r?  r  r@  r  r   r.  r  r  r  r   r   r5  r   r   r;   decoder_outputssequence_output	lm_logitsr  loss_fctoutputr.   r.   r/   r=   P  s|   5	


z)Pop2PianoForConditionalGeneration.forward	composer1c                    s   |du r| j }|jd	i | t|dstdt|j| jjkr1td| jj dt|j d| j||||d\}}t	 j
d	d|||d|S )
a  
        Generates token ids for midi outputs.

        <Tip warning={true}>

        Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
        model's default generation configuration. You can override any `generation_config` by passing the corresponding
        parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. For an overview of generation
        strategies and code examples, check out the [following guide](./generation_strategies).

        </Tip>

        Parameters:
            input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                This is the featurized version of audio generated by `Pop2PianoFeatureExtractor`.
            attention_mask:
                For batched generation `input_features` are padded to have the same shape across all examples.
                `attention_mask` helps to determine which areas were padded and which were not.
                - 1 for tokens that are **not padded**,
                - 0 for tokens that are **padded**.
            composer (`str`, *optional*, defaults to `"composer1"`):
                This value is passed to `Pop2PianoConcatEmbeddingToMel` to generate different embeddings for each
                `"composer"`. Please make sure that the composer value is present in `composer_to_feature_token` in
                `generation_config`. For an example please see
                https://huggingface.co/sweetcocoa/pop2piano/blob/main/generation_config.json .
            generation_config (`~generation.GenerationConfig`, *optional*):
                The generation configuration to be used as base parametrization for the generation call. `**kwargs`
                passed to generate matching the attributes of `generation_config` will override them. If
                `generation_config` is not provided, the default will be used, which had the following loading
                priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
                configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
                default values, whose documentation should be checked to parameterize generation.
            kwargs:
                Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
                forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
                specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
        Return:
            [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
            or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
                Since Pop2Piano is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
                [`~utils.ModelOutput`] types are:
                    - [`~generation.GenerateEncoderDecoderOutput`],
                    - [`~generation.GenerateBeamEncoderDecoderOutput`]
        Nr  z`composer_to_feature_token` was not found! Please refer to https://huggingface.co/sweetcocoa/pop2piano/blob/main/generation_config.jsonand parse a dict like that.ztconfig.composer_vocab_size must be same as the number of keys in generation_config.composer_to_feature_token! Found z vs .)r  r   r  r  )inputsr.  r   r  r.   )r  r   r  r  r}   r  rD   rn  r  r"   generate)r)   r  r   r  r  ri  r,   r.   r/   r    s:   6

z*Pop2PianoForConditionalGeneration.generatec                 C   s
   |  |S rS   )r  )r)   r  r.   r.   r/   %prepare_decoder_input_ids_from_labels/  r+  zGPop2PianoForConditionalGeneration.prepare_decoder_input_ids_from_labelsrS   )NNNNNNNNNNNNNNNNNN)Nr  N)r>   r?   r@   _tied_weights_keysr   r#   r  r*  r  r%   FloatTensorstrr   r   r  r   
LongTensor
BoolTensorrU   r>  r	   r   r   r   r=   no_gradr  r  rA   r.   r.   r,   r/   r     s    
1	
 Yr   )Grt  ry  r   typingr   r   r%   r   torch.nnr   transformers.generationr   activationsr   cache_utilsr	   r
   r   
generationr   modeling_attn_mask_utilsr   modeling_layersr   modeling_outputsr   r   r   modeling_utilsr   pytorch_utilsr   r   utilsr   r   r   r   r   utils.deprecationr   configuration_pop2pianor   !torch.nn.attention.flex_attentionr   integrations.flex_attentionr   
get_loggerr>   ro   _load_pop2piano_layer_normapex.normalizationr   infoImportError	ExceptionwarningModuler    rC   rX   r]   rd   r   r   r   r   r  r   r   __all__r.   r.   r.   r/   <module>   st   

 k&(dS  N  ?