o
    iJy                     @   sR  d Z ddlZddlZddlmZ ddlmZ ddl	Z	ddl
mZ ddlmZmZmZ ddlmZ ddlmZmZ ddlmZmZmZ dd	lmZmZmZmZ dd
lm Z m!Z! ddl"m#Z# dZ$dZ%G dd dej&Z'G dd dej&Z(G dd dej&Z)G dd dej&Z*G dd dej&Z+ej,fddZ-G dd dej&Z.G dd dej&Z/G dd  d ej&Z0G d!d" d"ej&Z1G d#d$ d$ej&Z2G d%d& d&ej&Z3G d'd( d(ej&Z4G d)d* d*eZ5G d+d, d,ej&Z6e d-e$G d.d/ d/e5Z7d0Z8ee7e8 ee7ee#d1 G d2d3 d3ej&Z9e d4e$G d5d6 d6e5Z:d7Z;ee:e; ee:ee#d1 g d8Z<dS )9zFlax DINOv2 model.    N)Optional)
FrozenDictfreezeunfreeze)dot_product_attention_weights)flatten_dictunflatten_dict   )FlaxBaseModelOutputFlaxBaseModelOutputWithPoolingFlaxSequenceClassifierOutput)ACT2FNFlaxPreTrainedModel append_replace_return_docstringsoverwrite_call_docstring)add_start_docstrings%add_start_docstrings_to_model_forward   )Dinov2Configa  

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading, saving and converting weights from PyTorch models)

    This model is also a
    [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
    a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
    behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

    Parameters:
        config ([`Dinov2Config`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
            `jax.numpy.bfloat16` (on TPUs).

            This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
            specified all the computation will be performed with the given `dtype`.

            **Note that this only specifies the dtype of the computation and does not influence the dtype of model
            parameters.**

            If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
            [`~FlaxPreTrainedModel.to_bf16`].
a
  
    Args:
        pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`Dinov2ImageProcessor.__call__`]
            for details.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
c                   @   6   e Zd ZU eed< ejZejed< dd Zdd Z	dS )FlaxDinov2PatchEmbeddingsconfigdtypec                 C   s   | j j}| j j}t|tjjr|n||f}t|tjjr|n||f}|d |d  |d |d   }|| _| j j| _t	j
| j j||d| jtj	j| j jd ddd| _d S )Nr   r   VALID   fan_intruncated_normal)kernel_sizestridespaddingr   kernel_init)r   
image_size
patch_size
isinstancecollectionsabcIterablenum_patchesnum_channelsnnConvhidden_sizer   jaxinitializersvariance_scalinginitializer_range
projection)selfr!   r"   r'    r2   c/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/dinov2/modeling_flax_dinov2.pysetup_   s"    
zFlaxDinov2PatchEmbeddings.setupc                 C   sF   |j d }|| jkrtd| |}|j \}}}}t||d|fS )NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.)shaper(   
ValueErrorr0   jnpreshape)r1   pixel_valuesr(   
embeddings
batch_size_channelsr2   r2   r3   __call__t   s   


z"FlaxDinov2PatchEmbeddings.__call__N
__name__
__module____qualname__r   __annotations__r8   float32r   r4   r?   r2   r2   r2   r3   r   [   s
   
 r   c                   @   sD   e Zd ZU dZeed< ejZejed< dd Z	dd Z
dd	d
ZdS )FlaxDinov2Embeddingsz7Construct the CLS token, position and patch embeddings.r   r   c                 C   s   |  dtjj| jjd dddd| jjf| _| jj	r3|  dtjj| jjd ddd| jjf| _
t| j| jd| _| jj}|  dtjj| jjd ddd|d | jjf| _tj| jjd	| _d S )
N	cls_tokenr   r   r   r   
mask_tokenr   position_embeddingsrate)paramr,   r)   r-   r.   r   r/   r+   rG   use_mask_tokenrH   r   r   patch_embeddingsr'   rJ   Dropouthidden_dropout_probdropout)r1   r'   r2   r2   r3   r4      s&   
zFlaxDinov2Embeddings.setupc              	   C   s  |j d d }|j d d }||kr||kr|S |d d df }|d d dd f }	|j d }
||j }||j }|d |d }}|	dtt|tt||
f}	t|	d}	|	j}t	|t| }t	|t| }tj
||gtj	d}tj
ddgtj	d}tjj|	tj	|	j d |	j d ||fd||d	d
d}	|	|}	t|	d|j d d|
f}	t|	|j d ddf}t||j d ddf}tj||fddS )Nr   r   r5   g?)r   r	   r   r   rI           )r   r	   bicubicF)r6   spatial_dimsscaletranslationmethod	antialiasr   r   r	   r   axis)r6   r"   r9   intmathsqrtr8   	transposer   rE   arrayr,   imagescale_and_translateastypetileconcatenate)r1   r   hidden_statesheightwidthrJ   r'   num_positionsclass_pos_embedpatch_pos_embeddimhwtarget_dtypenew_height_rationew_width_ratiorV   rW   patch_pos_embed_expandedclass_pos_embed_expandedr2   r2   r3   interpolate_pos_encoding   sB   




	z-FlaxDinov2Embeddings.interpolate_pos_encodingTc           	      C   s   |j d }| jjj}|j d |j d }}| ||}t| j|d| jj	f}tj
