o
    ߥi                     @   s2  d dl mZ d dlmZ d dlZd dlZd dlmZ d dl	m  m
Z d dlm  mZ d dlmZ zd dlmZ W n eyE   dZY nw dddgd	dgd
dggddgd	dgd
dggg dg dg dg dg dg dg dg dg dg dg dg dg dg dg dg dg dg dg dg dg dg d g d!g d"gg d#g d$d%Z	&	'dJd(d)Zd*d+ ZG d,d- d-ejZG d.d/ d/ejZdKd1ed2efd3d4ZG d5d6 d6ejZdLd8d9ZG d:d; d;ejZdMd<d=Z d>d? Z!d@dA Z"dBdC Z#G dDdE dEejZ$G dFdG dGejZ%G dHdI dIejZ&dS )N    )OrderedDict)partialN)trunc_normal_)checkpoint_wrapper      g       @      )r      r
   r
   )r
   r
   r
   r
   )r   r
   r   r   )   r
   r
   r
   )   r
   r
   r
   )r   r
   r   r   )   r
   r
   r
   )   r
   r
   r
   )   r
   r
   r
   )	   r
   r
   r
   )
   r
   r
   r
   )   r
   r
   r
   )   r
   r
   r
   )   r
   r
   r
   )   r
   r
   r
   )   r
   r
   r
   )   r
   r
   r
   )   r
   r
   r
   )   r
   r
   r
   )   r
   r
   r
   )   r
   r
   r
   )r	   r
   r   r   )   r
   r
   r
   )   r
   r
   r
   )r   r   r   )r
   r   r   )depthdim_mulhead_mulpool_q_stridepool_kvq_kernelpool_kv_stride_adaptiveTFc           	         s   ddg}|r|dg7 }|   }|  D ]Y\ }t fdd|D rl|  }|jd |jd krdtj|d|jd ddd	d|jd d
d}|d|jd dd}|rctd	 |j |j n|}|
 | < q|S )N	rel_pos_h	rel_pos_w	rel_pos_tc                    s   g | ]}| v qS  r'   .0xkr'   \/home/ubuntu/.local/lib/python3.10/site-packages/modelscope/models/multi_modal/mplug/mvit.py
<listcomp>-       z-interpolate_rel_pos_embed.<locals>.<listcomp>r   r
   r   linearsizemodezInflate {}: {} -> {}: {})copyitemsanyshapeFinterpolatereshapepermuteprintformatclone)	state_dict_originstate_dict_modeltemporalverboserel_pos_embed_typesstate_dict_inflatedv2dv3drel_pos_resizedr'   r+   r-   interpolate_rel_pos_embed#   s.   
rI   c                    s  | d }t |d t |d }}tt| d D ]| d  d || d  d < qtt| d D ]| d  d || d  d < q6dd t|D }dd t|D }d	d t|D d
d t|D }tt| d D ] | d  dd  | d  d < | d || d  d < qu| d d ur| d  g | d< t| d D ]#t dkrÇ fddtt D  | d g   qtt| d D ] | d  dd  || d  d < | d || d  d < q|||||fS )Nr   r
   r   r   r    c                 S      g | ]}g qS r'   r'   r)   ir'   r'   r-   r.   G       z)_prepare_mvit_configs.<locals>.<listcomp>c                 S   rJ   r'   r'   rK   r'   r'   r-   r.   H   rM   c                 S   rJ   r'   r'   rK   r'   r'   r-   r.   I   rM   c                 S   rJ   r'   r'   rK   r'   r'   r-   r.   J   rM   r!   r"   r#   pool_kv_stridec                    s&   g | ]}t  |  |  d qS r
   )max)r)   d
