o
    i_                  	   @   s  d dl Z d dlmZ d dlmZmZmZ d dlZd dlm	Z	 ddl
mZmZ ddlmZmZ ddlmZmZmZ ddlmZ dd	lmZmZ dd
lmZmZ ddlmZ ddlmZm Z  ddl!m"Z" ddl#m$Z$m%Z%m&Z&m'Z' ddl(m)Z) ddl*m+Z+ ddl,m-Z-m.Z.m/Z/m0Z0m1Z1m2Z2m3Z3m4Z4m5Z5 ddl6m7Z7m8Z8m9Z9m:Z: ddl;m<Z< e'=e>Z?G dd de+eZ@G dd deZAG dd de:ZBG dd de7ZCG dd de	jDZEG dd  d e/ZFG d!d" d"e2ZGG d#d$ d$e3ZHG d%d& d&e-ZIG d'd( d(eZJdZKG d)d* d*e1ZLd+eMd,eeMeMeMeMgeNf fd-d.ZOG d/d0 d0e0ZPG d1d2 d2e.ZQG d3d4 d4e	jRZSd5eejT d6eejT d7eMd,ee fd8d9ZUG d:d; d;e9ZVG d<d= d=e8ZWG d>d? d?eLZXG d@dA dAeeLZYg dBZZdS )C    N)Callable)AnyOptionalUnion   )CacheDynamicCache)PretrainedConfiglayer_type_validation)create_causal_maskcreate_masks_for_generate!create_sliding_window_causal_mask)FlashAttentionKwargs) GenericForSequenceClassificationGradientCheckpointingLayer)BaseModelOutputWithPast SequenceClassifierOutputWithPast)rope_config_validation)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)TransformersKwargsauto_docstringcan_return_tuplelogging)deprecate_kwarg   )Gemma2Config)	Gemma2AttentionGemma2ForCausalLM	Gemma2MLPGemma2ModelGemma2PreTrainedModelGemma2RMSNormGemma2RotaryEmbeddingapply_rotary_pos_embeager_attention_forward)PaligemmaCausalLMOutputWithPast!PaliGemmaForConditionalGenerationPaliGemmaModelPaligemmaModelOutputWithPast)SiglipVisionConfigc                   @   sT   e Zd ZdZdZ									
																				dddZdS )Gemma3TextConfiga,!  
    This is the configuration class to store the configuration of a [`Gemma3TextModel`]. It is used to instantiate an Gemma3Text
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the Gemma3Text-7B.
    e.g. [google/gemma3_text-7b](https://huggingface.co/google/gemma3_text-7b)
    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.
    Args:
        vocab_size (`int`, *optional*, defaults to 262208):
            Vocabulary size of the Gemma3Text model. Defines the number of different tokens that can be represented by the
            `inputs_ids` passed when calling [`Gemma3TextModel`]
        hidden_size (`int`, *optional*, defaults to 2304):
            Dimension of the hidden representations.
        intermediate_size (`int`, *optional*, defaults to 9216):
            Dimension of the MLP representations.
        num_hidden_layers (`int`, *optional*, defaults to 26):
            Number of hidden layers in the Transformer decoder.
        num_attention_heads (`int`, *optional*, defaults to 8):
            Number of attention heads for each attention layer in the Transformer decoder.
        num_key_value_heads (`int`, *optional*, defaults to 4):
            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
            by meanpooling all the original heads within that group. For more details, check out [this
            paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
            `num_attention_heads`.
        head_dim (`int`, *optional*, defaults to 256):
            The attention head dimension.
        hidden_activation (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
            The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
            if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
        max_position_embeddings (`int`, *optional*, defaults to 131072):
            The maximum sequence length that this model might ever be used with.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        rms_norm_eps (`float`, *optional*, defaults to 1e-06):
            The epsilon used by the rms normalization layers.
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether or not the model should return the last key/values attentions (not used by all models). Only
            relevant if `config.is_decoder=True`.
        pad_token_id (`int`, *optional*, defaults to 0):
            Padding token id.
        eos_token_id (`int`, *optional*, defaults to 1):
            End of stream token id.
        bos_token_id (`int`, *optional*, defaults to 2):
            Beginning of stream token id.
        tie_word_embeddings (`bool`, *optional*, defaults to `True`):
            Whether to tie weight embeddings
        rope_theta (`float`, *optional*, defaults to 1000000.0):
            The base period of the RoPE embeddings.
        attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
            Whether to use a bias in the query, key, value and output projection layers during self-attention.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.
        query_pre_attn_scalar (`float`, *optional*, defaults to 256):
            Scaling factor used on the attention scores
        sliding_window (`int`, *optional*, defaults to 4096):
            In Gemma3Text, every other layer uses sliding window attention. This is the size of the sliding window.
        layer_types (`list`, *optional*):
            Attention pattern for each layer.
        final_logit_softcapping (`float`, *optional*):
            Scaling factor when applying tanh softcapping on the logits.
        attn_logit_softcapping (`float`, *optional*):
            Scaling factor when applying tanh softcapping on the attention scores.
        rope_scaling (`Dict`, *optional*):
            Dictionary containing the scaling configuration for the RoPE embeddings used in global attention. NOTE: if you apply new rope type
            and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
            accordingly.
            Expected contents:
                `rope_type` (`str`):
                    The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
                    'llama3'], with 'default' being the original RoPE implementation.
                `factor` (`float`, *optional*):
                    Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
                    most scaling types, a `factor` of x will enable the model to handle sequences of length x *
                    original maximum pre-trained length.
                `original_max_position_embeddings` (`int`, *optional*):
                    Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
                    pretraining.
                `attention_factor` (`float`, *optional*):
                    Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
                    computation. If unspecified, it defaults to value recommended by the implementation, using the
                    `factor` field to infer the suggested value.
                `beta_fast` (`float`, *optional*):
                    Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
                    ramp function. If unspecified, it defaults to 32.
                `beta_slow` (`float`, *optional*):
                    Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
                    ramp function. If unspecified, it defaults to 1.
                `short_factor` (`list[float]`, *optional*):
                    Only used with 'longrope'. The scaling factor to be applied to short contexts (<
                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
                    size divided by the number of attention heads divided by 2
                `long_factor` (`list[float]`, *optional*):
                    Only used with 'longrope'. The scaling factor to be applied to long contexts (<
                    `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
                    size divided by the number of attention heads divided by 2
                `low_freq_factor` (`float`, *optional*):
                    Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
                `high_freq_factor` (`float`, *optional*):
                    Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
        rope_local_base_freq (float, *optional*, defaults to 10000.0):
            The base period of the RoPE embeddings for local attention.
        use_bidirectional_attention (`bool`, *optional*, defaults to `False`): If True, the model will attend to all
            text tokens instead of using a causal mask. This does not change behavior for vision tokens.

    ```python
    >>> from transformers import Gemma3TextModel, Gemma3TextConfig
    >>> # Initializing a Gemma3Text gemma3_text-7b style configuration
    >>> configuration = Gemma3TextConfig()
    >>> # Initializing a model from the gemma3_text-7b style configuration
    >>> model = Gemma3TextModel(configuration)
    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
    gemma3_text@   	   $              gelu_pytorch_tanh   {Gz?ư>Tr      r       .AF           N     @c                    s  t jd||||d| | _|	 _| _| _| _| _| _| _	|
 _
| _| _| _| _| _| _| _| _| _| _| _| _|rV jd d  _| _| _t  |dd _ jd u ry fddt jD  _t j j d S )	N)pad_token_idbos_token_ideos_token_idtie_word_embeddingsr   r9   sliding_window_pattern   c                    s&   g | ]}t |d   j rdndqS )r9   sliding_attentionfull_attention)bool_sliding_window_pattern).0iself f/home/ubuntu/veenaModal/venv/lib/python3.10/site-packages/transformers/models/gemma3/modular_gemma3.py
<listcomp>   s    z-Gemma3TextConfig.__init__.<locals>.<listcomp>rL   )r	   __init__
vocab_sizemax_position_embeddingshidden_sizeintermediate_sizenum_hidden_layersnum_attention_headshead_dimnum_key_value_headsinitializer_rangerms_norm_eps	use_cache
rope_thetaattention_biasattention_dropouthidden_activationquery_pre_attn_scalarsliding_windowfinal_logit_softcappingattn_logit_softcappinglayer_typesuse_bidirectional_attentionrope_local_base_freqrope_scalingr   getrG   ranger
   )rK   rP   rR   rS   rT   rU   rW   rV   r^   rQ   rX   rY   rZ   r>   r@   r?   rA   r[   r\   r]   r_   r`   rc   ra   rb   rf   re   rd   kwargsrL   rJ   rM   rO      sP   

zGemma3TextConfig.__init__)r.   r/   r0   r1   r2   r3   r4   r5   r6   r7   r8   Tr   r9   r   Tr:   Fr;   r4   r<   NNNNr=   F)__name__
__module____qualname____doc__
model_typerO   rL   rL   rL   rM   r,   :   s>    vr,   c                       s   e Zd ZdZdZddddZeedZ					
			dde	e
eeeef f  de	e
eeeef f  dededededef fddZ  ZS )Gemma3Configa  
    This is the configuration class to store the configuration of a [`Gemma3ForConditionalGeneration`]. It is used to instantiate an
    Gemma3ForConditionalGeneration according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the PaliGemma-2B.

    e.g. [google/gemma-3-4b](https://huggingface.co/google/gemma-3-4b)

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        text_config (`Union[Gemma3TextConfig, dict]`, *optional*):
            The config object of the text backbone.
        vision_config (`Union[AutoConfig, dict]`,  *optional*):
            Custom vision config or dict.
        mm_tokens_per_image (`int`, *optional*, defaults to 256):
            The number of tokens per image embedding.
        boi_token_index (`int`, *optional*, defaults to 255999):
            The begin-of-image token index to wrap the image prompt.
        eoi_token_index (`int`, *optional*, defaults to 256000):
            The end-of-image token index to wrap the image prompt.
        image_token_index (`int`, *optional*, defaults to 262144):
            The image token index to encode the image prompt.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.


    Example:

    ```python
    >>> from transformers import Gemma3ForConditionalGeneration, Gemma3Config, SiglipVisionConfig, Gemma3TextConfig

    >>> # Initializing a Siglip-like vision config
    >>> vision_config = SiglipVisionConfig()

    >>> # Initializing a Gemma3 Text config
    >>> text_config = Gemma3TextConfig()

    >>> # Initializing a Gemma3 gemma-3-4b style configuration
    >>> configuration = Gemma3Config(vision_config, text_config)

    >>> # Initializing a model from the gemma-3-4b style configuration
    >>> model = Gemma3TextConfig(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```gemma3image_token_indexboi_token_indexeoi_token_index)image_token_idboi_token_ideoi_token_id)text_configvision_configNr4         r7   rw   rx   mm_tokens_per_imagerX   c           	         s   |d u rt  }td nt|trt di |}t|tr&tdi |}n|d u r2t }td || _|| _|| _|| _	|| _
|| _|| _t jdi | d S )Nz@text_config is None, using default Gemma3TextConfig text config.zFvision_config is None, using default SiglipVisionConfig vision config.rL   )r,   loggerinfo
isinstancedictr+   rw   rx   r|   rr   rs   rq   rX   superrO   )	rK   rw   rx   r|   rr   rs   rq   rX   ri   	__class__rL   rM   rO   <  s$   


zGemma3Config.__init__)NNr4   ry   rz   r{   r7   )rj   rk   rl   rm   rn   attribute_mapr,   r+   sub_configsr   r   r   strr   intfloatrO   __classcell__rL   rL   r   rM   ro      s@    0ro   c                   @      e Zd ZdS )Gemma3ModelOutputWithPastNrj   rk   rl   rL   rL   rL   rM   r   ^      r   c                   @   r   )Gemma3CausalLMOutputWithPastNr   rL   rL   rL   rM   r   b  r   r   c                	       sH   e Zd ZdZddedededef fddZd	ejf fd
dZ	  Z
S )Gemma3TextScaledWordEmbeddingz\
    This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
          ?num_embeddingsembedding_dimpadding_idxembed_scalec                    s*   t  ||| | jdt|dd d S )Nr   F)
persistent)r   rO   register_buffertorchtensor)rK   r   r   r   r   r   rL   rM   rO   k  s   z&Gemma3TextScaledWordEmbedding.__init__	input_idsc                    s   t  || j| jj S N)r   forwardr   toweightdtype)rK   r   r   rL   rM   r   o  s   z%Gemma3TextScaledWordEmbedding.forward)r   )rj   rk   rl   rm   r   r   rO   r   Tensorr   r   rL   rL   r   rM   r   f  s     r   c                       s"   e Zd Zdef fddZ  ZS )	Gemma3MLPconfigc                       t  | d S r   r   rO   rK   r   r   rL   rM   rO   t     zGemma3MLP.__init__rj   rk   rl   r,   rO   r   rL   rL   r   rM   r   s  s    r   c                       s(   e Zd Zddedef fddZ  ZS )Gemma3RMSNormr8   dimepsc                    s   t  j||d d S )Nr   r   r   )rK   r   r   r   rL   rM   rO   y  s   zGemma3RMSNorm.__init__)r8   )rj   rk   rl   r   r   rO   r   rL   rL   r   rM   r   x  s     r   c                       s$   e Zd Zddef fddZ  ZS )Gemma3RotaryEmbeddingNr   c                    r   r   r   )rK   r   devicer   rL   rM   rO   ~  r   zGemma3RotaryEmbedding.__init__r   r   rL   rL   r   rM   r   }  s    r   c                       s   e Zd Zdedef fddZedddd				dd
ejdejde	ej de	e
 de	ej dee deeje	ej e	eej  f fddZ  ZS )Gemma3Attentionr   	layer_idxc                    sd   |j | dk| _t || | jr|jnd | _| jj | _t|j	|j
d| _t|j	|j
d| _d S )NrD   r   )rc   
is_slidingr   rO   r`   r   rd   	is_causalr   rV   rY   q_normk_normrK   r   r   r   rL   rM   rO     s   zGemma3Attention.__init__past_key_valuepast_key_values4.58new_nameversionNhidden_statesposition_embeddingsattention_maskcache_positionri   returnc                 K   s<  |j d d }g |d| jR }| ||dd}	| ||dd}
| ||dd}| |	}	| |
}
|\}}t	|	|
||\}	}
|d ura|||d}|
|
|| j|\}
}t}| jjdkrot| jj }|| |	|
||f| jr|| jnd| j| jd|\}}|jg |dR   }| |}||fS )Nr9   r   )sincosr   eagerr;   )dropoutscalingr`   )shaperV   q_projview	transposek_projv_projr   r   r%   updater   r&   r   _attn_implementationr   trainingr]   r   r`   reshape
contiguouso_proj)rK   r   r   r   r   r   ri   input_shapehidden_shapequery_states
key_statesvalue_statesr   r   cache_kwargsattention_interfaceattn_outputattn_weightsrL   rL   rM   r     s>   


	

