o
    	۷i                     @   s  d Z ddlZddlmZ ddlmZ ddlmZmZ ddl	Z	ddl	m
Z
 ddlmZmZmZ dd	lmZmZ e r?dd
lmZ ddlmZ ddlmZ ddlmZmZmZmZmZmZmZm Z m!Z! ddl"m#Z# ddl$m%Z% ddlm&Z& ddl'm(Z( e&)e*Z+dd Z,dd Z-dd Z.G dd de
j/Z0G dd de
j/Z1G dd de
j/Z2G dd  d e
j/Z3G d!d" d"e
j/Z4G d#d$ d$e
j/Z5G d%d& d&eZ6G d'd( d(e
j/Z7G d)d* d*e
j/Z8G d+d, d,e
j/Z9G d-d. d.e
j/Z:G d/d0 d0e
j/Z;G d1d2 d2e
j/Z<G d3d4 d4e
j/Z=eG d5d6 d6e#Z>eed7d8G d9d: d:eZ?eG d;d< d<e>Z@ed=d8G d>d? d?e>ZAeG d@dA dAe>ZBedBd8G dCdD dDe>ZCedEd8G dFdG dGe>ZDeG dHdI dIe>ZEeG dJdK dKe>ZFeG dLdM dMe>ZGg dNZHdS )OzPyTorch FNet model.    N)	dataclass)partial)OptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )auto_docstringis_scipy_available)linalg)ACT2FN)GradientCheckpointingLayer)	BaseModelOutputBaseModelOutputWithPoolingMaskedLMOutputModelOutputMultipleChoiceModelOutputNextSentencePredictorOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel)apply_chunking_to_forward)logging   )
FNetConfigc                 C   s:   | j d }|d|d|f }| tj} td| ||S )z4Applies 2D matrix multiplication to 3D input arrays.r   Nzbij,jk,ni->bnk)shapetypetorch	complex64einsum)xmatrix_dim_onematrix_dim_two
seq_length r'   \/home/ubuntu/vllm_env/lib/python3.10/site-packages/transformers/models/fnet/modeling_fnet.py_two_dim_matmul7   s   
r)   c                 C   s   t | ||S N)r)   )r#   r$   r%   r'   r'   r(   two_dim_matmul@      r+   c                 C   s4   | }t t| jdd D ]
}tjj||d}q|S )z
    Applies n-dimensional Fast Fourier Transform (FFT) to input array.

    Args:
        x: Input n-dimensional array.

    Returns:
        n-dimensional Fourier transform of input n-dimensional array.
    r   N)axis)reversedrangendimr    fft)r#   outr-   r'   r'   r(   fftnE   s   
r3   c                       s*   e Zd ZdZ fddZdddZ  ZS )FNetEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    s   t    tj|j|j|jd| _t|j|j| _	t|j
|j| _tj|j|jd| _t|j|j| _t|j| _| jdt|jddd | jdtj| j tjddd d S )	N)padding_idxepsposition_ids)r   F)
persistenttoken_type_idsdtype)super__init__r   	Embedding
vocab_sizehidden_sizepad_token_idword_embeddingsmax_position_embeddingsposition_embeddingstype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsLinear
projectionDropouthidden_dropout_probdropoutregister_bufferr    arangeexpandzerosr8   sizelongselfconfig	__class__r'   r(   r?   X   s   

