o
    wi8                    @   sX  d Z ddlZddlZddlmZmZ ddlZddlmZ ddlm	Z	 ddl
mZ ddlmZ dd	lmZmZmZ dd
lmZ ddlmZ ddlmZ ddlmZmZmZ ddlmZ ddlmZm Z  ddl!m"Z"m#Z#m$Z$m%Z%m&Z& ddl'm(Z( e# rddl)m*Z* ddl+m,Z, e&-e.Z/dZ0zddl1m2Z2 dZ0e/3d W n e4y   Y n e5y   e/6d Y nw G dd dej7Z8e0se2Z8G dd dej7Z9G dd dej7Z:G d d! d!ej7Z;G d"d# d#ej7Z<G d$d% d%ej7Z=G d&d' d'ej7Z>G d(d) d)eZ?e"G d*d+ d+eZ@G d,d- d-e@ZAG d.d/ d/ej7ZBe"d0d1G d2d3 d3e@eZCd3d+gZDdS )4zPyTorch Pop2Piano model.    N)OptionalUnion)nn)CrossEntropyLoss)GenerationConfig   )ACT2FN)CacheDynamicCacheEncoderDecoderCache)GenerationMixin)AttentionMaskConverter)GradientCheckpointingLayer)BaseModelOutput)BaseModelOutputWithPastAndCrossAttentionsSeq2SeqLMOutput)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)auto_docstringis_torch_flex_attn_availableis_torch_fx_proxyis_torchdynamo_compilinglogging   )Pop2PianoConfig)	BlockMask)make_flex_block_causal_maskT)FusedRMSNormFzVDiscovered apex.normalization.FusedRMSNorm - will use it instead of Pop2PianoLayerNormzIDiscovered apex but it failed to load, falling back to Pop2PianoLayerNormc                       s&   e Zd Zd fdd	Zdd Z  ZS )Pop2PianoLayerNormư>c                    s&   t    tt|| _|| _dS )zj
        Construct a layernorm module in the Pop2Piano style. No bias and no subtraction of mean.
        N)super__init__r   	Parametertorchonesweightvariance_epsilon)selfhidden_sizeeps	__class__ m/home/ubuntu/sommelier/.venv/lib/python3.10/site-packages/transformers/models/pop2piano/modeling_pop2piano.pyr"   A   s   

zPop2PianoLayerNorm.__init__c                 C   s\   | tjdjddd}|t|| j  }| jjtj	tj
fv r)| | jj}| j| S )N   T)keepdim)tor$   float32powmeanrsqrtr'   r&   dtypefloat16bfloat16)r(   hidden_statesvariancer-   r-   r.   forwardI   s
   
zPop2PianoLayerNorm.forward)r    )__name__
__module____qualname__r"   r<   __classcell__r-   r-   r+   r.   r   @   s    r   c                       *   e Zd Zdef fddZdd Z  ZS )Pop2PianoDenseActDenseconfigc                    sT   t    tj|j|jdd| _tj|j|jdd| _t|j	| _
t|j | _d S NFbias)r!   r"   r   Lineard_modeld_ffwiwoDropoutdropout_ratedropoutr   dense_act_fnactr(   rC   r+   r-   r.   r"   _   s
   
zPop2PianoDenseActDense.__init__c                 C   sl   |  |}| |}| |}t| jjtjr/|j| jjjkr/| jjjtj	kr/|
| jjj}| |}|S N)rJ   rP   rN   
isinstancerK   r&   r$   Tensorr7   int8r2   )r(   r:   r-   r-   r.   r<   f   s   



zPop2PianoDenseActDense.forwardr=   r>   r?   r   r"   r<   r@   r-   r-   r+   r.   rB   ^   s    rB   c                       rA   )Pop2PianoDenseGatedActDenserC   c                    sj   t    tj|j|jdd| _tj|j|jdd| _tj|j|jdd| _t	|j
| _t|j | _d S rD   )r!   r"   r   rG   rH   rI   wi_0wi_1rK   rL   rM   rN   r   rO   rP   rQ   r+   r-   r.   r"   v   s   
z$Pop2PianoDenseGatedActDense.__init__c                 C   sz   |  | |}| |}|| }| |}t| jjtjr6|j	| jjj	kr6| jjj	tj
kr6|| jjj	}| |}|S rR   )rP   rX   rY   rN   rS   rK   r&   r$   rT   r7   rU   r2   )r(   r:   hidden_geluhidden_linearr-   r-   r.   r<   ~   s   


z#Pop2PianoDenseGatedActDense.forwardrV   r-   r-   r+   r.   rW   u   s    rW   c                       rA   )Pop2PianoLayerFFrC   c                    sJ   t    |jrt|| _nt|| _t|j|jd| _	t
|j| _d S )Nr*   )r!   r"   is_gated_actrW   DenseReluDenserB   r   rH   layer_norm_epsilon
layer_normr   rL   rM   rN   rQ   r+   r-   r.   r"      s   