zGemma3Attention.forward)NN)rj   rk   rl   r,   r   rO   r   r   r   r   r   
LongTensorr   r   tupler   r   rL   rL   r   rM   r     s(    
r   c                       s   e Zd Zdedef fddZedddd							
	
		ddejdejdejde	ej de	ej
 de	e de	e de	e de	ej
 deeje	eejejf  f fddZ  ZS )Gemma3DecoderLayerr   r   c                    s   t    || _|j| _|| _|j| | _t||d| _t	|| _
t| j|jd| _t| j|jd| _t| j|jd| _t| j|jd| _d S )N)r   r   r   )r   rO   r   rR   r   rc   attention_typer   	self_attnr   mlpr   rY   input_layernormpost_attention_layernormpre_feedforward_layernormpost_feedforward_layernormr   r   rL   rM   rO     s   

zGemma3DecoderLayer.__init__r   r   r   r   NFr   position_embeddings_globalposition_embeddings_localr   position_idsoutput_attentionsrZ   r   r   c
                 K   s   |}|  |}| jjr|}n|}| jd||||||||	d|
\}}| |}|| }|}| |}| |}| |}|| }|f}|rK||f7 }|S )N)r   r   r   r   r   r   rZ   r   rL   )r   r   r   r   r   r   r   )rK   r   r   r   r   r   r   r   rZ   r   ri   residualr   self_attn_weightsoutputsrL   rL   rM   r     s8   
	





