o
    ig                     @   s  d dl mZmZmZ d dlZd dlmZ d dlZd dl	m
Z d dlmZmZmZ d dlmZmZ d dlmZ d dlmZmZ d dlmZ dd	lmZmZ dd
lmZmZmZmZ ddl m!Z!m"Z"m#Z# ddl$m%Z%m&Z&m'Z' e#(e)Z*dZ+dZ,dZ-dZ.ej/j0G dd de!Z1ej/j0G dd de!Z2G dd dej3Z4G dd dej3Z5G dd dej3Z6G dd dej3Z7G dd dej3Z8G d d! d!ej3Z9G d"d# d#ej3Z:G d$d% d%ej3Z;G d&d' d'ej3Z<G d(d) d)eZ=G d*d+ d+eZ>G d,d- d-eZ?G d.d/ d/ej3Z@G d0d1 d1e=ZAd2ZBeeAe,eB  eeAee&d3 G d4d5 d5ej3ZCG d6d7 d7e=ZDd8ZEeeDe,eE  eeDe1e&d3 G d9d: d:ej3ZFG d;d< d<e>ZGd=ZHeeGe-eH  eeGee'd3 G d>d? d?ej3ZIe"e+G d@dA dAe?ZJdBZKeeJe.eK  eeJe2e%d3 g dCZLdS )D    )AnyOptionalUnionN)
FrozenDictfreezeunfreeze)combine_masksmake_causal_mask)dot_product_attention_weights)flatten_dictunflatten_dict)lax   )FlaxBaseModelOutputFlaxBaseModelOutputWithPooling)ACT2FNFlaxPreTrainedModel append_replace_return_docstringsoverwrite_call_docstring)ModelOutputadd_start_docstringslogging   )
CLIPConfigCLIPTextConfigCLIPVisionConfiga  

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading, saving and converting weights from PyTorch models)

    This model is also a
    [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
    a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
    behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

    Parameters:
        config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
            `jax.numpy.bfloat16` (on TPUs).

            This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
            specified all the computation will be performed with the given `dtype`.

            **Note that this only specifies the dtype of the computation and does not influence the dtype of model
            parameters.**

            If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
            [`~FlaxPreTrainedModel.to_bf16`].
a~  
    Args:
        input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
aA  
    Args:
        pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
            [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
a  
    Args:
        input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
            [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
c                   @   sb   e Zd ZU dZdZejed< dZejed< dZ	e
eejdf  ed< dZe
eejdf  ed< dS )FlaxCLIPTextModelOutputaJ  
    Base class for text model's outputs that also contains a pooling of the last hidden states.

    Args:
        text_embeds (`jnp.ndarray` of shape `(batch_size, output_dim`):
            The text embeddings obtained by applying the projection layer to the pooled output of
            [`FlaxCLIPTextModel`].
        last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Ntext_embedslast_hidden_state.hidden_states
attentions)__name__
__module____qualname____doc__r   jnpndarray__annotations__r   r   r   tupler     r)   r)   _/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/clip/modeling_flax_clip.pyr      s   
 r   c                   @   st   e Zd ZU dZdZejed< dZejed< dZ	ejed< dZ
ejed< dZeed< dZeed< d	ee fd
dZdS )FlaxCLIPOutputah  
    Args:
        logits_per_image:(`jnp.ndarray` of shape `(image_batch_size, text_batch_size)`):
            The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
            similarity scores.
        logits_per_text:(`jnp.ndarray` of shape `(text_batch_size, image_batch_size)`):
            The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
            similarity scores.
        text_embeds(`jnp.ndarray` of shape `(batch_size, output_dim`):
            The text embeddings obtained by applying the projection layer to the pooled output of
            [`FlaxCLIPTextModel`].
        image_embeds(`jnp.ndarray` of shape `(batch_size, output_dim`):
            The image embeddings obtained by applying the projection layer to the pooled output of
            [`FlaxCLIPVisionModel`].
        text_model_output(`FlaxBaseModelOutputWithPooling`):
            The output of the [`FlaxCLIPTextModel`].
        vision_model_output(`FlaxBaseModelOutputWithPooling`):
            The output of the [`FlaxCLIPVisionModel`].
    Nlogits_per_imagelogits_per_textr   image_embedstext_model_outputvision_model_outputreturnc                    s   t  fdd  D S )Nc                 3   s.    | ]}|d vr | nt  | V  qdS ))r/   r0   N)getattrto_tuple).0kselfr)   r*   	<genexpr>   s
    
