o
    iK                     @   s  d dl mZmZ d dlZd dlmZ d dlZd dlm	Z
 d dl	Zd dlmZmZmZ d dlmZ d dlmZmZ d dlmZ ddlmZmZmZmZmZmZmZ dd	lmZm Z m!Z!m"Z"m#Z# dd
l$m%Z%m&Z&m'Z'm(Z( ddl)m*Z* e(+e,Z-dZ.dZ/ej0j1G dd de%Z2dZ3dZ4G dd dej5Z6G dd dej5Z7G dd dej5Z8G dd dej5Z9G dd dej5Z:G dd dej5Z;G dd  d ej5Z<G d!d" d"ej5Z=G d#d$ d$ej5Z>G d%d& d&e Z?G d'd( d(ej5Z@e&d)e3G d*d+ d+e?ZAe!eAe.ee/ G d,d- d-ej5ZBe&d.e3G d/d0 d0e?ZCd1ZDe#eCe4Ed2eD  e"eCe2e/d3 G d4d5 d5ej5ZFe&d6e3G d7d8 d8e?ZGe!eGe.ee/d9d: G d;d< d<ej5ZHe&d=e3G d>d? d?e?ZIe!eIe.ee/ G d@dA dAej5ZJe&dBe3G dCdD dDe?ZKe#eKe4EdE e!eKe.ee/ G dFdG dGej5ZLe&dHe3G dIdJ dJe?ZMe!eMe.ee/ G dKdL dLej5ZNe&dMe3G dNdO dOe?ZOe!eOe.ee/ g dPZPdS )Q    )CallableOptionalN)
FrozenDictfreezeunfreeze)dot_product_attention_weights)flatten_dictunflatten_dict)lax   )FlaxBaseModelOutputFlaxBaseModelOutputWithPoolingFlaxMaskedLMOutputFlaxMultipleChoiceModelOutput FlaxQuestionAnsweringModelOutputFlaxSequenceClassifierOutputFlaxTokenClassifierOutput)ACT2FNFlaxPreTrainedModelappend_call_sample_docstring append_replace_return_docstringsoverwrite_call_docstring)ModelOutputadd_start_docstrings%add_start_docstrings_to_model_forwardlogging   )AlbertConfigzalbert/albert-base-v2r   c                   @   sZ   e Zd ZU dZdZejed< dZejed< dZ	e
eej  ed< dZe
eej  ed< dS )FlaxAlbertForPreTrainingOutputaB  
    Output type of [`FlaxAlbertForPreTraining`].

    Args:
        prediction_logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        sop_logits (`jnp.ndarray` of shape `(batch_size, 2)`):
            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
            before SoftMax).
        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
            `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Nprediction_logits
sop_logitshidden_states
attentions)__name__
__module____qualname____doc__r   jnpndarray__annotations__r    r!   r   tupler"    r+   r+   c/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/albert/modeling_flax_albert.pyr   6   s   
 r   a  

    This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading, saving and converting weights from PyTorch models)

    This model is also a
    [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
    a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
    behavior.

    Finally, this model supports inherent JAX features such as:

    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)

    Parameters:
        config ([`AlbertConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
        dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
            The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
            `jax.numpy.bfloat16` (on TPUs).

            This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
            specified all the computation will be performed with the given `dtype`.

            **Note that this only specifies the dtype of the computation and does not influence the dtype of model
            parameters.**

            If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
            [`~FlaxPreTrainedModel.to_bf16`].
a  
    Args:
        input_ids (`numpy.ndarray` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.

c                   @   sB   e Zd ZU dZeed< ejZejed< dd Z	dde
fdd	Zd
S )FlaxAlbertEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.configdtypec                 C   s   t j| jj| jjtj jj| jjdd| _	t j| jj
| jjtj jj| jjdd| _t j| jj| jjtj jj| jjdd| _t j| jj| jd| _t j| jjd| _d S )N)stddev)embedding_initepsilonr/   rate)nnEmbedr.   
vocab_sizeembedding_sizejaxinitializersnormalinitializer_rangeword_embeddingsmax_position_embeddingsposition_embeddingstype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsr/   Dropouthidden_dropout_probdropoutselfr+   r+   r,   setup   s"   zFlaxAlbertEmbeddings.setupTdeterministicc           	      C   sX   |  |d}| |d}| |d}|| | }| |}| j||d}|S )Ni4rK   )r>   astyper@   rB   rC   rG   )	rI   	input_idstoken_type_idsposition_idsrK   inputs_embedsposition_embedsrB   r!   r+   r+   r,   __call__   s   
zFlaxAlbertEmbeddings.__call__NT)r#   r$   r%   r&   r   r)   r'   float32r/   rJ   boolrT   r+   r+   r+   r,   r-      s   
 r-   c                   @   s>   e Zd ZU eed< ejZejed< dd Zdde	fdd	Z
