o
    ix                  	   @   sX  d Z ddlZddlZddlmZ ddlmZmZ ddl	Z	ddl	m
Z
 ddlmZ ddlmZ dd	lmZ dd
lmZmZmZ ddlmZmZmZmZ ddlmZ eeZeeddG dd deZeeddG dd deZ eeddG dd deZ!dd Z"dd Z#G dd de
j$Z%G dd de
j$Z&G d d! d!e
j$Z'dDd$e	j(d%e)d&e*d'e	j(fd(d)Z+G d*d+ d+e
j$Z,G d,d- d-e
j$Z-G d.d/ d/e
j$Z.G d0d1 d1e
j$Z/G d2d3 d3e
j$Z0G d4d5 d5e
j$Z1G d6d7 d7e
j$Z2G d8d9 d9eZ3G d:d; d;e
j$Z4eG d<d= d=eZ5eG d>d? d?e5Z6ed@dG dAdB dBe5Z7g dCZ8dS )EzPyTorch Donut Swin Transformer model.

This implementation is identical to a regular Swin Transformer, without final layer norm on top of the final hidden
states.    N)	dataclass)OptionalUnion)nn   )ACT2FN)GradientCheckpointingLayer)PreTrainedModel) find_pruneable_heads_and_indicesmeshgridprune_linear_layer)ModelOutputauto_docstringlogging	torch_int   )DonutSwinConfigzS
    DonutSwin encoder's outputs, with potential hidden states and attentions.
    )custom_introc                   @   sr   e Zd ZU dZdZeej ed< dZ	ee
ejdf  ed< dZee
ejdf  ed< dZee
ejdf  ed< dS )DonutSwinEncoderOutputa  
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nlast_hidden_state.hidden_states
attentionsreshaped_hidden_states)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   tupler   r    r!   r!   a/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/donut/modeling_donut_swin.pyr   '   s   
 	r   z[
    DonutSwin model's outputs that also contains a pooling of the last hidden states.
    c                   @      e Zd ZU dZdZeej ed< dZ	eej ed< dZ
eeejdf  ed< dZeeejdf  ed< dZeeejdf  ed< dS )	DonutSwinModelOutputa  
    pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
        Average pooling of the last layer hidden-state.
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nr   pooler_output.r   r   r   )r   r   r   r   r   r   r   r   r   r%   r   r    r   r   r!   r!   r!   r"   r$   >   s   
 r$   z5
    DonutSwin outputs for image classification.
    c                   @   r#   )	DonutSwinImageClassifierOutputa7  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Classification (or regression if config.num_labels==1) loss.
    logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
        Classification (or regression if config.num_labels==1) scores (before SoftMax).
    reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
        shape `(batch_size, hidden_size, height, width)`.

        Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
        include the spatial dimensions.
    Nlosslogits.r   r   r   )r   r   r   r   r'   r   r   r   r   r(   r   r    r   r   r!   r!   r!   r"   r&   X   s   
 r&   c                 C   sR   | j \}}}}| ||| ||| ||} | dddddd d|||}|S )z2
    Partitions the given input into windows.
    r   r   r            shapeviewpermute
contiguous)input_featurewindow_size
batch_sizeheightwidthnum_channelswindowsr!   r!   r"   window_partitionu   s   $r9   c                 C   sN   | j d }| d|| || |||} | dddddd d|||} | S )z?
    Merges windows to produce higher resolution features.
    r,   r   r   r   r)   r*   r+   r-   )r8   r3   r5   r6   r7   r!   r!   r"   window_reverse   s   
$r:   c                
       sr   e Zd ZdZd fdd	Zdejdededejfd	d
Z		dde	ej
 de	ej dedeej fddZ  ZS )DonutSwinEmbeddingszW
    Construct the patch and position embeddings. Optionally, also the mask token.
    Fc                    s   t    t|| _| jj}| jj| _|r tt	