||fdd}|| | j|||| j }| j||d}|S )Nr   r   r   r[   deterministic)r6   rO   r0   r   rd   r8   broadcast_torG   r   r+   rf   ru   rJ   rR   )	r1   r:   rw   r<   rp   rh   ri   r;   
cls_tokensr2   r2   r3   r?      s   

zFlaxDinov2Embeddings.__call__NT)rA   rB   rC   __doc__r   rD   r8   rE   r   r4   ru   r?   r2   r2   r2   r3   rF      s   
 (rF   c                   @   B   e Zd ZU eed< ejZejed< dd Zdde	de	fd	d
Z
dS )FlaxDinov2SelfAttentionr   r   c                 C   s   | j j| j j dkrtdtj| j j| jtjjj	| j j
d ddd| j jd| _tj| j j| jtjjj	| j j
d ddd| j jd| _tj| j j| jtjjj	| j j
d ddd| j jd| _d S )Nr   z`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`: {self.config.num_attention_heads}r   r   r   )modedistribution)r   r    use_bias)r   r+   num_attention_headsr7   r)   Denser   r,   r-   r.   r/   qkv_biasquerykeyvaluer1   r2   r2   r3   r4      s8   zFlaxDinov2SelfAttention.setupTFrw   output_attentionsc              
   C   s   | j j| j j }| ||jd d | j j|f }| ||jd d | j j|f }| ||jd d | j j|f }d }|sP| j jdkrP| 	d}t
|||| j jd|| jd d}	td|	|}
|
|
jd d d }
|rz|
|	f}|S |
f}|S )Nr   rS   rR   T)dropout_rngdropout_ratebroadcast_dropoutrw   r   	precisionz...hqk,...khd->...qhd)r5   )r   r+   r   r   r9   r6   r   r   attention_probs_dropout_probmake_rngr   r   r8   einsum)r1   rg   rw   r   head_dimquery_statesvalue_states
key_statesr   attn_weightsattn_outputoutputsr2   r2   r3   r?      s:   



z FlaxDinov2SelfAttention.__call__NTFrA   rB   rC   r   rD   r8   rE   r   r4   boolr?   r2   r2   r2   r3   r}      s
   
  r}   c                   @   s>   e Zd ZU eed< ejZejed< dd Zd