z*FlaxCLIPOutput.to_tuple.<locals>.<genexpr>)r(   keysr6   r)   r6   r*   r3      s   zFlaxCLIPOutput.to_tuple)r!   r"   r#   r$   r,   r%   r&   r'   r-   r   r.   r/   r   r0   r(   r   r3   r)   r)   r)   r*   r+      s   
 r+   c                   @   6   e Zd ZU eed< ejZejed< dd Zdd Z	dS )FlaxCLIPVisionEmbeddingsconfigdtypec              	   C   s   | j j}| j j}| j j}| dtjjjdd|f| _	tj
|||f||fdd| jtjj d| _|| d | _| jd }tj||tjj d	| _tjtjd
|ddd
d| _d S )Nclass_embedding{Gz?)stddevVALIDF)kernel_sizestridespaddinguse_biasr=   kernel_init   r   embedding_initr   i4r=   axis)r<   hidden_size
image_size
patch_sizeparamjaxnninitializersnormalr>   Convr=   patch_embeddingnum_patchesEmbedposition_embeddingr%   expand_dimsarangeposition_ids)r7   	embed_dimrO   rP   num_positionsr)   r)   r*   setup   s"   


 zFlaxCLIPVisionEmbeddings.setupc           	      C   sv   |  |}|j\}}}}t|||| |f}tj| jdd}t||ddf}tj||gdd}|| | j	 }|S )Nr   r   rL   r   )
rW   shaper%   reshaper[   r>   tileconcatenaterZ   r]   )	r7   pixel_valuespatch_embeds
batch_sizeheightwidthchannelsclass_embeds
embeddingsr)   r)   r*   __call__   s   
z!FlaxCLIPVisionEmbeddings.__call__N)
r!   r"   r#   r   r'   r%   float32r=   r`   rn   r)   r)   r)   r*   r;      s
   
 r;   c                   @   r:   )FlaxCLIPTextEmbeddingsr<   r=   c                 C   sh   | j j}tj| j j|tjj d| _tj| j j	|tjj d| _
tjtjd| j j	dddd| _d S )NrH   r   rJ   rK   ra   rL   )r<   rN   rS   rY   
vocab_sizerR   rT   rU   token_embeddingmax_position_embeddingsrZ   r%   r[   r\   r]   )r7   r^   r)   r)   r*   r`     s   zFlaxCLIPTextEmbeddings.setupc                 C   s,   |  |d}| |d}|| }|S )NrJ   )rr   astyperZ   )r7   	input_idsr]   input_embedsposition_embedsrm   r)   r)   r*   rn     s   zFlaxCLIPTextEmbeddings.__call__N)
r!   r"   r#   r   r'   r%   ro   r=   r`   rn   r)   r)   r)   r*   rp     s
   
 rp   c                   @   s`   e Zd ZU eeef ed< ejZ	ej	ed< dd Z
dd Zdd Z			
	ddedefddZd	S )FlaxCLIPAttentionr<   r=   c                 C   s  | j j| _| j j| _| j| j | _| j| j | jkr'td| j d| j d| jd | _| j j| _	t
j| j| jtj
jdd| _t
j| j| jtj
jdd| _t
j| j| jtj
jdd| _t
j| j| jtj
jdd| _t| j t| _| jrttjd| j jfdd	| _d S d S )
Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      {Gz?r=   rF   r   rJ   rK   )r<   rN   r^   num_attention_heads	num_headshead_dim
ValueErrorscaleattention_dropoutdropoutrS   Denser=   rR   rT   rU   k_projv_projq_projout_proj
isinstancer   causalr	   r%   onesrs   causal_maskr6   r)   r)   r*   r`   "  s&   



     zFlaxCLIPAttention.setupc                 C   s    | |jd d | j| jf S NrG   )rc   rb   r|   r}   r7   r   r)   r)   r*   _split_heads7  s    zFlaxCLIPAttention._split_headsc                 C   s   | |jd d | jf S r   )rc   rb   r^   r   r)   r)   r*   _merge_heads:  s   zFlaxCLIPAttention._merge_headsNTFdeterministicoutput_attentionsc              
   C   s  |  |}| |}| |}| |}| |}| |}d }| jrA|jd |jd }	}