dd|jnd | _|jr5tt	
d|d |j| _nd | _t|j| _t|j| _|j| _|| _d S )Nr   )super__init__DonutSwinPatchEmbeddingspatch_embeddingsnum_patches	grid_size
patch_gridr   	Parameterr   zeros	embed_dim
mask_tokenuse_absolute_embeddingsposition_embeddings	LayerNormnormDropouthidden_dropout_probdropout
patch_sizeconfig)selfrO   use_mask_tokenr@   	__class__r!   r"   r=      s   


 
zDonutSwinEmbeddings.__init__
embeddingsr5   r6   returnc                 C   s   |j d d }| jj d d }tj s||kr||kr| jS | jddddf }| jddddf }|j d }|| j }	|| j }
t|d }|d|||}|dddd}t	j
j||	|
fdd	d
}|dddddd|}tj||fddS )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   Nr,   g      ?r   r   r)   bicubicF)sizemodealign_cornersdim)r.   rH   r   jit
is_tracingrN   r   reshaper0   r   
functionalinterpolater/   cat)rP   rT   r5   r6   r@   num_positionsclass_pos_embedpatch_pos_embedr[   
new_height	new_widthsqrt_num_positionsr!   r!   r"   interpolate_pos_encoding   s(   



z,DonutSwinEmbeddings.interpolate_pos_encodingNpixel_valuesbool_masked_posrh   c                 C   s   |j \}}}}| |\}}	| |}| \}
}}|d ur8| j|
|d}|d|}|d|  ||  }| jd urN|rI|| 	||| }n|| j }| 
|}||	fS )Nr,         ?)r.   r?   rJ   rW   rF   expand	unsqueezetype_asrH   rh   rM   )rP   ri   rj   rh   _r7   r5   r6   rT   output_dimensionsr4   seq_lenmask_tokensmaskr!   r!   r"   forward   s   



zDonutSwinEmbeddings.forward)F)NF)r   r   r   r   r=   r   Tensorintrh   r   r   
BoolTensorboolr    rt   __classcell__r!   r!   rR   r"   r;      s    +r;   c                       sN   e Zd ZdZ fddZdd Zdeej de	ej
e	e f fdd	Z  ZS )
r>   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    s   t    |j|j}}|j|j}}t|tjj	r|n||f}t|tjj	r)|n||f}|d |d  |d |d   }|| _|| _|| _|| _
|d |d  |d |d  f| _tj||||d| _d S )Nr   r   )kernel_sizestride)r<   r=   
image_sizerN   r7   rE   
isinstancecollectionsabcIterabler@   rA   r   Conv2d
projection)rP   rO   r|   rN   r7   hidden_sizer@   rR   r!   r"   r=      s   
 "z!DonutSwinPatchEmbeddings.__init__c                 C   s   || j d  dkrd| j d || j d   f}tj||}|| j d  dkr>ddd| j d || j d   f}tj||}|S )Nr   r   )rN   r   r_   pad)rP   ri   r5   r6   
pad_valuesr!   r!   r"   	maybe_pad   s    z"DonutSwinPatchEmbeddings.maybe_padri   rU   c                 C   sV   |j \}}}}| |||}| |}|j \}}}}||f}|ddd}||fS )Nr)   r   )r.   r   r   flatten	transpose)rP   ri   ro   r7   r5   r6   rT   rp   r!   r!   r"   rt   	  s   
z DonutSwinPatchEmbeddings.forward)r   r   r   r   r=   r   r   r   r   r    ru   rv   rt   ry   r!   r!   rR   r"   r>      s
    .	r>   c                	       sh   e Zd ZdZejfdee dedejddf fddZ	d	d
 Z
