o
    پiv                     @   s  d Z ddlZddlmZ ddlmZmZmZmZ ddl	Z	ddl
mZ ddlmZmZ ddlmZmZmZmZmZ ddlmZ dd	lmZ dd
lmZ ddlmZ ddlmZmZm Z  ddl!m"Z" dgZ#eG dd dej$Z%dddZ&G dd dej$Z'G dd dej$Z(G dd dej$Z)G dd dej$Z*G dd dej$Z+G dd dej$Z,dd Z-dd!d"Z.dd$d%Z/ei d&e/d'd(d)d*e/d'd+d)d,e/d'd-d.d/d0e/d'd1d)d2e/d'd3d)d4e/d'd5d.d/d6e/d'd7d)d8e/d'd9d)d:e/d'd;d.d/d<e/d'd=d)d>e/d'd?d)d@e/d'dAd.d/dBe/d'dCd)dDe/d'dEd)dFe/d'dGd.d/dHe/d'dId)dJe/d'dKd)i dLe/d'dMd.d/dNe/d'dOd)dPe/d'dQd)dRe/d'dSd.d/dTe/d'dUd)dVe/d'dWd)dXe/d'dYd.d/dZe/d'd[d)d\e/d'd]d)d^e/d'd_d.d/d`e/d'dad)dbe/d'dcd)dde/d'ded.d/dfe/d'dgd)dhe/d'did)dje/d'dkd.d/dle/d'dmd)e/d'dnd)e/d'dod.d/e/d'dpd)e/d'dqd)e/d'drd.d/e/d'dsd)e/d'dtd)e/d'dud.d/dvZ0eddwe,fdxdyZ1eddwe,fdzd{Z2eddwe,fd|d}Z3eddwe,fd~dZ4eddwe,fddZ5eddwe,fddZ6eddwe,fddZ7eddwe,fddZ8eddwe,fddZ9eddwe,fddZ:eddwe,fddZ;eddwe,fddZ<eddwe,fddZ=eddwe,fddZ>eddwe,fddZ?eddwe,fddZ@eddwe,fddZAeddwe,fddZBeddwe,fddZCeddwe,fddZDeddwe,fddZEeddwe,fddZFeddwe,fddZGeddwe,fddZHeddwe,fddZIeddwe,fddZJeddwe,fddZKeddwe,fddZLe eMi dd*dd,dd2dd4dd8dd:dd>dd@ddDddFddJddLddPddRddVddXdd\d^dbdddhdjdddddddǜ dS )a]   Cross-Covariance Image Transformer (XCiT) in PyTorch

Paper:
    - https://arxiv.org/abs/2106.09681

Same as the official implementation, with some minor adaptations, original copyright below
    - https://github.com/facebookresearch/xcit/blob/master/xcit.py

Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
    N)partial)ListOptionalTupleUnionIMAGENET_DEFAULT_MEANIMAGENET_DEFAULT_STD)DropPathtrunc_normal_	to_2tupleuse_fused_attnMlp   )build_model_with_cfg)feature_take_indices)register_notrace_module)
checkpoint)register_modelgenerate_default_cfgsregister_model_deprecations)	ClassAttnXcitc                       s8   e Zd ZdZd fdd	Zdeded	efd
dZ  ZS )PositionalEncodingFourierz
    Positional encoding relying on a fourier kernel matching the one used in the "Attention is all you Need" paper.
    Based on the official XCiT code
        - https://github.com/facebookresearch/xcit/blob/master/xcit.py
           '  c                    sH   t    tj|d |dd| _dtj | _|| _|| _	|| _
d| _d S )N   r   )kernel_sizeư>)super__init__nnConv2dtoken_projectionmathpiscaletemperature
hidden_dimdimeps)selfr)   r*   r(   	__class__ D/home/ubuntu/.local/lib/python3.10/site-packages/timm/models/xcit.pyr!   )   s   

z"PositionalEncodingFourier.__init__BHWc              	   C   s(  | j jj}| j jj}tjd|d |dtjd	dd|}tjd|d |dtj	d|d}||d d dd d d f | j
  | j }||d d d d dd f | j
  | j }tj| j|dtj}| jdtj|ddd | j  }|d d d d d d d f | }	|d d d d d d d f | }