| jd d d d |
|	 |
d |
f }|d urX|d urXtj|dd}t	||dd}n|d ur_|}n|d urjtj|dd}|d urt
|dkt|jd| jt|jt| jj| j}nd }d }|s| jdkr| d}t||||| j|| jd d	}td
||}| |}| |}|r||f}|S |f}|S )Nr   )rL   rJ   rK   r   g        r   )biasdropout_rngdropout_rater   r=   	precisionz...hqk,...khd->...qhd)r   r   r   r   r   rb   r   r%   r[   r   r   selectfullrt   r=   finfominr   make_rngr
   einsumr   r   )r7   r   attention_maskr   r   querykeyvaluecausal_attention_maskquery_length
key_lengthattention_biasr   attn_weightsattn_outputoutputsr)   r)   r*   rn   =  sX   





&


zFlaxCLIPAttention.__call__)NTF)r!   r"   r#   r   r   r   r'   r%   ro   r=   r`   r   r   boolrn   r)   r)   r)   r*   rx     s   
 rx   c                   @   s>   e Zd ZU eeef ed< ejZ	ej	ed< dd Z
dd ZdS )FlaxCLIPMLPr<   r=   c                 C   sV   t | jj | _tj| jj| jtjj	
dd| _tj| jj| jtjj	
dd| _d S )Nry   rz   )r   r<   
hidden_actactivation_fnrS   r   intermediate_sizer=   rR   rT   rU   fc1rN   fc2r6   r)   r)   r*   r`   }  s   &zFlaxCLIPMLP.setupc                 C   s"   |  |}| |}| |}|S N)r   r   r   r   r)   r)   r*   rn     s   


zFlaxCLIPMLP.__call__N)r!   r"   r#   r   r   r   r'   r%   ro   r=   r`   rn   r)   r)   r)   r*   r   y  s
   
 	r   c                   @   sN   e Zd ZU eeef ed< ejZ	ej	ed< dd Z
		ddedefd	d
ZdS )FlaxCLIPEncoderLayerr<   r=   c                 C   T   t | j| jd| _tj| jj| jd| _t| j| jd| _	tj| jj| jd| _
d S NrK   )epsilonr=   )rx   r<   r=   	self_attnrS   	LayerNormlayer_norm_epslayer_norm1r   mlplayer_norm2r6   r)   r)   r*   r`        zFlaxCLIPEncoderLayer.setupTFr   r   c                 C   sn   |}|  |}| j||||d}|d }|| }|}| |}| |}|| }|f}|r5||dd  7 }|S )N)r   r   r   r   r   r   )r   r   r   r   )r7   r   r   r   r   residualattn_outputsr   r)   r)   r*   rn     s$   


zFlaxCLIPEncoderLayer.__call__N)TFr!   r"   r#   r   r   r   r'   r%   ro   r=   r`   r   rn   r)   r)   r)   r*   r     s   
 
r   c                	   @   \   e Zd ZU eeef ed< ejZ	ej	ed< dd Z
					dded	ed
edefddZdS )FlaxCLIPLayerCollectionr<   r=   c                    s     fddt  jjD  _d S )Nc                    s"   g | ]}t  jt| jd qS ))namer=   )r   r<   strr=   )r4   ir6   r)   r*   
<listcomp>  s    z1FlaxCLIPLayerCollection.setup.<locals>.<listcomp>)ranger<   num_hidden_layerslayersr6   r)   r6   r*   r`     s   

zFlaxCLIPLayerCollection.setupNTFr   r   output_hidden_statesreturn_dictc                 C   s   |rdnd }|r
dnd }| j D ]}	|r||f7 }|	||||d}
|
d }|r-||
d f7 }q|r5||f7 }|f}|sCtdd |D S t|||dS )Nr)   )r   r   r   r   c                 s   s    | ]	}|d ur|V  qd S r   r)   )r4   vr)   r)   r*   r8     s    z3FlaxCLIPLayerCollection.__call__.<locals>.<genexpr>)r   r   r    )r   r(   r   )r7   r   r   r   r   r   r   all_attentionsall_hidden_stateslayerlayer_outputsr   r)   r)   r*   rn     s(   	


z FlaxCLIPLayerCollection.__call__NTFFTr   r)   r)   r)   r*   r     s$   
 	r   c                	   @   r   )FlaxCLIPEncoderr<   r=   c                 C      t | j| jd| _d S NrK   )r   r<   r=   r   r6   r)   r)   r*   r`        zFlaxCLIPEncoder.setupNTFr   r   r   r   c                 C   s   | j ||||||dS )N)r   r   r   r   r   r   )r   )r7   inputs_embedsr   r   r   r   r   r)   r)   r*   rn     s   	zFlaxCLIPEncoder.__call__r   r   r)   r)   r)   r*   r     s$   
 r   c                	   @   R   e Zd ZU eed< ejZejed< dd Z				dde	de	d	e	d