d
S )FlaxAlbertSelfAttentionr.   r/   c                 C   s   | j j| j j dkrtdtj| j j| jtjj	| j j
d| _tj| j j| jtjj	| j j
d| _tj| j j| jtjj	| j j
d| _tj| j jtjj	| j j
| jd| _tj| j j| jd| _tj| j jd| _d S )Nr   z`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`                    : {self.config.num_attention_heads})r/   kernel_initrY   r/   r2   r4   )r.   hidden_sizenum_attention_heads
ValueErrorr6   Denser/   r:   r;   r<   r=   querykeyvaluedenserC   rD   rE   rF   rG   rH   r+   r+   r,   rJ      s4   zFlaxAlbertSelfAttention.setupTFoutput_attentionsc                 C   s  | j j| j j }| ||jd d | j j|f }| ||jd d | j j|f }| ||jd d | j j|f }|d urmtj	|dd}t
|dkt|jd| jt|jt| jj| j}	nd }	d }
|s~| j jdkr~| d}
t|||	|
| j jd|| jd d	}td	||}||jd d d
 }| |}| j||d}| || }|r||f}|S |f}|S )N   )axisr   g        rG   T)biasdropout_rngdropout_ratebroadcast_dropoutrK   r/   	precisionz...hqk,...khd->...qhd)rM   )r.   r[   r\   r_   reshapeshapera   r`   r'   expand_dimsr
   selectfullrN   r/   finfominattention_probs_dropout_probmake_rngr   einsumrb   rG   rC   )rI   r!   attention_maskrK   rc   head_dimquery_statesvalue_states
key_statesattention_biasrj   attn_weightsattn_outputprojected_attn_outputlayernormed_attn_outputoutputsr+   r+   r,   rT      sR   




z FlaxAlbertSelfAttention.__call__NTFr#   r$   r%   r   r)   r'   rV   r/   rJ   rW   rT   r+   r+   r+   r,   rX      s
   
 rX   c                   @   sF   e Zd ZU eed< ejZejed< dd Z		dde	de	fd	d
Z
dS )FlaxAlbertLayerr.   r/   c                 C   s   t | j| jd| _tj| jjtjj	| jj
| jd| _t| jj | _tj| jjtjj	| jj
| jd| _tj| jj| jd| _tj| jjd| _d S )Nr/   rZ   r2   r4   )rX   r.   r/   	attentionr6   r^   intermediate_sizer:   r;   r<   r=   ffnr   
hidden_act
activationr[   
ffn_outputrC   rD   full_layer_layer_normrE   rF   rG   rH   r+   r+   r,   rJ     s   zFlaxAlbertLayer.setupTFrK   rc   c           	      C   sp   | j ||||d}|d }| |}| |}| |}| j||d}| || }|f}|r6||d f7 }|S )NrK   rc   r   rM   r   )r   r   r   r   rG   r   )	rI   r!   ry   rK   rc   attention_outputsattention_outputr   r   r+   r+   r,   rT   )  s   


zFlaxAlbertLayer.__call__Nr   r   r+   r+   r+   r,   r     s   
 r   c                   @   sL   e Zd ZU eed< ejZejed< dd Z			dde	de	d	e	fd
dZ
dS )FlaxAlbertLayerCollectionr.   r/   c                         fddt  jjD  _d S )Nc                    s"   g | ]}t  jt| jd qS ))namer/   )r   r.   strr/   .0irH   r+   r,   
<listcomp>F  s    z3FlaxAlbertLayerCollection.setup.<locals>.<listcomp>)ranger.   inner_group_numlayersrH   r+   rH   r,   rJ   E  s   

zFlaxAlbertLayerCollection.setupTFrK   rc   output_hidden_statesc                 C   sz   d}d}t | jD ] \}}	|	||||d}
|
d }|r"||
d f }|r)||f }q	|f}|r4||f }|r;||f }|S )Nr+   r   r   r   )	enumerater   )rI   r!   ry   rK   rc   r   layer_hidden_stateslayer_attentionslayer_indexalbert_layerlayer_outputr   r+   r+   r,   rT   J  s*   


z"FlaxAlbertLayerCollection.__call__NTFFr   r+   r+   r+   r,   r   A  s   
 	r   c                   @   s\   e Zd ZU eed< ejZejed< dZe	e
 ed< dd Z			dd	ed
edefddZdS )FlaxAlbertLayerCollectionsr.   r/   Nr   c                 C   s   t | j| jd| _d S )Nr   )r   r.   r/   albert_layersrH   r+   r+   r,   rJ   q  s   z FlaxAlbertLayerCollections.setupTFrK   rc   r   c                 C   s   | j |||||d}|S NrK   rc   r   )r   )rI   r!   ry   rK   rc   r   r   r+   r+   r,   rT   t  s   z#FlaxAlbertLayerCollections.__call__r   )r#   r$   r%   r   r)   r'   rV   r/   r   r   r   rJ   rW   rT   r+   r+   r+   r,   r   l  s   
 r   c                	   @   R   e Zd ZU eed< ejZejed< dd Z				dde	de	d	e	d
e	fddZ
dS )FlaxAlbertLayerGroupsr.   r/   c                    r   )Nc                    s(   g | ]}t  jt|t| jd qS ))r   r   r/   )r   r.   r   r/   r   rH   r+   r,   r     s    z/FlaxAlbertLayerGroups.setup.<locals>.<listcomp>)r   r.   num_hidden_groupsr   rH   r+   rH   r,   rJ     s   

zFlaxAlbertLayerGroups.setupTFrK   rc   r   return_dictc                 C   s   |rdnd }|r|fnd }t | jjD ]-}	t|	| jj| jj  }
| j|
 |||||d}|d }|r9||d  }|r@||f }q|sOtdd |||fD S t|||dS )Nr+   r   r   rn   c                 s   s    | ]	}|d ur|V  qd S Nr+   )r   vr+   r+   r,   	<genexpr>  s    z1FlaxAlbertLayerGroups.__call__.<locals>.<genexpr>)last_hidden_stater!   r"   )r   r.   num_hidden_layersintr   r   r*   r   )rI   r!   ry   rK   rc   r   r   all_attentionsall_hidden_statesr   	group_idxlayer_group_outputr+   r+   r,   rT     s,   	
zFlaxAlbertLayerGroups.__call__NTFFTr   r+   r+   r+   r,   r     s"   
 
r   c                	   @   r   )FlaxAlbertEncoderr.   r/   c                 C   s<   t j| jjtj j| jj| jd| _	t
| j| jd| _d S )NrZ   r   )r6   r^   r.   r[   r:   r;   r<   r=   r/   embedding_hidden_mapping_inr   albert_layer_groupsrH   r+   r+   r,   rJ     s   zFlaxAlbertEncoder.setupTFrK   rc   r   r   c                 C   s   |  |}| j|||||dS r   )r   r   )rI   r!   ry   rK   rc   r   r   r+   r+   r,   rT     s   
	zFlaxAlbertEncoder.__call__Nr   r   r+   r+   r+   r,   r     s"   
 r   c                   @   sT   e Zd ZU eed< ejZejed< ej	j
jZedejf ed< dd Zd
dd	ZdS )FlaxAlbertOnlyMLMHeadr.   r/   .	bias_initc                 C   sn   t j| jj| jd| _t| jj | _t j	| jj
| jd| _	t j| jj| jdd| _| d| j| jjf| _d S )Nr   r2   F)r/   use_biasri   )r6   r^   r.   r9   r/   rb   r   r   r   rC   rD   r8   decoderparamr   ri   rH   r+   r+   r,   rJ     s
   zFlaxAlbertOnlyMLMHead.setupNc                 C   sX   |  |}| |}| |}|d ur | jdd|jii|}n| |}|| j7 }|S )Nparamskernel)rb   r   rC   r   applyTri   )rI   r!   shared_embeddingr+   r+   r,   rT     s   




zFlaxAlbertOnlyMLMHead.__call__r   )r#   r$   r%   r   r)   r'   rV   r/   r:   r6   r;   zerosr   r   npr(   rJ   rT   r+   r+   r+   r,   r     s   
 r   c                   @   s8   e Zd ZU eed< ejZejed< dd Zd	ddZ	dS )
FlaxAlbertSOPHeadr.   r/   c                 C   s&   t | jj| _t jd| jd| _d S )Nrd   r   )r6   rE   r.   classifier_dropout_probrG   r^   r/   
classifierrH   r+   r+   r,   rJ     s   zFlaxAlbertSOPHead.setupTc                 C   s   | j ||d}| |}|S )NrM   )rG   r   )rI   pooled_outputrK   logitsr+   r+   r,   rT     s   
zFlaxAlbertSOPHead.__call__NrU   )
r#   r$   r%   r   r)   r'   rV   r/   rJ   rT   r+   r+   r+   r,   r     s
   
 r   c                       s   e Zd ZU dZeZdZdZej	e
d< ddejdfded	ed
edejdef
 fddZddejjd	ededefddZeed									ddee dejjdedee dee dee fddZ  ZS )FlaxAlbertPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    albertNmodule_class)r   r   r   Tr.   input_shapeseedr/   _do_initc                    s2   | j d||d|}t j||||||d d S )Nr.   r/   )r   r   r/   r   r+   )r   super__init__)rI   r.   r   r   r/   r   kwargsmodule	__class__r+   r,   r     s   	z"FlaxAlbertPreTrainedModel.__init__rngr   returnc                 C   s   t j|dd}t |}t t t |jd |}t |}tj	
|\}}	||	d}
| jj|
||||ddd }|d uratt|}tt|}| jD ]}|| ||< qNt | _tt|S |S )NrL   r   rn   )r   rG   F)r   r   )r'   r   
zeros_likebroadcast_toarange
atleast_2drp   	ones_liker:   randomsplitr   initr   r   _missing_keyssetr   r	   )rI   r   r   r   rO   rP   rQ   ry   
params_rngrj   rngsrandom_paramsmissing_keyr+   r+   r,   init_weights  s&   



z&FlaxAlbertPreTrainedModel.init_weightsbatch_size, sequence_lengthFrj   trainrc   r   r   c                 C   s   |d ur|n| j j}|	d ur|	n| j j}	|
d ur|
n| j j}
|d u r't|}|d u r;ttt|j	d |j	}|d u rDt
|}i }|d urN||d< | jjd|pV| jitj|ddtj|ddtj|ddtj|dd| ||	|
|d
S )Nrn   rG   r   rL   r   )r   )r.   rc   r   r   r'   r   r   r   r   rp   r   r   r   r   array)rI   rO   ry   rP   rQ   r   rj   r   rc   r   r   r   r+   r+   r,   rT   *  s2   
 
z"FlaxAlbertPreTrainedModel.__call__r   )	NNNNNFNNN)r#   r$   r%   r&   r   config_classbase_model_prefixr   r6   Moduler)   r'   rV   r*   r   r/   rW   r   r:   r   PRNGKeyr   r   r   ALBERT_INPUTS_DOCSTRINGformatr   dictrT   __classcell__r+   r+   r   r,   r     sX   
  	
r   c                   @   sv   e Zd ZU eed< ejZejed< dZe	ed< dd Z
						dd	eej d
eej de	de	de	de	fddZdS )FlaxAlbertModuler.   r/   Tadd_pooling_layerc                 C   sn   t | j| jd| _t| j| jd| _| jr/tj| jj	t
jj| jj| jdd| _tj| _d S d | _d | _d S )Nr   pooler)rY   r/   r   )r-   r.   r/   
embeddingsr   encoderr   r6   r^   r[   r:   r;   r<   r=   r   tanhpooler_activationrH   r+   r+   r,   rJ   `  s   
zFlaxAlbertModule.setupNFrP   rQ   rK   rc   r   r   c	                 C   s   |d u r	t |}|d u rt t t |jd |j}| j||||d}	| j|	|||||d}
|
d }	| jrI| 	|	d d df }| 
|}nd }|sd|d u rZ|	f|
dd   S |	|f|
dd   S t|	||
j|
jdS )Nrn   rM   rK   rc   r   r   r   r   )r   pooler_outputr!   r"   )r'   r   r   r   r   rp   r   r   r   r   r   r   r!   r"   )rI   rO   ry   rP   rQ   rK   rc   r   r   r!   r   pooledr+   r+   r,   rT   o  s8   
 zFlaxAlbertModule.__call__)NNTFFT)r#   r$   r%   r   r)   r'   rV   r/   r   rW   rJ   r   r   r(   rT   r+   r+   r+   r,   r   [  s0   
 	r   z`The bare Albert Model transformer outputting raw hidden-states without any specific head on top.c                   @      e Zd ZeZdS )FlaxAlbertModelN)r#   r$   r%   r   r   r+   r+   r+   r,   r     s    r   c                	   @   r   )FlaxAlbertForPreTrainingModuler.   r/   c                 C   s:   t | j| jd| _t| j| jd| _t| j| jd| _d S )Nr   )r   r.   r/   r   r   predictionsr   sop_classifierrH   r+   r+   r,   rJ        z$FlaxAlbertForPreTrainingModule.setupTFrK   rc   r   r   c	              
   C   s   | j ||||||||d}	| jjr| j jd d d d }
nd }
|	d }|	d }| j||
d}| j||d	}|sB||f|	d
d   S t|||	j|	jdS )Nr   r   r   r>   	embeddingr   r   r   rM   rd   )r   r    r!   r"   )	r   r.   tie_word_embeddings	variablesr  r  r   r!   r"   )rI   rO   ry   rP   rQ   rK   rc   r   r   r   r   r!   r   prediction_scores
sop_scoresr+   r+   r,   rT     s2   z'FlaxAlbertForPreTrainingModule.__call__Nr   r   r+   r+   r+   r,   r    "   
 	r  z
    Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
    `sentence order prediction (classification)` head.
    c                   @   r   )FlaxAlbertForPreTrainingN)r#   r$   r%   r  r   r+   r+   r+   r,   r        r  a  
    Returns:

    Example:

    ```python
    >>> from transformers import AutoTokenizer, FlaxAlbertForPreTraining

    >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
    >>> model = FlaxAlbertForPreTraining.from_pretrained("albert/albert-base-v2")

    >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
    >>> outputs = model(**inputs)

    >>> prediction_logits = outputs.prediction_logits
    >>> seq_relationship_logits = outputs.sop_logits
    ```
r   )output_typer   c                	   @   r   )FlaxAlbertForMaskedLMModuler.   r/   c                 C   s*   t | jd| jd| _t| j| jd| _d S )NF)r.   r   r/   r   )r   r.   r/   r   r   r  rH   r+   r+   r,   rJ     s   z!FlaxAlbertForMaskedLMModule.setupTFrK   rc   r   r   c	              
   C   s~   | j ||||||||d}	|	d }
| jjr"| j jd d d d }nd }| j|
|d}|s6|f|	dd   S t||	j|	jd	S )
Nr   r   r   r   r>   r  r  r   r   r!   r"   )r   r.   r  r  r  r   r!   r"   )rI   rO   ry   rP   rQ   rK   rc   r   r   r   r!   r   r   r+   r+   r,   rT     s,   z$FlaxAlbertForMaskedLMModule.__call__Nr   r   r+   r+   r+   r,   r  	  "   
 
	r  z4Albert Model with a `language modeling` head on top.c                   @   r   )FlaxAlbertForMaskedLMN)r#   r$   r%   r  r   r+   r+   r+   r,   r  ;  s    r  z
refs/pr/11)revisionc                	   @   r   ))FlaxAlbertForSequenceClassificationModuler.   r/   c                 C   sV   t | j| jd| _| jjd ur| jjn| jj}tj|d| _tj	| jj
| jd| _d S )Nr   r4   r   r   r.   r/   r   r   rF   r6   rE   rG   r^   
num_labelsr   rI   classifier_dropoutr+   r+   r,   rJ   I  s   z/FlaxAlbertForSequenceClassificationModule.setupTFrK   rc   r   r   c	              
   C   b   | j ||||||||d}	|	d }
| j|
|d}
| |
}|s(|f|	dd   S t||	j|	jdS )Nr   r   rM   rd   r  )r   rG   r   r   r!   r"   )rI   rO   ry   rP   rQ   rK   rc   r   r   r   r   r   r+   r+   r,   rT   V  (   
z2FlaxAlbertForSequenceClassificationModule.__call__Nr   r   r+   r+   r+   r,   r  E  s"   
 	r  z
    Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    c                   @   r   )#FlaxAlbertForSequenceClassificationN)r#   r$   r%   r  r   r+   r+   r+   r,   r  {  r  r  c                	   @   r   )!FlaxAlbertForMultipleChoiceModuler.   r/   c                 C   s:   t | j| jd| _tj| jjd| _tjd| jd| _	d S )Nr   r4   r   r   )
r   r.   r/   r   r6   rE   rF   rG   r^   r   rH   r+   r+   r,   rJ     r  z'FlaxAlbertForMultipleChoiceModule.setupTFrK   rc   r   r   c	              
   C   s   |j d }	|d ur|d|j d nd }|d ur!|d|j d nd }|d ur0|d|j d nd }|d ur?|d|j d nd }| j||||||||d}
|
d }| j||d}| |}|d|	}|so|f|
dd   S t||
j|
jdS )Nr   rn   r   rM   rd   r  )rp   ro   r   rG   r   r   r!   r"   )rI   rO   ry   rP   rQ   rK   rc   r   r   num_choicesr   r   r   reshaped_logitsr+   r+   r,   rT     s4   

z*FlaxAlbertForMultipleChoiceModule.__call__Nr   r   r+   r+   r+   r,   r    r  r  z
    Albert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
    softmax) e.g. for RocStories/SWAG tasks.
    c                   @   r   )FlaxAlbertForMultipleChoiceN)r#   r$   r%   r  r   r+   r+   r+   r,   r    r  r  z(batch_size, num_choices, sequence_lengthc                	   @   r   )&FlaxAlbertForTokenClassificationModuler.   r/   c                 C   sX   t | j| jdd| _| jjd ur| jjn| jj}tj|d| _tj	| jj
| jd| _d S )NFr.   r/   r   r4   r   r  r  r+   r+   r,   rJ     s   z,FlaxAlbertForTokenClassificationModule.setupTFrK   rc   r   r   c	              
   C   r  )Nr   r   rM   r   r  )r   rG   r   r   r!   r"   )rI   rO   ry   rP   rQ   rK   rc   r   r   r   r!   r   r+   r+   r,   rT     r  z/FlaxAlbertForTokenClassificationModule.__call__Nr   r   r+   r+   r+   r,   r     s"   
 	r   z
    Albert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
    Named-Entity-Recognition (NER) tasks.
    c                   @   r   ) FlaxAlbertForTokenClassificationN)r#   r$   r%   r   r   r+   r+   r+   r,   r"    r  r"  c                	   @   r   )$FlaxAlbertForQuestionAnsweringModuler.   r/   c                 C   s.   t | j| jdd| _tj| jj| jd| _d S )NFr!  r   )r   r.   r/   r   r6   r^   r  
qa_outputsrH   r+   r+   r,   rJ   $  s   z*FlaxAlbertForQuestionAnsweringModule.setupTFrK   rc   r   r   c	              
   C   s   | j ||||||||d}	|	d }
| |
}tj|| jjdd\}}|d}|d}|s8||f|	dd   S t|||	j|	j	dS )Nr   r   rn   rg   r   )start_logits
end_logitsr!   r"   )
r   r$  r'   r   r.   r  squeezer   r!   r"   )rI   rO   ry   rP   rQ   rK   rc   r   r   r   r!   r   r%  r&  r+   r+   r,   rT   (  s.   


z-FlaxAlbertForQuestionAnsweringModule.__call__Nr   r   r+   r+   r+   r,   r#     r  r#  z
    Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
    layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
    c                   @   r   )FlaxAlbertForQuestionAnsweringN)r#   r$   r%   r#  r   r+   r+   r+   r,   r(  Q  r  r(  )r   r   r  r  r  r  r"  r(  )Qtypingr   r   flax
flax.linenlinenr6   r:   	jax.numpynumpyr'   r   flax.core.frozen_dictr   r   r   flax.linen.attentionr   flax.traverse_utilr   r	   r
   modeling_flax_outputsr   r   r   r   r   r   r   modeling_flax_utilsr   r   r   r   r   utilsr   r   r   r   configuration_albertr   
get_loggerr#   logger_CHECKPOINT_FOR_DOC_CONFIG_FOR_DOCstruct	dataclassr   ALBERT_START_DOCSTRINGr   r   r-   rX   r   r   r   r   r   r   r   r   r   r   r  r  %FLAX_ALBERT_FOR_PRETRAINING_DOCSTRINGr   r  r  r  r  r  r  r   r"  r#  r(  __all__r+   r+   r+   r,   <module>   s   $	
#!(U,+/_F62
66
31