o
    iT                     @   s  d dl Z d dlmZ d dlmZmZ d dlZd dlmZmZ ddl	m
Z
mZ ddlmZ ddlmZ dd	lmZmZmZmZ dd
lmZmZmZ ddlmZmZmZ ddlmZ ddlm Z m!Z! ddl"m#Z#m$Z$m%Z%m&Z&m'Z'm(Z(m)Z)m*Z*m+Z+ ddl,m-Z-m.Z.m/Z/m0Z0m1Z1m2Z2 ddl3m4Z4m5Z5 e6e7Z8G dd de$Z9dd Z:dd Z;G dd dej<Z=G dd de*Z>G dd de+Z?G dd  d e#Z@G d!d" d"e&ZAG d#d$ d$e(ZBG d%d& d&e'ZCG d'd( d(e%ZDG d)d* d*e)ZEeG d+d, d,eZFG d-d. d.eFZGG d/d0 d0ej<ZHG d1d2 d2ej<ZIG d3d4 d4ej<ZJeeG d5d6 d6eZKG d7d8 d8ej<ZLG d9d: d:ej<ZMG d;d< d<e1ZNG d=d> d>e2ZOG d?d@ d@e/ZPG dAdB dBe-ZQG dCdD dDe.ZRG dEdF dFe0ZSG dGdH dHeSZTG dIdJ dJeSeZUg dKZVdS )L    N)	dataclass)OptionalUnion)Tensornn   )CacheDynamicCache)GenerationMixin)create_causal_mask)BaseModelOutputWithPast,BaseModelOutputWithPoolingAndCrossAttentionsCausalLMOutputWithPastModelOutput)ModuleUtilsMixinPreTrainedModelget_parameter_dtype)auto_docstringcan_return_tuplelogging)deprecate_kwarg)OutputRecordercheck_model_inputs   )	EsmAttentionEsmEmbeddings
EsmEncoderEsmIntermediateEsmLayer	EsmOutput	EsmPoolerEsmSelfAttentionEsmSelfOutput)LlamaAttentionLlamaDecoderLayerLlamaMLPLlamaPreTrainedModelLlamaRMSNormLlamaRotaryEmbedding   )EvollaConfigSaProtConfigc                       s   e Zd Z fddZ  ZS )EvollaSaProtEmbeddingsc                    s   t  | d | _d S N)super__init__position_idsselfconfig	__class__ ]/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/evolla/modular_evolla.pyr/   B   s   
zEvollaSaProtEmbeddings.__init__)__name__
__module____qualname__r/   __classcell__r6   r6   r4   r7   r,   A   s    r,   c                 C   s&   | j ddd\}}tj| |fddS )Nr   dim)chunktorchcat)xx1x2r6   r6   r7   rotate_half_esmH   s   rE   c                 C   s`   |d d d d d | j d d d f }|d d d d d | j d d d f }| | t| |  S )N)shaperE   )rB   cossinr6   r6   r7   apply_rotary_pos_emb_esmM   s   &&rJ   c                       sb   e Zd ZU dZejed< def fddZdddZ	d	ejd
ejde
ejejf fddZ  ZS )EvollaSaProtRotaryEmbeddingz
    Rotary position embeddings based on those in
    [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
    matrices which depend on their relative positions.
    inv_freqr>   c                    sP   t    ddtjd|dtjd |   }| d| d | _d | _d | _	d S )N      ?i'  r   r   dtyperL   )
r.   r/   r@   arangeint64floatregister_buffer_seq_len_cached_cos_cached_sin_cached)r2   r>   rL   r4   r6   r7   r/   ]   s   
$
z$EvollaSaProtRotaryEmbedding.__init__r   c                 C   s   |j | }|| jks| jj|jkrU|| _tj|j | |jd| j}t|| j}tj	||fdd
|j}| d d d d d d f | _| d d d d d d f | _| j| jfS )Ndevicer<   r=   )rG   rT   rU   rX   r@   rP   type_asrL   outerrA   torH   rI   rV   )r2   rB   seq_dimensionseq_lentfreqsembr6   r6   r7   _update_cos_sin_tablesg   s   
z2EvollaSaProtRotaryEmbedding._update_cos_sin_tablesqkreturnc                 C   sJ   | j |dd\| _| _t|| j| jj|jdt|| j| jj|jdfS )NrF   )r\   rN   )ra   rU   rV   rJ   r[   rO   )r2   rb   rc   r6   r6   r7   forwardw   s   z#EvollaSaProtRotaryEmbedding.forward)r   )r8   r9   r:   __doc__r@   r   __annotations__intr/   ra   tuplere   r;   r6   r6   r4   r7   rK   T   s   
 


.rK   c                   @   s   e Zd ZdddZdS )EvollaSaProtSelfAttentionNFc                 C   s:  t j|  || _|j|j dkr#t|ds#td|j d|j d|j| _t|j|j | _	| j| j	 | _
t |j| j
| _t |j| j
| _t |j| j
| _|j| _|p]t|dd| _d | _| jdksl| jd	kr~|j| _t d
|j d | j	| _n| jdkrt| j	d| _|j| _|| _d| _| jo| | _d S )Nr   embedding_sizezThe hidden size (z6) is not a multiple of the number of attention heads ()position_embedding_typeabsoluterelative_keyrelative_key_queryr   r)   rotaryr=   rM   )r   Moduler/   r3   hidden_sizenum_attention_headshasattr
ValueErrorrh   attention_head_sizeall_head_sizeLinearquerykeyvalueattention_probs_dropout_probdropoutgetattrrm   rotary_embeddingsmax_position_embeddings	Embeddingdistance_embeddingrK   
is_decoder	layer_idxscaling	is_causal)r2   r3   rm   r   is_cross_attentionr6   r6   r7   r/      s8   

z"EvollaSaProtSelfAttention.__init__)NNF)r8   r9   r:   r/   r6   r6   r6   r7   rj      s    rj   c                   @      e Zd ZdS )EvollaSaProtSelfOutputNr8   r9   r:   r6   r6   r6   r7   r          r   c                   @   r   )EvollaSaProtAttentionNr   r6   r6   r6   r7   r      r   r   c                   @   r   )EvollaSaProtIntermediateNr   r6   r6   r6   r7   r      r   r   c                   @   r   )EvollaSaProtOutputNr   r6   r6   r6   r7   r      r   r   c                   @   r   )EvollaSaProtLayerNr   r6   r6   r6   r7   r      r   r   c                   @   r   )EvollaSaProtEncoderNr   r6   r6   r6   r7   r      r   r   c                   @   r   )EvollaSaProtPoolerNr   r6   r6   r6   r7   r      r   r   c                   @   sT   e Zd ZU eed< dgZdZdZdZe	e
edddge
edddgdZd	d
 ZdS )EvollaSaProtPreTrainedModelr3   r   Tr)   	attention)index
layer_namecrossattention)hidden_states
attentionscross_attentionsc                 C   s   | j j}t|tjr"|jjjd|d |jdur |jj	  dS dS t|tj
rC|jjjd|d |jdurA|jj|j 	  dS dS t|tjrX|jj	  |jjd dS dS )zInitialize the weights        meanstdNrM   )r3   initializer_range
isinstancer   ry   weightdatanormal_biaszero_r   padding_idx	LayerNormfill_r2   moduler   r6   r6   r7   _init_weights   s   

z)EvollaSaProtPreTrainedModel._init_weightsN)r8   r9   r:   r+   rg   _no_split_modules_supports_flash_attn_supports_sdpa_supports_attention_backendr   r   rj   _can_record_outputsr   r6   r6   r6   r7   r      s   
 r   c                       s   e Zd Zdef fddZdd Zdd Zdd	 Ze	
dde	e
j de	e
j deee
j ef fddZ	
	
ddedee de	e
j de	e
j def
ddZ  ZS )EvollaSaProtProteinEncoderr3   c                    s$   t  | t|| _t|| _d S r-   )r.   r/   r,   
embeddingsr   encoderr1   r4   r6   r7   r/      s   
z#EvollaSaProtProteinEncoder.__init__c                 C   s   | j jS r-   r   word_embeddingsr2   r6   r6   r7   get_input_embeddings   s   z/EvollaSaProtProteinEncoder.get_input_embeddingsc                 C   s   || j _d S r-   r   r2   r|   r6   r6   r7   set_input_embeddings      z/EvollaSaProtProteinEncoder.set_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr   layerr   prune_heads)r2   heads_to_pruner   headsr6   r6   r7   _prune_heads   s   z'EvollaSaProtProteinEncoder._prune_headsN	input_idsattention_maskrd   c                 C   sv   |  }|\}}|j}|d u rtj||f|d}| j||d}| ||}| j||d}	|	d }
t|
|	j|	j	|	j
dS )NrW   r   r   )r   r   )last_hidden_stater   r   r   )sizerX   r@   onesr   get_extended_attention_maskr   r   r   r   r   )r2   r   r   input_shape
batch_size
seq_lengthrX   inputs_embedsextended_attention_maskencoder_outputssequence_outputr6   r6   r7   re      s   z"EvollaSaProtProteinEncoder.forwardr   rX   rO   c                 C   s   |du rt | }| dkr| jjs|durtdt | dkr1|dddddddf }n+| dkrP| jjrCt|||}n|ddddddf }nt	d| d|j
 d|j|d}d	| t|j }|S )
a  
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            attention_mask (`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (`Tuple[int]`):
                The shape of the input to the model.

        Returns:
            `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
        Nr   zNThe `device` argument is deprecated and will be removed in v5 of Transformers.r   z!Wrong shape for input_ids (shape z) or attention_mask (shape rl   rN   rM   )r   r>   r3   r   warningswarnFutureWarningr   *create_extended_attention_mask_for_decoderrv   rG   r[   r@   finfomin)r2   r   r   rX   rO   r   r6   r6   r7   r     s*   	z6EvollaSaProtProteinEncoder.get_extended_attention_maskr-   )NN)r8   r9   r:   r+   r/   r   r   r   r   r   r@   r   r   ri   r   re   rh   rX   rO   r   r;   r6   r6   r4   r7   r      s6    r   c                       s&   e Zd Zd fdd	Zdd Z  ZS )!EvollaSequenceCompressorAttention@      c                    sx   t    |d | _|| _|| }t|| _t|| _tj||dd| _	tj||d dd| _
tj||dd| _d S )N      Fr   r   )r.   r/   scaler   r   r   
norm_medianorm_latentsry   to_qto_kvto_out)r2   r>   dim_headr   	inner_dimr4   r6   r7   r/   G  s   

z*EvollaSequenceCompressorAttention.__init__c                 C   s  |  |}| |}| j}| |}tj||fdd}| |jddd\}}||	d|	d|d
dddd}||	d|	d|d
dddd}||	d|	d|d
dddd}|| j }t||dd}	|	|	jddd	  }	|	j\}
}}}t|||j}|d
d
d
d
d
d
f }|d
d
d
d
d
d
f }|| }|	d|  d}	|	jdd}t||}|
dddd}||	d|	dd}| |S )z
        Args:
            x (torch.Tensor): image features
                shape (b, n1, D)
            latent (torch.Tensor): latent features
                shape (b, n2, D);  n2: num of latent tokens
        rF   r=   r   r<   r   r)   r   Tr>   keepdimNg     )r   r   r   r   r@   rA   r   r?   viewr   permuter   matmul	transposeamaxdetachrG   r   r[   rX   masked_fillboolsoftmaxreshaper   )r2   rB   latentsmaskhrb   kv_inputrc   vsimbsnhskdokdr   mask_expones_expattnoutr6   r6   r7   re   T  s2   




(((

z)EvollaSequenceCompressorAttention.forward)r   r   r8   r9   r:   r/   re   r;   r6   r6   r4   r7   r   F  s    r   c                       s&   e Zd Zd fdd	Zdd Z  ZS )EvollaFeedForward   c                    sT   t    t|| }t|| _tj||dd| _t | _	tj||dd| _
d S NFr   )r.   r/   rh   r   r   normry   fc1GELU
activationfc2)r2   r>   multr   r4   r6   r7   r/     s   

zEvollaFeedForward.__init__c              	   C   s   |  | | | |S r-   )r  r  r  r   )r2   rB   r6   r6   r7   re     s   zEvollaFeedForward.forward)r   r   r6   r6   r4   r7   r     s    	r   c                       s*   e Zd Zdef fddZdd Z  ZS )!EvollaSequenceCompressorResamplerr3   c              
      s   t    |jj}|j| _tjt	| j|dd| _
tg | _t|jD ]}| jtt||j|jdt||jdg q%t|j| _t||j| _d S )NT)requires_grad)r>   r   r   )r>   r  )r.   r/   protein_encoder_configrs   resampler_num_latentsnum_latentsr   	Parameterr@   randnr   
ModuleListlayersrangeresampler_depthappendr   resampler_dim_headresampler_headsr   resampler_ff_multr   r   ry   protein_projector)r2   r3   protein_repr_dim_r4   r6   r7   r/     s"   

z*EvollaSequenceCompressorResampler.__init__c                 C   s   |j d }|j \}}t|| j|j}tj||fdd}t|| jj}| jd  |ddd }||j	}| j
D ]\}	}
|	|||| }|
|| }q=| |}| |S )Nr   r)   r=   r<   )rG   r@   r   r
  r[   rX   rA   r   r   rO   r  r  r   )r2   embedsr   br   r  latent_maskr   r   r   fftransformed_featurer6   r6   r7   re     s   



z)EvollaSequenceCompressorResampler.forward)r8   r9   r:   r*   r/   re   r;   r6   r6   r4   r7   r    s    r  c                   @   sf   e Zd ZU dZeej ed< dZeej ed< dZ	ee
ejdf  ed< dZee
ejdf  ed< dS )EvollaProteinEncoderModelOutputNsequence_compressor_outputr   .r   r   )r8   r9   r:   r  r   r@   FloatTensorrg   r   r   ri   r   r6   r6   r6   r7   r    s
   
 r  c                       s<   e Zd Zdef fddZedejdejfddZ	  Z
S )EvollaProteinEncoderr3   c                    s(   t    t|jd| _t|d| _d S )Nr3   )r.   r/   r   r  modelr  sequence_compressor_resamplerr1   r4   r6   r7   r/     s   
zEvollaProteinEncoder.__init__r   r   c                 K   s.   | j ||d}|j}| ||}t||jdS )Nr   )r  r   )r"  r   r#  r  )r2   r   r   kwargsprotein_outputprotein_embedssequence_reprr6   r6   r7   re     s   zEvollaProteinEncoder.forward)r8   r9   r:   r*   r/   r   r@   
LongTensorr  re   r;   r6   r6   r4   r7   r     s     r   c                       sl   e Zd Z			ddee dee dee f fddZdd Zed	d
dd							dddZ  Z	S )#EvollaSequenceAlignerCrossAttentionNprotein_encoder_dimstructure_encoder_dimmsa_encoder_dimc                    st  t    |j| _|j| _| jd | _t| j| j | _| j| j | _|j}|j	}|j
}t| j| j| _|d urJt|| j| _t|| j| _nd | _d | _|d uret|| j| _t|| j| _nd | _d | _|d urt|| j| _t|| j| _nd | _d | _t| j| _t|| _tj| j| j|d| _t| j|| _ttdg| _ttdg| _d S )Nr   r   r   ) r.   r/   rs   rt   r   rh   rw   rx   $aligner_attention_probs_dropout_probaligner_enable_biasaligner_ffn_multr   ry   rz   key_proteinvalue_proteinkey_structurevalue_structurekey_msa	value_msaEvollaRMSNormattention_normDropoutr~   out_projr   r  r  r@   tensorgate_attentiongate_ffw)r2   r3   r*  r+  r,  r}   enable_biasffn_multr4   r6   r7   r/     s>   
z,EvollaSequenceAlignerCrossAttention.__init__c	                 C   s  |||g}	dd |	D }	|	st dtj|	dd}	| |}
| |
}
| jdur=| jdur=||}| |}| |}nd}d}| jdur[| j	dur[||}| |}| 	|}nd}d}| j
dury| jdury||}| 
|}| |}nd}d}|||g}dd |D }tj|dd}|||g}dd |D }tj|dd}|
 dd	 | j| jf }|
j| d
ddd}
| dd	 | j| jf }|j| d
ddd}| dd	 | j| jf }|j| d
ddd}|
| j }
|du rt|d
|d|j}|ddddddf |	ddddddf  }t|
|d	d}||jd	dd  }|d|  t|jj}tjd	d|}t||}|d
ddd }| dd | j f }|j| }| !|}|S )z
        query_states: text
        key_value_states: protein
        query_states: [bs, query_seq_len, dim]
        key_value_states: [bs, kv_seq_len, dim]
        query_attn_mask: [bs, query_seq_len]
        kv_attn_mask: [bs, kv_seq_len]
        c                 S      g | ]}|d ur|qS r-   r6   .0r  r6   r6   r7   
<listcomp>      zGEvollaSequenceAlignerCrossAttention.cross_attention.<locals>.<listcomp>z=At least one modality should be provided for cross attention.r)   r=   Nc                 S   r?  r-   r6   r@  r6   r6   r7   rB  C  rC  c                 S   r?  r-   r6   r@  r6   r6   r7   rB  G  rC  r<   r   r   r   rF   Tr   )"rv   r@   rA   r7  rz   r0  r1  r[   r2  r3  r4  r5  r   rt   rw   r   r   r   r   rX   r   r   r   r   r   r   r   rO   r   r   Softmax
contiguousrx   r9  )r2   query_statesprotein_key_value_statesstructure_key_value_statesmsa_key_value_statesquery_attn_maskprotein_kv_attn_maskstructure_kv_attn_maskmsa_kv_attn_maskkv_attn_maskquery_layerkey_layer_proteinvalue_layer_proteinkey_layer_structurevalue_layer_structurekey_layer_msavalue_layer_msa	key_layervalue_layernew_query_layer_shapenew_key_layer_shapenew_value_layer_shaper   attn_weightsattention_scoresattention_probscontext_layernew_context_layer_shaper6   r6   r7   cross_attention  s|   












 0

z3EvollaSequenceAlignerCrossAttention.cross_attentionpast_key_valuepast_key_values4.58new_nameversionc              
   C   s  |d ur&|j \}}}|d u r%t|||	j|	j||fdj |j}nd }|d urN|j \}}}|d u rMt|||	j|
j||fdj |j}nd }|d urv|j \}}}|d u rut|||	j|j||fdj |j}nd }|}|d ur| s|d ur| s|d ur| r|}| j||||||||d}t	| j
| }|| }|}| |t	| j }|| }|S )N)r   )rF  rG  rH  rI  rJ  rK  rL  rM  )rG   r@   r   r[   rX   expandTanyr`  tanhr;  r  r<  )r2   rF  protein_kv_statesstructure_kv_statesmsa_kv_statesrJ  rK  rL  rM  protein_batch_maskstructure_batch_maskmsa_batch_maskrb  r   protein_kv_seq_lenr>   structure_kv_seq_lenmsa_kv_seq_lenr   residualr6   r6   r7   re   w  sf   z+EvollaSequenceAlignerCrossAttention.forward)NNNNNNNNNN)
r8   r9   r:   r   rh   r/   r`  r   re   r;   r6   r6   r4   r7   r)    s*    3pr)  c                   @   r   )r6  Nr   r6   r6   r6   r7   r6    r   r6  c                   @   r   )EvollaRotaryEmbeddingNr   r6   r6   r6   r7   rv    r   rv  c                   @   r   )	EvollaMLPNr   r6   r6   r6   r7   rw    r   rw  c                   @   r   )EvollaAttentionNr   r6   r6   r6   r7   rx    r   rx  c                       s   e Zd Zdedef fddZedddd							
																ddejde	ejejf de
ej de
ej de
e de
e de
ej de
ej de
ej de
ej de
ej de
ej de
ej de
ej fddZ  ZS )EvollaDecoderLayerr3   r   c                    sD   t  || |d t|j|j d dkr t||jd| _d S d S )Nr)   r   )r*  )r.   r/   maxnum_hidden_layersaligner_num_add_layersr)  rs   adapter)r2   r3   r   r4   r6   r7   r/     s   zEvollaDecoderLayer.__init__ra  rb  rc  rd  NFr   position_embeddingsr   r0   	use_cachecache_positionrk  rl  rm  rn  ro  rp  rJ  c              
   K   s   |}|  |}| jd|||||||d|\}}|| }|}| |}| |}|| }t| dr?| j|||	|
||||d}|S )N)r   r   r0   rb  r  r  r~  r}  )rF  rk  rl  rm  rJ  rn  ro  rp  r6   )input_layernorm	self_attnpost_attention_layernormmlpru   r}  )r2   r   r~  r   r0   rb  r  r  rk  rl  rm  rn  ro  rp  rJ  r$  rt  r  r6   r6   r7   re     s<   





zEvollaDecoderLayer.forward)NNNFNNNNNNNN)r8   r9   r:   r*   rh   r/   r   r@   r   ri   r   r(  r   r   re   r;   r6   r6   r4   r7   ry    sX    	
ry  c                   @   s(   e Zd ZdZdZdZg dZdd ZdS )EvollaPreTrainedModelF)ry  r  r)  c                 C   sj   | j j}t| | t|tr#|j  |j  |j	j
jd d S t|tr3|jjjd|d d S d S )NrM   r   r   )r3   r   r   r   r   r)  r;  r   r<  r7  r   r   r   r  r   r   r   r6   r6   r7   r     s   



z#EvollaPreTrainedModel._init_weightsN)r8   r9   r:   r   _supports_flex_attnr   r   r   r6   r6   r6   r7   r    s    r  c                !       s   e Zd Zdef fddZdd Zdd Zee													dd	e	e
j d
e	e
j de	e
j de	e de	e
j de	e de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j deeef fddZ  ZS )EvollaModelr3   c                    s   t     j| _ j| _t| j j| j| _t	 d| _
t fddt jD | _t j jd| _t d| _t dd| _|   d S )Nr!  c                    s   g | ]}t  |d qS ))r3   r   )ry  )rA  r   r!  r6   r7   rB  1  s    z(EvollaModel.__init__.<locals>.<listcomp>)epsgradient_checkpointingF)r.   r/   pad_token_idr   
vocab_sizer   r   rs   embed_tokensr   protein_encoderr  r  r{  r  r6  rms_norm_epsr   rv  
rotary_embr   r  	post_initr1   r4   r!  r7   r/   *  s   

zEvollaModel.__init__c                 C   s   | j S r-   r  r   r6   r6   r7   r   ?  s   z EvollaModel.get_input_embeddingsc                 C   s
   || _ d S r-   r  r   r6   r6   r7   r   B     
z EvollaModel.set_input_embeddingsNr   r   r0   rb  r   r  r  protein_input_idsprotein_attention_maskstructure_feats	msa_featsro  rp  rd   c                 K   sJ  |du |duA rt d|du r| |}|r!|du r!t| jd}|du r=|dur-| nd}tj|||jd  |jd}|du rF|	d}d}d}|durj|	durj| j
||	d}|j}tjdg|jd  |jd}t| j||||d	}|}| ||}| jD ]}||f||||||||
|||||d
|}q| |}t||d}|S )a;  
        protein_input_ids (torch.LongTensor):
            The input IDs for the protein sequence in structure-aware tokens. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`.
        protein_attention_mask (torch.Tensor):
            The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`.
        structure_feats (torch.FloatTensor):
            The input IDs for purely structure-based features. Should be of shape `(batch_size, structure_seq_length, structure_feat_dim)` and type `torch.FloatTensor`. Dummy input for now.
        msa_feats (torch.FloatTensor):
            The input IDs for purely MSA-based features. Should be of shape `(batch_size, msa_seq_length, msa_feat_dim)` and type `torch.FloatTensor`. Dummy input for now.
        structure_batch_mask (torch.Tensor):
            The batch mask to decide which protein sequences are purely structure-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `structure_feats`. Dummpy input for now.
        msa_batch_mask (torch.Tensor):
            The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now.
        Nz:You must specify exactly one of input_ids or inputs_embedsr!  r   r)   rW   r   T)r3   input_embedsr   r  rb  )r   r0   rb  r  r  r~  rk  rl  rm  rn  ro  rp  rJ  )r   rb  )rv   r  r	   r3   get_seq_lengthr@   rP   rG   rX   	unsqueezer  r  r:  r   r  r  r   r   )r2   r   r   r0   rb  r   r  r  r  r  r  r  ro  rp  r$  past_seen_tokensprotein_featsrn  protein_outputscausal_maskr   r~  decoder_layeroutputr6   r6   r7   re   E  sr   !



zEvollaModel.forward)NNNNNNNNNNNNN)r8   r9   r:   r*   r/   r   r   r   r   r   r@   r(  r   r   r  r   r   ri   r   re   r;   r6   r6   r4   r7   r  )  s`    	