e	fddZ
dS )FlaxCLIPTextTransformerr<   r=   c                 C   sH   t | j| jd| _t| j| jd| _tj| jj| jd| _	| jj
| _
d S r   )rp   r<   r=   rm   r   encoderrS   r   r   final_layer_normeos_token_idr6   r)   r)   r*   r`     s   zFlaxCLIPTextTransformer.setupTFr   r   r   r   c                 C   s   |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}| j||d}| j||||||d}	|	d }
| |
}
| jdkrO|
t	|
j
d |jddf }n|
t	|
j
d || jkjddf }|sn|
|f|	dd   S t|
||	j|	jdS )	N)ru   r]   )r   r   r   r   r   r   r   rG   rL   r   r   pooler_outputr   r    )r<   r   r   use_return_dictrm   r   r   r   r%   r\   rb   argmaxr   r   r    )r7   ru   r   r]   r   r   r   r   r   encoder_outputsr   pooled_outputr)   r)   r*   rn     s8   
	

" z FlaxCLIPTextTransformer.__call__NTFFTr!   r"   r#   r   r'   r%   ro   r=   r`   r   rn   r)   r)   r)   r*   r      s"   
 r   c                   @   sL   e Zd ZU eed< ejZejed< dd Z					dde	de	fd	d
Z
dS )FlaxCLIPVisionTransformerr<   r=   c                 C   r   r   )r;   r<   r=   rm   rS   r   r   pre_layrnormr   r   post_layernormr6   r)   r)   r*   r`   F  r   zFlaxCLIPVisionTransformer.setupNTr   r   c           
      C   s   |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}| |}| |}| j|||||d}|d }|d d dd d f }	| |	}	|sR||	f|dd   S t||	|j	|j
dS )N)r   r   r   r   r   r   r   r   )r<   r   r   r   rm   r   r   r   r   r   r    )
r7   rf   r   r   r   r   r   r   r   r   r)   r)   r*   rn   L  s0   


z"FlaxCLIPVisionTransformer.__call__)NTNNTr!   r"   r#   r   r'   r%   ro   r=   r`   r   rn   r)   r)   r)   r*   r   B  s   
 r   c                       s   e Zd ZU eZdZejed< dde	j
dfdedede	jd	ef fd
dZddejjdededefddZ								ddee dejjdedee dee dee fddZ  ZS )FlaxCLIPTextPreTrainedModelNmodule_classr   r   r   Tr<   seedr=   _do_initc                    s2   | j d||d|}t j||||||d d S )Nr<   r=   input_shaper   r=   r   r)   )r   super__init__r7   r<   r   r   r=   r   kwargsmodule	__class__r)   r*   r   x  s   	z$FlaxCLIPTextPreTrainedModel.__init__rngr   paramsr1   c                 C   s   t j|dd}t t t |jd |}t |}tj	|\}}||d}	| j
|	|||d }
|d urYtt|
}
tt|}| jD ]}|
| ||< qFt | _tt|S |
S )NrJ   rK   r   r   r   r   )r%   zerosbroadcast_tor\   
atleast_2drb   	ones_likerR   randomsplitr   initr   r   _missing_keyssetr   r   )r7   r   r   r   ru   r]   r   
params_rngr   rngsrandom_paramsmissing_keyr)   r)   r*   init_weights  s   


z(FlaxCLIPTextPreTrainedModel.init_weightsFr   trainr   r   r   c
                 C   s   |d ur|n| j j}|d ur|n| j j}|	d ur|	n| j j}	|d u r2ttt|jd |j}|d u r;t	|}i }
|d urE||
d< | j
jd|pM| jitj|ddtj|ddtj|dd| |||	|
d	S )Nr   r   r   rJ   rK   r  )r<   r   r   r   r%   r   r\   r   rb   r   r   applyr   array)r7   ru   r   r]   r   r   r  r   r   r   r  r)   r)   r*   rn     s,    
z$FlaxCLIPTextPreTrainedModel.__call__r   NNNNFNNN)r!   r"   r#   r   config_classr   rS   Moduler'   r%   ro   intr=   r   r   rR   r   PRNGKeyr(   r   r  r   dictrn   __classcell__r)   r)   r   r*   r   t  sL   
  	