zFNetEmbeddings.__init__Nc                 C   s   |d ur	|  }n|  d d }|d }|d u r$| jd d d |f }|d u rNt| drC| jd d d |f }||d |}|}ntj|tj| jjd}|d u rW| 	|}| 
|}	||	 }
| |}|
|7 }
| |
}
| |
}
| |
}
|
S )Nr9   r   r;   r   r=   device)rT   r8   hasattrr;   rR   r    rS   rU   r\   rD   rH   rF   rI   rL   rO   )rW   	input_idsr;   r8   inputs_embedsinput_shaper&   buffered_token_type_ids buffered_token_type_ids_expandedrH   
embeddingsrF   r'   r'   r(   forwardn   s,   







zFNetEmbeddings.forward)NNNN)__name__
__module____qualname____doc__r?   rd   __classcell__r'   r'   rY   r(   r4   U   s    r4   c                       ,   e Zd Z fddZdd Zdd Z  ZS )FNetBasicFourierTransformc                    s   t    | | d S r*   )r>   r?   _init_fourier_transformrV   rY   r'   r(   r?         
z"FNetBasicFourierTransform.__init__c                 C   s   |j sttjjdd| _d S |jdkrLt rB| dtj	t
|jtjd | dtj	t
|jtjd tt| j| jd| _d S td t| _d S t| _d S )	N)r      dim   dft_mat_hiddenr<   dft_mat_seq)r$   r%   zpSciPy is needed for DFT matrix calculation and is not found. Using TPU optimized fast fourier transform instead.)use_tpu_fourier_optimizationsr   r    r1   r3   fourier_transformrE   r   rP   tensorr   dftrB   r!   tpu_short_seq_lengthr+   rs   rr   r   warningrV   r'   r'   r(   rl      s$   



z1FNetBasicFourierTransform._init_fourier_transformc                 C   s   |  |j}|fS r*   )ru   real)rW   hidden_statesoutputsr'   r'   r(   rd      s   z!FNetBasicFourierTransform.forward)re   rf   rg   r?   rl   rd   ri   r'   r'   rY   r(   rk      s    rk   c                       $   e Zd Z fddZdd Z  ZS )FNetBasicOutputc                    s"   t    tj|j|jd| _d S Nr6   )r>   r?   r   rI   rB   rJ   rV   rY   r'   r(   r?      s   
zFNetBasicOutput.__init__c                 C   s   |  || }|S r*   )rI   rW   r{   input_tensorr'   r'   r(   rd      s   zFNetBasicOutput.forwardre   rf   rg   r?   rd   ri   r'   r'   rY   r(   r~          r~   c                       r}   )FNetFourierTransformc                    s"   t    t|| _t|| _d S r*   )r>   r?   rk   rW   r~   outputrV   rY   r'   r(   r?      s   

zFNetFourierTransform.__init__c                 C   s$   |  |}| |d |}|f}|S Nr   )rW   r   )rW   r{   self_outputsfourier_outputr|   r'   r'   r(   rd      s   
zFNetFourierTransform.forwardr   r'   r'   rY   r(   r          r   c                       2   e Zd Z fddZdejdejfddZ  ZS )FNetIntermediatec                    sD   t    t|j|j| _t|jt	rt
|j | _d S |j| _d S r*   )r>   r?   r   rK   rB   intermediate_sizedense
isinstance
hidden_actstrr   intermediate_act_fnrV   rY   r'   r(   r?      s
   
zFNetIntermediate.__init__r{   returnc                 C      |  |}| |}|S r*   )r   r   rW   r{   r'   r'   r(   rd         

zFNetIntermediate.forwardre   rf   rg   r?   r    Tensorrd   ri   r'   r'   rY   r(   r      s    r   c                       s8   e Zd Z fddZdejdejdejfddZ  ZS )
FNetOutputc                    sB   t    t|j|j| _tj|j|jd| _t	|j
| _d S r   )r>   r?   r   rK   r   rB   r   rI   rJ   rM   rN   rO   rV   rY   r'   r(   r?      s   
zFNetOutput.__init__r{   r   r   c                 C   s&   |  |}| |}| || }|S r*   )r   rO   rI   r   r'   r'   r(   rd      s   

zFNetOutput.forwardr   r'   r'   rY   r(   r      s    $r   c                       rj   )	FNetLayerc                    s:   t    |j| _d| _t|| _t|| _t|| _	d S Nr   )
r>   r?   chunk_size_feed_forwardseq_len_dimr   fourierr   intermediater   r   rV   rY   r'   r(   r?      s   