zPop2PianoLayerFF.__init__c                 C   s&   |  |}| |}|| | }|S rR   )ra   r_   rN   )r(   r:   forwarded_statesr-   r-   r.   r<      s   

zPop2PianoLayerFF.forwardrV   r-   r-   r+   r.   r\      s    
r\   c                       sl   e Zd Z		ddedee f fddZdd ZedddZ	dddZ
									dddZ  ZS )Pop2PianoAttentionFNrC   	layer_idxc                    s  t    |j| _|| _|j| _|j| _|j| _|j| _|j	| _
|j| _| j
| j | _|| _|d u r@| jr@td| jj d tj| j| jdd| _tj| j| jdd| _tj| j| jdd| _tj| j| jdd| _| jrxt| j| j
| _t | _d| _d S )NzInstantiating a decoder z without passing `layer_idx` is not recommended and will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` when creating this class.FrE   )r!   r"   
is_decoderhas_relative_attention_biasrelative_attention_num_bucketsrelative_attention_max_distancerH   d_kvkey_value_proj_dim	num_headsn_headsrM   rN   	inner_dimrd   loggerwarning_oncer,   r=   r   rG   qkvo	Embeddingrelative_attention_biassetpruned_headsgradient_checkpointingr(   rC   rf   rd   r+   r-   r.   r"      s.   

zPop2PianoAttention.__init__c                 C   s   t |dkrd S t|| j| j| j\}}t| j|| _t| j|| _t| j|| _t| j	|dd| _	| jt | | _| j| j | _
| j|| _d S )Nr   r   dim)lenr   rl   rj   rw   r   rp   rq   rr   rs   rm   union)r(   headsindexr-   r-   r.   prune_heads   s   zPop2PianoAttention.prune_headsT       c                 C   s   d}|r|d }|| dk tj| 7 }t| } n
t| t|  } |d }| |k }|t|  | t||  ||   tj }t|t	||d }|t
|| |7 }|S )a  
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on

        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
            max_distance: an integer

        Returns:
            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
        r   r/   r   )r2   r$   longabsmin
zeros_likelogfloatmath	full_likewhere)relative_positionbidirectionalnum_bucketsmax_distancerelative_buckets	max_exactis_smallrelative_position_if_larger-   r-   r.   _relative_position_bucket   s*   z,Pop2PianoAttention._relative_position_bucketc           
      C   s   |du r	| j jj}|du rtj|tj|ddddf }n|dddf |}tj|tj|ddddf }|| }| j|| j | j	| j
d}|  |}	|	g dd}	|	S )z%Compute binned relative position biasN)r7   device)r   r   r   )r/   r   r   r   )ru   r&   r   r$   aranger   r2   r   re   rg   rh   permute	unsqueeze)
r(   query_length
key_lengthr   cache_positioncontext_positionmemory_positionr   relative_position_bucketvaluesr-   r-   r.   compute_bias
  s    
 
zPop2PianoAttention.compute_biasc                 C   s  |j dd \}}|du}| |}||d| j| jdd}|dur4|j| j}|r1|j	}n|j
}|r8|n|}|rO|durO|rO|j| j }|j| j }nE| |}| |}||d| j| jdd}||d| j| jdd}|dur|s}|
nd}
|||| jd|
i\}}|rd|j| j< t||dd}|du r|j d }|dur|n|
d d }| jstjd| j||f|j|jd	}| jr| jrd|_n| j|||j|
d
}|dddd| dddf }|dur|ddddddd|j d f }|| }| jr%t|j d }d|t| j< |dd| f }n|}||7 }tjj |! dd"|}tjj#|| j#| jd}|durL|| }t||}|dd$ }||d| j%}| &|}|||f}|	rt||f }|S )z
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
        Nr/   r0   r   r   Tr   )r   r7   )r   r   r   rz   )ptraining)'shaperp   viewrl   rj   	transpose
is_updatedgetrd   cross_attention_cacheself_attention_cache	key_cachevalue_cacherq   rr   updater$   matmulrf   zerosr   r7   rx   r   requires_gradr   rw   r%   listboolr   
functionalsoftmaxr   type_asrN   
contiguousrm   rs   )r(   r:   maskkey_value_statesposition_biaspast_key_valuelayer_head_maskr   	use_cacheoutput_attentionsr   
batch_size
seq_lengthis_cross_attentionquery_statesr   curr_past_key_valuecurrent_states
key_statesvalue_statesscoresr   real_seq_lengthcausal_maskposition_bias_maskedattn_weightsattn_outputoutputsr-   r-   r.   r<     sx   





"
&



zPop2PianoAttention.forwardFN)Tr   r   )NN)	NNNNNNFFN)r=   r>   r?   r   r   intr"   r   staticmethodr   r   r<   r@   r-   r-   r+   r.   rc      s,    #
/rc   c                       s@   e Zd Zddee f fddZ							d	ddZ  ZS )
Pop2PianoLayerSelfAttentionFNrd   c                    s>   t    t|||d| _t|j|jd| _t	|j
| _d S )Nrf   rd   r]   )r!   r"   rc   SelfAttentionr   rH   r`   ra   r   rL   rM   rN   ry   r+   r-   r.   r"     s   
z$Pop2PianoLayerSelfAttention.__init__c	              
   C   sL   |  |}	| j|	|||||||d}