r   c                       s   e Zd ZU eZdZdZeje	d< dde
jdfdedee ded	e
jd
ef
 fddZddejjdededefddZ						ddee dejjdedee dee dee fddZ  ZS )FlaxCLIPVisionPreTrainedModelrf   Nr   r   Tr<   r   r   r=   r   c                    sJ   |d u rd|j |j df}| jd||d|}t j||||||d d S )Nr   r   r   r   r)   )rO   r   r   r   r   r   r)   r*   r     s   	z&FlaxCLIPVisionPreTrainedModel.__init__r   r   r1   c           
      C   s   t j||}t j|\}}||d}| j||d }|d urCtt|}tt|}| jD ]}	||	 ||	< q0t	 | _t
t|S |S )Nr   r   )rR   r   rU   r   r   r   r   r   r   r   r   r   )
r7   r   r   r   rf   r  r   r  r  r  r)   r)   r*   r    s   

z*FlaxCLIPVisionPreTrainedModel.init_weightsFr   r  r   r   r   c           	   	   C   s   |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}t|d}i }|d ur.||d< | jjd|p6| jitj	|tj
d| ||||dS )Nr   rG   r   r   r   r   rK   r  )r<   r   r   r   r%   	transposer   r  r   r	  ro   )	r7   rf   r   r   r  r   r   r   r  r)   r)   r*   rn     s"   
z&FlaxCLIPVisionPreTrainedModel.__call__r   )NNFNNN)r!   r"   r#   r   r  main_input_namer   rS   r  r'   r%   ro   r   r(   r  r=   r   r   rR   r   r  r   r  r  rn   r  r)   r)   r   r*   r    sN   
  r  c                       s  e Zd ZU eZdZejed< dde	j
dfdedee dede	jd	ef
 fd
dZddejjdededefddZ								ddee dejjdedee dee dee fddZ					ddee dejjfddZ	d dee dejjfddZ  ZS )!FlaxCLIPPreTrainedModelNr   r   Tr<   r   r   r=   r   c                    sR   |d u rdd|j j|j jdff}| jd||d|}t j||||||d d S )Nr   r   r   r   r   r)   )vision_configrO   r   r   r   r   r   r)   r*   r     s   	z FlaxCLIPPreTrainedModel.__init__r   r   r1   c                 C   s   t j|d dd}t t t |jd |d }t |}tj	||d }tj
|\}}	||	d}
| j|
||||d }|d urgtt|}tt|}| jD ]}|| ||< qTt | _tt|S |S )Nr   rJ   rK   r   r   r   r   )r%   r   r   r\   r   rb   r   rR   r   rU   r   r   r   r   r   r   r   r   r   )r7   r   r   r   ru   r]   r   rf   r  r   r  r  r  r)   r)   r*   r    s   "


z$FlaxCLIPPreTrainedModel.init_weightsFr   r  r   r   r   c                 C   s   |d ur|n| j j}|	d ur|	n| j j}	|
d ur|
n| j j}
|d u r2ttt|jd |j}|d u r;t	|}t
|d}i }|d urK||d< | jjd|pS| jitj|ddtj|tjdtj|ddtj|dd| ||	|
|d
S )Nr   r  r   r   rJ   rK   r  )r<   r   r   r   r%   r   r\   r   rb   r   r  r   r  r   r	  ro   )r7   ru   rf   r   r]   r   r   r  r   r   r   r  r)   r)   r*   rn   4  s0    
z FlaxCLIPPreTrainedModel.__call__c           	   	   C   s   |du rt t t |jd |j}|du rt |}i }|dur'||d< dd }| jjd|p3| jit j	|ddt j	|ddt j	|dd| ||d	S )
at  
        Args:
            input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)

        Returns:
            text_features (`jnp.ndarray` of shape `(batch_size, output_dim`): The text embeddings obtained by applying
            the projection layer to the pooled output of [`FlaxCLIPTextModel`].

        Examples:

        ```python
        >>> from transformers import AutoTokenizer, FlaxCLIPModel

        >>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np")
        >>> text_features = model.get_text_features(**inputs)
        ```Nr   r   c                 S   s(   | j ||||d}|d }| |}|S )N)ru   r   r]   r   r   )