de	fddZ
d	S )FlaxDinov2SelfOutputr   r   c                 C   sD   t j| jjtj j| jjd dd| jd| _	t j
| jjd| _d S )Nr   r   r   r    r   rK   )r)   r   r   r+   r,   r-   r.   r/   r   denserP   rQ   rR   r   r2   r2   r3   r4   !  s   zFlaxDinov2SelfOutput.setupTrw   c                 C   s   |  |}| j||d}|S )Nrv   )r   rR   )r1   rg   input_tensorrw   r2   r2   r3   r?   +  s   
zFlaxDinov2SelfOutput.__call__Nrz   r   r2   r2   r2   r3   r     s
   
 
r   c                   @   s>   e Zd ZU eed< ejZejed< dd Zdde	fdd	Z
d
S )FlaxDinov2Attentionr   r   c                 C   s(   t | j| jd| _t| j| jd| _d S NrI   )r}   r   r   	attentionr   outputr   r2   r2   r3   r4   6  s   zFlaxDinov2Attention.setupTFr   c                 C   sD   | j |||d}|d }| j|||d}|f}|r ||d f7 }|S )Nrw   r   r   rv   r   )r   r   )r1   rg   rw   r   attn_outputsr   r   r2   r2   r3   r?   :  s   zFlaxDinov2Attention.__call__Nr   r   r2   r2   r2   r3   r   2  s
   
 r   c                 C   s   t ||| S N)r8   ones)r   r6   rV   r   r2   r2   r3   ones_with_scaleG  s   r   c                   @   r   )FlaxDinov2LayerScaler   r   c                 C   s8   | j j| dtjjj| j jf | _| j| j j | _d S )Nlambda1)	r   layerscale_valuerM   r,   r)   r-   r   r+   r   r   r2   r2   r3   r4   O  s   
zFlaxDinov2LayerScale.setupc                 C   s
   | j | S r   )r   r1   rg   r2   r2   r3   r?   W  s   
zFlaxDinov2LayerScale.__call__Nr@   r2   r2   r2   r3   r   K  s
   
 r   c                   @   s6   e Zd ZU dZeed< ejjdde	e
 fddZdS )	FlaxDinov2DropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).rL   Trw   c           	      C   sv   | j dkr|S d| j  }|r|S |jd fd|jd   }| d}|tjj|||jd }t	|}|| | }|S )NrS   g      ?r   )r   r   droppath)r6   r   )
rL   r6   ndimr   r,   randomuniformr   r8   floor)	r1   inputsrw   	keep_probr6   rngrandom_tensorbinary_tensorr   r2   r2   r3   r?   a  s   



zFlaxDinov2DropPath.__call__Nrz   )rA   rB   rC   r{   floatrD   r)   modulecompactr   r   r?   r2   r2   r2   r3   r   \  s
   
 r   c                   @   r   )FlaxDinov2MLPr   r   c                 C   s   t j| jj| jj tj j| jjd dd| j	d| _
t j| jjtj j| jjd dd| j	d| _t| jjtrBt| jj | _d S | jj| _d S )Nr   r   r   r   )r)   r   r   r+   	mlp_ratior,   r-   r.   r/   r   fc1fc2r#   
hidden_actstrr   actr   r2   r2   r3   r4   u  s"   zFlaxDinov2MLP.setupc                 C   s"   |  |}| |}| |}|S r   )r   r   r   r   r2   r2   r3   r?     s   


zFlaxDinov2MLP.__call__Nr@   r2   r2   r2   r3   r   q  s
   
 r   c                   @   r   )FlaxDinov2SwiGLUFFNr   r   c                 C   s   t | jj| jj }t |d d d d d }tjd| tjj| jj	d dd| j
d| _tj| jjtjj| jj	d dd| j
d| _d S )Nr   r	         r   r   r   )r]   r   r+   r   r)   r   r,   r-   r.   r/   r   
weights_inweights_out)r1   hidden_featuresr2   r2   r3   r4     s    zFlaxDinov2SwiGLUFFN.setupc                 C   s6   |  |}tj|ddd\}}t|| }| |S )Nr   r5   r[   )r   r8   splitr)   silur   )r1   rg   x1x2hiddenr2   r2   r3   r?     s   

zFlaxDinov2SwiGLUFFN.__call__Nr@   r2   r2   r2   r3   r     s
   
 r   c                   @   r|   )FlaxDinov2Layerr   r   c                 C   s   t j| jj| jd| _t| j| jd| _t| j| jd| _	t
| jj| _t j| jj| jd| _| jjr=t| j| jd| _n	t| j| jd| _t| j| jd| _d S )Nepsilonr   rI   )r)   	LayerNormr   layer_norm_epsr   norm1r   r   r   layer_scale1r   drop_path_rate	drop_pathnorm2use_swiglu_ffnr   mlpr   layer_scale2r   r2   r2   r3   r4     s   zFlaxDinov2Layer.setupTFrw   r   c                 C   s|   | j | |||d}|d }| |}|dd  }| || }| |}| |}| |}| || }|f| }|S )Nr   r   r   )r   r   r   r   r   r   r   )r1   rg   rw   r   self_attention_outputsattention_outputr   layer_outputr2   r2   r3   r?     s   




