o
    i                     @   s  d dl mZmZ d dlZd dlmZ d dlZd dlm	Z
 d dl	Zd dlmZmZmZ d dlmZ d dlmZmZ ddlmZmZmZmZ ddlmZmZmZmZ dd	lm Z m!Z! d
dl"m#Z# ej$j%G dd deZ&dZ'dZ(de)e*e*f de
j+fddZ,e
j-fddZ.G dd dej/Z0G dd dej/Z1G dd dej/Z2G dd dej/Z3G dd dej/Z4G d d! d!ej/Z5G d"d# d#ej/Z6G d$d% d%ej/Z7G d&d' d'ej/Z8G d(d) d)ej/Z9G d*d+ d+ej/Z:G d,d- d-ej/Z;G d.d/ d/eZ<G d0d1 d1ej/Z=G d2d3 d3ej/Z>e d4e'G d5d6 d6e<Z?d7Z@ee?e@ ee?e&e#d8 G d9d: d:ej/ZAe d;e'G d<d= d=e<ZBd>ZCeeBeC eeBee#d8 G d?d@ d@ej/ZDe dAe'G dBdC dCe<ZEdDZFeeEeF eeEee#d8 g dEZGdS )F    )CallableOptionalN)
FrozenDictfreezeunfreeze)dot_product_attention_weights)flatten_dictunflatten_dict   )FlaxBaseModelOutputFlaxBaseModelOutputWithPoolingFlaxMaskedLMOutputFlaxSequenceClassifierOutput)ACT2FNFlaxPreTrainedModel append_replace_return_docstringsoverwrite_call_docstring)add_start_docstrings%add_start_docstrings_to_model_forward   )
BeitConfigc                   @   s   e Zd ZdZdS )FlaxBeitModelOutputWithPoolinga  
    Class for outputs of [`FlaxBeitModel`].

    Args:
        last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
            Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
            *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
            will be returned.
        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
            the initial embedding outputs.
        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    N)__name__