text_modeltext_projection)r   ru   r   r]   r   text_outputsr   text_featuresr)   r)   r*   _get_features  s   
z@FlaxCLIPPreTrainedModel.get_text_features.<locals>._get_featuresr   rJ   rK   methodr  )
r%   r   r\   r   rb   r   r   r  r   r	  )	r7   ru   r   r]   r   r   r  r  r  r)   r)   r*   get_text_featuresa  s"   # 
z)FlaxCLIPPreTrainedModel.get_text_featuresc                 C   sV   t |d}i }|dur||d< dd }| jjd|p| jit j|t jd| ||dS )	a  
        Args:
            pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
                Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained
                using [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.

        Returns:
            image_features (`jnp.ndarray` of shape `(batch_size, output_dim`): The image embeddings obtained by
            applying the projection layer to the pooled output of [`FlaxCLIPVisionModel`]

        Examples:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, FlaxCLIPModel

        >>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="np")

        >>> image_features = model.get_image_features(**inputs)
        ```r  Nr   c                 S   s$   | j ||d}|d }| |}|S )N)rf   r   r   )vision_modelvisual_projection)r   rf   r   vision_outputsr   image_featuresr)   r)   r*   r    s   
zAFlaxCLIPPreTrainedModel.get_image_features.<locals>._get_featuresr   rK   r  )r%   r  r   r  r   r	  ro   )r7   rf   r   r   r  r  r  r)   r)   r*   get_image_features  s   z*FlaxCLIPPreTrainedModel.get_image_featuresr   r
  )NNNNF)NNF)r!   r"   r#   r   r  r   rS   r  r'   r%   ro   r   r(   r  r=   r   r   rR   r   r  r   r  r  rn   r  r#  r  r)   r)   r   r*   r    sp   
  	

0
Dr  c                	   @   r   )FlaxCLIPTextModuler<   r=   c                 C   r   r   )r   r<   r=   r  r6   r)   r)   r*   r`     r   zFlaxCLIPTextModule.setupTFr   r   r   r   c              	   C   s   | j |||||||dS )Nru   r   r]   r   r   r   r   )r  )r7   ru   r   r]   r   r   r   r   r)   r)   r*   rn     s   
zFlaxCLIPTextModule.__call__Nr   r   r)   r)   r)   r*   r$    s"   
 r$  c                   @      e Zd ZeZdS )FlaxCLIPTextModelN)r!   r"   r#   r$  r   r)   r)   r)   r*   r'        r'  a'  
    Returns:

    Example:

    ```python
    >>> from transformers import AutoTokenizer, FlaxCLIPTextModel

    >>> model = FlaxCLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
    >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

    >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np")

    >>> outputs = model(**inputs)
    >>> last_hidden_state = outputs.last_hidden_state
    >>> pooler_output = outputs.pooler_output  # pooled (EOS token) states
    ```
)output_typer  c                	   @   r   )%FlaxCLIPTextModelWithProjectionModuler<   r=   c                 C   s.   t | j| jd| _tj| jjd| jd| _d S )NrK   F)rE   r=   )r   r<   r=   r  rS   r   projection_dimr  r6   r)   r)   r*   r`     s   z+FlaxCLIPTextModelWithProjectionModule.setupTFr   r   r   r   c              	   C   s\   | j |||||||d}|d }	| |	}
|s#|
|d f|dd   S t|
|j|j|jdS )Nr%  r   r   rG   )r   r   r   r    )r  r  r   r   r   r    )r7   ru   r   r]   r   r   r   r   r  r   r   r)   r)   r*   rn     s&   


z.FlaxCLIPTextModelWithProjectionModule.__call__Nr   r   r)   r)   r)   r*   r*    s"   
 	r*  c                   @   r&  )FlaxCLIPTextModelWithProjectionN)r!   r"   r#   r*  r   r)   r)   r)   r*   r,  ;  r(  r,  a  
    Returns:

    Example:

    ```python
    >>> from transformers import AutoTokenizer, FlaxCLIPTextModelWithProjection

    >>> model = FlaxCLIPTextModelWithProjection.from_pretrained("openai/clip-vit-base-patch32")
    >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")

    >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np")

    >>> outputs = model(**inputs)
    >>> text_embeds = outputs.text_embeds
    ```
c                	   @   r   )FlaxCLIPVisionModuler<   r=   c                 C   r   r   )r   r<   r=   r  r6   r)   r)   r*   r`   ]  r   zFlaxCLIPVisionModule.setupTFr   r   r   r   c                 C   s   | j |||||dS )Nrf   r   r   r   r   )r  )r7   rf   r   r   r   r   r)   r)   r*   rn   `  s   zFlaxCLIPVisionModule.__call__Nr   r   r)   r)   r)   r*   r-  Y  s"   
 r-  c                   @   r&  )FlaxCLIPVisionModelN)r!   r"   r#   r-  r   r)   r)   r)   r*   r/  q  r(  r/  a  
    Returns:

    Example:

    ```python
    >>> from PIL import Image
    >>> import requests
    >>> from transformers import AutoProcessor, FlaxCLIPVisionModel

    >>> model = FlaxCLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
    >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> inputs = processor(images=image, return_tensors="np")

    >>> outputs = model(**inputs)
    >>> last_hidden_state = outputs.last_hidden_state
    >>> pooler_output = outputs.pooler_output  # pooled CLS states
    ```
c                   @   sN   e Zd ZU eed< ejZejed< dd Z								d
de	fdd	Z
dS )FlaxCLIPModuler<   r=   c                    s    j j} j j} j j _|j _|j _t| jd _	t
| jd _tj j jtjjddd _tj j jtjjddd _ d fddg  _d S )NrK   r?   F)r=   rF   rE   logit_scalec                    s   t | jj S r   )r%   r   r<   logit_scale_init_value)_rb   r6   r)   r*   <lambda>  s    z&FlaxCLIPModule.setup.<locals>.<lambda>)r<   text_configr  r+  rN   text_embed_dimvision_embed_dimr   r=   r  r   r  rS   r   rR   rT   rU   r   r  rQ   r1  )r7   r5  r  r)   r6   r*   r`     s,   

