o
    	۷i                     @   s  d Z ddlZddlmZmZ ddlZddlmZ ddlmZ ddl	m
Z
 ddlmZmZmZ dd	lmZ dd
lmZmZ ddlmZ ddlmZmZ ddlmZ ddlmZmZ ddlmZ ddl m!Z! e"e#Z$G dd dej%Z&G dd dej%Z'G dd dej(Z)G dd dej(Z*G dd deZ+eG dd deZ,G dd de,Z-ed d!G d"d# d#e,Z.ed$d!G d%d& d&e,eZ/d&dgZ0dS )'z/PyTorch TrOCR decoder model (based on RoBERTa).    N)OptionalUnion)nn)CrossEntropyLoss   )ACT2FN)CacheDynamicCacheEncoderDecoderCache)GenerationMixin)_prepare_4d_attention_mask!_prepare_4d_causal_attention_mask)GradientCheckpointingLayer))BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentions)PreTrainedModel)auto_docstringlogging)deprecate_kwarg   )TrOCRConfigc                       sP   e Zd ZdZdedef fddZ	ddejd	ed
eej f fddZ	  Z
S )TrOCRLearnedPositionalEmbeddingzN
    This module learns positional embeddings up to a fixed maximum size.
    num_embeddingsembedding_dimc                    s   d| _ t || j  | d S )N   )offsetsuper__init__)selfr   r   	__class__ ^/home/ubuntu/vllm_env/lib/python3.10/site-packages/transformers/models/trocr/modeling_trocr.pyr   0   s   z(TrOCRLearnedPositionalEmbedding.__init__r   N	input_idspast_key_values_lengthposition_idsc                    s\   |du r |j dd \}}tj||| tj| jjd|d}n|d}t 	|| j
 S )z3`input_ids' shape is expected to be [bsz x seqlen].Nr   )dtypedevicer   )shapetorcharangelongweightr'   expand	unsqueezer   forwardr   )r   r#   r$   r%   bszseq_lenr   r!   r"   r0   6   s   
z'TrOCRLearnedPositionalEmbedding.forward)r   N)__name__
__module____qualname____doc__intr   r*   Tensorr   r0   __classcell__r!   r!   r   r"   r   +   s    r   c                
       sL   e Zd ZdZddedededee f fddZd	ej	f fd
dZ
  ZS )TrOCRScaledWordEmbeddingz\
    This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
          ?r   r   padding_idxembed_scalec                    s   t  ||| || _d S N)r   r   r=   )r   r   r   r<   r=   r   r!   r"   r   L   s   
z!TrOCRScaledWordEmbedding.__init__r#   c                    s   t  || j S r>   )r   r0   r=   )r   r#   r   r!   r"   r0   P   s   z TrOCRScaledWordEmbedding.forward)r;   )r3   r4   r5   r6   r7   r   floatr   r*   r8   r0   r9   r!   r!   r   r"   r:   G   s    $r:   c                	       s   e Zd ZdZddededee f fddZeddededee fd	d
Ze	
 dde	jdefddZ	dde	jdedee fddZ  ZS )"TrOCRSinusoidalPositionalEmbeddingzDThis module produces sinusoidal positional embeddings of any length.Nnum_positionsr   r<   c                    sB   t    d| _|| _|| _| |||| _| dt	d d S )Nr   _float_tensorr   )
r   r   r   r   r<   get_embeddingweightsregister_bufferr*   FloatTensor)r   rA   r   r<   r   r!   r"   r   W   s   
z+TrOCRSinusoidalPositionalEmbedding.__init__r   c                 C   s   |d }t d|d  }ttj|tjd |  }tj| tjd d|d }tjt	|t
|gdd| d}|d dkrUtj|t| dgdd}|durad||ddf< |t S )	z
        Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
        description in Section 3.5 of "Attention Is All You Need".
        r   i'  r   )r&   r   dimr(   N)mathlogr*   expr+   int64r?   r/   catsincosviewzerostoget_default_dtype)r   r   r<   half_dimembr!   r!   r"   rC   _   s    $&z0TrOCRSinusoidalPositionalEmbedding.get_embeddingr   r#   r$   c                 C   s   |  \}}| || j||j}| jd | }| jd u s&|| j dkr0| || j| j| _| j| j| _| j	d|
d
||d }|S )Nr   r   r(   )size"create_position_ids_from_input_idsr<   rR   r'   rD   rC   r   rB   index_selectrP   detach)r   r#   r$   r1   r2   r%   max_posxr!   r!   r"   r0   r   s   "z*TrOCRSinusoidalPositionalEmbedding.forwardc                 C   s6   | | }tj|dd|| | }| | S )z
        Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
        symbols are ignored. This is modified from fairseq's `utils.make_positions`.
        r   rG   )ner7   r*   cumsumtype_asr,   )r   r#   r<   r$   maskincremental_indicesr!   r!   r"   rW      s   zETrOCRSinusoidalPositionalEmbedding.create_position_ids_from_input_idsr>   )r   )r3   r4   r5   r6   r7   r   r   staticmethodrC   r*   no_gradr8   r0   rW   r9   r!   r!   r   r"   r@   T   s     r@   c                       s   e Zd ZdZ							ddededee d	ee d
ee dee dee dee dee f fddZe	dddd						dde
jdee
j dee dee
j dee
j dee dee
j dee
jee
j eee
j  f fddZ  ZS ) TrOCRAttentionz>Multi-headed attention from 'Attention Is All You Need' paper.N        FT	embed_dim	num_headskdimvdimdropout
is_decoderbiasis_cross_attention	layer_idxc                    s   t    || _|d ur|n|| _|d ur|n|| _|| _|| _|| | _| j| | jks9td| j d| d| jd | _	|| _
|
| _tj| j||d| _tj| j||d| _tj|||d| _tj|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      ࿩rk   )r   r   re   rg   rh   rf   ri   head_dim
ValueErrorscalingrj   rm   r   Lineark_projv_projq_projout_proj)r   configre   rf   rg   rh   ri   rj   rk   rl   rm   r   r!   r"   r      s(   


zTrOCRAttention.__init__past_key_valuepast_key_values4.58new_nameversionhidden_stateskey_value_statesattention_masklayer_head_maskoutput_attentionscache_positionreturnc                 C   s  |du}|  \}	}
}| || j }d}|dur1t|tr/|j| j}|r+|j}n|j	}n|}|r5|n|}|rN|durN|rN|j
| j j}|j
| j j}nJ| |}| |}||	d| j| jdd}||	d| j| jdd}|dur|s||nd}|||| jd|i\}}|rt|trd|j| j< |	| j d| jf}||	|
| j| jdd}|j| }|j| }|j| }| d}t||dd}|  |	| j |
|fkrtd|	| j |
|f d	|   |dur |  |	d|
|fkrtd
|	d|
|f d	|   ||	| j|
|| }||	| j |
|}tjj|dd}|dur_|  | jfkrDtd| jf d	|   |dddd||	| j|
| }||	| j |
|}|rv||	| j|
|}||	| j |
|}nd}tjj|| j| jd}t||}|  |	| j |
| jfkrtd|	| j|
| jf d	|   ||	| j|
| j}|dd}||	|
|}| |}||fS )z#Input shape: Batch x Time x ChannelNFr(   r   r   r   Tz$Attention weights should be of size z	, but is z!Attention mask should be of size rG   z/Head mask for a single layer should be of size ptrainingz `attn_output` should be of size )rV   ru   rq   
isinstancer
   
is_updatedgetrm   cross_attention_cacheself_attention_cachelayerskeysvaluesrs   rt   rP   rf   ro   	transposeupdatereshaper*   bmmrp   r   
functionalsoftmaxri   r   rv   )r   r~   r   ry   r   r   r   r   rl   r1   tgt_lenre   query_statesr   curr_past_key_valuecurrent_states
key_statesvalue_states
proj_shapesrc_lenattn_weightsattn_weights_reshaped
attn_probsattn_outputr!   r!   r"   r0      s   








"
zTrOCRAttention.forward)NNrd   FTFN)NNNNFN)r3   r4   r5   r6   r7   r   r?   boolr   r   r*   r8   r   tupler0   r9   r!   r!   r   r"   rc      sh    	
#	rc   c                       s   e Zd Zddef fddZedddd									
	ddejdeej deej deej deej deej dee	 dee
 dee
 deej fddZ  ZS )TrOCRDecoderLayerNrw   c                    s   t    |j| _t|| j|j|jd|d| _|j| _t	|j
 | _|j| _t| j| _|jrIt|| j|j|j|j|jdd|d	| _t| j| _t| j|j| _t|j| j| _t| j| _d S )NT)re   rf   ri   rj   rm   )re   rf   rg   rh   ri   rj   rl   rm   )r   r   hidden_sizere   rc   decoder_attention_headsattention_dropout	self_attnri   r   activation_functionactivation_fnactivation_dropoutr   	LayerNormself_attn_layer_normrj   cross_attention_hidden_sizeencoder_attnencoder_attn_layer_normrr   decoder_ffn_dimfc1fc2final_layer_norm)r   rw   rm   r   r!   r"   r   .  s<   
zTrOCRDecoderLayer.__init__rx   ry   rz   r{   FTr~   r   encoder_hidden_statesencoder_attention_maskr   cross_attn_layer_head_maskr   	use_cacher   c              	   C   s  |}| j ||||||
d\}}tjj|| j| jd}|| }| |}d}|durM|}| j|||||||
d\}}tjj|| j| jd}|| }| |}|}| | 	|}tjj|| j
| jd}| |}tjj|| j| jd}|| }| |}|f}|r|||f7 }|S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            encoder_hidden_states (`torch.FloatTensor`):
                cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
                size *(decoder_attention_heads,)*.
            past_key_values (`Cache`): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r~   ry   r   r   r   r   r   N)r~   r   r   r   ry   r   r   )r   r   r   ri   r   r   r   r   r   r   r   r   r   )r   r~   r   r   r   r   r   ry   r   r   r   residualself_attn_weightscross_attn_weightsoutputsr!   r!   r"   r0   R  sL    
	





zTrOCRDecoderLayer.forwardr>   )	NNNNNNFTN)r3   r4   r5   r   r   r   r*   r8   r   r   r   r0   r9   r!   r!   r   r"   r   -  sB    $	
r   c                   @   s,   e Zd ZU eed< dZdZdgZdd ZdS )TrOCRPreTrainedModelrw   modelTr   c                 C   s   | j j}t|tjtjfr%|jjjd|d |j	d ur#|j	j
  d S d S t|tjrD|jjjd|d |jd urF|jj|j 
  d S d S d S )Nrd   )meanstd)rw   init_stdr   r   rr   Conv1dr-   datanormal_rk   zero_	Embeddingr<   )r   moduler   r!   r!   r"   _init_weights  s   

z"TrOCRPreTrainedModel._init_weightsN)	r3   r4   r5   r   __annotations__base_model_prefixsupports_gradient_checkpointing_no_split_modulesr   r!   r!   r!   r"   r     s   
 r   c                       sJ   e Zd ZdZdef fddZ													dddZ  ZS )	TrOCRDecoderz
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TrOCRDecoderLayer`]

    Args:
        config: TrOCRConfig
    rw   c                    s   t     j| _ j| _ j| _ jrt	 j