dejdeeef dejfddZ  ZS )DonutSwinPatchMerginga'  
    Patch Merging Layer.

    Args:
        input_resolution (`tuple[int]`):
            Resolution of input feature.
        dim (`int`):
            Number of input channels.
        norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
            Normalization layer class.
    input_resolutionr[   
norm_layerrU   Nc                    sB   t    || _|| _tjd| d| dd| _|d| | _d S )Nr*   r)   Fbias)r<   r=   r   r[   r   Linear	reductionrJ   )rP   r   r[   r   rR   r!   r"   r=   #  s
   
zDonutSwinPatchMerging.__init__c                 C   sF   |d dkp|d dk}|r!ddd|d d|d f}t j||}|S )Nr)   r   r   )r   r_   r   )rP   r2   r5   r6   
should_padr   r!   r!   r"   r   *  s
   zDonutSwinPatchMerging.maybe_padr2   input_dimensionsc                 C   s   |\}}|j \}}}|||||}| |||}|d d dd ddd dd d f }|d d dd ddd dd d f }	|d d dd ddd dd d f }
|d d dd ddd dd d f }t||	|
|gd}||dd| }| |}| |}|S )Nr   r)   r   r,   r*   )r.   r/   r   r   ra   rJ   r   )rP   r2   r   r5   r6   r4   r[   r7   input_feature_0input_feature_1input_feature_2input_feature_3r!   r!   r"   rt   2  s   $$$$

zDonutSwinPatchMerging.forward)r   r   r   r   r   rI   r    rv   Moduler=   r   r   ru   rt   ry   r!   r!   rR   r"   r     s
    **r           Finput	drop_probtrainingrU   c                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    r   r   r   )r   dtypedevice)r.   ndimr   randr   r   floor_div)r   r   r   	keep_probr.   random_tensoroutputr!   r!   r"   	drop_pathM  s   
r   c                       sT   e Zd ZdZddee ddf fddZdejdejfdd	Z	de
fd
dZ  ZS )DonutSwinDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).Nr   rU   c                    s   t    || _d S N)r<   r=   r   )rP   r   rR   r!   r"   r=   e  s   

zDonutSwinDropPath.__init__r   c                 C   s   t || j| jS r   )r   r   r   rP   r   r!   r!   r"   rt   i  s   zDonutSwinDropPath.forwardc                 C   s   d| j  S )Nzp=)r   rP   r!   r!   r"   
extra_reprl  s   zDonutSwinDropPath.extra_reprr   )r   r   r   r   r   floatr=   r   ru   rt   strr   ry   r!   r!   rR   r"   r   b  s
    r   c                       sZ   e Zd Z fddZ			ddejdeej deej dee d	e	ej f
d
dZ
  ZS )DonutSwinSelfAttentionc                    s
  t    || dkrtd| d| d|| _t|| | _| j| j | _t|tj	j
r0|n||f| _ttd| jd  d d| jd  d  || _t| jd }t| jd }tt||gdd}t|d}|d d d d d f |d d d d d f  }	|	ddd }	|	d d d d df  | jd d 7  < |	d d d d df  | jd d 7  < |	d d d d df  d| jd  d 9  < |	d	}
| d
|
 tj| j| j|jd| _tj| j| j|jd| _tj| j| j|jd| _t|j| _ d S )Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ()r)   r   ij)indexingr,   relative_position_indexr   )!r<   r=   