zGemma3DecoderLayer.forward)NNNFFN)rj   rk   rl   r,   r   rO   r   r   r   r   r   r   rF   r   FloatTensorr   r   rL   rL   r   rM   r     s<    	
r   c                   @   s    e Zd ZdZg dZdd ZdS )Gemma3PreTrainedModel )r   SiglipVisionEmbeddingsSiglipEncoderLayer#SiglipMultiheadAttentionPoolingHeadc                 C   sF   t | | t|tr|jj  d S d|jjv r!|j	j  d S d S )NRMSNorm)
r   _init_weightsr   Gemma3MultiModalProjectormm_input_projection_weightdatazero_r   rj   r   )rK   modulerL   rL   rM   r     s   
z#Gemma3PreTrainedModel._init_weightsN)rj   rk   rl   base_model_prefix_no_split_modulesr   rL   rL   rL   rM   r     s    r   r`   r   c              
      s&   dt dt dt dt dtf
 fdd}|S )zA
    Enables a bidirectional mask within the sliding window.
    	batch_idxhead_idxq_idxkv_idxr   c                    s   t ||  k S )zA token can attend to any other token if their absolute distance is within
        the (exclusive) sliding window size (distance < sliding_window).)abs)r   r   r   r   r`   rL   rM   
inner_mask  s   z1_bidirectional_window_overlay.<locals>.inner_maskr   rF   )r`   r  rL   r   rM   _bidirectional_window_overlay  s   "r  c                       s   e Zd ZU eed< def fddZ									ddeej deej	 deej dee
 d	eej d