nd}t j j
| j|d| _ jr5t j j
| _nt j| j d  j
| j| _ jrOt j
| _nd | _t fddt jD | _d| _|   d S )Nr;   )r=   r   c                    s   g | ]}t  |d qS ))rm   )r   ).0irw   r!   r"   
<listcomp>  s    z)TrOCRDecoder.__init__.<locals>.<listcomp>F)r   r   ri   decoder_layerdrop	layerdroppad_token_idr<   scale_embeddingrI   sqrtr   r:   
vocab_sizeembed_tokensuse_learned_position_embeddingsr   max_position_embeddingsembed_positionsr@   layernorm_embeddingr   r   
ModuleListrangedecoder_layersr   gradient_checkpointing	post_init)r   rw   r=   r   r   r"   r     s*    zTrOCRDecoder.__init__Nc                 C   s  |
dur|
n| j j}
|dur|n| j j}|	dur|	n| j j}	|dur$|n| j j}|dur4|dur4td|durD|}|d|jd }n|dur\| dd }|dddddf }ntd| j	ro| j
ro|	rotd d}	|	r|du r|durtt| j dt| j dnt| j d}|	rt|trtd t|}|dur| nd	}|du r| |}| j jr| j||d
}n| j||d
}|| }| jdur| |}tjj|| j| j
d}|j}t||||}|dur|durt||j|d d}|rdnd}|
rdnd}|
r|durdnd}t||gddgD ]+\}}|durD| d	 t| jkrDtd| dt| j d| d	  dqt | jD ]X\}}|rW||f7 }| j
rht!"g }|| j#k rhqK||||||durv|| nd|dur|| nd||
|	|d
}|d	 }|
r||d f7 }|dur||d f7 }qK|r||f7 }|stdd |||||fD S t$|||||dS )a;  
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
                of the decoder.
            encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
                Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
                selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
                on hidden heads. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
                It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).

                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.

                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        NzTYou cannot specify both decoder_input_ids and decoder_inputs_embeds at the same timer(   zEYou have to specify either decoder_input_ids or decoder_inputs_embedsz^`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`...Fr   zPassing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.r   )r$   r   )r   r!   	head_maskcross_attn_head_maskzThe `z` should be specified for z layers, but it is for .)r   r   r   ry   r   r   r   r   r   c                 s   s    | ]	}|d ur|V  qd S r>   r!   )r   vr!   r!   r"   	<genexpr>  s    z'TrOCRDecoder.forward.<locals>.<genexpr>)last_hidden_statery   r~   