zFNetLayer.__init__c                 C   s0   |  |}|d }t| j| j| j|}|f}|S r   )r   r   feed_forward_chunkr   r   )rW   r{   self_fourier_outputsr   layer_outputr|   r'   r'   r(   rd      s   
zFNetLayer.forwardc                 C   s   |  |}| ||}|S r*   )r   r   )rW   r   intermediate_outputr   r'   r'   r(   r     s   
zFNetLayer.feed_forward_chunk)re   rf   rg   r?   rd   r   ri   r'   r'   rY   r(   r      s    r   c                       s&   e Zd Z fddZdddZ  ZS )FNetEncoderc                    s:   t     | _t fddt jD | _d| _d S )Nc                    s   g | ]}t  qS r'   )r   ).0_rX   r'   r(   
<listcomp>  s    z(FNetEncoder.__init__.<locals>.<listcomp>F)	r>   r?   rX   r   
ModuleListr/   num_hidden_layerslayergradient_checkpointingrV   rY   r   r(   r?   
  s   
 
zFNetEncoder.__init__FTc                 C   sr   |rdnd }t | jD ]\}}|r||f }||}|d }q|r&||f }|s3tdd ||fD S t||dS )Nr'   r   c                 s   s    | ]	}|d ur|V  qd S r*   r'   )r   vr'   r'   r(   	<genexpr>  s    z&FNetEncoder.forward.<locals>.<genexpr>)last_hidden_stater{   )	enumerater   tupler   )rW   r{   output_hidden_statesreturn_dictall_hidden_statesilayer_modulelayer_outputsr'   r'   r(   rd     s   


zFNetEncoder.forward)FTr   r'   r'   rY   r(   r   	  s    r   c                       r   )
FNetPoolerc                    s*   t    t|j|j| _t | _d S r*   )r>   r?   r   rK   rB   r   Tanh
activationrV   rY   r'   r(   r?   &  s   
zFNetPooler.__init__r{   r   c                 C   s(   |d d df }|  |}| |}|S r   )r   r   )rW   r{   first_token_tensorpooled_outputr'   r'   r(   rd   +  s   

zFNetPooler.forwardr   r'   r'   rY   r(   r   %  s    r   c                       r   )FNetPredictionHeadTransformc                    sV   t    t|j|j| _t|jtrt	|j | _
n|j| _
tj|j|jd| _d S r   )r>   r?   r   rK   rB   r   r   r   r   r   transform_act_fnrI   rJ   rV   rY   r'   r(   r?   6  s   
z$FNetPredictionHeadTransform.__init__r{   r   c                 C   s"   |  |}| |}| |}|S r*   )r   r   rI   r   r'   r'   r(   rd   ?  s   


z#FNetPredictionHeadTransform.forwardr   r'   r'   rY   r(   r   5  s    	r   c                       s.   e Zd Z fddZdd Zd	ddZ  ZS )
FNetLMPredictionHeadc                    sH   t    t|| _t|j|j| _t	t
|j| _| j| j_d S r*   )r>   r?   r   	transformr   rK   rB   rA   decoder	Parameterr    rS   biasrV   rY   r'   r(   r?   G  s
   

zFNetLMPredictionHead.__init__c                 C   r   r*   )r   r   r   r'   r'   r(   rd   R  r   zFNetLMPredictionHead.forwardr   Nc                 C   s,   | j jjjdkr| j| j _d S | j j| _d S )Nmeta)r   r   r\   r   rW   r'   r'   r(   _tie_weightsW  s   z!FNetLMPredictionHead._tie_weights)r   N)re   rf   rg   r?   rd   r   ri   r'   r'   rY   r(   r   F  s    r   c                       r}   )FNetOnlyMLMHeadc                    s   t    t|| _d S r*   )r>   r?   r   predictionsrV   rY   r'   r(   r?   a  rm   zFNetOnlyMLMHead.__init__c                 C      |  |}|S r*   )r   )rW   sequence_outputprediction_scoresr'   r'   r(   rd   e     
zFNetOnlyMLMHead.forwardr   r'   r'   rY   r(   r   `  r   r   c                       r}   )FNetOnlyNSPHeadc                    s   t    t|jd| _d S Nrn   )r>   r?   r   rK   rB   seq_relationshiprV   rY   r'   r(   r?   l  s   
zFNetOnlyNSPHead.__init__c                 C   r   r*   )r   )rW   r   seq_relationship_scorer'   r'   r(   rd   p  r   zFNetOnlyNSPHead.forwardr   r'   r'   rY   r(   r   k  r   r   c                       r}   )FNetPreTrainingHeadsc                    s(   t    t|| _t|jd| _d S r   )r>   r?   r   r   r   rK   rB   r   rV   rY   r'   r(   r?   w  s   

zFNetPreTrainingHeads.__init__c                 C   s   |  |}| |}||fS r*   )r   r   )rW   r   r   r   r   r'   r'   r(   rd   |  s   

zFNetPreTrainingHeads.forwardr   r'   r'   rY   r(   r   v  r   r   c                   @   s&   e Zd ZU eed< dZdZdd ZdS )FNetPreTrainedModelrX   fnetTc                 C   s   t |tjr |jjjd| jjd |jdur|jj	  dS dS t |tj
rC|jjjd| jjd |jdurA|jj|j 	  dS dS t |tjrX|jj	  |jjd dS dS )zInitialize the weightsg        )meanstdNg      ?)r   r   rK   weightdatanormal_rX   initializer_ranger   zero_r@   r5   rI   fill_)rW   moduler'   r'   r(   _init_weights  s   

z!FNetPreTrainedModel._init_weightsN)re   rf   rg   r   __annotations__base_model_prefixsupports_gradient_checkpointingr   r'   r'   r'   r(   r     s
   
 r   z0
    Output type of [`FNetForPreTraining`].
    )custom_introc                   @   s^   e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eej ed< dZeeej  ed< dS )FNetForPreTrainingOutputa  
    loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
        Total loss as the sum of the masked language modeling loss and the next sequence prediction
        (classification) loss.
    prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
    seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
        Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
        before SoftMax).
    Nlossprediction_logitsseq_relationship_logitsr{   )re   rf   rg   rh   r   r   r    FloatTensorr   r   r   r{   r   r'   r'   r'   r(   r     s   
 r   c                       s   e Zd ZdZd fdd	Zdd Zdd Ze												dd
ee	j
 dee	j
 dee	j
 dee	j dee dee deeef fddZ  ZS )	FNetModelz

    The model can behave as an encoder, following the architecture described in [FNet: Mixing Tokens with Fourier
    Transforms](https://huggingface.co/papers/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon.

    Tc                    sD   t  | || _t|| _t|| _|rt|nd| _| 	  dS )zv
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        N)
r>   r?   rX   r4   rc   r   encoderr   pooler	post_init)rW   rX   add_pooling_layerrY   r'   r(   r?     s   

zFNetModel.__init__c                 C   s   | j jS r*   rc   rD   r   r'   r'   r(   get_input_embeddings  s   zFNetModel.get_input_embeddingsc                 C   s   || j _d S r*   r   )rW   valuer'   r'   r(   set_input_embeddings  r,   zFNetModel.set_input_embeddingsNr^   r;   r8   r_   r   r   r   c                 C   sv  |d ur|n| j j}|d ur|n| j j}|d ur |d ur td|d ur-| }|\}}	n|d ur>| d d }|\}}	ntd| j jrT|	dkrT| j j|	krTtd|d ur[|jn|j}
|d u rt| j	dr}| j	j
d d d |	f }|||	}|}n	tj|tj|
d}| j	||||d}| j|||d	}|d
 }| jd ur| |nd }|s||f|dd   S t|||jdS )NzDYou cannot specify both input_ids and inputs_embeds at the same timer9   z5You have to specify either input_ids or inputs_embedsrq   zThe `tpu_short_seq_length` in FNetConfig should be set equal to the sequence length being passed to the model when using TPU optimizations.r;   r[   )r^   r8   r;   r_   )r   r   r   r   )r   pooler_outputr{   )rX   r   use_return_dict
ValueErrorrT   rt   rx   r\   r]   rc   r;   rR   r    rS   rU   r   r   r   r{   )rW   r^   r;   r8   r_   r   r   r`   
batch_sizer&   r\   ra   rb   embedding_outputencoder_outputsr   r   r'   r'   r(   rd     s\   

zFNetModel.forward)T)NNNNNN)re   rf   rg   rh   r?   r   r   r   r   r    
LongTensorr   boolr   r   r   rd   ri   r'   r'   rY   r(   r     s6    
r   z
    FNet Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
    sentence prediction (classification)` head.
    c                       s   e Zd ZddgZ fddZdd Zdd Ze																dd
ee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee dee deeef fddZ  ZS )FNetForPreTrainingcls.predictions.decoder.biascls.predictions.decoder.weightc                    ,   t  | t|| _t|| _|   d S r*   )r>   r?   r   r   r   clsr   rV   rY   r'   r(   r?   !     

zFNetForPreTraining.__init__c                 C   
   | j jjS r*   r   r   r   r   r'   r'   r(   get_output_embeddings*     
z(FNetForPreTraining.get_output_embeddingsc                 C      || j j_|j| j j_d S r*   r   r   r   r   rW   new_embeddingsr'   r'   r(   set_output_embeddings-     
z(FNetForPreTraining.set_output_embeddingsNr^   r;   r8   r_   labelsnext_sentence_labelr   r   r   c	                 C   s   |dur|n| j j}| j||||||d}	|	dd \}
}| |
|\}}d}|durP|durPt }||d| j j|d}||dd|d}|| }|sg||f|	dd  }|dure|f| S |S t||||	jdS )aH  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
            (see `input_ids` docstring) Indices should be in `[0, 1]`:

            - 0 indicates sequence B is a continuation of sequence A,
            - 1 indicates sequence B is a random sequence.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, FNetForPreTraining
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
        >>> model = FNetForPreTraining.from_pretrained("google/fnet-base")
        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> prediction_logits = outputs.prediction_logits
        >>> seq_relationship_logits = outputs.seq_relationship_logits
        ```Nr;   r8   r_   r   r   rn   r9   )r   r   r   r{   )	rX   r   r   r   r   viewrA   r   r{   )rW   r^   r;   r8   r_   r  r  r   r   r|   r   r   r   r   