_stride_kvrL   stride_qr'   r-   r.   U   s    )torchonesrangelenappend)cfgr   r   r    pool_qpool_kv	stride_kvr'   rR   r-   _prepare_mvit_configs?   s4   ""$
$r^   c                       s0   e Zd Zddejdf fdd	Zdd Z  ZS )MlpN        c                    sb   t    || _|p|}|p|}t||| _| | _t||| _| jdkr/t|| _	d S d S Nr`   )
super__init__	drop_ratennLinearfc1actfc2Dropoutdrop)selfin_featureshidden_featuresout_features	act_layerrd   	__class__r'   r-   rc   d   s   

zMlp.__init__c                 C   sJ   |  |}| |}| jdkr| |}| |}| jdkr#| |}|S ra   )rg   rh   rd   rk   ri   rl   r*   r'   r'   r-   forwardv   s   






zMlp.forward)__name__
__module____qualname__re   GELUrc   rt   __classcell__r'   r'   rq   r-   r_   b   s    r_   c                       s$   e Zd Z fddZdd Z  ZS )Permutec                    s   t    || _d S N)rb   rc   dims)rl   r|   rq   r'   r-   rc      s   

zPermute.__init__c                 C   s   |j | j S r{   )r<   r|   rs   r'   r'   r-   rt      s   zPermute.forward)ru   rv   rw   rc   rt   ry   r'   r'   rq   r-   rz      s    rz   r`   	drop_probtrainingc                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )z&
    Stochastic Depth per sample.
    r`   r
   r   rO   )dtypedevice)r8   ndimrU   randr   r   floor_div)r*   r}   r~   	keep_probr8   maskoutputr'   r'   r-   	drop_path   s   r   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )DropPathzYDrop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).Nc                    s   t t|   || _d S r{   )rb   r   rc   r}   )rl   r}   rq   r'   r-   rc      s   
zDropPath.__init__c                 C   s   t || j| jS r{   )r   r}   r~   rs   r'   r'   r-   rt      s   zDropPath.forwardr{   ru   rv   rw   __doc__rc   rt   ry   r'   r'   rq   r-   r      s    r   r
   c                 C   s   |s| S | |9 } |p|}|r0t d|  t d|  d|  t dt| |d  | |   t|t| |d  | | }|d|  k rI||7 }t|S )Nz
min width zwidth z	 divisor zother r   g?)r=   intrP   )width
multiplier	min_widthdivisorrC   	width_outr'   r'   r-   round_width   s   "r   c                       s6   e Zd ZdZ						d fdd		Zd
d Z  ZS )
PatchEmbedz
    PatchEmbed.
    r      r   r   r   r   r   r   Fc                    s4   t    |rtj}ntj}||||||d| _d S )N)kernel_sizestridepadding)rb   rc   re   Conv2dConv3dproj)rl   dim_indim_outkernelr   r   conv2dconv_functionrq   r'   r-   rc      s   
	zPatchEmbed.__init__c                 C   s"   |  |}|ddd|jfS )Nr   r
   )r   flatten	transposer8   rs   r'   r'   r-   rt      s   
zPatchEmbed.forward)r   r   r   r   r   Fr   r'   r'   rq   r-   r      s    r   c                 C   sn  |d u r| |fS | j }|dkrn|dkr| d} ntd| j |rE| d d d d d dd d f | d d d d dd d d f }} | j\}}}	}
|\}}}| || ||||
ddddd } || } | jd | jd | jd g}| jd | jd  | jd  }| |||
|dd} |rtj	|| fdd} |d ur|| } |dkr	 | |fS | 
d} | |fS )Nr   r   r
   zUnsupported input dimension r   r   dim)r   	unsqueezeNotImplementedErrorr8   r;   r<   
contiguousr   rU   catsqueeze)tensorpool	thw_shapehas_cls_embednorm
tensor_dimcls_tokBNLCTHWL_pooledr'   r'   r-   attention_pool   s:   
B
&
r   c                 C   s\   t |tr,| jd }||kr| S tj| d|dddd|dd}|d|ddS d S )Nr   r
   r0   r   r1   r2   )
isinstancer   r8   r9   r:   r;   r<   )rel_posrQ   ori_dnew_pos_embedr'   r'   r-   get_rel_pos   s   

r   c                  C   s\  |rdnd}|\}	}
}|\}}}t dt|
| d }t dt|| d }t||
 d}t|
| d}t|
dddf | t|dddf |  }||d | 7 }t|| d}t|| d}t|dddf | t|dddf |  }||d | 7 }t||}t||}||  }||  }|j\}}}}|dddd|df |||	|
||}td||}td||}| dddd|d|df 	|d|	|
|||||ddddddddddddddf  |ddddddddddddddf  	|d|	|
 | || | | dddd|d|df< | S )	z<
    Decomposed Spatial Relative Positional Embeddings.
    r
   r   r         ?Nzbythwc,hkc->bythwkzbythwc,wkc->bythwkr0   )
r   rP   rU   aranger   longr8   r;   einsumview) attnqr,   r   q_shapek_shaper$   r%   sp_idxq_tq_hq_wk_tk_hk_wdhdw	q_h_ratio	k_h_ratiodist_h	q_w_ratio	k_w_ratiodist_wRhRwr   n_headq_Nr   r_qrel_h_qrel_w_qr'   r'   r-   cal_rel_pos_spatial  sR   



*2..
r   c              
   C   s  |rdnd}|\}}}	|\}
}}t dt||
 d }t||}t|
| d}t||
 d}t|dddf | t|
dddf |  }||
d | 7 }||  }|j\}}}}|dddd|df |||||	|}|dddddd||| | |	 |}t	||
dd
dd}|||||	||
dddddd}| dddd|d|df |d	|||	|
|||ddddddddddddddf  |d	|| |	 |
| | | dddd|d|df< | S )
z2
    Temporal Relative Positional Embeddings.
    r
   r   r   r   Nr   r   r   r0   )r   rP   r   rU   r   r   r8   r;   r<   matmulr   r   )r   r   r   r   r   r&   r   r   r   r   r   r   r   dt	q_t_ratio	k_t_ratiodist_tRtr   r   r   r   r   relr'   r'   r-   cal_rel_pos_temporal=  s8   


*$$2.
r   c                       sH   e Zd Zdddddddejddddddddf fdd	Zd	d
 Z  ZS )MultiScaleAttentionr   Fr`   r
   r
   r
   Tconvc              	      s  t    || _|| _|| _|| _|| _|| }|d | _|| _dd |D }dd |D }|s2|rNt	j