tj|	d d d d d d dd df  |	d d d d d d dd df  gdd	d
}	tj|
d d d d d d dd df  |
d d d d d d dd df  gdd	d
}
tj|
|	fd
d	dd
dd}|  ||}|	|dddS )Nr   )devicer   floor)rounding_moder      r*      )r$   weightr4   dtypetorcharangetofloat32	unsqueezerepeatr+   r'   r)   r(   divstacksincosflattencatpermute)r,   r1   r2   r3   r4   r<   y_embedx_embeddim_tpos_xpos_yposr/   r/   r0   forward2   s   

,&**   \\z!PositionalEncodingFourier.forward)r   r   r   )__name__
__module____qualname____doc__r!   intrP   __classcell__r/   r/   r-   r0   r   !   s    	r   c              
   C   s&   t jtj| |d|dddt|S )z3x3 convolution + batch normr:   r   F)r   stridepaddingbias)r=   r"   
Sequentialr#   BatchNorm2d)	in_planes
out_planesrW   r/   r/   r0   conv3x3D   s   r^   c                       s6   e Zd ZdZddddejf fdd	Zdd	 Z  ZS )
ConvPatchEmbedz<Image to Patch Embedding using multiple convolutional layers      r:   r   c                    s   t    t|}|d | |d |  }|| _|| _|| _|dkrPtjt	||d d| t	|d |d d| t	|d |d d| t	|d |d| _
d S |dkrwtjt	||d d| t	|d |d d| t	|d |d| _
d S d)Nr   r   ra      r   r8   z=For convolutional projection, patch size has to be in [8, 16])r    r!   r   img_size
patch_sizenum_patchesr=   r"   rZ   r^   proj)r,   rc   rd   in_chans	embed_dim	act_layerre   r-   r/   r0   r!   O   s2   

	
zConvPatchEmbed.__init__c                 C   s>   |  |}|jd |jd }}|ddd}|||ffS )Nr   r:   r   )rf   shaperG   	transpose)r,   xHpWpr/   r/   r0   rP   l   s   
zConvPatchEmbed.forward)	rQ   rR   rS   rT   r"   GELUr!   rP   rV   r/   r/   r-   r0   r_   L   s    r_   c                       s<   e Zd ZdZdejdf fdd	Zdedefdd	Z  Z	S )
LPIa  
    Local Patch Interaction module that allows explicit communication between tokens in 3x3 windows to augment the
    implicit communication performed by the block diagonal scatter attention. Implemented using 2 layers of separable
    3x3 convolutions with GeLU and BatchNorm2d
    Nr:   c                    sb   t    |p|}|d }tjj|||||d| _| | _t|| _tjj|||||d| _	d S )Nr   )r   rX   groups)
r    r!   r=   r"   r#   conv1actr[   bnconv2)r,   in_featuresout_featuresri   r   rX   r-   r/   r0   r!   z   s   


zLPI.__init__r2   r3   c                 C   sj   |j \}}}|ddd||||}| |}| |}| |}| |}||||ddd}|S )Nr   r   r   )rj   rI   reshaperr   rs   rt   ru   )r,   rl   r2   r3   r1   NCr/   r/   r0   rP      s   



zLPI.forward)
rQ   rR   rS   rT   r"   ro   r!   rU   rP   rV   r/   r/   r-   r0   rp   s   s    rp   c                	       s@   e Zd ZdZdddddejejddf	 fdd	Zdd	 Z  Z	S )
ClassAttentionBlockzAClass Attention Layer as in CaiT https://arxiv.org/abs/2103.17239      @F              ?c                    s   t    |	|| _t|||||d| _|dkrt|nt | _|	|| _	t
|t|| ||d| _|
d urNt|
t| | _t|
t| | _nd\| _| _|| _d S )N	num_headsqkv_bias	attn_drop	proj_dropr}   rv   hidden_featuresri   drop)r~   r~   )r    r!   norm1r   attnr
   r"   Identity	drop_pathnorm2r   rU   mlp	Parameterr=   onesgamma1gamma2tokens_norm)r,   r*   r   	mlp_ratior   r   r   r   ri   