zFlaxDinov2Layer.__call__Nr   r   r2   r2   r2   r3   r     s
   
 r   c                	   @   R   e Zd ZU eed< ejZejed< dd Z				dde	de	d	e	d
e	fddZ
dS )FlaxDinov2LayerCollectionr   r   c                    s     fddt  jjD  _d S )Nc                    s"   g | ]}t  jt| jd qS ))namer   )r   r   r   r   ).0ir   r2   r3   
<listcomp>  s    z3FlaxDinov2LayerCollection.setup.<locals>.<listcomp>)ranger   num_hidden_layerslayersr   r2   r   r3   r4     s   

zFlaxDinov2LayerCollection.setupTFrw   r   output_hidden_statesreturn_dictc                 C   s   |rdnd }|r
dnd }t | jD ]\}}	|r||f7 }|	|||d}
|
d }|r0||
d f7 }q|r8||f7 }|f}|sFtdd |D S t|||dS )Nr2   r   r   r   c                 s   s    | ]	}|d ur|V  qd S r   r2   )r   vr2   r2   r3   	<genexpr>  s    z5FlaxDinov2LayerCollection.__call__.<locals>.<genexpr>)last_hidden_staterg   
attentions)	enumerater   tupler
   )r1   rg   rw   r   r   r   all_attentionsall_hidden_statesr   layerlayer_outputsr   r2   r2   r3   r?     s$   

z"FlaxDinov2LayerCollection.__call__NTFFTr   r2   r2   r2   r3   r     "   
 r   c                	   @   r   )FlaxDinov2Encoderr   r   c                 C   s   t | j| jd| _d S r   )r   r   r   r   r   r2   r2   r3   r4     s   zFlaxDinov2Encoder.setupTFrw   r   r   r   c                 C   s   | j |||||dS )Nrw   r   r   r   )r   )r1   rg   rw   r   r   r   r2   r2   r3   r?     s   zFlaxDinov2Encoder.__call__Nr   r   r2   r2   r2   r3   r   
  s"   
 r   c                       s   e Zd ZU dZeZdZdZdZe	j
ed< ddejdfded	ed
ejdef fddZddejjdededefddZeed						ddee dejjdedee dee dee fddZ  ZS )FlaxDinov2PreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    dinov2r:   Nmodule_classr   Tr   seedr   _do_initc                    sL   | j d||d|}|d u rd|j|j|jf}t j||||||d d S )Nr   r   r   )input_shaper   r   r   r2   )r   r!   r(   super__init__)r1   r   r   r   r   r   kwargsr   	__class__r2   r3   r   -  s   	z"FlaxDinov2PreTrainedModel.__init__r   r   paramsreturnc                 C   s   t j|| jd}tj|\}}tj|\}}|||d}| jj||ddd }	|d urOtt	|	}	tt	|}| j
D ]}
|	|
 ||
< q<t | _
tt|S |	S )NrI   )r   rR   r   F)r   r   )r8   zerosr   r,   r   r   r   initr   r   _missing_keyssetr   r   )r1   r   r   r   r:   
params_rngr   droppath_rngrngsrandom_paramsmissing_keyr2   r2   r3   init_weights;  s   
z&FlaxDinov2PreTrainedModel.init_weightszbatch_size, sequence_lengthFr   trainr   r   r   c           
   	   C   s   |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}t|d}i }|d ur:tj|\}}	||d< |	|d< | j	j