__module____qualname____doc__ r   r   _/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/beit/modeling_flax_beit.pyr   ,   s    r   a  

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading, saving and converting weights from PyTorch models)

    This model is also a
    [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
    a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
    behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

    Parameters:
        config ([`BeitConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
            `jax.numpy.bfloat16` (on TPUs).

            This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
            specified all the computation will be performed with the given `dtype`.

            **Note that this only specifies the dtype of the computation and does not influence the dtype of model
            parameters.**

            If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
            [`~FlaxPreTrainedModel.to_bf16`].
a  
    Args:
        pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
            [`AutoImageProcessor.__call__`] for details.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
window_sizereturnc                 C   s  d| d  d d| d  d  d }t | d }t | d }t t j||dd}t |d}|dddddf |dddddf  }t |d	}|dddddf  | d d 7  < |dddddf  | d d 7  < |dddddf  d| d  d 9  < t j| d | d  d fd |jd
}|d|ddddf< |d |dddf< |d |dddf< |d |d< t	
|S )zP
    get pair-wise relative position index for each token inside the window
       r   r   r
   ij)indexing)r    N)r   r    r   shapedtyper#   )r   r   )nparangestackmeshgridreshape	transposezerosr&   sumjnparray)r   num_relative_distancecoords_hcoords_wcoordscoords_flattenrelative_coordsrelative_position_indexr   r   r   relative_position_index_initw   s    $,&&*&
r8   c                 C   s   t ||| S N)r/   ones)keyr%   scaler&   r   r   r   ones_with_scale   s   r=   c                   @   s6   e Zd ZU dZeed< ejjdde	e
 fddZdS )	FlaxBeitDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).rateTdeterministicc           	      C   sv   | j dkr|S d| j  }|r|S |jd fd|jd   }| d}|tjj|||jd }t	|}|| | }|S )N        g      ?r   )r   r   droppathr$   )
r?   r%   ndimmake_rngjaxrandomuniformr&   r/   floor)	selfinputsr@   	keep_probr%   rngrandom_tensorbinary_tensoroutputr   r   r   __call__   s   



zFlaxBeitDropPath.__call__NT)r   r   r   r   float__annotations__nnmodulecompactr   boolrP   r   r   r   r   r>      s
   
 r>   c                   @   6   e Zd ZU eed< ejZejed< dd Zdd Z	dS )FlaxBeitPatchEmbeddingsconfigr&   c              	   C   s~   | j j| _| j j}| j j}|| ||  }|| || f}|| _|| _tj| j j||f||fd| j	t
jj| j jd| _d S )NVALID)kernel_sizestridespaddingr&   kernel_init)rZ   num_channels
image_size
patch_sizenum_patchespatch_shaperT   Convhidden_sizer&   rE   initializersnormalinitializer_range
projection)rI   ra   rb   rc   rd   r   r   r   setup   s   
zFlaxBeitPatchEmbeddings.setupc                 C   sF   |j d }|| jkrtd| |}|j \}}}}t||d|fS )Nr#   zeMake sure that the channel dimension of the pixel values match with the one set in the configuration.)r%   r`   
ValueErrorrj   r/   r+   )rI   pixel_valuesr`   
embeddings
batch_size_channelsr   r   r   rP      s   


z FlaxBeitPatchEmbeddings.__call__N
r   r   r   r   rS   r/   float32r&   rk   rP   r   r   r   r   rY      s
   
 rY   c                   @   s<   e Zd ZU dZeed< ejZejed< dd Z	d
dd	Z
dS )FlaxBeitEmbeddingsz7Construct the CLS token, position and patch embeddings.rZ   r&   c                 C   s   |  dtjjdd| jjf| _| jjr"|  dtjjdd| jjf| _t	| j| j
d| _| jj}| jjrD|  dtjjd|d | jjf| _tj| jjd| _d S )N	cls_tokenr   
mask_tokenr&   position_embeddingsr?   )paramrT   rg   r-   rZ   rf   ru   use_mask_tokenrv   rY   r&   patch_embeddingsrc    use_absolute_position_embeddingsrx   Dropouthidden_dropout_probdropout)rI   rc   r   r   r   rk      s   zFlaxBeitEmbeddings.setupNTc                 C   s   |  |}|j\}}}t| j|d| jjf}||j}|d urDt| j	||| jjf}	|	|j}	tj
|dd}
|d|
  |	|
  }tj||fdd}| jjrZ|| j|j }| j||d}|S )Nr   r#   axisr@   )r|   r%   r/   broadcast_toru   rZ   rf   astyper&   rv   expand_dimsconcatenater}   rx   r   )rI   rm   bool_masked_posr@   rn   ro   seq_lenrp   
cls_tokensmask_tokenswr   r   r   rP      s   
zFlaxBeitEmbeddings.__call__)NT)r   r   r   r   r   rS   r/   rs   r&   rk   rP   r   r   r   r   rt      s   
 rt   c                   @   sF   e Zd ZU eed< eeef ed< ejZ	ej	ed< dd Z
dd ZdS )	FlaxBeitRelativePositionBiasrZ   r   r&   c                 C   sT   d| j d  d d| j d  d  d }| dtjj|| jjf| _t| j | _	d S )Nr    r   r   r
   relative_position_bias_table)
r   rz   rT   rg   r-   rZ   num_attention_headsr   r8   r7   )rI   r1   r   r   r   rk      s   (
z"FlaxBeitRelativePositionBias.setupc                 C   sZ   | j d}| jd | jd  d | jd | jd  d df}| j| |}t|dS )Nr#   r   r   )r    r   r   )r7   r+   r   r   r/   r,   )rI   indexr%   relative_position_biasr   r   r   rP     s   2z%FlaxBeitRelativePositionBias.__call__N)r   r   r   r   rS   tupleintr/   rs   r&   rk   rP   r   r   r   r   r      s   
 r   c                   @   sT   e Zd ZU eed< eeef ed< ejZ	ej	ed< dd Z
	dd	ed
efddZdS )FlaxBeitSelfAttentionrZ   r   r&   c                 C   s   | j j| j j dkrt| j dstd| j j d| j j dtj| j j| jtjj	
| j jd| _tj| j j| jtjj	
| j jdd| _tj| j j| jtjj	
| j jd| _| jrit| j | j| jd	| _d S d | _d S )
Nr   embedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .)r&   r_   F)r&   r_   use_biasr   r&   )rZ   rf   r   hasattrrl   rT   Denser&   rE   rg   rh   ri   queryr;   valuer   r   r   rI   r   r   r   rk     s:   zFlaxBeitSelfAttention.setupNTFr@   output_attentionsc                 C   sN  | j j| j j }| ||jd d | j j|f }| ||jd d | j j|f }| ||jd d | j j|f }d }	|sP| j jdkrP| 	d}	t
jd| jd}
| jd urkt
|  d}
|
|j}
|d urw|
||
j }
t|||
|	| j jd|| jd d	}t
d||}||jd d d	 }|r||f}|S |f}|S )
Nr    rA   r   rw   r   T)biasdropout_rngdropout_ratebroadcast_dropoutr@   r&   	precisionz...hqk,...khd->...qhd)r#   )rZ   rf   r   r   r+   r%   r   r;   attention_probs_dropout_probrD   r/   r0   r&   r   r   r   r   einsum)rI   hidden_statesr   r@   r   head_dimquery_statesvalue_states
key_statesr   attention_biasattn_weightsattn_outputoutputsr   r   r   rP   -  sH   




zFlaxBeitSelfAttention.__call__NTFr   r   r   r   rS   r   r   r/   rs   r&   rk   rW   rP   r   r   r   r   r     s   
 !r   c                   @   >   e Zd ZU eed< ejZejed< dd Zd
de	fddZ
d	S )FlaxBeitSelfOutputrZ   r&   c                 C   <   t j| jjtj j| jj| jd| _	t j
| jjd| _d S Nr_   r&   ry   rT   r   rZ   rf   rE   rg   rh   ri   r&   denser~   r   r   r   r   r   r   rk   a     zFlaxBeitSelfOutput.setupTr@   c                 C      |  |}| j||d}|S Nr   r   r   rI   r   r@   r   r   r   rP   i  s   
zFlaxBeitSelfOutput.__call__NrQ   r   r   r   r   rS   r/   rs   r&   rk   rW   rP   r   r   r   r   r   ]  
   
 r   c                   @   sP   e Zd ZU eed< eeef ed< ejZ	ej	ed< dd Z
	dd	efd
dZdS )FlaxBeitAttentionrZ   r   r&   c                 C   s,   t | j| j| jd| _t| j| jd| _d S )Nrw   )r   rZ   r   r&   	attentionr   rO   r   r   r   r   rk   t  s   zFlaxBeitAttention.setupNTFr   c                 C   sD   | j ||||d}|d }| j||d}|f}|r ||d f7 }|S Nr@   r   r   r   r   )r   rO   )rI   r   r   r@   r   attn_outputsr   r   r   r   r   rP   x  s   zFlaxBeitAttention.__call__r   r   r   r   r   r   r   o  s   
 r   c                   @   rX   )FlaxBeitIntermediaterZ   r&   c                 C   s8   t j| jjtj j| jj| jd| _	t
| jj | _d S )Nr   )rT   r   rZ   intermediate_sizerE   rg   rh   ri   r&   r   r   
hidden_act
activationr   r   r   r   rk     s   zFlaxBeitIntermediate.setupc                 C   s   |  |}| |}|S r9   )r   r   )rI   r   r   r   r   rP     s   

zFlaxBeitIntermediate.__call__Nrr   r   r   r   r   r     s
   
 r   c                   @   r   )FlaxBeitOutputrZ   r&   c                 C   r   r   r   r   r   r   r   rk     r   zFlaxBeitOutput.setupTr@   c                 C   r   r   r   r   r   r   r   rP     s   
zFlaxBeitOutput.__call__NrQ   r   r   r   r   r   r     r   r   c                   @   s\   e Zd ZU eed< eeef ed< eed< ej	Z
ej
ed< dd Z		dd
edefddZdS )FlaxBeitLayerrZ   r   drop_path_rater&   c                 C   s   t | j| j| jd| _t| j| jd| _t| j| jd| _t	j
| jj| jd| _t| jd| _t	j
| jj| jd| _| jj| _| jdkr^| dt| jj| j| _| dt| jj| j| _d S d | _d | _d S )Nrw   epsilonr&   ry   r   lambda_1lambda_2)r   rZ   r   r&   r   r   intermediater   rO   rT   	LayerNormlayer_norm_epslayernorm_beforer>   r   	drop_pathlayernorm_afterlayer_scale_init_valueinit_valuesrz   r=   rf   r   r   r   r   r   r   rk     s   


zFlaxBeitLayer.setupNTFr@   r   c           	      C   s   | j | ||||d}|d }| jd ur| j|j| }| j||d| }| |}| |}| j||d}| j	d urF| j	|j| }| j||d| }|f}|r[||d f7 }|S r   )
r   r   r   r   r&   r   r   r   rO   r   )	rI   r   r   r@   r   self_attention_outputsattention_outputlayer_outputr   r   r   r   rP     s(   



zFlaxBeitLayer.__call__r   )r   r   r   r   rS   r   r   rR   r/   rs   r&   rk   rW   rP   r   r   r   r   r     s   
 r   c                	   @   s   e Zd ZU eed< eeef ed< ee ed< e	g e
jf ed< e
jZe
jed< dd Z						dd
edededefddZdS )FlaxBeitLayerCollectionrZ   r   drop_path_ratesr   r&   c                    s     fddt  jjD  _d S )Nc              	      s:   g | ]}t  j jjr jnd  j| t| jdqS )N)r   r   namer&   )r   rZ   use_relative_position_biasr   r   strr&   ).0ir   r   r   
<listcomp>  s    z1FlaxBeitLayerCollection.setup.<locals>.<listcomp>)rangerZ   num_hidden_layerslayersr   r   r   r   rk     s   

zFlaxBeitLayerCollection.setupTFr@   r   output_hidden_statesreturn_dictc                 C   s   |rdnd }|r
dnd }t | jD ]+\}}	|r||f7 }| jd ur%|  nd }
|	||
||d}|d }|r<||d f7 }q|rD||f7 }|f}|sRtdd |D S t|||dS )Nr   r   r   r   c                 s   s    | ]	}|d ur|V  qd S r9   r   )r   vr   r   r   	<genexpr>  s    z3FlaxBeitLayerCollection.__call__.<locals>.<genexpr>)last_hidden_stater   
attentions)	enumerater   r   r   r   )rI   r   r@   r   r   r   all_attentionsall_hidden_statesr   layerr   layer_outputsr   r   r   r   rP     s*   

z FlaxBeitLayerCollection.__call__NTFFT)r   r   r   r   rS   r   r   listrR   r   r/   ndarrayrs   r&   rk   rW   rP   r   r   r   r   r     s(   
 r   c                	   @   sb   e Zd ZU eed< eeef ed< ejZ	ej	ed< dd Z
				dded	ed
edefddZdS )FlaxBeitEncoderrZ   r   r&   c                 C   sd   | j jrt| j | j| jd| _ttd| j j	| j j
}t| j | j|| j jr)| jnd | jd| _d S )N)rZ   r   r&   r   )r   r   r   r&   )rZ   !use_shared_relative_position_biasr   r   r&   r   r   r'   linspacer   r   r   r   )rI   r   r   r   r   rk   (  s   zFlaxBeitEncoder.setupTFr@   r   r   r   c                 C   s   | j |||||dS )Nr@   r   r   r   )r   )rI   r   r@   r   r   r   r   r   r   rP   :  s   zFlaxBeitEncoder.__call__Nr   r   r   r   r   r   r   #  s$   
 r   c                       s   e Zd ZU dZeZdZdZdZe	j
ed< ddejdfded	ed
ejdef fddZddejjdededefddZeed							ddee dejjdedee dee dee fddZ  ZS )FlaxBeitPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    beitrm   Nmodule_classr   TrZ   seedr&   _do_initc                    sL   | j d||d|}|d u rd|j|j|jf}t j||||||d d S )N)rZ   r&   r   )input_shaper   r&   r   r   )r   ra   r`   super__init__)rI   rZ   r   r   r&   r   kwargsrU   	__class__r   r   r   V  s   	z FlaxBeitPreTrainedModel.__init__rL   r   paramsr   c                 C   s   t j|| jd}tj|\}}tj|\}}|||d}| jj||ddd }	|d urOtt	|	}	tt	|}| j
D ]}
|	|
 ||