total_lossloss_fctmasked_lm_lossnext_sentence_lossr   r'   r'   r(   rd   1  s4   %	zFNetForPreTraining.forwardNNNNNNNN)re   rf   rg   _tied_weights_keysr?   r   r  r   r   r    r   r   r   r   r   rd   ri   r'   r'   rY   r(   r     sB    		

r   c                       s   e Zd ZddgZ fddZdd Zdd Ze														dd
ee	j
 dee	j
 dee	j
 dee	j
 dee	j
 dee dee deeef fddZ  ZS )FNetForMaskedLMr   r   c                    r   r*   )r>   r?   r   r   r   r   r   rV   rY   r'   r(   r?   {  r   zFNetForMaskedLM.__init__c                 C   r   r*   r   r   r'   r'   r(   r     r   z%FNetForMaskedLM.get_output_embeddingsc                 C   r  r*   r  r  r'   r'   r(   r    r  z%FNetForMaskedLM.set_output_embeddingsNr^   r;   r8   r_   r  r   r   r   c                 C   s   |dur|n| j j}| j||||||d}|d }	| |	}
d}|dur5t }||
d| j j|d}|sK|
f|dd  }|durI|f| S |S t||
|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
        Nr	  r   r9   rn   r   logitsr{   )	rX   r   r   r   r   r
  rA   r   r{   )rW   r^   r;   r8   r_   r  r   r   r|   r   r   r  r  r   r'   r'   r(   rd     s&   	
zFNetForMaskedLM.forwardNNNNNNN)re   rf   rg   r  r?   r   r  r   r   r    r   r   r   r   r   rd   ri   r'   r'   rY   r(   r  w  s<    	
	r  zT
    FNet Model with a `next sentence prediction (classification)` head on top.
    c                          e Zd Z fddZe							ddeej deej deej deej deej d	ee d
ee de	e
ef fddZ  ZS )FNetForNextSentencePredictionc                    r   r*   )r>   r?   r   r   r   r   r   rV   rY   r'   r(   r?     r   z&FNetForNextSentencePrediction.__init__Nr^   r;   r8   r_   r  r   r   r   c                 K   s   d|v rt dt |d}|dur|n| jj}| j||||||d}	|	d }
| |
}d}|durBt }||	dd|	d}|sX|f|	dd  }|durV|f| S |S t
|||	jdS )	a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
            (see `input_ids` docstring). Indices should be in `[0, 1]`:

            - 0 indicates sequence B is a continuation of sequence A,
            - 1 indicates sequence B is a random sequence.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, FNetForNextSentencePrediction
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
        >>> model = FNetForNextSentencePrediction.from_pretrained("google/fnet-base")
        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
        >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
        >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
        >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
        >>> logits = outputs.logits
        >>> assert logits[0, 0] < logits[0, 1]  # next sentence was random
        ```r  zoThe `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.Nr	  r   r9   rn   r  )warningswarnFutureWarningpoprX   r   r   r   r   r
  r   r{   )rW   r^   r;   r8   r_   r  r   r   kwargsr|   r   seq_relationship_scoresr  r  r   r'   r'   r(   rd     s:   $
	
z%FNetForNextSentencePrediction.forwardr  )re   rf   rg   r?   r   r   r    r   r   r   r   r   rd   ri   r'   r'   rY   r(   r    s6    	

r  z
    FNet Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    c                       r  )FNetForSequenceClassificationc                    J   t  | |j| _t|| _t|j| _t	|j