attentionscross_attentions)%rw   r   output_hidden_statesr   use_return_dictrp   rP   r)   rV   r   r   loggerwarning_oncer
   r	   r   r   from_legacy_cacheget_seq_lengthr   r   r   r   r   r   ri   r   r   r&   ziplenr   	enumerater*   randr   r   )r   r#   r   r   r   r   r   ry   inputs_embedsr   r   r   return_dictr   inputinput_shaper$   	embed_posr~   all_hidden_statesall_self_attnsall_cross_attentions	attn_mask	mask_nameidxdecoder_layerdropout_probabilitylayer_outputsr!   r!   r"   r0     s   O










zTrOCRDecoder.forward)NNNNNNNNNNNNN)r3   r4   r5   r6   r   r   r0   r9   r!   r!   r   r"   r     s"    !r   a  
    The TrOCR Model with a language modeling head. Can be used for summarization.
    This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
    used in combination with the [`EncoderDecoderModel`] framework.
    )custom_introc                       s$   e Zd Z fddZdd Z  ZS )TrOCRDecoderWrapperc                    s   t  | t|| _d S r>   )r   r   r   decoderr   rw   r   r!   r"   r     s   zTrOCRDecoderWrapper.__init__c                 O   s   | j |i |S r>   )r  )r   argskwargsr!   r!   r"   r0     s   zTrOCRDecoderWrapper.forward)r3   r4   r5   r   r0   r9   r!   r!   r   r"   r    s    r  zy
    The TrOCR Decoder with a language modeling head. Can be used as the decoder part of [`EncoderDecoderModel`] and
    c                "       s  e Zd ZdgZ fddZdd Zdd Zdd	 Zd