ValueErrornum_attention_headsrv   attention_head_sizeall_head_sizer}   r~   r   r   r3   r   rC   r   rD   relative_position_bias_tablearangestackr   r   r0   r1   sumregister_bufferr   qkv_biasquerykeyvaluerK   attention_probs_dropout_probrM   )rP   rO   r[   	num_headsr3   coords_hcoords_wcoordscoords_flattenrelative_coordsr   rR   r!   r"   r=   r  s8   
*,((,
zDonutSwinSelfAttention.__init__NFr   attention_mask	head_maskoutput_attentionsrU   c                 C   s  |j \}}}||d| jf}| ||dd}	| ||dd}
| ||dd}t|	|
dd}|t	
| j }| j| jd }|| jd | jd  | jd | jd  d}|ddd }||d }|d ur|j d }||| || j||}||dd }|d| j||}tjj|dd}| |}|d ur|| }t||}|dddd }| d d | jf }||}|r||f}|S |f}|S )Nr,   r   r)   r   rZ   r   )r.   r   r   r/   r   r   r   r   matmulmathsqrtr   r   r3   r0   r1   rm   r   r   r_   softmaxrM   rW   r   )rP   r   r   r   r   r4   r[   r7   hidden_shapequery_layer	key_layervalue_layerattention_scoresrelative_position_bias
mask_shapeattention_probscontext_layernew_context_layer_shapeoutputsr!   r!   r"   rt     s@   &


zDonutSwinSelfAttention.forwardNNF)r   r   r   r=   r   ru   r   r   rx   r    rt   ry   r!   r!   rR   r"   r   q  s     (r   c                       s8   e Zd Z fddZdejdejdejfddZ  ZS )DonutSwinSelfOutputc                    s*   t    t||| _t|j| _d S r   )r<   r=   r   r   denserK   r   rM   rP   rO   r[   rR   r!   r"   r=     s   
zDonutSwinSelfOutput.__init__r   input_tensorrU   c                 C      |  |}| |}|S r   r   rM   )rP   r   r   r!   r!   r"   rt     s   

zDonutSwinSelfOutput.forwardr   r   r   r=   r   ru   rt   ry   r!   r!   rR   r"   r     s    $r   c                       sb   e Zd Z fddZdd Z			ddejdeej d	eej d
ee	 de
ej f
ddZ  ZS )DonutSwinAttentionc                    s2   t    t||||| _t||| _t | _d S r   )r<   r=   r   rP   r   r   setpruned_heads)rP   rO   r[   r   r3   rR   r!   r"   r=     s   
zDonutSwinAttention.__init__c                 C   s   t |dkrd S t|| jj| jj| j\}}t| jj|| j_t| jj|| j_t| jj	|| j_	t| j
j|dd| j
_| jjt | | j_| jj| jj | j_| j|| _d S )Nr   r   rZ   )lenr
   rP   r   r   r   r   r   r   r   r   r   r   union)rP   headsindexr!   r!   r"   prune_heads  s   zDonutSwinAttention.prune_headsNFr   r   r   r   rU   c                 C   s6   |  ||||}| |d |}|f|dd   }|S )Nr   r   )rP   r   )rP   r   r   r   r   self_outputsattention_outputr   r!   r!   r"   rt     s   zDonutSwinAttention.forwardr   )r   r   r   r=   r   r   ru   r   r   rx   r    rt   ry   r!   r!   rR   r"   r     s"    r   c                       2   e Zd Z fddZdejdejfddZ  ZS )DonutSwinIntermediatec                    sJ   t    t|t|j| | _t|jt	rt
|j | _d S |j| _d S r   )r<   r=   r   r   rv   	mlp_ratior   r}   
hidden_actr   r   intermediate_act_fnr   rR   r!   r"   r=     s
   
zDonutSwinIntermediate.__init__r   rU   c                 C   r   r   )r   r   r   r!   r!   r"   rt        

zDonutSwinIntermediate.forwardr   r!   r!   rR   r"   r     s    r   c                       r   )DonutSwinOutputc                    s4   t    tt|j| || _t|j| _	d S r   )
r<   r=   r   r   rv   r   r   rK   rL   rM   r   rR   r!   r"   r=     s   
zDonutSwinOutput.__init__r   rU   c                 C   r   r   r   r   r!   r!   r"   rt     r   zDonutSwinOutput.forwardr   r!   r!   rR   r"   r     s    r   c                       s   e Zd Zd fdd	Zdd Zdd Zd	d
 Z			ddejde	e