norm_layeretar   r-   r/   r0   r!      s   




zClassAttentionBlock.__init__c                 C   s   |  |}tj| ||d d dd f gdd}|| | j|  }| jr,| |}ntj| |d d ddf |d d dd f gdd}|}|d d ddf }| j| 	| }tj||d d dd f gdd}|| | }|S )Nr   r9   r   )
r   r=   rH   r   r   r   r   r   r   r   )r,   rl   x_norm1x_attnx_res	cls_tokenr/   r/   r0   rP      s   
(8"zClassAttentionBlock.forward)
rQ   rR   rS   rT   r"   ro   	LayerNormr!   rP   rV   r/   r/   r-   r0   r{      s    !r{   c                       sJ   e Zd ZU ejje ed< 	 d fdd	Zdd Z	ejj
d	d
 Z  ZS )XCA
fused_attnrb   Fr}   c                    sr   t    || _tdd| _tt|dd| _	tj
||d |d| _t|| _t
||| _t|| _d S )NT)experimentalr   r:   )rY   )r    r!   r   r   r   r"   r   r=   r   r(   LinearqkvDropoutr   rf   r   )r,   r*   r   r   r   r   r-   r/   r0   r!      s   
zXCA.__init__c           
      C   s  |j \}}}| |||d| j|| j ddddd}|d\}}}| jrItjj	j
|dd| j }tjj	j
|dd}tjj	j|||dd	}n,tjj	j
|dd}tjj	j
|dd}||d
d | j }	|	jdd}	| |	}	|	| }|dddd|||}| |}| |}|S )Nr:   r   r   r8   r   r5   r9   r~   )r'   )rj   r   rx   r   rI   unbindr   r=   r"   
functional	normalizer(   scaled_dot_product_attentionrk   softmaxr   rf   r   )
r,   rl   r1   ry   rz   r   qkvr   r/   r/   r0   rP      s"   .


zXCA.forwardc                 C   s   dhS )Nr(   r/   r,   r/   r/   r0   no_weight_decay      zXCA.no_weight_decay)rb   Fr}   r}   )rQ   rR   rS   r=   jitFinalbool__annotations__r!   rP   ignorer   rV   r/   r/   r-   r0   r      s   
 
r   c                       sD   e Zd Zdddddejejdf fdd	Zdedefd	d
Z  Z	S )XCABlockr|   Fr}   r~   c                    s   t    |	|| _t|||||d| _|dkrt|nt | _|	|| _	t
||d| _|	|| _t|t|| ||d| _t|
t| | _t|
t| | _t|
t| | _d S )Nr   r}   )rv   ri   r   )r    r!   r   r   r   r
   r"   r   r   norm3rp   local_mpr   r   rU   r   r   r=   r   r   gamma3r   )r,   r*   r   r   r   r   r   r   ri   r   r   r-   r/   r0   r!      s   



zXCABlock.__init__r2   r3   c              	   C   sh   ||  | j| | |  }||  | j| | |||  }||  | j| | 	|  }|S N)
r   r   r   r   r   r   r   r   r   r   )r,   rl   r2   r3   r/   r/   r0   rP     s    $ zXCABlock.forward)
rQ   rR   rS   r"   ro   r   r!   rU   rP   rV   r/   r/   r-   r0   r      s    r   c                       sd  e Zd ZdZ											
									
		d6 fdd	Zdd Zejjdd Z	ejjd7ddZ