zFlaxCLIPModule.setupNTr   c	              	   C   s   |d ur|n| j j}| j|||||d}	| j|||||||d}
|	d }| |}|
d }| |}|tjj|ddd }|tjj|ddd }t	| j
}t||j| }|j}|sd|||||
|	fS t|||||
|	dS )Nr.  r%  r   r   T)rM   keepdims)r,   r-   r   r.   r/   r0   )r<   r   r  r  r   r  r%   linalgnormexpr1  matmulTr+   )r7   ru   rf   r   r]   r   r   r   r   r!  r  r.   r   r1  r-   r,   r)   r)   r*   rn     sH   


zFlaxCLIPModule.__call__)NNNNTNNN)r!   r"   r#   r   r'   r%   ro   r=   r`   r   rn   r)   r)   r)   r*   r0    s   
 r0  c                   @   r&  )FlaxCLIPModelN)r!   r"   r#   r0  r   r)   r)   r)   r*   r>    s    r>  ai  
    Returns:

    Example:

    ```python
    >>> import jax
    >>> from PIL import Image
    >>> import requests
    >>> from transformers import AutoProcessor, FlaxCLIPModel

    >>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> inputs = processor(
    ...     text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="np", padding=True
    ... )

    >>> outputs = model(**inputs)
    >>> logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
    >>> probs = jax.nn.softmax(logits_per_image, axis=1)  # we can take the softmax to get the label probabilities
    ```
)r>  r  r'  r   r,  r/  r  )Mtypingr   r   r   flax
flax.linenlinenrS   rR   	jax.numpynumpyr%   flax.core.frozen_dictr   r   r   r   r	   flax.linen.attentionr
   flax.traverse_utilr   r   r   modeling_flax_outputsr   r   modeling_flax_utilsr   r   r   r   utilsr   r   r   configuration_clipr   r   r   
get_loggerr!   loggerCLIP_START_DOCSTRINGCLIP_TEXT_INPUTS_DOCSTRINGCLIP_VISION_INPUTS_DOCSTRINGCLIP_INPUTS_DOCSTRINGstruct	dataclassr   r+   r  r;   rp   rx   r   r   r   r   r   r   r   r  r  r$  r'  FLAX_CLIP_TEXT_MODEL_DOCSTRINGr*  r,  .FLAX_CLIP_TEXT_MODEL_WITH_PROJECTION_DOCSTRINGr-  r/   FLAX_CLIP_VISION_MODEL_DOCSTRINGr0  r>  FLAX_CLIP_MODEL_DOCSTRING__all__r)   r)   r)   r*   <module>   s   
# $#&[*/B2OH N*[