ee dee dee deej dee defddZ  ZS )Gemma3TextModelr   c                    sX   t  | t|j|j| j| jjd d| _t	|}|j
|_ddi|_t|d| _d S )N      ?)r   	rope_typedefaultr   )r   rO   r   rP   rR   r   r   embed_tokenscopydeepcopyre   r[   rf   r   rotary_emb_localr   r   rL   rM   rO   &  s   

zGemma3TextModel.__init__Nr   r   r   r   inputs_embedsrZ   r   output_hidden_statesr   ri   r   c
                 K   s  |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}|d u |d uA r*td| jr9| jr9|r9td d}|d u rB| 	|}|rQ|d u rQ| jsQt
| j d}|	d u rm|d ur]| nd}tj|||jd  |jd}	|d u rv|	d}t| }ts| j |||	||d}| }| j jrd	d
 |d< t| j j|d< tdi |tdi |d}|}| ||}| ||}|rdnd }|rdnd }| jd | j j D ]*}|r||f7 }||f||||j |||||	d|
}|d }|r||d f7 }q| |}|r||f7 }t||||dS )N:You must specify exactly one of input_ids or inputs_embedszX`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.Fr  r   r9   r   r   input_embedsr   r   r   r   c                  W   s   t jdt jdS )NT)r   )r   r   rF   )argsrL   rL   rM   <lambda>p  s    z)Gemma3TextModel.forward.<locals>.<lambda>or_mask_functionrE   rD   rL   )r   r   r   r   r   r   rZ   r   )last_hidden_stater   r   
attentions) r   r   r  rZ   
ValueErrorgradient_checkpointingr   r}   warning_oncer	  r   get_seq_lengthr   aranger   r   	unsqueezer   r   r
  rd   r  r`   r   r   