e
f deej dee dee de	ejejf fddZ  ZS )DonutSwinLayerr   r   c                    s   t    |j| _|| _|j| _|| _tj||jd| _	t
|||| jd| _|dkr.t|nt | _tj||jd| _t||| _t||| _d S )N)eps)r3   r   )r<   r=   chunk_size_feed_forward
shift_sizer3   r   r   rI   layer_norm_epslayernorm_beforer   	attentionr   Identityr   layernorm_afterr   intermediater   r   )rP   rO   r[   r   r   drop_path_rater   rR   r!   r"   r=   $  s   
zDonutSwinLayer.__init__c                 C   sD   t || jkr td| _tj rt t|nt || _d S d S Nr   )minr3   r   r   r   r\   r]   tensor)rP   r   r!   r!   r"   set_shift_and_window_size1  s
   
 z(DonutSwinLayer.set_shift_and_window_sizec              	   C   s  | j dkrtjd||df||d}td| j t| j | j  t| j  d f}td| j t| j | j  t| j  d f}d}|D ]}	|D ]}
||d d |	|
d d f< |d7 }qEqAt|| j}|d| j| j }|d|d }||dkd|dkd}|S d }|S )Nr   r   r   r,   r)   g      Yr   )	r   r   rD   slicer3   r9   r/   rm   masked_fill)rP   r5   r6   r   r   img_maskheight_sliceswidth_slicescountheight_slicewidth_slicemask_windows	attn_maskr!   r!   r"   get_attn_mask9  s.   

zDonutSwinLayer.get_attn_maskc                 C   sR   | j || j   | j  }| j || j   | j  }ddd|d|f}tj||}||fS r   )r3   r   r_   r   )rP   r   r5   r6   	pad_right
pad_bottomr   r!   r!   r"   r   U  s
   zDonutSwinLayer.maybe_padNFr   r   r   r   always_partitionrU   c                 C   s  |s|  | n	 |\}}| \}}	}
|}| |}|||||
}| |||\}}|j\}	}}}	| jdkrGtj|| j | j fdd}n|}t	|| j
}|d| j
| j
 |
}| j|||j|jd}| j||||d}|d }|d| j
| j
|
}t|| j
||}| jdkrtj|| j| jfdd}n|}|d dkp|d dk}|r|d d d |d |d d f  }|||| |
}|| | }| |}| |}|| | }|r||d	 f}|S |f}|S )
Nr   )r   r)   )shiftsdimsr,   r   )r   r   r+   r   )r   rW   r   r/   r   r.   r   r   rollr9   r3   r  r   r   r   r:   r1   r   r   r   r   )rP   r   r   r   r   r  r5   r6   r4   ro   channelsshortcutr   
height_pad	width_padshifted_hidden_stateshidden_states_windowsr
  attention_outputsr   attention_windowsshifted_windows
was_paddedlayer_outputlayer_outputsr!   r!   r"   rt   \  sN   


$

zDonutSwinLayer.forward)r   r   NFF)r   r   r   r=   r   r  r   r   ru   r    rv   r   r   rx   rt   ry   r!   r!   rR   r"   r   #  s*    
r   c                       sd   e Zd Z fddZ			ddejdeeef deej	 dee
 d	ee
 d
eej fddZ  ZS )DonutSwinStagec                    sh   t     | _| _t fddt|D | _|d ur,|tjd| _	nd | _	d| _
d S )Nc              
      s:   g | ]}t  | |d  dkrdn jd  dqS )r)   r   )rO   r[   r   r   r   r   )r   r3   ).0irO   r[   r   r   r   r!   r"   
<listcomp>  s    	z+DonutSwinStage.__init__.<locals>.<listcomp>)r[   r   F)r<   r=   rO   r[   r   
ModuleListrangeblocksrI   
downsamplepointing)rP   rO   r[   r   depthr   r   r'  rR   r"  r"   r=     s   
	
zDonutSwinStage.__init__NFr   r   r   r   r  rU   c                 C   s   |\}}t | jD ]\}}	|d ur|| nd }
|	|||
||}|d }q	|}| jd urE|d d |d d }}||||f}| ||}n||||f}|||f}|rZ||dd  7 }|S )Nr   r   r)   )	enumerater&  r'  )rP   r   r   r   r   r  r5   r6   r!  layer_modulelayer_head_maskr  !hidden_states_before_downsamplingheight_downsampledwidth_downsampledrp   stage_outputsr!   r!   r"   rt     s"   



