o
    i                     @   s  d Z ddlmZmZmZ ddlZddlZddlmZ ddl	m  m
Z ddlmZmZ ddlmZ ddlmZmZ ddlmZ dd	lmZ dd
lmZmZmZ ddlmZmZ ddlm Z m!Z! ddl"m#Z#m$Z$m%Z%m&Z&m'Z'm(Z(m)Z)m*Z* ddl+m,Z,m-Z-m.Z.m/Z/m0Z0 e1e2Z3dKde4fddZ5	dLdej6de4de4de4fddZ7dej6dej8de4de9de4d ej6fd!d"Z:d#dej;fd$ej6d%e4d&e4d'e<d(e4d)ej=d e>ej6ej6f fd*d+Z?d,ej6d-ee4 d ej6fd.d/Z@G d0d1 d1e'ZAG d2d3 d3e(ZBG d4d5 d5e ZCG d6d7 d7e%ZDG d8d9 d9e)ZEG d:d; d;e&ZFeG d<d= d=e$ZGG d>d? d?eGZHG d@dA dAeGZIG dBdC dCeGZJG dDdE dEeGZKG dFdG dGeGZLG dHdI dIe#ZMg dJZNdS )Mz<Blt modular model, inheriting from Mllama where appropriate.    )CallableOptionalUnionN   )CacheDynamicCache)create_causal_mask)BaseModelOutputWithPastCausalLMOutputWithPast)ALL_ATTENTION_FUNCTIONS)Unpack)TransformersKwargsauto_docstringlogging)OutputRecordercheck_model_inputs   )Cohere2RotaryEmbeddingrotate_half)MllamaForCausalLMMllamaPreTrainedModelMllamaSelfAttentionDecoderLayerMllamaTextCrossAttentionMllamaTextMLPMllamaTextRMSNormMllamaTextSelfAttentioneager_attention_forward   )	BltConfigBltGlobalTransformerConfigBltLocalDecoderConfigBltLocalEncoderConfigBltPatcherConfigʚ;primec                 C   sD   t j|t j| jd}t j| jd | jd}|| }t j| | ddS )a  
    A polynomial rolling hash algorithm that converts sequences
    of tokens into hash values. The hash is computed as:
        hash = (token_0 * prime^0 + token_1 * prime^1 + ... + token_n * prime^n)

    The rolling hash allows the model to efficiently
    identify and encode recurring byte-level patterns in the input text.

    Args:
        token_tensor (torch.Tensor): [batch_size, seq_len, group_size] containing token IDs to hash
        prime (int): Prime number used as the base for the polynomial hash.

    Returns:
        torch.Tensor: Hash values of shape [batch_size, seq_len] where each value
                     represents the hash of the corresponding token group

    Example:
        >>> tokens = torch.tensor([[1, 2, 3], [4, 5, 6]])
        >>> hashes = rolling_polynomial_hash(tokens, prime=31)
        >>> # hash[0] = 1*31^0 + 2*31^1 + 3*31^2
        >>> # hash[1] = 4*31^0 + 5*31^1 + 6*31^2
    dtypedevicer'   dim)torchtensorint64r'   arangeshapesum)token_tensorr$   prime_tensorpowersprime_powers r6   W/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/blt/modular_blt.pyrolling_polynomial_hash9   s   r8   0u  	token_ids
group_sizemax_hashc                 C   s   t  4 | j\}}t j||d t j| jd}t j|| gdd}|d|d}t||}	|	| }
W d   |
S 1 s;w   Y  |
S )z1Hash token groups and map to range [0, max_hash].r   r%   r*   N)	r,   no_gradr0   zerosr.   r'   catunfoldr8   )r:   r;   r$   r<   
batch_sizeseq_lenpaddingpadded_tokenswindowshasheshash_valuesr6   r6   r7   byte_group_hash_functionV   s   




rH   local_encoder_tokensencoder_hash_tok_embedding$encoder_hash_byte_group_nb_functionsencoder_hash_byte_group_sizeencoder_hash_byte_group_vocabreturnc                 C   sp   g d}| | }d}t|D ]&}	||	t|  }
|D ]}t| ||
|}|||  }|||7 }|d7 }qq|S )z=Compute token embeddings enhanced with hash-based embeddings.)r#   l   21A ioYl   vt l   . l   }g l   Au l   0 l   T l   AK l   | r   r   )embed_tokensrangelenrH   )rI   local_encoderrJ   rK   rL   rM   primes
embeddingsembedding_idxfunc_nbr$   r;   hash_idsoffset_hash_idsr6   r6   r7   compute_hash_embeddingsh   s   