d|pB| jitj|tjd| ||||dS )NrZ   rR   r   r   rI   )r  )r   r   r   r   r8   r`   r,   r   r   r   applyr   ra   rE   )
r1   r:   r   r   r  r   r   r   r  r  r2   r2   r3   r?   O  s&   z"FlaxDinov2PreTrainedModel.__call__r   )NNFNNN)rA   rB   rC   r{   r   config_classbase_model_prefixmain_input_namer   r)   ModulerD   r8   rE   r]   r   r   r   r,   r   PRNGKeyr   r   r  r   DINOV2_INPUTS_DOCSTRINGformatr   dictr?   __classcell__r2   r2   r   r3   r   "  sP   
  r   c                	   @   r   )FlaxDinov2Moduler   r   c                 C   s>   t | j| jd| _t| j| jd| _tj| jj| jd| _	d S )NrI   r   )
rF   r   r   r;   r   encoderr)   r   r   	layernormr   r2   r2   r3   r4   w  s   zFlaxDinov2Module.setupTFrw   r   r   r   c                 C   sz   | j ||d}| j|||||d}|d }| |}|d d dd d f }	|s3||	f}
|
|dd   S t||	|j|jdS )Nrv   r   r   r   )r   pooler_outputrg   r   )r;   r  r  r   rg   r   )r1   r:   rw   r   r   r   rg   encoder_outputssequence_outputpooled_outputhead_outputsr2   r2   r3   r?   |  s(   
zFlaxDinov2Module.__call__Nr   r   r2   r2   r2   r3   r  s  r   r  z`The bare Dinov2 Model transformer outputting raw hidden-states without any specific head on top.c                   @      e Zd ZeZdS )FlaxDinov2ModelN)rA   rB   rC   r  r   r2   r2   r2   r3   r     s    r   ar  
    Returns:

    Examples:

    ```python
    >>> from transformers import AutoImageProcessor, FlaxDinov2Model
    >>> from PIL import Image
    >>> import requests

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
    >>> model = FlaxDinov2Model.from_pretrained("facebook/dinov2-base")

    >>> inputs = image_processor(images=image, return_tensors="np")
    >>> outputs = model(**inputs)
    >>> last_hidden_states = outputs.last_hidden_state
    ```
)output_typer  c                   @   sH   e Zd ZU eed< ejZejed< dd Z					d
de	fdd	Z
dS )&FlaxDinov2ForImageClassificationModuler   r   c                 C   sD   t | j| jd| _tj| jj| jtjj	| jj
d ddd| _d S )Nr   r   r   r   )r   r    )r  r   r   r   r)   r   
num_labelsr,   r-   r.   r/   
classifierr   r2   r2   r3   r4     s   z,FlaxDinov2ForImageClassificationModule.setupNTrw   c                 C   s   |d ur|n| j j}| j|||||d}|d }|d d df }|d d dd f }	tj||	jddgdd}
| |
}|sI|f|dd   }|S t||j|j	dS )Nr   r   r   r[   r5   r   )logitsrg   r   )
r   use_return_dictr   r8   rf   meanr$  r   rg   r   )r1   r:   rw   r   r   r   r   rg   rG   patch_tokenslinear_inputr%  r   r2   r2   r3   r?     s*   
z/FlaxDinov2ForImageClassificationModule.__call__)NTNNNr   r2   r2   r2   r3   r"    s   
 r"  z
    Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
    the [CLS] token) e.g. for ImageNet.
    c                   @   r  ) FlaxDinov2ForImageClassificationN)rA   rB   rC   r"  r   r2   r2   r2   r3   r*    s    r*  a  
    Returns:

    Example:

    ```python
    >>> from transformers import AutoImageProcessor, FlaxDinov2ForImageClassification
    >>> from PIL import Image
    >>> import jax
    >>> import requests

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> image_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer")
    >>> model = FlaxDinov2ForImageClassification.from_pretrained("facebook/dinov2-base-imagenet1k-1-layer", from_pt=True)

    >>> inputs = image_processor(images=image, return_tensors="np")
    >>> outputs = model(**inputs)
    >>> logits = outputs.logits

    >>> # model predicts one of the 1000 ImageNet classes
    >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
    >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()])
    ```
)r*  r   r   )=r{   collections.abcr$   r^   typingr   
flax.linenlinenr)   r,   	jax.numpynumpyr8   flax.core.frozen_dictr   r   r   flax.linen.attentionr   flax.traverse_utilr   r   modeling_flax_outputsr
   r   r   modeling_flax_utilsr   r   r   r   utilsr   r   configuration_dinov2r   DINOV2_START_DOCSTRINGr  r  r   rF   r}   r   r   rE   r   r   r   r   r   r   r   r   r   r  r   FLAX_VISION_MODEL_DOCSTRINGr"  r*  $FLAX_VISION_CLASSIFICATION_DOCSTRING__all__r2   r2   r2   r3   <module>   sj   #$VH0,Q+
3