|| |
d  }|f|
dd   }|S )N)r   r   r   r   r   r   r   r   r   )ra   r   rN   )r(   r:   attention_maskr   r   r   r   r   r   normed_hidden_statesattention_outputr   r-   r-   r.   r<     s   

z#Pop2PianoLayerSelfAttention.forwardr   )NNNNFFNr=   r>   r?   r   r   r"   r<   r@   r-   r-   r+   r.   r     s    r   c                       sB   e Zd Zddee f fddZ								d	ddZ  ZS )
Pop2PianoLayerCrossAttentionNrd   c                    s>   t    t|d|d| _t|j|jd| _t	|j
| _d S )NFr   r]   )r!   r"   rc   EncDecAttentionr   rH   r`   ra   r   rL   rM   rN   )r(   rC   rd   r+   r-   r.   r"     s   
z%Pop2PianoLayerCrossAttention.__init__Fc                 C   sP   |  |}| j|||||||||	|
d
}|| |d  }|f|dd   }|S )N)	r   r   r   r   r   r   r   r   r   r   r   )ra   r   rN   )r(   r:   r   r   r   r   r   r   r   r   r   r   r   layer_outputr   r-   r-   r.   r<     s    
z$Pop2PianoLayerCrossAttention.forwardrR   )NNNNFNFNr   r-   r-   r+   r.   r     s    
r   c                       sJ   e Zd Zd	dee f fddZ												d
ddZ  ZS )Pop2PianoBlockFNrd   c                    s`   t    |j| _t | _| jt|||d | jr&| jt||d | jt	| d S )Nr   )rd   )
r!   r"   re   r   
ModuleListlayerappendr   r   r\   ry   r+   r-   r.   r"     s   

zPop2PianoBlock.__init__Tc                 C   s  | j d |||||	|
||d}|d d \}}	|dd  }|jtjkrDtt| t|jjd t|jj}tj	|| |d}| j
oJ|d u}|r| j d ||||||	|d d |
|d	}|d d \}}	|jtjkrtt| t|jjd t|jj}tj	|| |d}||dd   }| j d |}|jtjkrtt| t|jjd t|jj}tj	|| |d}|f}|
r||	f | }|S || }|S )	Nr   )r   r   r   r   r   r   r   r/   i  )r   maxr   r0   )r   r   r   r   r   r   r   r   )r   r7   r$   r8   r   isinfanyfinfor   clampre   )r(   r:   r   r   encoder_hidden_statesencoder_attention_maskencoder_decoder_position_biasr   cross_attn_layer_head_maskr   r   r   return_dictr   self_attention_outputsattention_outputsclamp_valuedo_cross_attentioncross_attention_outputsr   r-   r-   r.   r<     sn   

zPop2PianoBlock.forwardr   )NNNNNNNNFFTNr   r-   r-   r+   r.   r     s    r   c                   @   s@   e Zd ZeZdZdZdZdZdZ	dgZ
dgZdd Zdd	 Zd
S )Pop2PianoPreTrainedModeltransformerFTr   rK   c                 C   s  | j j}t|tr|jj|d  dS t|tr'|jjjj	d|d d dS t|t
rS|jjjj	d|d d t|drO| j jsQ|jjjj	d|d d dS dS dS t|tr|jjjj	d|| j jd  d t|jdr{|jjdur{|jjj  |jjjj	d|| j jd  d t|jdr|jjdur|jjj  dS dS dS t|tr|jjjj	d|| j jd  d t|jdr|jjdur|jjj  |jjjj	d|| j jd  d t|jdr|jjdur|jjj  |jjjj	d|| j jd  d t|jdr|jjdur|jjj  dS dS dS t|tr}| j j}| j j}| j j}|jjjj	d||| d  d |jjjj	d||d  d |jjjj	d||d  d |jjjj	d||| d  d |j r|j!jjj	d||d  d dS dS dS )zInitialize the weights      ?        )r5   stdlm_head      rF   N)"rC   initializer_factorrS   r   r&   datafill_Pop2PianoConcatEmbeddingToMel	embeddingnormal_!Pop2PianoForConditionalGenerationsharedhasattrtie_word_embeddingsr   rB   rJ   rH   rF   zero_rK   rI   rW   rX   rY   rc   ri   rk   rp   rq   rr   rs   rf   ru   )r(   modulefactorrH   rj   rl   r-   r-   r.   _init_weightsH  sR   



        
z&Pop2PianoPreTrainedModel._init_weightsc                 C   s   | j j}| j j}|d u rtdt|r1t|jd d d |}tj||dd df gdd}n|	|j}|dd df 
 |ddd f< ||d< |d u rStd||d	k| |S )
Nzoself.model.config.decoder_start_token_id has to be defined. In Pop2Piano it is usually set to the pad_token_id.r0   )r   .rz   r   ).r   z1self.model.config.pad_token_id has to be defined.)rC   decoder_start_token_idpad_token_id
ValueErrorr   r$   fullr   cat	new_zerosclonemasked_fill_)r(   	input_idsr  r  shifted_input_idsr-   r-   r.   _shift_rightv  s      z%Pop2PianoPreTrainedModel._shift_rightN)r=   r>   r?   r   config_classbase_model_prefixis_parallelizablesupports_gradient_checkpointing_supports_cache_class_supports_static_cache_no_split_modules_keep_in_fp32_modulesr  r  r-   r-   r-   r.   r   =  s    .r   c                       s   e Zd Zd fdd	Zdd Zdd Z													ddd	Z	
ddeej	df dej	dej	de
def
ddZedej	dededejdej	defddZ  ZS )Pop2PianoStackNc                    sx   t    || _ j| _t fddt jD | _t	 j
 jd| _t j| _|   d| _d | _d| _d S )Nc                    s"   g | ]}t  t|d k|dqS )r   r   )r   r   ).0irC   r-   r.   
<listcomp>  s    z+Pop2PianoStack.__init__.<locals>.<listcomp>r]   F)r!   r"   embed_tokensre   r   r   range
num_layersblockr   rH   r`   final_layer_normrL   rM   rN   	post_initmodel_parallel
device_maprx   )r(   rC   r  r+   r  r.   r"     s   

zPop2PianoStack.__init__c                 C      | j S rR   r  r(   r-   r-   r.   get_input_embeddings     z#Pop2PianoStack.get_input_embeddingsc                 C   
   || _ d S rR   r$  r(   new_embeddingsr-   r-   r.   set_input_embeddings     
z#Pop2PianoStack.set_input_embeddingsc           )      C   s  |	d ur|	n| j j}	|
d ur|
n| j j}
|d ur|n| j j}|d ur$|n| j j}|d urB|d urB| jr5dnd}td| d| d|d urS| }|d|d }n|d ur`| d d }n| jrednd}td| d| d	| j	r| j
r|	rtd
 d}	|d u r| jd u rtd| |}|\}}|	du r| jstd|  dd}d}| jr|	s|d urt|trt|tsd}t|t }n#t|tsd}td t|}n|d u rtt t }n| jsd }|d ur| nd}|d u rtj||| |jd}|d u rt s|| }tj|||jd}| j jr0| ||||d ur+|jnd |
}n|d d d d d d f }|j|jd}d| t|jj }| jru|d uru| \}}}||f}|d u rotj||jd}| |}nd }|  || j j!}|  || j j!}|rdnd }|
rdnd }|
r| jrdnd }d }d } | "|}!t#| j$D ]k\}"}#||" }$||" }%|r||!f }|#|!||||| |$|%||	|
|d}&|	du r|&d d d |&dd   }&|&d d \}!}'|&d }| jr|d ur|&|
r dnd } |
r||&d f }| jr||&d f }q| %|!}!| "|!}!|r,||!f }|	r1|'nd }(|r9|j}(|r@|& }(|sQt'dd |!|(|||fD S t(|!|(|||dS ) Ndecoder_ zYou cannot specify both zinput_ids and zinputs_embeds at the same timer0   zYou have to specify either zinput_ids or inputs_embedszZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fz<You have to initialize the model with valid token embeddingsTz)`use_cache` can only be set to `True` if z is used as a decoderzPassing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.r   r   )r7   r   r-   )r   r   r   r   r   r   r   rR   r/      r      c                 s   s    | ]	}|d ur|V  qd S rR   r-   )r  rr   r-   r-   r.   	<genexpr>b  s    z)Pop2PianoStack.forward.<locals>.<genexpr>)last_hidden_statepast_key_valuesr:   
attentionscross_attentions))rC   r   r   output_hidden_statesuse_return_dictre   r  sizer   rx   r   rn   ro   r  rS   r	   r   r
   from_legacy_cacheget_seq_lengthr$   r   r   r   r%   _update_causal_maskr   r2   r7   r   r   invert_attention_maskget_head_maskr  rN   	enumerater  r  to_legacy_cachetupler   ))r(   r  r   r   r   r/  	head_maskcross_attn_head_maskr5  r   r   r8  r   r   err_msg_prefixinput_shaper   r   return_legacy_cachereturn_self_attention_cachepast_key_values_lengthmask_seq_lengthr   encoder_batch_sizeencoder_sequence_length_encoder_hidden_shapeencoder_extended_attention_maskall_hidden_statesall_attentionsall_cross_attentionsr   r   r:   r  layer_moduler   r   layer_outputsnext_decoder_cache
next_cacher-   r-   r.   r<     s  











zPop2PianoStack.forwardFr   r   input_tensorr   r5  r   c                 C   s:  | j jdkr|d ur|dk r|S d S | j jdkr&t|tjr$t|}|S |d ur.| nd}|d ur7|jnd}| j jdkrO|sO|sOt	j
|||| jdrOd S |j}|jd }	|r^| }
nt|tjri|jd	 n||	 d }
| j||	|
|||jd d
}| j jdkr|d ur|jjdv r|st|j}t	||}|S )Nflash_attention_2r   flex_attentionr   Fsdpa)r/  rI  is_trainingr   r0   )sequence_lengthtarget_lengthr7   r   r   )cudaxpunpu)rC   _attn_implementationr   rS   r$   rT   r   r<  is_compileabler   _ignore_causal_mask_sdpar   r7   r   get_max_cache_shape5_prepare_4d_causal_attention_mask_with_cache_positionr   typer   r   _unmask_unattended)r(   r   rW  r   r5  r   past_seen_tokensusing_compilable_cacher7   r\  r]  r   	min_dtyper-   r-   r.   r=  v  sT   




z"Pop2PianoStack._update_causal_maskr\  r]  r7   r   c                 K   sD  | dur|   dkr| }|S t|j}tj||f|||jd}|dkr+tj|dd}|tj||jd|ddk9 }|ddddddf 	|ddd}| dur|
 }| jd }	|ddddddd|	f | ddddddf |j }
|
dk}
|ddddddd|	f |
||ddddddd|	f< |S )	aM  
        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

        Args:
            attention_mask (`torch.Tensor`):
                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
                `(batch_size, 1, query_length, key_value_length)`.
            sequence_length (`int`):
                The sequence length being processed.
            target_length (`int`):
                The target length: when generating with static cache, the mask should be as long as the static cache,
                to account for the 0 padding, the part of the cache that is not filled yet.
            dtype (`torch.dtype`):
                The dtype to use for the 4D attention mask.
            cache_position (`torch.Tensor`):
                Indices depicting the position of the input sequence tokens in the sequence.
            batch_size (`torch.Tensor`):
                Batch size.
        Nr1  )
fill_valuer7   r   r   )diagonalr0  r0   r   )r{   r$   r   r   r  r   triur   reshapeexpandr	  r   r2   masked_fill)r   r\  r]  r7   r   r   kwargsr   rj  mask_lengthpadding_maskr-   r-   r.   re    s,    $
6  zDPop2PianoStack._prepare_4d_causal_attention_mask_with_cache_positionrR   )NNNNNNNNNNNNN)F)r=   r>   r?   r"   r&  r+  r<   r   r$   rT   r	   r   r=  r   r   r7   re  r@   r-   r-   r+   r.   r    sZ    
 L
Dr  c                       s(   e Zd ZdZ fddZdd Z  ZS )r   z'Embedding Matrix for `composer` tokens.c                    s"   t    tj|j|jd| _d S )N)num_embeddingsembedding_dim)r!   r"   r   rt   composer_vocab_sizerH   r   rQ   r+   r-   r.   r"     s   
z&Pop2PianoConcatEmbeddingToMel.__init__c                 C   s.   || }|  |d}tj||gdd}|S )Nr   rz   )r   r   r$   r  )r(   featureindex_valueembedding_offsetindex_shiftedcomposer_embeddingr/  r-   r-   r.   r<     s   z%Pop2PianoConcatEmbeddingToMel.forward)r=   r>   r?   __doc__r"   r<   r@   r-   r-   r+   r.   r     s    r   zA
    Pop2Piano Model with a `language modeling` head on top.
    )custom_introc                *       s  e Zd Zg dZdef fddZdd Zdd Zd	d
 Zdd Z	dd Z
dd Z	d2dejdededeej fddZe																		d3deej deej deej deej deej deej deej deeeej   deeeej   d eej deej d!eej d"eej d#ee d$ee d%ee d&ee d'eej d(eeej ef f&d)d*Ze 		+	d4 fd,d-	Zd"ejfd.d/Zd0d1 Z  ZS )5r   )zencoder.embed_tokens.weightzdecoder.embed_tokens.weightzlm_head.weightrC   c                    s   t  | || _|j| _t|j|j| _t	|| _
t|}d|_d|_d|_t|| j| _t|}d|_d|_|j|_t|| j| _tj|j|jdd| _|   d S )NFTrE   )r!   r"   rC   rH   	model_dimr   rt   
vocab_sizer   r   mel_conditionercopydeepcopyre   r   is_encoder_decoderr  encodernum_decoder_layersr  decoderrG   r   r   )r(   rC   encoder_configdecoder_configr+   r-   r.   r"   	  s"   


z*Pop2PianoForConditionalGeneration.__init__c                 C   r#  rR   )r   r%  r-   r-   r.   r&  $  r'  z6Pop2PianoForConditionalGeneration.get_input_embeddingsc                 C   s"   || _ | j| | j| d S rR   )r   r  r+  r  r)  r-   r-   r.   r+  '  s   z6Pop2PianoForConditionalGeneration.set_input_embeddingsc                 C   r(  rR   r   r)  r-   r-   r.   set_output_embeddings,  r,  z7Pop2PianoForConditionalGeneration.set_output_embeddingsc                 C   r#  rR   r  r%  r-   r-   r.   get_output_embeddings/  r'  z7Pop2PianoForConditionalGeneration.get_output_embeddingsc                 C   r#  rR   )r  r%  r-   r-   r.   get_encoder2  r'  z-Pop2PianoForConditionalGeneration.get_encoderc                 C   r#  rR   )r  r%  r-   r-   r.   get_decoder5  r'  z-Pop2PianoForConditionalGeneration.get_decoderNinput_featurescomposergeneration_configr   c                 C   s   |j }|| vrtdt|  d| || }tj|| jd}||jd }t	|
 }| j|||d}|durad||dddf   < tj|dddf dd	|gd	d
}||fS |dfS )a  
        This method is used to concatenate mel conditioner tokens at the front of the input_features in order to
        control the type of MIDI token generated by the model.

        Args:
            input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                input features extracted from the feature extractor.
            composer (`str`):
                composer token which determines the type of MIDI tokens to be generated.
            generation_config (`~generation.GenerationConfig`):
                The generation is used to get the composer-feature_token pair.
            attention_mask (``, *optional*):
                For batched generation `input_features` are padded to have the same shape across all examples.
                `attention_mask` helps to determine which areas were padded and which were not.
                - 1 for tokens that are **not padded**,
                - 0 for tokens that are **padded**.
        zPlease choose a composer from z. Composer received - r0  r   )rw  rx  ry  Nr   r0   r   )axis)composer_to_feature_tokenkeysr  r   r$   tensorr   repeatr   r   r   r  r   concatenater   )r(   r  r  r  r   r  composer_valuery  r-   r-   r.   get_mel_conditioner_outputs8  s&   &z=Pop2PianoForConditionalGeneration.get_mel_conditioner_outputsr  decoder_input_idsdecoder_attention_maskrC  decoder_head_maskrD  encoder_outputsr5  r/  decoder_inputs_embedslabelsr   r   r8  r   r   returnc                 C   s  |dur|n| j j}|dur|n| j j}|
dur |dur td|dur*|
du r*|}
|du r;| j|||
||||d}n$|r_t|ts_t|d t|dkrP|d ndt|dkr[|d ndd}|d }|durt|du rt|du rt| |}| j	||||	|||||||||d}|d }| j j
r|| jd	  }| |}d}|durtd
d}||d|d|d}|s|f|dd  | }|dur|f| S |S t|||j|j|j|j|j|j|jd	S )a`  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Pop2Piano is a model with relative position embeddings
            so you should be able to pad the inputs on both the right and the left. Indices can be obtained using
            [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for detail.
            [What are input IDs?](../glossary#input-ids) To know more on how to prepare `input_ids` for pretraining
            take a look a [Pop2Piano Training](./Pop2Piano#training).
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
            [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
            [What are decoder input IDs?](../glossary#decoder-input-ids) Pop2Piano uses the `pad_token_id` as the
            starting token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last
            `decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
            1]`:
            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
            `[0, 1]`:
            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Does the same task as `inputs_embeds`. If `inputs_embeds` is not present but `input_features` is present
            then `input_features` will be considered as `inputs_embeds`.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
            labels in `[0, ..., config.vocab_size]`
        NzSBoth `inputs_embeds` and `input_features` received! Please provide only one of them)r  r   r/  rC  r   r8  r   r   r   r/   )r4  r:   r6  )r  r   r/  r5  r   r   rC  rD  r   r   r8  r   r   r   r  )ignore_indexr0   )	losslogitsr5  decoder_hidden_statesdecoder_attentionsr7  encoder_last_hidden_stater   encoder_attentions)rC   r   r9  r  r  rS   r   r|   r  r  r   r~  r   r   r   r:  r   r5  r:   r6  r7  r4  )r(   r  r   r  r  rC  r  rD  r  r5  r/  r  r  r  r   r   r8  r   r   r:   decoder_outputssequence_output	lm_logitsr  loss_fctoutputr-   r-   r.   r<   i  s|   8	


z)Pop2PianoForConditionalGeneration.forward	composer1c                    s   |du r| j }|jd	i | t|dstdt|j| jjkr1td| jj dt|j d| j||||d\}}t	 j
d	d|||d|S )
a  
        Generates token ids for midi outputs.

        <Tip warning={true}>

        Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
        model's default generation configuration. You can override any `generation_config` by passing the corresponding
        parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. For an overview of generation
        strategies and code examples, check out the [following guide](./generation_strategies).

        </Tip>

        Parameters:
            input_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                This is the featurized version of audio generated by `Pop2PianoFeatureExtractor`.
            attention_mask:
                For batched generation `input_features` are padded to have the same shape across all examples.
                `attention_mask` helps to determine which areas were padded and which were not.
                - 1 for tokens that are **not padded**,
                - 0 for tokens that are **padded**.
            composer (`str`, *optional*, defaults to `"composer1"`):
                This value is passed to `Pop2PianoConcatEmbeddingToMel` to generate different embeddings for each
                `"composer"`. Please make sure that the composet value is present in `composer_to_feature_token` in
                `generation_config`. For an example please see
                https://huggingface.co/sweetcocoa/pop2piano/blob/main/generation_config.json .
            generation_config (`~generation.GenerationConfig`, *optional*):
                The generation configuration to be used as base parametrization for the generation call. `**kwargs`
                passed to generate matching the attributes of `generation_config` will override them. If
                `generation_config` is not provided, the default will be used, which had the following loading
                priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
                configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
                default values, whose documentation should be checked to parameterize generation.
            kwargs:
                Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
                forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
                specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
        Return:
            [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
            or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
                Since Pop2Piano is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
                [`~utils.ModelOutput`] types are:
                    - [`~generation.GenerateEncoderDecoderOutput`],
                    - [`~generation.GenerateBeamEncoderDecoderOutput`]
        Nr  z`composer_to_feature_token` was not found! Please refer to https://huggingface.co/sweetcocoa/pop2piano/blob/main/generation_config.jsonand parse a dict like that.ztconfig.composer_vocab_size must be same as the number of keys in generation_config.composer_to_feature_token! Found z vs .)r  r   r  r  )inputsr/  r   r  r-   )r  r   r   r  r|   r  rC   rv  r  r!   generate)r(   r  r   r  r  rq  r+   r-   r.   r    s:   6

z*Pop2PianoForConditionalGeneration.generatec                 C   s
   |  |S rR   )r  )r(   r  r-   r-   r.   %prepare_decoder_input_ids_from_labelsK  r,  zGPop2PianoForConditionalGeneration.prepare_decoder_input_ids_from_labelsc              	   C   s   |d u rt d |S d}|D ]M}d}|D ]}||d||jf }q|d j|d jkr@td|d j d|d j dt|t|krWtdt| dt| d||f }q|S )	NzHYou might want to consider setting `use_cache=True` to speed up decodingr-   r   z%reordered_layer_past_states[0] shape z  and layer_past_states[0] shape z mismatchedz&length of reordered_layer_past_states z! and length of layer_past_states )rn   warningindex_selectr2   r   r   r  r|   )r(   r5  beam_idxreordered_decoder_pastlayer_past_statesreordered_layer_past_stateslayer_past_stater-   r-   r.   _reorder_cacheN  s(   
z0Pop2PianoForConditionalGeneration._reorder_cacherR   )NNNNNNNNNNNNNNNNNN)Nr  N) r=   r>   r?   _tied_weights_keysr   r"   r&  r+  r  r  r  r  r$   FloatTensorstrr   r   r  r   
LongTensor
BoolTensorrT   rB  r   r   r   r<   no_gradr  r  r  r@   r-   r-   r+   r.   r     s    
1	
 Yr   )Er|  r  r   typingr   r   r$   r   torch.nnr   transformers.generationr   activationsr   cache_utilsr	   r
   r   
generationr   modeling_attn_mask_utilsr   modeling_layersr   modeling_outputsr   r   r   modeling_utilsr   pytorch_utilsr   r   utilsr   r   r   r   r   configuration_pop2pianor   !torch.nn.attention.flex_attentionr   integrations.flex_attentionr   
get_loggerr=   rn   _load_pop2piano_layer_normapex.normalizationr   infoImportError	Exceptionr  Moduler   rB   rW   r\   rc   r   r   r   r   r  r   r   __all__r-   r-   r-   r.   <module>   sr   

 f%'fS  d  i