ejjd8ddZejjdejfddZd9dedee fdd Z				!	d:d"ejd#eeeee f  d$ed%ed&ed'edeeej eejeej f f fd(d)Z	*		
d;d#eeee f d+ed,efd-d.Zd/d0 Zd7d1efd2d3Zd4d5 Z  ZS )<r   z
    Based on timm and DeiT code bases
    https://github.com/rwightman/pytorch-image-models/tree/master/timm
    https://github.com/facebookresearch/deit/
    r`   ra   r:     tokenr      r|   Tr}   Nr   r~   Fc                    s  t    |dv sJ t|}|d | dkr|d | dks#J dp+ttjdd p0tj || _ | _ | _	| _
|| _d| _t||| d| _|ttdd| _|rctd	| _nd
| _tj|d| _t 	
f
ddt|D | _fddt|D | _t 
f
ddt|D | _| _t| _|dkrt| j|nt | _ t!| jdd | "| j# d
S )a  
        Args:
            img_size (int, tuple): input image size
            patch_size (int): patch size
            in_chans (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            drop_rate (float): dropout rate after positional embedding, and in XCA/CA projection + MLP
            pos_drop_rate: position embedding dropout rate
            proj_drop_rate (float): projection dropout rate
            attn_drop_rate (float): attention dropout rate
            drop_path_rate (float): stochastic depth rate (constant across all layers)
            norm_layer: (nn.Module): normalization layer
            cls_attn_layers: (int) Depth of Class attention layers
            use_pos_embed: (bool) whether to use positional encoding
            eta: (float) layerscale initialization value
            tokens_norm: (bool) Whether to normalize all tokens or just the cls_token in the CA

        Notes:
            - Although `layer_norm` is user specifiable, there are hard-coded `BatchNorm2d`s in the local patch
              interaction (class LPI) and the patch embedding (class ConvPatchEmbed)
         avgr   r   z2`patch_size` should divide image dimensions evenlyr   )r+   F)rc   rd   rg   rh   ri   r   r9   N)pc                    s(   g | ]}t 	 d 
qS ))
r*   r   r   r   r   r   r   ri   r   r   )r   .0_)
ri   attn_drop_ratedrop_path_raterh   r   r   r   r   proj_drop_rater   r/   r0   
<listcomp>p      z!Xcit.__init__.<locals>.<listcomp>c                    s    g | ]}t  d | dqS )zblocks.)num_chs	reductionmoduledict)r   i)rh   rr/   r0   r   ~  s     c                    s(   g | ]}t  	d 
qS ))
r*   r   r   r   r   r   ri   r   r   r   )r{   r   )
ri   r   	drop_raterh   r   r   r   r   r   r   r/   r0   r     r   {Gz?std)$r    r!   r   r   r"   r   ro   num_classesnum_featureshead_hidden_sizerh   global_poolgrad_checkpointingr_   patch_embedr   r=   zerosr   r   	pos_embedr   pos_drop
ModuleListrangeblocksfeature_infocls_attn_blocksnorm	head_dropr   r   headr   apply_init_weights)r,   rc   rd   rg   r   r   rh   depthr   r   r   r   pos_drop_rater   r   r   ri   r   cls_attn_layersuse_pos_embedr   r   r-   )ri   r   r   r   rh   r   r   r   r   r   r   r   r   r0   r!   !  sJ   
2"
 
 

 zXcit.__init__c                 C   sP   t |tjr"t|jdd t |tjr$|jd ur&tj|jd d S d S d S d S )Nr   r   r   )
isinstancer"   r   r   r;   rY   init	constant_)r,   mr/   r/   r0   r     s   zXcit._init_weightsc                 C   s   ddhS )Nr   r   r/   r   r/   r/   r0   r     s   zXcit.no_weight_decayc                 C   s   t ddddgdS )Nz ^cls_token|pos_embed|patch_embedz^blocks\.(\d+))z^cls_attn_blocks\.(\d+)N)z^norm)i )stemr   r   r   )r,   coarser/   r/   r0   group_matcher  s
   zXcit.group_matcherc                 C   s
   || _ d S r   )r   )r,   enabler/   r/   r0   set_grad_checkpointing  s   
zXcit.set_grad_checkpointingreturnc                 C   s   | j S r   )r   r   r/   r/   r0   get_classifier  r   zXcit.get_classifierr   r   c                 C   sJ   || _ |d ur|dv sJ || _|dkrt| j|| _d S t | _d S )Nr   r   )r   r   r"   r   r   r   r   )r,   r   r   r/   r/   r0   reset_classifier  s
   *zXcit.reset_classifierNCHWrl   indicesr   
stop_early
output_fmtintermediates_onlyc                    s  |dv sJ d|dk}g }t t| j|\}	}
|j\ }}}| |\}\| jdurE|   d|jd ddd}|| }| |}t	j
 sQ|sU| j}n	| jd|
d  }t|D ]*\}}| jrvt	j
 svt||}n||}||	v r||r| |n| qb|r fd	d
|D }|r|S t	j| j dd|fdd}| jD ]}| jrt	j
 st||}q||}q| |}||fS )a   Forward features that returns intermediates.

        Args:
            x: Input image tensor
            indices: Take last n blocks if int, all if None, select matching indices if sequence
            norm: Apply norm layer to all intermediates
            stop_early: Stop iterating over blocks when last desired intermediate hit
            output_fmt: Shape of intermediate feature outputs
            intermediates_only: Only return intermediate features
        Returns:

        )r   NLCz)Output format must be one of NCHW or NLC.r   Nr5   r   r   r   c                    s,   g | ]}|  d dddd qS )r5   r   r:   r   r   )rx   rI   
contiguous)r   yr1   rm   rn   r/   r0   r     s   , z.Xcit.forward_intermediates.<locals>.<listcomp>r9   )r   lenr   rj   r   r   rx   rI   r   r=   r   is_scripting	enumerater   r   appendr   rH   r   expandr   )r,   rl   r   r   r   r   r   rx   intermediatestake_indices	max_indexr   heightwidthpos_encodingr   r   blkr/   r   r0   forward_intermediates  s>   
(



zXcit.forward_intermediatesr   
prune_norm
prune_headc                 C   sT   t t| j|\}}| jd|d  | _|rt | _|r(t | _| dd |S )z@ Prune layers not required for specified intermediates.
        Nr   r   r   )	r   r   r   r"   r   r   r   r   r   )r,   r   r	  r
  r  r  r/   r/   r0   prune_intermediate_layers  s   

zXcit.prune_intermediate_layersc                 C   s   |j d }| |\}\}}| jd ur+| ||||d|j d ddd}|| }| |}| jD ]}| jrEtj	
 sEt||||}q3||||}q3tj| j|dd|fdd}| jD ]}| jrntj	
 snt||}q^||}q^| |}|S )Nr   r5   r   r   r9   )rj   r   r   rx   rI   r   r   r   r=   r   r   r   rH   r   r   r   r   )r,   rl   r1   rm   rn   r  r  r/   r/   r0   forward_features  s"   

(




zXcit.forward_features
pre_logitsc                 C   sX   | j r| j dkr|d d dd f jddn|d d df }| |}|r'|S | |S )Nr   r   r9   r   )r   meanr   r   )r,   rl   r  r/   r/   r0   forward_head*  s   6
zXcit.forward_headc                 C   s   |  |}| |}|S r   )r  r  )r,   rl   r/   r/   r0   rP   0  s   

zXcit.forward)r`   ra   r:   r   r   r   r   r   r|   Tr}   r}   r}   r}   r}   NNr   Tr~   FF)Tr   )NFFr   F)r   FT)rQ   rR   rS   rT   r!   r   r=   r   r   r   r   r   r"   Moduler   rU   r   strr   Tensorr   r   r   r   r  r  r  r  rP   rV   r/   r/   r-   r0   r     s    w

 
F
c              	   C   s0  d| v r| d } t |dd d u}dd | D }|D ]}|r)| || |dd< q| |= qd| v rd| v rt|j}t|D ]U}| d	| d
}|dd|jd }t	dD ]\}}	|| | d	| d|	 d< qY| d	| dd }
|
d ur|
dd}
t	dD ]\}}	|
| | d	| d|	 d< qq@| S )Nmodelr   c                 S   s   g | ]	}| d r|qS )r   )
startswith)r   r   r/   r/   r0   r   <  s    z(checkpoint_filter_fn.<locals>.<listcomp>zpos_embeder.z
pos_embed.z!cls_attn_blocks.0.attn.qkv.weightzcls_attn_blocks.0.attn.q.weightzcls_attn_blocks.z.attn.qkv.weightr:   r5   r   z.attn.z.weightz.attn.qkv.biasz.bias)
getattrpopreplace
state_dictr   r   r   rx   rj   r   )r  r  r   pos_embed_keysr   num_ca_blocksr   
qkv_weightj	subscriptr   r/   r/   r0   checkpoint_filter_fn6  s,   
r  Fc                 K   s2   | dd}tt| |ftt|ddd|}|S )Nout_indicesr:   getter)r   feature_cls)pretrained_filter_fnfeature_cfg)r  r   r   r  r   )variant
pretraineddefault_cfgkwargsr   r  r/   r/   r0   _create_xcitS  s   
r)  r   c                 K   s    | ddd dddt tddd|S )	Nr   )r:   r`   r`   r~   bicubicTzpatch_embed.proj.0.0r   )urlr   
input_size	pool_sizecrop_pctinterpolationfixed_input_sizer  r   
first_conv
classifierr   )r+  r(  r/   r/   r0   _cfg`  s   r3  zxcit_nano_12_p16_224.fb_in1kztimm/z<https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224.pth)	hf_hub_idr+  z!xcit_nano_12_p16_224.fb_dist_in1kzAhttps://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224_dist.pthz!xcit_nano_12_p16_384.fb_dist_in1kzAhttps://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_384_dist.pth)r:     r5  )r4  r+  r,  zxcit_tiny_12_p16_224.fb_in1kz<https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224.pthz!xcit_tiny_12_p16_224.fb_dist_in1kzAhttps://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224_dist.pthz!xcit_tiny_12_p16_384.fb_dist_in1kzAhttps://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_384_dist.pthzxcit_tiny_24_p16_224.fb_in1kz<https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224.pthz!xcit_tiny_24_p16_224.fb_dist_in1kzAhttps://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224_dist.pthz!xcit_tiny_24_p16_384.fb_dist_in1kzAhttps://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_384_dist.pthzxcit_small_12_p16_224.fb_in1kz=https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224.pthz"xcit_small_12_p16_224.fb_dist_in1kzBhttps://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224_dist.pthz"xcit_small_12_p16_384.fb_dist_in1kzBhttps://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_384_dist.pthzxcit_small_24_p16_224.fb_in1kz=https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224.pthz"xcit_small_24_p16_224.fb_dist_in1kzBhttps://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224_dist.pthz"xcit_small_24_p16_384.fb_dist_in1kzBhttps://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_384_dist.pthzxcit_medium_24_p16_224.fb_in1kz>https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224.pthz#xcit_medium_24_p16_224.fb_dist_in1kzChttps://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224_dist.pthz#xcit_medium_24_p16_384.fb_dist_in1kzChttps://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_384_dist.pthzxcit_large_24_p16_224.fb_in1kz=https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224.pthz"xcit_large_24_p16_224.fb_dist_in1kzBhttps://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224_dist.pthz"xcit_large_24_p16_384.fb_dist_in1kzBhttps://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_384_dist.pthzxcit_nano_12_p8_224.fb_in1kz;https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224.pthz xcit_nano_12_p8_224.fb_dist_in1kz@https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224_dist.pthz xcit_nano_12_p8_384.fb_dist_in1kz@https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_384_dist.pthzxcit_tiny_12_p8_224.fb_in1kz;https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224.pthz xcit_tiny_12_p8_224.fb_dist_in1kz@https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224_dist.pthz xcit_tiny_12_p8_384.fb_dist_in1kz@https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_384_dist.pthzxcit_tiny_24_p8_224.fb_in1kz;https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224.pthz xcit_tiny_24_p8_224.fb_dist_in1kz@https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224_dist.pthz xcit_tiny_24_p8_384.fb_dist_in1kz@https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_384_dist.pthzxcit_small_12_p8_224.fb_in1kz<https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224.pthz!xcit_small_12_p8_224.fb_dist_in1kzAhttps://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224_dist.pthz!xcit_small_12_p8_384.fb_dist_in1kzAhttps://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_384_dist.pthzxcit_small_24_p8_224.fb_in1kz<https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224.pthzAhttps://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224_dist.pthzAhttps://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_384_dist.pthz=https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224.pthzBhttps://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224_dist.pthzBhttps://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_384_dist.pthz<https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224.pthzAhttps://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224_dist.pthzAhttps://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_384_dist.pth)!xcit_small_24_p8_224.fb_dist_in1k!xcit_small_24_p8_384.fb_dist_in1kzxcit_medium_24_p8_224.fb_in1k"xcit_medium_24_p8_224.fb_dist_in1k"xcit_medium_24_p8_384.fb_dist_in1kzxcit_large_24_p8_224.fb_in1k!xcit_large_24_p8_224.fb_dist_in1k!xcit_large_24_p8_384.fb_dist_in1kr   c                 K   6   t ddddddd}td
d	| it |fi |}|S )Nra      r   r8   r~   Frd   rh   r   r   r   r   xcit_nano_12_p16_224r&  )r?  r   r)  r&  r(  
model_argsr  r/   r/   r0   r?    
   r?  c              	   K   s8   t dddddddd}tdd
| it |fi |}|S )Nra   r=  r   r8   r~   Fr5  )rd   rh   r   r   r   r   rc   xcit_nano_12_p16_384r&  )rD  r@  rA  r/   r/   r0   rD    s
   rD  c                 K   r<  )Nra      r   r8   r~   Tr>  xcit_tiny_12_p16_224r&  )rF  r@  rA  r/   r/   r0   rF     rC  rF  c                 K   r<  )Nra   rE  r   r8   r~   Tr>  xcit_tiny_12_p16_384r&  )rG  r@  rA  r/   r/   r0   rG    rC  rG  c                 K   r<  )Nra   r5  r   rb   r~   Tr>  xcit_small_12_p16_224r&  )rH  r@  rA  r/   r/   r0   rH    rC  rH  c                 K   r<  )Nra   r5  r   rb   r~   Tr>  xcit_small_12_p16_384r&  )rI  r@  rA  r/   r/   r0   rI    rC  rI  c                 K   r<  )Nra   rE     r8   h㈵>Tr>  xcit_tiny_24_p16_224r&  )rL  r@  rA  r/   r/   r0   rL     rC  rL  c                 K   r<  )Nra   rE  rJ  r8   rK  Tr>  xcit_tiny_24_p16_384r&  )rM  r@  rA  r/   r/   r0   rM  (  rC  rM  c                 K   r<  )Nra   r5  rJ  rb   rK  Tr>  xcit_small_24_p16_224r&  )rN  r@  rA  r/   r/   r0   rN  0  rC  rN  c                 K   r<  )Nra   r5  rJ  rb   rK  Tr>  xcit_small_24_p16_384r&  )rO  r@  rA  r/   r/   r0   rO  8  rC  rO  c                 K   r<  )Nra      rJ  rb   rK  Tr>  xcit_medium_24_p16_224r&  )rQ  r@  rA  r/   r/   r0   rQ  @  rC  rQ  c                 K   r<  )Nra   rP  rJ  rb   rK  Tr>  xcit_medium_24_p16_384r&  )rR  r@  rA  r/   r/   r0   rR  H  rC  rR  c                 K   6   t ddddddd}td	d| it |fi |}|S )
Nra   r   rJ  rK  Tr>  xcit_large_24_p16_224r&  )rT  r@  rA  r/   r/   r0   rT  P  rC  rT  c                 K   rS  )
Nra   r   rJ  rK  Tr>  xcit_large_24_p16_384r&  )rU  r@  rA  r/   r/   r0   rU  X  rC  rU  c                 K   r<  )Nrb   r=  r   r8   r~   Fr>  xcit_nano_12_p8_224r&  )rV  r@  rA  r/   r/   r0   rV  a  rC  rV  c                 K   r<  )Nrb   r=  r   r8   r~   Fr>  xcit_nano_12_p8_384r&  )rW  r@  rA  r/   r/   r0   rW  i  rC  rW  c                 K   r<  )Nrb   rE  r   r8   r~   Tr>  xcit_tiny_12_p8_224r&  )rX  r@  rA  r/   r/   r0   rX  q  rC  rX  c                 K   r<  )Nrb   rE  r   r8   r~   Tr>  xcit_tiny_12_p8_384r&  )rY  r@  rA  r/   r/   r0   rY  y  rC  rY  c                 K   rS  )
Nrb   r5  r   r~   Tr>  xcit_small_12_p8_224r&  )rZ  r@  rA  r/   r/   r0   rZ    rC  rZ  c                 K   rS  )
Nrb   r5  r   r~   Tr>  xcit_small_12_p8_384r&  )r[  r@  rA  r/   r/   r0   r[    rC  r[  c                 K   r<  )Nrb   rE  rJ  r8   rK  Tr>  xcit_tiny_24_p8_224r&  )r\  r@  rA  r/   r/   r0   r\    rC  r\  c                 K   r<  )Nrb   rE  rJ  r8   rK  Tr>  xcit_tiny_24_p8_384r&  )r]  r@  rA  r/   r/   r0   r]    rC  r]  c                 K   rS  )
Nrb   r5  rJ  rK  Tr>  xcit_small_24_p8_224r&  )r^  r@  rA  r/   r/   r0   r^    rC  r^  c                 K   rS  )
Nrb   r5  rJ  rK  Tr>  xcit_small_24_p8_384r&  )r_  r@  rA  r/   r/   r0   r_    rC  r_  c                 K   rS  )
Nrb   rP  rJ  rK  Tr>  xcit_medium_24_p8_224r&  )r`  r@  rA  r/   r/   r0   r`    rC  r`  c                 K   rS  )
Nrb   rP  rJ  rK  Tr>  xcit_medium_24_p8_384r&  )ra  r@  rA  r/   r/   r0   ra    rC  ra  c                 K   r<  )Nrb   r   rJ  ra   rK  Tr>  xcit_large_24_p8_224r&  )rb  r@  rA  r/   r/   r0   rb    rC  rb  c                 K   r<  )Nrb   r   rJ  ra   rK  Tr>  xcit_large_24_p8_384r&  )rc  r@  rA  r/   r/   r0   rc    rC  rc  xcit_nano_12_p16_224_distxcit_nano_12_p16_384_distxcit_tiny_12_p16_224_distxcit_tiny_12_p16_384_distxcit_tiny_24_p16_224_distxcit_tiny_24_p16_384_distxcit_small_12_p16_224_distxcit_small_12_p16_384_distxcit_small_24_p16_224_distxcit_small_24_p16_384_distxcit_medium_24_p16_224_distxcit_medium_24_p16_384_distxcit_large_24_p16_224_distxcit_large_24_p16_384_distxcit_nano_12_p8_224_distxcit_nano_12_p8_384_distxcit_tiny_12_p8_224_distr6  r7  r8  r9  r:  r;  )xcit_tiny_12_p8_384_distxcit_tiny_24_p8_224_distxcit_tiny_24_p8_384_distxcit_small_12_p8_224_distxcit_small_12_p8_384_distxcit_small_24_p8_224_distxcit_small_24_p8_384_distxcit_medium_24_p8_224_distxcit_medium_24_p8_384_distxcit_large_24_p8_224_distxcit_large_24_p8_384_dist)r   )FN)r   r  )NrT   r%   	functoolsr   typingr   r   r   r   r=   torch.nnr"   	timm.datar   r	   timm.layersr
   r   r   r   r   _builderr   	_featuresr   _features_fxr   _manipulater   	_registryr   r   r   caitr   __all__r  r   r^   r_   rp   r{   r   r   r   r  r)  r3  default_cfgsr?  rD  rF  rG  rH  rI  rL  rM  rN  rO  rQ  rR  rT  rU  rV  rW  rX  rY  rZ  r[  r\  r]  r^  r_  r`  ra  rb  rc  rQ   r/   r/   r/   r0   <module>   s   
"'4.&  

 #&),/258;>CFILORUX[^adgj
 	