rY   F	patch_idsnum_patchessequence_lengthpatches_as_queriescross_attn_kr&   c                 C   s"  | j \}}| j}|r-|| }	|}
tj||ddd|||}| d|||}n"|}	|| }
| d|||}tj||ddd|||}||k}|rWdnd}|j||d}||	|
f}|j |krutd|j  d| |d}d|| }|	|tj
t|j}|S )	aR  
    Prepare cross-attention mask for patch-based attention, following mllama's robust approach.

    This function creates masks that control which patches can attend to which other patches,
    with support for query/key role swapping and cross-attention multipliers.

    Args:
        patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids.
        num_patches (int): Total number of patches.
        sequence_length (int): Length of the sequence.
        patches_as_queries (bool): If True, patches are used as queries, otherwise as keys.
        cross_attn_k (int): Cross-attention multiplier for repeating patches.
        dtype (torch.dtype): Data type for the output mask.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]:
            - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len]
    r)   r   r(   r   r*   zCross attention mask shape z doesn't match expected g      ?)r0   r'   r,   r/   	unsqueezeexpandrepeat_interleave
ValueErrortomasked_fillboolfinfomin)rZ   r[   r\   r]   r^   r&   rA   rB   r'   q_lenkv_lenq_patch_idskv_patch_idscross_attention_mask
repeat_dimexpected_shapeinverted_cross_attn_maskr6   r6   r7   #_prepare_patch_cross_attention_mask   s<   

"


rp   patch_lengthsmax_patch_lengthc                 C   s2  |du r| S |  d}g }| D ],}g }||dk D ]}| }t||\}}||g|  |r5|| q|| qtdd |D }	tj||	f| j| j	d}
t
|D ]\}}|rmtj|| j| j	d|
|dt|f< qU|
dkjdd |
jd k r|
dkjdd   d }|
ddd|f }
|
S )a  
    Splits patch lengths into smaller segments if they exceed `max_patch_length`.
    Pads the result to uniform length across the batch.

    Args:
        patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths.
        max_patch_length (int, optional): Maximum allowed length per patch.

    Returns:
        torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths.
    Nr   c                 s   s    | ]}t |V  qd S N)rQ   ).0splitsr6   r6   r7   	<genexpr>   s    z(process_patch_lengths.<locals>.<genexpr>r%   r*   r   )sizeitemdivmodextendappendmaxr,   r>   r&   r'   	enumerater-   rQ   anyr1   r0   nonzero)rq   rr   rA   	processedseqru   lengthfull_chunks	remaindermax_lenpaddedilast_nonzeror6   r6   r7   process_patch_lengths   s0   

$ r   c                   @      e Zd ZdS )BltMLPN__name__
__module____qualname__r6   r6   r6   r7   r         r   c                   @   r   )
BltRMSNormNr   r6   r6   r6   r7   r     r   r   c                   @   r   )BltRotaryEmbeddingNr   r6   r6   r6   r7   r     r   r   c                       s"   e Zd Zdef fddZ  ZS )BltTransformerLayer	layer_idxc                    sJ   t    t||d| _t|| _t|j|jd| _	t|j|jd| _
d S )N)configr   eps)super__init__BltSelfAttention	self_attnr   mlpr   hidden_sizerms_norm_epsinput_layernormpost_attention_layernormselfr   r   	__class__r6   r7   r     s
   

zBltTransformerLayer.__init__)r   r   r   intr   __classcell__r6   r6   r   r7   r     s    r   c                	       sR   e Zd Zdedef fddZ			ddejdejd	ejd
ef fddZ	  Z
S )r   r   r   c                    s   t  || d| _d S )NT)r   r   	is_causalr   r   r6   r7   r     s   
zBltSelfAttention.__init__FNhidden_statesattention_maskposition_embeddings	use_cachec              	      s    t  jd||||||d|S )N)r   r   r   r   past_key_valuescache_positionr6   )r   forward)r   r   r   r   r   r   r   kwargsr   r6   r7   r   #  s   
zBltSelfAttention.forward)FNN)r   r   r   r   r   r   r,   Tensorre   r   r   r6   r6   r   r7   r     s    	r   c                       s|   e Zd ZdZddededee f fddZ				ddej	d	eej	 d
ee
 deej	 deej dee fddZ  ZS )BltCrossAttentionz<Cross-attention module for Blt, following transformers styleNr   r   r   c                    s8   t    d| _t| j|jd| _t| j|jd| _d S )NFr   )r   r   r   r   r   r   q_normk_norm)r   r   r   r   r   r6   r7   r   ;  s   