|||d| _t	j
|||d| _t	j
|||d| _nt	j
||d |d| _t	
||| _|dkrjt	|| _t|dkrzt|	dkrzd	}t|dkrt|
dkrd	}|| _|d
v r|dkrt	jnt	j}t|dkr|||	|ddnd | _t|dkr|||
|ddnd | _t|dkr|||
|ddnd | _n|dks|dkr\|r|dkr|| n|}n
|dkr|| n|}t|dkrt	j||||	||ddnd | _t|dkr||nd | _t|dkr&t	j||||
||ddnd | _t|dkr4||nd | _t|dkrJt	j||||
||ddnd | _t|dkrX||nd | _ntd| || _ || _!| j r|d |d ksxJ |d }t|	dkr||	d  n|}t|
dkr||
d  n|}dt"|| d }t	#t$%||| _&t	#t$%||| _'|st(| j&dd t(| j'dd | j!rt	#t$%d|d  d || _)|| _*d S )Ng      c                 S      g | ]}t |d  qS r   r   )r)   r   r'   r'   r-   r.         z0MultiScaleAttention.__init__.<locals>.<listcomp>c                 S   r   r   r   )r)   kvr'   r'   r-   r.     r   )biasr   r`   r
   r'   )avgrP   rP   r   F	ceil_moder   conv_unshared)r   r   groupsr   zUnsupported model r   {Gz?std)+rb   rc   
pool_firstseparate_qkvrd   	num_headsr   scaler   re   rf   r   r,   vqkvr   rj   	proj_dropnpprodr4   	MaxPool3d	AvgPool3drX   r[   pool_kpool_vr   norm_qnorm_knorm_vr   rel_pos_spatialrel_pos_temporalrP   	ParameterrU   zerosr$   r%   r   r&   residual_pooling)rl   r   r   
input_sizer   qkv_biasrd   kernel_q	kernel_kvrT   r]   
norm_layerr   r4   r   r
  r  rel_pos_zero_initr  r   head_dim	padding_q
padding_kvpool_opdim_convr3   q_sizekv_size
rel_sp_dimrq   r'   r-   rc   g  s   








zMultiScaleAttention.__init__c              	   C   s  |j \}}}| jr)| jdkrd}n| j}||||ddddd}| } }}	ni| jdks0J | jsV| |||d| jdddddd}
|
d |
d |
d }}}	n<| } }}	| |||| jddddd}| 	|||| jddddd}| 
|	||| jddddd}	t|| j|| jt| dr| jnd d	\}}t|| j|| jt| d
r| jnd d	\}}t|	| j|| jt| dr| jnd d	\}	}| jr_| jrt|d nt|}| jrt|d nt|}| jrt|d nt|}|dddd||d}| |||| jddddd}|	dddd||d}	| 
|	||| jddddd}	|dddd||d}| 	|||| jddddd}|j d }|| j |dd }| jrt|||| j||| j| j}| jrt||| j||| j}|jdd}||	 }| j r| jr|d d d d dd d d f  |d d d d dd d d f 7  < n|| }|dd|d| j!}| "|}| j#dkr| $|}||fS )Nr   r
   r0   r   r   r   r   r  )r   r   r  r	  r   r`   )%r8   r   r4   r   r;   r<   r   r   r   r,   r   r   r[   r   hasattrr  r  r  r  r	  r  r  r   r   r
  r   r$   r%   r  r   r&   softmaxr  r   r   rd   r   )rl   r*   r   r   r   _fold_dimr   r,   r   r   r   r   v_shaper   k_Nv_Nr   r'   r'   r-   rt     s   







F

zMultiScaleAttention.forward)ru   rv   rw   re   	LayerNormrc   rt   ry   r'   r'   rq   r-   r   e  s(     r   c                       sV   e Zd Zdddddejejdddddddddddddddf fdd		Zd
d Z  ZS )MultiScaleBlock      @FNr`   r   r   Tc           !         sz  t    || _|| _||| _|| _dd |D }|}dd |D }|r'|n|}|| _t||fi d|d|d|d|d|d	|d
|d|d|d|d|d|d|d|d|d|d|| _|	dkrot	|	nt
 | _||| _t|| }|| _|d ur|dkr|| } n|} t||| |