|j| _|   d S r*   r>   r?   
num_labelsr   r   r   rM   rN   rO   rK   rB   
classifierr   rV   rY   r'   r(   r?     s   
z&FNetForSequenceClassification.__init__Nr^   r;   r8   r_   r  r   r   r   c                 C   sh  |dur|n| j j}| j||||||d}|d }	| |	}	| |	}
d}|dur| j jdu rS| jdkr9d| j _n| jdkrO|jtj	ksJ|jtj
krOd| j _nd| j _| j jdkrqt }| jdkrk||
 | }n+||
|}n%| j jdkrt }||
d| j|d}n| j jdkrt }||
|}|s|
f|dd  }|dur|f| S |S t||
|jd	S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr	  r   
regressionsingle_label_classificationmulti_label_classificationr9   rn   r  )rX   r   r   rO   r!  problem_typer   r=   r    rU   intr	   squeezer   r
  r   r   r{   )rW   r^   r;   r8   r_   r  r   r   r|   r   r  r   r  r   r'   r'   r(   rd   "  sF   	



"


z%FNetForSequenceClassification.forwardr  )re   rf   rg   r?   r   r   r    r   r   r   r   r   rd   ri   r'   r'   rY   r(   r    s6    
	r  c                       r  )FNetForMultipleChoicec                    s@   t  | t|| _t|j| _t|j	d| _