zBltCrossAttention.__init__r   cross_attention_statesr   r   r   r   c                 K   sl  |  \}}}	| |}
| |
}
|
||| j| jdd}
|d ur`| |}| |}| 	|}||d| j
| jdd}||d| j
| jdd}|d ur_|||| jd|i\}}n|d dkrv|j| j j|j| j j}}ntdt}| jjdkrt| jj }|| |
|||f| jsdn| j| jd	|\}}|||d }| |}|| }||fS )
Nr   r   r(   r   r   z^Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!eagerg        )dropoutscaling)rw   r   q_projview	num_headshead_dim	transposer   k_projv_projnum_key_value_headsupdater   layerskeysvaluesrb   r   r   _attn_implementationr   trainingr   r   reshape
contiguouso_proj)r   r   r   r   r   r   r   bszrh   _query_states
key_statesvalue_statesattention_interfaceattn_outputattn_weightsr6   r6   r7   r   A  sR   	







zBltCrossAttention.forwardrs   NNNN)r   r   r   __doc__r   r   r   r   r,   r   r   
LongTensorr   r   r   r   r6   r6   r   r7   r   8  s(     	r   c                   @   s^   e Zd ZU eed< dZdZdZdgZe	e
ddde	eddddZd	d
 Zdd Zdd ZdS )BltPreTrainedModelr   Fr   r   local_decoderindex
layer_namer   )r   
attentionsc                 C      t dNzNo need to inherit it!AttributeErrorr   moduler6   r6   r7   _init_weights     z BltPreTrainedModel._init_weightsc                 C   r   r   r   r   r6   r6   r7   _update_causal_mask  r   z&BltPreTrainedModel._update_causal_maskc                 C   r   r   r   r   r6   r6   r7   5_prepare_4d_causal_attention_mask_with_cache_position  r   zHBltPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_positionN)r   r   r   r   __annotations___supports_attention_backend_supports_flash_attn_supports_flex_attn_no_split_modulesr   r   r   _can_record_outputsr   r   r   r6   r6   r6   r7   r   w  s   
 r   c                       s   e Zd ZU eed< deedddiZdef fddZ										dd	e	e
j d
e	e
j de	e
j de	e
j de	e
j de	e de	e
j de	e
j de	e de	e
j dee fddZdd Z  ZS )BltLocalEncoderr   encoder_attentionsr   rR   r   c                    s   t    d| _ | _t fddt jD | _t	 d| _
tj j j j dd| _t j j| _t | _ jrD jnd}t|D ]}| jt | jd qJ|   d S )NFc                       g | ]}t  |qS r6   r   rt   r   r   r6   r7   
<listcomp>      z,BltLocalEncoder.__init__.<locals>.<listcomp>r   in_featuresout_featuresbiasr   r   r   r   )r   r   gradient_checkpointingr   nn
ModuleListrP   num_hidden_layersr   r   
rotary_embLinearr   r^   patch_embedding_projection	Embedding
vocab_sizerO   cross_attn_layerscross_attn_all_layersr{   r   	post_initr   r   layers_to_addr   r   r   r7   r     s(   

zBltLocalEncoder.__init__N	input_idsinputs_embedspatch_embedsr   position_idsr   r   encoder_attention_maskr[   rZ   r   c                 K   sD  |d u r	|  |}|jd }tj|| jj| jd}|d u r/tj|jd |jd	d
|d}| ||}tj|| jj| jd}t| jD ]V\}}||f||||d|}|t| jd ksc| jjr| ||	|
}| |}|||jd | jj | jj}| jjr|nd}| j| d|||d|\}}|| }qE|}||fS )	Nr   pr   r   r)   r(   r   r   r   r   r   r   r   r6   )rO   r0   Fr   r   r   r,   r/   r'   r_   r`   r   r}   r   rQ   r  patch_reducer   r   r^   r   r   )r   r  r  r  r   r  r   r   r	  r[   rZ   r   rA   r   r   idxlayerr   cross_attention_outputr   encoder_cross_statesr6   r6   r7   r     sL   

"


zBltLocalEncoder.forwardc                 C   sz   |j d }|j d }|ddd|j d }tj|||f|j|jd}|j|d|ddd}|ddd|ddf }|S )	a  
        Reduce variable length patches to single embedding per patch
        Note: this works with variable number of patches for different sequences in the batch
        It handles variable length patches by assuming that patch_lengths will be 0 for any
        extra patches on the *right*. Since there can be a variable number of patches
        this function also return the number of patches for each sequence in the batch.
        Any embeddings on the right that are not allocated to a patch
        (i.e. if the sum(patch_lengths[i]) < seq_len for any i)
        will be sent to a dummy patch, which is trimmed before returning.
        r   r(   r%   r   amaxF)srcr+   r   reduceinclude_selfN)r0   r_   r`   r,   r>   r&   r'   scatter_reduce)r   r   max_num_patchesrZ   rA   embedding_dimreduced_embeddingsr6   r6   r7   r    s   

zBltLocalEncoder.patch_reduce
NNNNNNNNNN)r   r   r   r!   r   r   r   r   r   r   r,   r   r   r   r   r   r   r   r  r   r6   r6   r   r7   r     sN   
 	

6r   c                       s   e Zd ZU eed< def fddZe								ddeej	 deej
 deej
 deej
 d	eej	 d
ee deej	 deej
 dee fddZ  ZS )BltLocalDecoderr   c                    s   t    d| _ | _d| _t fddt jD | _	t
 d| _tj j j j dd| _t j jd| _t | _ jrG jnd}t|D ]}| jt | jd	 qM|   d S )
NFTc                    r   r6   r   r   r   r6   r7   r   	  r   z,BltLocalDecoder.__init__.<locals>.<listcomp>r   r   r   r   r   )r   r   r   r   cross_attn_decoderr   r   rP   r   r   r   r   r   hidden_size_globalr   r^   r   r   r   normr   r  r{   r   r  r  r   r   r7   r     s*   

zBltLocalDecoder.__init__Nr  r  r  r   r  r   r   r	  r   c	                 K   s  |j d }
|}| |}||
|j d | jj | jj}|d ur'| js'|| }|d u r=tj|j d |j	d
d|
d}| ||}tj|| jj| jd}t| jD ]-\}}|dks_| jjrs| j| d|||d|	\}}|| }||f||||d|	}qS| |}|S )	Nr   r   r)   r(   r
  r  r  r6   )r0   r   r   r   r^   r   r  r,   r/   r'   r_   r`   r   r  r   r   r}   r   r  r   r   )r   r  r  r  r   r  r   r   r	  r   rA   r   r   r   r  r  r   logitsr6   r6   r7   r     sF   