|d| _||krt
||| _t|dkrt
j|||dd| _d S d | _d S )Nc                 S   s    g | ]}|d kr|d  n|qS rO   r'   )r)   sr'   r'   r-   r.     s     z,MultiScaleBlock.__init__.<locals>.<listcomp>c                 S   r   r   r   )r)   skipr'   r'   r-   r.     r   r   r  r  rd   r  r  rT   r]   r  r   r4   r   r
  r  r  r  r   r`   r
   )rm   rn   ro   rp   rd   r   Fr   )rb   rc   r   r   norm1dim_mul_in_attuse_grad_checkpointr   r   r   re   Identityr   norm2r   r   r_   mlprf   r   rX   r  	pool_skip)!rl   r   r   r   r  	mlp_ratior  qk_scalerd   r   rp   r  up_rater  r  rT   r]   r4   r   r   r
  r  r  r  r+  r   r,  kernel_skipstride_skippadding_skipatt_dimmlp_hidden_dimmlp_dim_outrq   r'   r-   rc   `  s   

	


zMultiScaleBlock.__init__c           	      C   s   |  |}| jrt| j||\}}n| ||\}}| jr)| j| jkr)| |}t|| j	|| j
d\}}|| | }| |}| jrLt| j|}n| |}| js_| j| jkr_| |}|| | }||fS )N)r   )r*  r,  
checkpointr   r+  r   r   r   r   r0  r   r   r.  r/  )	rl   r*   r   x_normx_blockthw_shape_newx_resr   x_mlpr'   r'   r-   rt     s(   






zMultiScaleBlock.forward)	ru   rv   rw   re   rx   r%  rc   rt   ry   r'   r'   rq   r-   r&  ^  s2    Wr&  c                       s   e Zd ZdZddddddg dg d	g d
dddddddddddddddddddf fdd	Zdd Zejjdd Z	dd Z
dd Zdd Z  ZS )MViTv2a  
    Improved Multiscale Vision Transformers for Classification and Detection
    Yanghao Li*, Chao-Yuan Wu*, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik,
        Christoph Feichtenhofer*
    https://arxiv.org/abs/2112.01526
    Multiscale Vision Transformers
    Haoqi Fan*, Bo Xiong*, Karttikeya Mangalam*, Yanghao Li*, Zhicheng Yan, Jitendra Malik,
        Christoph Feichtenhofer*
    https://arxiv.org/abs/2104.11227
       `   i  r   r
   r   )r   r   r   )r   r   r   )r
   r   r   Nr`   r'  Tr   Fc           -   	      s  t    d}|| _|| _|| _|| _|| _|| _|| _|| _	|| _
|| _|| _ttjdd}|r>tt|||||	d| _n
t|||||	d| _||d  ||d  ||d  g}t|} dd	 td||D }!| jr}ttdd|| _| d }"n| }"| jrttd|"|| _| jr| jrttd| jd | jd  || _ttd| jd || _| jrttdd|| _nttd|"|| _|
d usJ t|
\}#}$}%}&}'}(|})|r||nd | _ t! | _"t#|D ]}*t$||$|* }|rt$||#|* t$||$|* d
}+nt$||#|*d  t$||$|*d  d
}+t%d'i d|d|+d|d|)d|d|d| jd|!|* d|dt&|%|*krL|%|* ng dt&|&|*krZ|&|* ng dt&|'|*krh|'|* ng dt&|(|*krv|(|* ng d|d| jd|d|d|d|d|d|d |d!d"},|rt|,d"d#},| j"'|, t&|'|* dkrd$d	 t(|)|'|* D })|+}q||| _)t* | _+| jr| jrt,| jd%d& t,| jd%d& | jrt,| jd%d& nt,| jd%d& | jrt,| jd%d& | -| j. d S )(Nr   gư>)eps)r   r   r   r   r   r   r
   r   c                 S   s   g | ]}|  qS r'   )itemr(   r'   r'   r-   r.   '  r/   z#MViTv2.__init__.<locals>.<listcomp>)r   r   r   r   r  r1  r  rd   r   r  r  r  rT   r]   r4   r   r   r
  r  r  r  r+  r   r,  F)offload_to_cpuc                 S   s   g | ]\}}|| qS r'   r'   )r)   r3   r   r'   r'   r-   r.   w  s    r   r   r'   )/rb   rc   img_sizenum_classes	embed_dimr   r   cls_embed_onuse_abs_poszero_decay_pos_clsr,  sep_pos_embedrd   r   re   r%  r   r   patch_embedr  r  rU   linspacer  r  	cls_token	pos_embed