rotary_embr  layersrT   r   normr   )rK   r   r   r   r   r  rZ   r   r  r   ri   past_seen_tokenscausal_mask_mappingmask_kwargssliding_mask_kwargsr   r   r   all_hidden_statesall_self_attnsdecoder_layerlayer_outputsrL   rL   rM   r   5  s   





zGemma3TextModel.forward	NNNNNNNNN)rj   rk   rl   r,   __annotations__rO   r   r   r   r   r   r   rF   r   r   r   r   r   rL   rL   r   rM   r  #  sF   
 	
r  c                       s0   e Zd ZU eed< dZdef fddZ  ZS )Gemma3ForCausalLMr   language_modelc                    s   t  | t|| _d S r   )r   rO   r  modelr   r   rL   rM   rO     s   zGemma3ForCausalLM.__init__)rj   rk   rl   r,   r+  r   rO   r   rL   rL   r   rM   r,    s   
 r,  c                       s2   e Zd Zdef fddZdejfddZ  ZS )r   r   c                    s   t    tt|jj|jj| _	t
|jj|jjd| _t|jj|jj | _t|jd | _| j| j | _tj| j| jd| _d S )Nr   r  )kernel_sizestride)r   rO   nn	Parameterr   zerosrx   rR   rw   r   r   layer_norm_epsmm_soft_emb_normr   
image_size
patch_sizepatches_per_imager|   tokens_per_sider/  	AvgPool2davg_poolr   r   rL   rM   rO     s   
z"Gemma3MultiModalProjector.__init__vision_outputsc           	      C   sv   |j \}}}|dd}|||| j| j}| }| |}|d}|dd}| |}t	|| j
}||S )Nr9   r   )r   r   r   r8  r   r;  flattenr5  r   matmulr   type_as)	rK   r<  
batch_size_
seq_lengthreshaped_vision_outputspooled_vision_outputsnormed_vision_outputsprojected_vision_outputsrL   rL   rM   r     s   