d Zdd Z	dd Z
e														d"deej deej deej deej deej deej dee deej deej dee dee dee dee deej deeef fd d!Z  ZS )#TrOCRForCausalLMzoutput_projection.weightc                    sD   d|_ d|_t | t|| _tj|j|j	dd| _
|   d S )NTFrn   )rj   is_encoder_decoderr   r   r  r   r   rr   r   r   output_projectionr   r  r   r!   r"   r     s   
zTrOCRForCausalLM.__init__c                 C   s
   | j jjS r>   r   r  r   r   r!   r!   r"   get_input_embeddings     
z%TrOCRForCausalLM.get_input_embeddingsc                 C   s   || j j_d S r>   r  )r   valuer!   r!   r"   set_input_embeddings  s   z%TrOCRForCausalLM.set_input_embeddingsc                 C   s   | j S r>   r
  r  r!   r!   r"   get_output_embeddings  s   z&TrOCRForCausalLM.get_output_embeddingsc                 C   s
   || _ d S r>   r  )r   new_embeddingsr!   r!   r"   set_output_embeddings  r  z&TrOCRForCausalLM.set_output_embeddingsc                 C   s   || j _d S r>   r   r  )r   r  r!   r!   r"   set_decoder  s   zTrOCRForCausalLM.set_decoderc                 C   s   | j jS r>   r  r  r!   r!   r"   get_decoder  s   zTrOCRForCausalLM.get_decoderNr#   r   r   r   r   r   ry   r   labelsr   r   r   r   r   r   c                 C   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}| jj|||||||||