< q<t | _
tt|S |	S )Nrw   )r  r   rB   F)r   r  )r/   r-   r&   rE   rF   splitrU   initr   r   _missing_keyssetr   r	   )rI   rL   r   r  rm   
params_rngr   droppath_rngrngsrandom_paramsmissing_keyr   r   r   init_weightsd  s   
z$FlaxBeitPreTrainedModel.init_weightszbatch_size, sequence_lengthFr   trainr   r   r   c	              
   C   s   |d ur|n| j j}|d ur|n| j j}|d ur|n| j j}t|d}i }	|d ur:tj|\}}
||	d< |
|	d< | j	j
d|pB| jitj|tjd|| ||||	dS )N)r   r    r
   r   r   rB   r  rw   )r	  )rZ   r   r   r   r/   r,   rE   rF   r  rU   applyr  r0   rs   )rI   rm   r   r  r   r  r   r   r   r	  r  r   r   r   rP   x  s(   z FlaxBeitPreTrainedModel.__call__r9   )NNNFNNN)r   r   r   r   r   config_classbase_model_prefixmain_input_namer   rT   ModulerS   r/   rs   r   r&   rW   r   rE   rF   PRNGKeyr   r   r  r   BEIT_INPUTS_DOCSTRINGformatr   dictrP   __classcell__r   r   r   r   r   K  sR   
  	r   c                   @   rX   )FlaxBeitPoolerrZ   r&   c                 C   s&   | j jrtj| j j| jd| _d S d S )Nr   )rZ   use_mean_poolingrT   r   r   r&   	layernormr   r   r   r   rk     s   zFlaxBeitPooler.setupc                 C   sN   | j jr|d d dd d d f }| tj|dd}|S |d d df }|S )Nr   r   r   )rZ   r  r  r/   mean)rI   r   patch_tokenspooled_outputr   r   r   rP     s   zFlaxBeitPooler.__call__Nrr   r   r   r   r   r    s
   
 r  c                	   @   s`   e Zd ZU eed< ejZejed< dZe	ed< dd Z
					dd	e	d
e	de	de	fddZdS )FlaxBeitModulerZ   r&   Tadd_pooling_layerc                 C   sp   t | j| jd| _t| j| jjj| jd| _| jjs%t	j
| jj| jd| _| jr3t| j| jd| _d S d | _d S )Nrw   r   r   )rt   rZ   r&   rn   r   r|   rd   encoderr  rT   r   r   r  r  r  poolerr   r   r   r   rk     s   &zFlaxBeitModule.setupNFr@   r   r   r   c           
      C   s   | j |||d}| j|||||d}|d }| jjs| |}| jr'| |nd }	|sB|	d u r8|f|dd   S ||	f|dd   S t||	|j|j	dS )Nr   r   r   r   )r   pooler_outputr   r   )
rn   r   rZ   r  r  r  r!  r   r   r   )
rI   rm   r   r@   r   r   r   r   r   pooledr   r   r   rP     s,   	
zFlaxBeitModule.__call__)NTFFT)r   r   r   r   rS   r/   rs   r&   r  rW   rk   rP   r   r   r   r   r    s&   
 r  z^The bare Beit Model transformer outputting raw hidden-states without any specific head on top.c                   @      e Zd ZeZdS )FlaxBeitModelN)r   r   r   r  r   r   r   r   r   r%        r%  a  
    Returns:

    Examples:

    ```python
    >>> from transformers import AutoImageProcessor, FlaxBeitModel
    >>> from PIL import Image
    >>> import requests

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k")
    >>> model = FlaxBeitModel.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k")

    >>> inputs = image_processor(images=image, return_tensors="np")
    >>> outputs = model(**inputs)
    >>> last_hidden_states = outputs.last_hidden_state
    ```
)output_typer  c                   @   J   e Zd ZU eed< ejZejed< dd Z						d
de	fdd	Z
dS )$FlaxBeitForMaskedImageModelingModulerZ   r&   c                 C   sT   t | jd| jd| _tj| jj| jd| _tj| jj	t
jj| jj| jd| _d S )NF)r  r&   r   r   )r  rZ   r&   r   rT   r   r   r  r   
vocab_sizerE   rg   rh   ri   lm_headr   r   r   r   rk     s   z*FlaxBeitForMaskedImageModelingModule.setupNTr@   c                 C   s   |d ur|n| j j}| j||||||d}|d }| |}| |d d dd f }	|s8|	f|dd   }
|
S t|	|j|jdS )Nr   r   r   r    logitsr   r   )rZ   use_return_dictr   r  r+  r   r   r   )rI   rm   r   r@   r   r   r   r   sequence_outputprediction_scoresrO   r   r   r   rP     s(   		
z-FlaxBeitForMaskedImageModelingModule.__call__NNTNNNr   r   r   r   r   r)    s   
 r)  zYBeit Model transformer with a 'language' modeling head on top (to predict visual tokens).c                   @   r$  )FlaxBeitForMaskedImageModelingN)r   r   r   r)  r   r   r   r   r   r2  9  r&  r2  a?  
    bool_masked_pos (`numpy.ndarray` of shape `(batch_size, num_patches)`):
        Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

    Returns:

    Examples:

    ```python
    >>> from transformers import AutoImageProcessor, BeitForMaskedImageModeling
    >>> from PIL import Image
    >>> import requests

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k")
    >>> model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k")

    >>> inputs = image_processor(images=image, return_tensors="np")
    >>> outputs = model(**inputs)
    >>> logits = outputs.logits
    ```
c                   @   r(  )$FlaxBeitForImageClassificationModulerZ   r&   c                 C   s>   t | j| jdd| _tj| jjtjj	| jj
| jd| _d S )NT)rZ   r&   r  r   )r  rZ   r&   r   rT   r   
num_labelsrE   rg   rh   ri   
classifierr   r   r   r   rk   d  s   z*FlaxBeitForImageClassificationModule.setupNTr@   c                 C   sf   |d ur|n| j j}| j|||||d}|d }| |}	|s*|	f|dd   }
|
S t|	|j|jdS )Nr   r   r    r,  )rZ   r.  r   r5  r   r   r   )rI   rm   r   r@   r   r   r   r   r  r-  rO   r   r   r   rP   l  s$   	
z-FlaxBeitForImageClassificationModule.__call__r1  r   r   r   r   r   r3  `  s   
 