zDonutSwinStage.forwardr  )r   r   r   r=   r   ru   r    rv   r   r   rx   rt   ry   r!   r!   rR   r"   r    s$    
r  c                       s   e Zd Z fddZ						ddejdeeef deej	 d	ee
 d
ee
 dee
 dee
 dee
 deeef fddZ  ZS )DonutSwinEncoderc                    sp   t    t j_ _dd tjd jt	 jddD t
 fddtjD _d_d S )Nc                 S   s   g | ]}|  qS r!   )item)r   xr!   r!   r"   r#    s    z-DonutSwinEncoder.__init__.<locals>.<listcomp>r   cpu)r   c                    s   g | ]E}t  t jd |  d d |  d d |  f j|  j| t jd| t jd|d   |jd k rCtnddqS )r)   r   r   N)rO   r[   r   r)  r   r   r'  )r  rv   rE   depthsr   r   
num_layersr   )r   i_layerrO   dprrA   rP   r!   r"   r#    s    
*F)r<   r=   r   r5  r6  rO   r   linspacer   r   r   r$  r%  layersgradient_checkpointing)rP   rO   rA   rR   r8  r"   r=     s   
$

zDonutSwinEncoder.__init__NFTr   r   r   r   output_hidden_states(output_hidden_states_before_downsamplingr  return_dictrU   c	                 C   s  |rdnd }	|r
dnd }
|rdnd }|r7|j \}}}|j|g||R  }|dddd}|	|f7 }	|
|f7 }
t| jD ]\}}|d urH|| nd }||||||}|d }|d }|d }|d |d f}|r|r|j \}}}|j|g|d |d f|R  }|dddd}|	|f7 }	|
|f7 }
n'|r|s|j \}}}|j|g||R  }|dddd}|	|f7 }	|
|f7 }
|r||dd  7 }q<|stdd	 ||	|fD S t||	||
d
S )Nr!   r   r   r   r)   r   r,   c                 s   s    | ]	}|d ur|V  qd S r   r!   )r   vr!   r!   r"   	<genexpr>/  s    z+DonutSwinEncoder.forward.<locals>.<genexpr>)r   r   r   r   )r.   r/   r0   r*  r;  r    r   )rP   r   r   r   r   r=  r>  r  r?  all_hidden_statesall_reshaped_hidden_statesall_self_attentionsr4   ro   r   reshaped_hidden_stater!  r+  r,  r  r-  rp   r!   r!   r"   rt     s^   





zDonutSwinEncoder.forward)NFFFFT)r   r   r   r=   r   ru   r    rv   r   r   rx   r   r   rt   ry   r!   r!   rR   r"   r1    s6    
	

r1  c                   @   s0   e Zd ZU eed< dZdZdZdgZdd Z	dS )	DonutSwinPreTrainedModelrO   donutri   Tr  c                 C   s   t |tjtjfr#|jjjd| jjd |j	dur!|j	j
  dS dS t |tjr8|j	j
  |jjd dS t |trW|jdurH|jj
  |jdurU|jj
  dS dS t |trd|jj
  dS dS )zInitialize the weightsr   )meanstdNrk   )r}   r   r   r   weightdatanormal_rO   initializer_ranger   zero_rI   fill_r;   rF   rH   r   r   )rP   moduler!   r!   r"   _init_weightsB  s"   




z&DonutSwinPreTrainedModel._init_weightsN)
r   r   r   r   r   base_model_prefixmain_input_namesupports_gradient_checkpointing_no_split_modulesrQ  r!   r!   r!   r"   rF  9  s   
 rF  c                       s   e Zd Zd fdd	Zdd Zdd Ze													dd
eej	 deej
 deej	 dee dee dedee deeef fddZ  ZS )DonutSwinModelTFc                    sv   t  | || _t|j| _t|jd| jd   | _t	||d| _