"


zBltLocalDecoder.forwardNNNNNNNN)r   r   r   r    r   r   r   r   r,   r   r   r   r   r   r   r   r6   r6   r   r7   r     s>   
 	
r  c                       s   e Zd ZU eed< deedddiZdef fddZ				dd	e	j
d
ee	j
 dee	j dee dee	j dee fddZ  ZS )BltGlobalTransformerr   global_attentionsr   global_transformerr   c                    s   t  | || _t | _t|jD ]}| jt	|| qt
|d| _t|dd d ur9tj|j|jdd| _nt | _|   d S )Nr   encoder_cross_output_sizeFr   )r   r   r   r   r   r   rP   r   r{   r   r   r   getattrr   r&  r   token_embedding_projectionIdentityr  r   r   r6   r7   r   U  s   



zBltGlobalTransformer.__init__Ninput_embedsr   r  r   r   r   c                 K   s   |j \}}}	| |}
tj|
| jj| jd}
|d u r,tj|j d |jd	d
|d}| |
|}t| jD ]\}}||
f||||d|}
q7|
S )Nr
  r   r)   r   r(   r  )r0   r)  r  r   r   r   r,   r/   r'   r_   r`   r   r}   r   )r   r+  r   r  r   r   r   rA   rB   r   r   r   r   r  r6   r6   r7   r   g  s&   	
"zBltGlobalTransformer.forwardr   )r   r   r   r   r   r   r   r   r   r,   r   r   r   r   r   r   r   r   r6   r6   r   r7   r#  O  s,   
 r#  c                       s   e Zd ZU eed< def fddZ										ddeej deej	 deej dee
 d	eej d