patch_dimspos_embed_spatialpos_embed_temporalpos_embed_classr^   	norm_stem
ModuleListblocksrW   r   r&  rX   rY   zipr   r-  headr   apply_init_weights)-rl   rF  rH  rG  
num_framesr   r   patch_kernelpatch_stridepatch_paddingconfigdropout_ratedrop_path_rater1  r  r4   rI  rJ  r
  r  r  r  r+  r   rK  r   rU  rL  r,  in_chansr  rQ  num_patchesdprpos_embed_dimr   r    r[   r\   rT   r]   r  rL   r   attention_blockrq   r'   r-   rc     s:  
	
	





	


zMViTv2.__init__c                 C   s   t |tjr(tjj|jdd t |tjr$|jd ur&tj|jd d S d S d S t |tjr@tj|jd tj|jd d S d S )Nr   r   r   r   )	r   re   rf   initr   weightr   	constant_r%  )rl   mr'   r'   r-   r[    s   zMViTv2._init_weightsc                 C   sl   g }| j r4| jr| jr|g d n|dg | jr#|g d | jr,|dg | jr4|d |S )N)rR  rS  rT  rP  )r$   r%   
rel_pos_hwr&   rO  )rK  rJ  rL  extendrY   r
  r  rI  )rl   namesr'   r'   r-   no_weight_decay  s   
zMViTv2.no_weight_decayc                 C   s  |d |d |d }}}| j r(|d d ddd d f }|d d dd f }|jd }| j\}}	}
||	 |
 |ks=J ||	|
f|||fkr{tj|d d d d d d f d||	|
dddddd|||fd	d
}|dd|| | ddd}| j rtj||fdd}|S )Nr  r0   r   r
   r   r   r   	trilinearr2   r   )	rI  r8   rQ  r9   r:   r;   r<   rU   r   )rl   rP  bcthwthwcls_pos_embedtxy_nump_tp_hp_wr   r'   r'   r-   _get_pos_embed  s0   
"

zMViTv2._get_pos_embedc                 C   sJ  | ddddd}| |\}}|d |d |d }}}|j\}}}| jr8| j|dd}	tj|	|fdd	}| jr|| j	rq| j
d| jd dtj| j| jd | jd  dd	 }
| jrft| j|
gd}
| |
|}
||
 }n| | j|}
||
 }| jr| |}| jr| |}|||g}| jD ]	}|||\}}q| |}|S )
Nr   r   r
   r   r   rp  r  r0   r   )r<   rM  r8   rI  rO  expandrU   r   rJ  rL  rR  repeatrQ  repeat_interleaverS  rT  r{  rP  rd   pos_droprU  rW  r   )rl   r*   rr  r   r   r   r   r   r   
cls_tokensrP  thwblkr'   r'   r-   forward_features  sF   





zMViTv2.forward_featuresc                 C   s   |  |}| |}|S r{   )r  rY  rs   r'   r'   r-   rt     s   

zMViTv2.forward)ru   rv   rw   r   rc   r[  rU   jitignorero  r{  r  rt   ry   r'   r'   rq   r-   r@    sL     2	
(r@  )TF)r`   F)r
   r
   F)TN)'collectionsr   	functoolsr   numpyr  rU   torch.nnre   torch.nn.functional
functionalr9   torch.utils.checkpointutilsr:  timm.models.layersr   fairscale.nn.checkpointr   ImportErrorMViTv2_Base_configrI   r^   Moduler_   rz   floatboolr   r   r   r   r   r   r   r   r   r&  r@  r'   r'   r'   r-   <module>   sV   
#


#%0( zr