t|| j
j| _|r2tdnd| _|   dS )z
        add_pooling_layer (bool, *optional*, defaults to `True`):
            Whether to add a pooling layer
        use_mask_token (`bool`, *optional*, defaults to `False`):
            Whether to use a mask token for masked image modeling.
        r)   r   )rQ   N)r<   r=   rO   r   r5  r6  rv   rE   num_featuresr;   rT   r1  rB   encoderr   AdaptiveAvgPool1dpooler	post_init)rP   rO   add_pooling_layerrQ   rR   r!   r"   r=   X  s   zDonutSwinModel.__init__c                 C   s   | j jS r   )rT   r?   r   r!   r!   r"   get_input_embeddingsl  s   z#DonutSwinModel.get_input_embeddingsc                 C   s*   |  D ]\}}| jj| j| qdS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsrX  layerr   r   )rP   heads_to_pruner_  r   r!   r!   r"   _prune_headso  s   zDonutSwinModel._prune_headsNri   rj   r   r   r=  rh   r?  rU   c                 C   s   |dur|n| j j}|dur|n| j j}|dur|n| j j}|du r&td| |t| j j}| j|||d\}}	| j	||	||||d}
|
d }d}| j
dur_| 
|dd}t|d}|sm||f|
dd  }|S t|||
j|
j|
jdS )	z
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        Nz You have to specify pixel_values)rj   rh   )r   r   r=  r?  r   r   r)   )r   r%   r   r   r   )rO   r   r=  use_return_dictr   get_head_maskr   r5  rT   rX  rZ  r   r   r   r$   r   r   r   )rP   ri   rj   r   r   r=  rh   r?  embedding_outputr   encoder_outputssequence_outputpooled_outputr   r!   r!   r"   rt   w  sB   
	
zDonutSwinModel.forward)TFNNNNNFN)r   r   r   r=   r]  ra  r   r   r   r   rw   rx   r   r    r$   rt   ry   r!   r!   rR   r"   rV  V  s:    
	rV  a  
    DonutSwin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
    the [CLS] token) e.g. for ImageNet.

    <Tip>

        Note that it's possible to fine-tune DonutSwin on higher resolution images than the ones it has been trained on, by
        setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
        position embeddings to the higher resolution.

    </Tip>
    c                       s   e Zd Z fddZe							ddeej deej deej dee	 d	ee	 d
e	dee	 de
eef fddZ  ZS )DonutSwinForImageClassificationc                    sP   t  | |j| _t|| _|jdkrt| jj|jnt | _	| 
  d S r   )r<   r=   
num_labelsrV  rG  r   r   rW  r   
classifierr[  )rP   rO   rR   r!   r"   r=     s   
"z(DonutSwinForImageClassification.__init__NFri   r   labelsr   r=  rh   r?  rU   c                 C   s   |dur|n| j j}| j||||||d}|d }	| |	}
d}|dur,| ||
| j }|sB|
f|dd  }|dur@|f| S |S t||
|j|j|jdS )a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        N)r   r   r=  rh   r?  r   r)   )r'   r(   r   r   r   )	rO   rb  rG  rk  loss_functionr&   r   r   r   )rP   ri   r   rl  r   r=  rh   r?  r   rg  r(   r'   r   r!   r!   r"   rt     s0   	
z'DonutSwinForImageClassification.forwardrh  )r   r   r   r=   r   r   r   r   
LongTensorrx   r   r    r&   rt   ry   r!   r!   rR   r"   ri    s6    
	ri  )rV  rF  ri  )r   F)9r   collections.abcr~   r   dataclassesr   typingr   r   r   r   activationsr   modeling_layersr   modeling_utilsr	   pytorch_utilsr
   r   r   utilsr   r   r   r   configuration_donut_swinr   
get_loggerr   loggerr   r$   r&   r9   r:   r   r;   r>   r   ru   r   rx   r   r   r   r   r   r   r   r   r  r1  rF  rV  ri  __all__r!   r!   r!   r"   <module>   sj   
], 7`'~=[a@