r  c                       s   e Zd Z fddZdd Zdd Zee							ddee	j
 d	ee	j d
ee	j dee	j
 dee	j
 dee	j dee fddZ  ZS )EvollaForProteinText2Textc                    s@   t  | t|| _|j| _tj|j| jdd| _| 	  d S r   )
r.   r/   r  r"  r  r   ry   rs   lm_headr  r1   r4   r6   r7   r/     s
   
z"EvollaForProteinText2Text.__init__c                 C   s
   | j  S r-   )r"  r   r   r6   r6   r7   r     r  z.EvollaForProteinText2Text.get_input_embeddingsc                 C   s   | j |S r-   )r"  r   r   r6   r6   r7   r     r   z.EvollaForProteinText2Text.set_input_embeddingsNr   r   r   labelsr  r  r  c              	   K   sr   | j d||||||d|}	|	d }
| |
}d}|dur+| jd||| jd|}t|||	j|	j|	jd}|S )a,  
        protein_input_ids (torch.LongTensor):
            The input IDs for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`.
        protein_attention_mask (torch.Tensor):
            The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`.

        Example:

        ```python
        >>> from transformers import EvollaProcessor, EvollaForProteinText2Text
        >>> model = EvollaForProteinText2Text.from_pretrained("westlake/Evolla-10B-hf")
        >>> processor = EvollaProcessor.from_pretrained("westlake/Evolla-10B-hf")

        >>> protein_information = {
            "aa_seq": "your amino acid sequence",
            "foldseek": "your foldseek sequence",
        }
        >>> question = "What is the function of this protein?"
        >>> message = [
            {"role": "system", "content": "You are an AI expert that can answer any questions about protein."},
            {"role": "user", "content": question},
        ]

        >>> inputs = processor(proteins=[protein_information], messages_list=[message], return_tensors="pt", padding="longest")
        >>> outputs = model.generate(**inputs)

        >>> print(processor.batch_decode(outputs, skip_special_tokens=True))
        ```)r   r   r   r  r  r  r   N)logitsr  r  )lossr  rb  r   r   r6   )r"  r  loss_functionr  r   rb  r   r   )r2   r   r   r   r  r  r  r  r$  outputsr   r  r  
lm_outputsr6   r6   r7   re     s.   *	
z!EvollaForProteinText2Text.forwardru  )r8   r9   r:   r/   r   r   r   r   r   r@   r(  r   r  r   re   r;   r6   r6   r4   r7   r    s8    r  )r  r  r  )Wr   dataclassesr   typingr   r   r@   r   r   cache_utilsr   r	   
generationr
   masking_utilsr   modeling_outputsr   r   r   r   modeling_utilsr   r   r   utilsr   r   r   utils.deprecationr   utils.genericr   r   esm.modeling_esmr   r   r   r   r   r   r    r!   r"   llama.modeling_llamar#   r$   r%   r&   r'   r(   configuration_evollar*   r+   
get_loggerr8   loggerr,   rE   rJ   rr   rK   rj   r   r   r   r   r   r   r   r   r   r   r   r  r  r   r)  r6  rv  rw  rx  ry  r  r  r  __all__r6   r6   r6   r7   <module>   sf   , 
,$f:* pB S