||||d}| |d }d}|	durOt }||d| j j	|	d}|se|f|dd  }|durc|f| S |S t
|||j|j|j|jdS )a
  
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from transformers import (
        ...     TrOCRConfig,
        ...     TrOCRProcessor,
        ...     TrOCRForCausalLM,
        ...     ViTConfig,
        ...     ViTModel,
        ...     VisionEncoderDecoderModel,
        ... )
        >>> import requests
        >>> from PIL import Image

        >>> # TrOCR is a decoder model and should be used within a VisionEncoderDecoderModel
        >>> # init vision2text model with random weights
        >>> encoder = ViTModel(ViTConfig())
        >>> decoder = TrOCRForCausalLM(TrOCRConfig())
        >>> model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder)

        >>> # If you want to start from the pretrained model, load the checkpoint with `VisionEncoderDecoderModel`
        >>> processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
        >>> model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

        >>> # load image from the IAM dataset
        >>> url = "https://fki.tic.heia-fr.ch/static/img/a01-122-02.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
        >>> pixel_values = processor(image, return_tensors="pt").pixel_values
        >>> text = "industry, ' Mr. Brown commented icily. ' Let us have a"

        >>> # training
        >>> model.config.decoder_start_token_id = processor.tokenizer.eos_token_id
        >>> model.config.pad_token_id = processor.tokenizer.pad_token_id
        >>> model.config.vocab_size = model.config.decoder.vocab_size

        >>> labels = processor.tokenizer(text, return_tensors="pt").input_ids
        >>> outputs = model(pixel_values, labels=labels)
        >>> loss = outputs.loss
        >>> round(loss.item(), 2)
        5.30

        >>> # inference
        >>> generated_ids = model.generate(pixel_values)
        >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        >>> generated_text
        'industry, " Mr. Brown commented icily. " Let us have a'
        ```N)r#   r   r   r   r   r   ry   r   r   r   r   r   r   r   r(   r   )losslogitsry   r~   r   r   )rw   r   r   r   r   r  r
  r   rP   r   r   ry   r~   r   r   )r   r#   r   r   r   r   r   ry   r   r  r   r   r   r   r   r   r  r  loss_fctoutputr!   r!   r"   r0     sF   MzTrOCRForCausalLM.forward)NNNNNNNNNNNNNN)r3   r4   r5   _tied_weights_keysr   r  r  r  r  r  r  r   r   r*   
LongTensorr8   rF   r   r   r   r   r   r0   r9   r!   r!   r   r"   r    sn    	

r  )1r6   rI   typingr   r   r*   r   torch.nnr   activationsr   cache_utilsr   r	   r
   
generationr   modeling_attn_mask_utilsr   r   modeling_layersr   modeling_outputsr   r   modeling_utilsr   utilsr   r   utils.deprecationr   configuration_trocrr   
get_loggerr3   r   r   r   r:   Moduler@   rc   r   r   r   r  r  __all__r!   r!   r!   r"   <module>   sJ   
> z z	 