|   d S r   )r>   r?   r   r   r   rM   rN   rO   rK   rB   r!  r   rV   rY   r'   r(   r?   a  s
   
zFNetForMultipleChoice.__init__Nr^   r;   r8   r_   r  r   r   r   c                 C   sF  |dur|n| j j}|dur|jd n|jd }|dur%|d|dnd}|dur4|d|dnd}|durC|d|dnd}|durV|d|d|dnd}| j||||||d}	|	d }
| |
}
| |
}|d|}d}|durt }|||}|s|f|	dd  }|dur|f| S |S t	|||	j
dS )a[  
        input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
            model's internal embedding lookup matrix.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        Nr   r9   r	  rn   r  )rX   r   r   r
  rT   r   rO   r!  r   r   r{   )rW   r^   r;   r8   r_   r  r   r   num_choicesr|   r   r  reshaped_logitsr   r  r   r'   r'   r(   rd   k  s:   )	


zFNetForMultipleChoice.forwardr  )re   rf   rg   r?   r   r   r    r   r   r   r   r   rd   ri   r'   r'   rY   r(   r(  _  s6    

	r(  c                       r  )FNetForTokenClassificationc                    r  r*   r  rV   rY   r'   r(   r?     s   
z#FNetForTokenClassification.__init__Nr^   r;   r8   r_   r  r   r   r   c                 C   s   |dur|n| j j}| j||||||d}|d }	| |	}	| |	}
d}|dur9t }||
d| j|d}|sO|
f|dd  }|durM|f| S |S t||
|j	dS )z
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        Nr	  r   r9   rn   r  )
rX   r   r   rO   r!  r   r
  r   r   r{   )rW   r^   r;   r8   r_   r  r   r   r|   r   r  r   r  r   r'   r'   r(   rd     s(   	

z"FNetForTokenClassification.forwardr  )re   rf   rg   r?   r   r   r    r   r   r   r   r   rd   ri   r'   r'   rY   r(   r,    s6    
	r,  c                       s   e Zd Z fddZe								ddeej deej deej deej deej d	eej d
ee dee de	e
ef fddZ  ZS )FNetForQuestionAnsweringc                    s<   t  | |j| _t|| _t|j|j| _| 	  d S r*   )
r>   r?   r   r   r   r   rK   rB   
qa_outputsr   rV   rY   r'   r(   r?     s
   