z!Gemma3MultiModalProjector.forward)	rj   rk   rl   ro   rO   r   r   r   r   rL   rL   r   rM   r     s    r   token_type_idsimage_group_idstokens_per_imagec              
      s4   du rdS dt dt dt dt dtf
 fdd}|S )	z
    This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
    not start and end indices.
    Nr   r   r   r   r   c           	         s   t |jd k |d}| |f }t |jd k |d} | |f }t | jd k |d}| |f dk|dk@ } | |f |k}||@ S )Nr9   r   r   )r   wherer   )	r   r   r   r   safe_idxtoken_type_ids_at_kv_idximage_group_ids_at_kv_idxis_image_blocksame_image_blockrH  rG  rL   rM   r    s   z0token_type_ids_mask_function.<locals>.inner_maskr  )rG  rH  rI  r  rL   rP  rM   token_type_ids_mask_function  s   
$rQ  c                !       s   e Zd ZdZdef fddZdejdejfddZd	d
 Z	e
e													ddeej deej deej deej dee deej deej deej deej dee dee dee dee deeef fddZ  ZS )Gemma3ModelFr   c                    s   t  | | `d S r   )r   rO   text_config_dtyper   r   rL   rM   rO     s   zGemma3Model.__init__pixel_valuesr   c                 C   s   | j |dj}| |}|S )a  
        Projects the last hidden state from the vision model into language model space.

        Args:
            pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
               The tensors corresponding to the input images.
        Returns:
            image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
        )rT  )vision_towerr  multi_modal_projector)rK   rT  r<  image_featuresrL   rL   rM   get_image_features  s   

zGemma3Model.get_image_featuresc                 K      t dNzWe don't want to inherit itAttributeErrorrK   super_kwargsrL   rL   rM   _update_causal_mask     zGemma3Model._update_causal_maskNr   r   r   r   rG  r   r  labelsrZ   r   r  return_dictc                 K   sf  |d u |d uA rt d|d ur|n| jj}|d ur|n| jj}|d ur&|n| jj}|d urD| jj| jkrD|| jjk}| }d||< n|}|d u rP|  |}|d u rl|d ur\|	 nd}t
j|||jd  |jd}|d ur| |}||j|j}| j|||d}|||}t| }ts| j |||||d}|
 p|d u p|j p|d u}|d ur|r|dk|j}|tjj|dddd d d d	f  @ }t
j| dd
d }t
||t
j|d	|jd}t||j|| jj|d< t di |t!di |d}| j"d|||||
||d|d	|}t#|j$|
r!|j%nd |j&|j'|d ur/|dS d dS )Nr  r   r9   r  )r  rW  r  r9   r   valuer   r   r  r  T)	r   r   r   r  rZ   r   r  rb  r   )r  r   r   r  image_hidden_statesrL   )(r  r   r   r  use_return_dictrt   rP   cloneget_input_embeddingsr  r   r  r   r   rX  r   r   get_placeholder_maskmasked_scatterr   r   get_text_configis_initializedr1  
functionalpadcumsumr   rJ  	full_likerQ  r|   r   r   r-  r   r  r   r   r  )rK   r   rT  r   r   r   rG  r   r  ra  rZ   r   r  rb  	lm_kwargsspecial_image_maskllm_input_idsr"  rW  r#  r$  
is_prefillis_imagenew_image_startrH  r   rL   rL   rM   r     s   

(
zGemma3Model.forward)NNNNNNNNNNNNN)rj   rk   rl   accepts_loss_kwargsro   rO   r   r   rX  r_  r   r   r   r   r   r   rF   r   r   r   r   r   rL   rL   r   rM   rR    sb    	

rR  c                "       sH  e Zd ZdZe														ddeej deej deej	 deej dee
 d	eej d
eej deej deej dee dee dee dee deeej	f deeef fddZ										d fdd	Zdd Ze	d dedej	deej	 d
ej	dee
 deej	 d	eej	 defddZ  ZS )!Gemma3ForConditionalGenerationFNr   r   rT  r   r   r   rG  r   r  ra  rZ   r   r  rb  logits_to_keepr   c                 K   s  |dur|n| j j}|dur|n| j j}|dur|n| j j}| jd||||||||
|	||||d|}|d }t|trCt| dn|}| |dd|ddf }d}|	dur|	 }|dddddf }|	dddf }|dur|dd|j
d  df |j}|||jdk  }|||jdk  }n| }| }t }|d| j jj}|d|j}|||}|s|f|dd  }|dur|f| S |S t|||j|j|j|jdS )	a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.

        Example:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration

        >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/gemma-3-4b-it")
        >>> processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")

        >>> messages = [
        ...     {
        ...         "role": "system",
        ...         "content": [
        ...             {"type": "text", "text": "You are a helpful assistant."}
        ...         ]
        ...     },
        ...     {
        ...         "role": "user", "content": [
        ...             {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
        ...             {"type": "text", "text": "Where is the cat standing?"},
        ...         ]
        ...     },
        ... ]

        >>> inputs = processor.apply_chat_template(
        ...     messages,
        ...     tokenize=True,
        ...     return_dict=True,
        ...     return_tensors="pt",
        ...     add_generation_prompt=True
        ... )
        >>> # Generate
        >>> generate_ids = model.generate(**inputs)
        >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
        ```
        N)r   rT  rG  r   r   r   r  rZ   ra  r   r  rb  r   r   .r   r9   )losslogitsr   r   r  rg  rL   )r   r   r  rh  r.  r   r   slicelm_headr   r   r   r   r   r1  CrossEntropyLossr   rw   rP   r   r   r   r  rg  )rK   r   rT  r   r   r   rG  r   r  ra  rZ   r   r  rb  r{  rs  r   r   slice_indicesr}  r|  shift_logitsshift_labelsshift_attention_maskloss_fctflat_logitsflat_labelsoutputrL   rL   rM   r     sd   @$