r3  z
    Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final
    hidden states of the patch tokens) e.g. for ImageNet.
    c                   @   r$  )FlaxBeitForImageClassificationN)r   r   r   r3  r   r   r   r   r   r6    s    r6  aM  
    Returns:

    Example:

    ```python
    >>> from transformers import AutoImageProcessor, FlaxBeitForImageClassification
    >>> from PIL import Image
    >>> import requests

    >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    >>> image = Image.open(requests.get(url, stream=True).raw)

    >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224")
    >>> model = FlaxBeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224")

    >>> inputs = image_processor(images=image, return_tensors="np")
    >>> outputs = model(**inputs)
    >>> logits = outputs.logits
    >>> # model predicts one of the 1000 ImageNet classes
    >>> predicted_class_idx = logits.argmax(-1).item()
    >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
    ```
)r6  r2  r%  r   )Htypingr   r   flax
flax.linenlinenrT   rE   	jax.numpynumpyr/   r'   flax.core.frozen_dictr   r   r   flax.linen.attentionr   flax.traverse_utilr   r	   modeling_flax_outputsr   r   r   r   modeling_flax_utilsr   r   r   r   utilsr   r   configuration_beitr   struct	dataclassr   BEIT_START_DOCSTRINGr  r   r   r   r8   rs   r=   r  r>   rY   rt   r   r   r   r   r   r   r   r   r   r   r  r  r%  FLAX_BEIT_MODEL_DOCSTRINGr)  r2  FLAX_BEIT_MLM_DOCSTRINGr3  r6  FLAX_BEIT_CLASSIF_DOCSTRING__all__r   r   r   r   <module>   s   # )U=7(S3
2
-