ee deej dee dee dee dee fddZe		dddZ  ZS )
BltPatcherr   c                    s   t  | t| jd| _t | _t| jj	D ]}| j
t| j| qt| jj| jj| _t| jj| jjd| _tj| jj| jjdd| _d S )Nr   r   Fr'  )r   r   r   r   r   r   r   r   rP   r   r{   r   r   r   r   rO   r   r   r   r   lm_headr   r   r6   r7   r     s   
zBltPatcher.__init__Nr  r   r  r   r  r   r   
patch_size	thresholdrr   r   c                 K   sB  |d u |d uA rt d|d u r| |}|r|d u rt }|d u r:|d ur*| nd}tj|||jd  |jd}|d u rC|d}t	| j
|||||d}|}| ||}| jD ]	}||||d}qY| | |}tjj|d }|jd d \}}|d ur| j||||	d	}ntj||f|j|jd
}t||
}|||fS )N:You must specify exactly one of input_ids or inputs_embedsr   r   r)   r   r+  r   r   r   r  )r   r   )r!  r   )	entropiesr\   r.  r/  r%   )rb   rO   r   get_seq_lengthr,   r/   r0   r'   r_   r   r   r   r   r-  r   distributionsCategoricalentropypatch_lengths_from_entropiesonesr&   r   )r   r  r   r  r   r  r   r   r.  r/  rr   r   past_seen_tokenscausal_maskr   r   r  r!  prediction_entropiesrA   r\   rq   r6   r6   r7   r     sP   

	


zBltPatcher.forwardc                 C   sP  | j d }tjddgtj| jdd|d}|j d }| ddddf } | |k}|j d }tj|| jdd|d}	t	|	|}
tj
|	|
gdd}tj
|| gdd}|| ||}|jdd }|ddd|f }tj
||| fdd}t	|ddddf |d }tj
|ddddf d |fdd}|| d }|S )z
        Computes patch lengths from token entropies.

        Depending on whether a threshold is provided, the function uses either:
        - Thresholding the entropy values (when `threshold` is set).
        r   r   r%   Nr)   r(   r*   )r0   r,   r-   longr'   r_   repeatr/   r`   	full_liker?   r   r1   r|   )r2  r\   r.  r/  rA   init_tokensoffset
patch_maskrB   token_indicessentinelpadded_indicespadded_maskpatch_startsmax_valid_patchespatch_start_ids
last_token
patch_endsrq   r6   r6   r7   r7    s&   
$

 &z'BltPatcher.patch_lengths_from_entropiesr  )NN)r   r   r   r"   r   r   r   r,   r   r   r   FloatTensorre   r   floatr   r   r   staticmethodr7  r   r6   r6   r   r7   r,    sP   
 	

Ar,  c                       s   e Zd Zdef fddZe								ddeej deej	 deej	 deej d	ee
 d
eej dee deej dee defddZdd Zdd Zdej	dedej	fddZ  ZS )BltModelr   c                    s   t  | d| _|| _t|j| _t|j| _	t
|j| _|jt|j }|j| }t||jj| _| jjrOt|j| _| j  | j D ]}d|_qHnd | _|   d S )NF)r   r   r   r   r   encoder_configrR   r#  global_configr%  r  decoder_configr   rK   rQ   rL   rM   r   r   r   rJ   patch_in_forwardr,  patcher_configpatchereval
parametersrequires_gradr  )r   r   num_embeddingstotal_vocab_sizeparamr   r6   r7   r     s"   