z&Gemma3ForConditionalGeneration.forwardTc                    s>   t  j|f||||||	|
|d|}|d dkr||d< |S )N)r   r  r   r   r   rZ   r{  rG  r   rT  )r   prepare_inputs_for_generation)rK   r   r   r  r   r   rT  r   rG  rZ   r{  ra  ri   model_inputsr   rL   rM   r    s"   
z<Gemma3ForConditionalGeneration.prepare_inputs_for_generationc                 K   rY  rZ  r[  r]  rL   rL   rM   5_prepare_4d_causal_attention_mask_with_cache_position,  r`  zTGemma3ForConditionalGeneration._prepare_4d_causal_attention_mask_with_cache_positionr   r  c                 K   s   |   |||||d}|d urU|jd dkrU|dk|j}	|	tjj|	dddd d d df  @ }
tj|
	 ddd }t
|	|t|d}t||j|| j|d< td	i |S )
Nr  r9   rc  r   rd  r   rf  r  rL   )rm  r   r   r   r1  ro  rp  r   rq  r   rJ  rr  rQ  r|   r   )r   r  r   r   r   r   rG  ri   r$  rw  rx  rH  rL   rL   rM   r   /  s    	(z8Gemma3ForConditionalGeneration.create_masks_for_generate)NNNNNNNNNNNNNr   )
NNNNNNNTNNr   )rj   rk   rl   ry  r   r   r   r   r   r   r   rF   r   r   r   r   r   r  r  staticmethodr	   r   r   r   rL   rL   r   rM   rz    s    	

 $	rz  c                       s   e Zd ZddddZ fddZdd Zd	d
 Zee									dde	e
j de	e
j de	e
j de	e
j de	e de	e
j de	e
j de	e
j de	e dee defddZ  ZS )Gemma3ForSequenceClassificationzmodel.language_modelzmodel.vision_towerzmodel.multi_modal_projector)z^language_model.modelz^vision_towerz^multi_modal_projectorc                    sB   t  | |j| _t|| _tj|jj| jdd| _	| 
  d S )NF)bias)r   rO   
num_labelsrR  r.  r1  Linearrw   rR   score	post_initr   r   rL   rM   rO   [  s
   
z(Gemma3ForSequenceClassification.__init__c                 C   s
   | j  S r   )r.  rj  rJ   rL   rL   rM   rj  d  s   
z4Gemma3ForSequenceClassification.get_input_embeddingsc                 C   s   | j | d S r   )r.  set_input_embeddings)rK   re  rL   rL   rM   r  g  r   z4Gemma3ForSequenceClassification.set_input_embeddingsNr   rT  r   r   r   r  rG  ra  rZ   ri   r   c
              
   K   s6  | j |f|||||||	d|
}|j}| |}|dur#|jd }n|jd }| jjjdu r7|dkr7td| jjjdu rAd}n2|durg|| jjjk|j	t
j}t
j|jd |j	t
jd}|| d}nd}t| jj d |t
j||j	d	|f }d}|dur| j|||| jd
}t|||j|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        )r   rT  r   r   r  rG  rZ   Nr   r9   z=Cannot handle batch sizes > 1 if no padding token is defined.r   )r   r   z will not detect padding tokens in `inputs_embeds`. Results may be unexpected if using padding tokens in conjunction with `inputs_embeds.`r  )r}  ra  pooled_logitsr   )r|  r}  r   r   r  )r.  r  r  r   r   rw   r>   r  r   r   r   int32r  argmaxr}   r  r   rj   loss_functionr   r   r   r  )rK   r   rT  r   r   r   r  rG  ra  rZ   ri   transformer_outputsr   r}  r@  last_non_pad_tokennon_pad_masktoken_indicesr  r|  rL   rL   rM   r   j  sR   	

z'Gemma3ForSequenceClassification.forwardr*  )rj   rk   rl   _checkpoint_conversion_mappingrO   rj  r  r   r   r   r   r   r   r   r   rF   r   r   r   r   r   rL   rL   r   rM   r  T  sT    		
r  c                   @   s   e Zd ZU dZeed< dS )#Gemma3TextForSequenceClassificationz
    Gemma3TextForSequenceClassification is a text-only sequence classification model that works with Gemma3TextConfig.
    It uses the generic sequence classification implementation for efficiency and consistency.
    r   N)rj   rk   rl   rm   r,   r+  rL   rL   rL   rM   r    s   
 r  )	ro   r,   r   r  r,  rz  rR  r  r  )[r
  collections.abcr   typingr   r   r   r   torch.nnr1  cache_utilsr   r   configuration_utilsr	   r
   masking_utilsr   r   r   modeling_flash_attention_utilsr   modeling_layersr   r   modeling_outputsr   r   modeling_rope_utilsr   modeling_utilsr   r   processing_utilsr   utilsr   r   r   r   utils.deprecationr   gemma2.configuration_gemma2r   gemma2.modeling_gemma2r   r   r    r!   r"   r#   r$   r%   r&   paligemma.modeling_paligemmar'   r(   r)   r*   siglipr+   
get_loggerrj   r}   r,   ro   r   r   	Embeddingr   r   r   r   r   r   GEMMA3_START_DOCSTRINGr   r   rF   r  r  r,  Moduler   r   rQ  rR  rz  r  r  __all__rL   rL   rL   rM   <module>   sn   ,
 G^<B" 	$
!  Q^	