z!FNetForQuestionAnswering.__init__Nr^   r;   r8   r_   start_positionsend_positionsr   r   r   c	                 C   s>  |d ur|n| j j}| j||||||d}	|	d }
| |
}|jddd\}}|d }|d }d }|d ur|d urt| dkrL|d}t| dkrY|d}|d}|	d|}|	d|}t
|d}|||}|||}|| d }|s||f|	dd   }|d ur|f| S |S t||||	jdS )	Nr	  r   r   r9   ro   )ignore_indexrn   )r   start_logits
end_logitsr{   )rX   r   r   r.  splitr'  
contiguouslenrT   clampr   r   r{   )rW   r^   r;   r8   r_   r/  r0  r   r   r|   r   r  r2  r3  r  ignored_indexr  
start_lossend_lossr   r'   r'   r(   rd     sB   	







z FNetForQuestionAnswering.forwardr  )re   rf   rg   r?   r   r   r    r   r   r   r   r   rd   ri   r'   r'   rY   r(   r-    s<    	

r-  )
r  r(  r  r   r-  r  r,  r   r   r   )Irh   r  dataclassesr   	functoolsr   typingr   r   r    r   torch.nnr   r   r	   utilsr   r   scipyr   activationsr   modeling_layersr   modeling_outputsr   r   r   r   r   r   r   r   r   modeling_utilsr   pytorch_utilsr   r   configuration_fnetr   
get_loggerre   loggerr)   r+   r3   Moduler4   rk   r~   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r(  r,  r-  __all__r'   r'   r'   r(   <module>   s   ,
	=&
eY>UI[9D