zBltModel.__init__Nr  rq   r   r  r   r  r   r   r   rN   c	                 K   s  |d u |d uA rt d|d ur|}
|j\}}}n|j\}}t|| j| j| jj| jj| jj}
|d u r| jj	dkr^| j
d ur^|d u rFt d| j
|| jj| jj| jj| jj|jd\}}}n%|d ure|jn|j}|d uro|jn|j}ttj||d f||d| jj}| ||}|d u r|d ur| nd}tj|||
jd  |
jd}|d u r|d}t| j|
||||d	}t||jd |d
| jj|
jd}| jd||
||||jd |d|	\}}|||jd d}tjd|jd |jd}|d}t| j|d |d d d	}| jd|||d|	}| |d d dd f |}t||jd |d| jj|
jd}| jd||||||||d|	}t||dS )Nr0  r6  z0input_ids is required for entropy-based patching)r.  r/  rr   patching_batch_sizer'   r   r%   r   r)   r1  T)rZ   r[   r\   r]   r^   r&   )r  r  r   r  r	  r[   rZ   r(   )r+  r   r  F)r  r  r  r   r  r   r   r	  )last_hidden_stater   r6   )rb   r0   rY   rR   rJ   r   rK   rL   rM   patching_moderT  r.  patching_thresholdrr   r[  r'   r&   r   r,   r8  _patch_ids_from_lengthsr3  r/   r_   r   rp   r^   r   r%  r   r	   )r   r  rq   r   r  r   r  r   r   r   encoder_embedsrA   r\   r   r'   r&   rZ   r9  r:  cross_attn_mask_encencoder_hidden_statesr  global_cache_positionglobal_position_idsglobal_causal_maskglobal_hidden_statesdecoder_patch_idscross_attn_mask_decoutputr6   r6   r7   r   "  s   
		
	


		zBltModel.forwardc                 C   s   | j jS rs   rR   rO   )r   r6   r6   r7   get_input_embeddings  r   zBltModel.get_input_embeddingsc                 C   s   || j _d S rs   rj  )r   valuer6   r6   r7   set_input_embeddings  s   zBltModel.set_input_embeddingsrB   c                 C   s|   |j d }tjtj|d|j|jd|jddd d d df gdd}tj||jd}|d|ddkj	ddd S )Nr   r   r%   r(   r*   r)   )
r0   r,   r?   r>   r&   r'   cumsumr/   r_   r1   )r   rq   rB   rA   rF  token_positionsr6   r6   r7   r_    s   
&z BltModel._patch_ids_from_lengthsr"  )r   r   r   r   r   r   r   r,   r   r   r   rK  re   r   r   r	   r   rk  rm  r   r_  r   r6   r6   r   r7   rN    sH    	
 "rN  c                       s  e Zd ZU eed< dZdZdgZdef fddZ												dd	e	e
j d
e	e
j de	e
j de	e
j de	e
j de	ee
je
jf  de	eeee
j f  de	e
j de	e
j de	e de	e
j deee
jf dee deeef fddZ  ZS )BltForCausalLMr   Fmodelzlm_head.weightc                    sB   t  | |j| _t|| _tj|jj|jdd| _	| 
  d S )NFr'  )r   r   r   rN  rq  r   r   rQ  r   r-  r  )r   r   r   r6   r7   r     s
   
zBltForCausalLM.__init__Nr   r  r   r  r   rl   full_text_row_masked_out_maskr   r  labelsr   r   logits_to_keepr   rN   c                 K   s   | j d||||||||
|d	|}|j}t|tr t| d n|}| |d d |d d f  }d }|	d urD| j||	| jfi |}t	|||j
|j|jdS )N)	r  r   r  rl   rr  r   r  r   r   )lossr!  r   r   r   r6   )rq  r\  
isinstancer   slicer-  rL  loss_functionr   r
   r   r   r   )r   r  r   r  r   rl   rr  r   r  rs  r   r   rt  r   outputsr   slice_indicesr!  ru  r6   r6   r7   r     s4   
 zBltForCausalLM.forward)NNNNNNNNNNNr   )r   r   r   r   r   _can_compile_fullgraphbase_model_prefix_tied_weights_keysr   r   r,   r   r   tupler   r   listrK  re   r   r   r   r
   r   r   r6   r6   r   r7   rp    s^   
 		

rp  )r   rN  r,  rp  )r#   )r   r#   r9   )Or   typingr   r   r   r,   torch.distributionstorch.nnr   torch.nn.functional
functionalr  cache_utilsr   r   masking_utilsr   modeling_outputsr	   r
   modeling_utilsr   processing_utilsr   utilsr   r   r   utils.genericr   r   cohere2.modeling_cohere2r   r   mllama.modeling_mllamar   r   r   r   r   r   r   r   configuration_bltr   r   r    r!   r"   
get_loggerr   loggerr   r8   r   rH   r   r  rY   float32re   r&   r~  rp   r   r   r   r   r   r   r   r   r   r  r#  r,  rN  rp  __all__r6   r6   r6   r7   <module>   s   (

	

*
N,
?sO5 
 *<