o
    ߥiH                     @   s   d dl Z d dlZd dlmZ d dlm  mZ d dlmZm	Z	 dde
defddZG d	d
 d
ejZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZdS )    N)_pair_triple        F	drop_probtrainingc                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )a  
    From https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py.
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    r      r   )r   )dtypedevice)shapendimtorchrandr   r	   floor_div)xr   r   	keep_probr
   random_tensoroutput r   i/home/ubuntu/.local/lib/python3.10/site-packages/modelscope/models/cv/action_recognition/tada_convnext.py	drop_path   s   

r   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )DropPathz
    From https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py.
    Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    Nc                    s   t t|   || _d S N)superr   __init__r   )selfr   	__class__r   r   r   )   s   
zDropPath.__init__c                 C   s   t || j| jS r   )r   r   r   r   r   r   r   r   forward-   s   zDropPath.forwardr   __name__
__module____qualname____doc__r   r   __classcell__r   r   r   r   r   #   s    r   c                       s8   e Zd ZdZ fddZdd Zdd Zdd	 Z  ZS )
TadaConvNeXta   ConvNeXt
        A PyTorch impl of : `A ConvNet for the 2020s`  -
          https://arxiv.org/pdf/2201.03545.pdf

    Args:
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
        dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
        drop_path_rate (float): Stochastic depth rate. Default: 0.
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
        head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
    c           
   
      s  t     jjj} jjj jjj} jjj} jjjt	 jjj
dr, jjj
jnd}t	 jjj
dr< jjj
jnd}t | _ttj|d |ddf|ddf|d d ddfdtd dd	d
}| j| tdD ]"tt dd	d
tj d  ddd}| j| qqt | _dd td|t|D dtdD ]#tj fddt| D  }	| j|	 | 7 qtjd dd| _d S )NT_KERNEL_SIZE   T_STRIDEr      r   )kernel_sizestridepaddingư>channels_firstepsdata_format   )r   r(   r(   )r+   r,   c                 S   s   g | ]}|  qS r   )item).0r   r   r   r   
<listcomp>i   s    z)TadaConvNeXt.__init__.<locals>.<listcomp>c                    s(   g | ]}t   |  d qS ))dimr   layer_scale_init_value)TAdaConvNeXtBlock)r5   jcfgcurdimsdp_ratesir8   r   r   r6   n   s    
r1   )r   r   VIDEOBACKBONENUM_INPUT_CHANNELSNUM_FILTERS	DROP_PATHDEPTHLARGE_SCALE_INIT_VALUEhasattrSTEMr'   r)   nn
ModuleListdownsample_layers
SequentialConv3d	LayerNormappendrangestagesr   linspacesumnorm)
r   r<   in_chansdrop_path_ratedepthsstem_t_kernel_sizet_stridestemdownsample_layerstager   r;   r   r   @   sf   











zTadaConvNeXt.__init__c                 C   s>   t dD ]}| j| |}| j| |}q| |g dS )Nr*   )rA   )rS   rN   rT   rW   mean)r   r   r@   r   r   r   forward_features{   s   zTadaConvNeXt.forward_featuresc                 C   s    t |tr	|d }| |}|S )Nvideo)
isinstancedictrc   r   r   r   r   r      s   

zTadaConvNeXt.forwardc                 C   s   dS )N)   r   r   r   r   r   r   get_num_layers   s   zTadaConvNeXt.get_num_layers)	r!   r"   r#   r$   r   rc   r   ri   r%   r   r   r   r   r&   1   s    ;r&   c                       *   e Zd ZdZd fdd	Zdd Z  ZS )	ConvNeXtBlocka   ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch

    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    r   r.   c                    s   t    tj||dd|d| _t|dd| _t|d| | _t	 | _
td| || _|dkr>tj|t| dd	nd | _|d
krLt|| _d S t | _d S )Nr      rm   r   r3   r3   )r+   r-   groupsr.   rB   r*   r   Trequires_gradr   )r   r   rL   rP   dwconvrQ   rW   Linearpwconv1GELUactpwconv2	Parameterr   onesgammar   Identityr   )r   r<   r7   r   r8   r   r   r   r      s2   


zConvNeXtBlock.__init__c                 C   s   |}|  |}|ddddd}| |}| |}| |}| |}| jd ur.| j| }|ddddd}|| | }|S Nr   r(   r3   r*   r   )rr   permuterW   rt   rv   rw   rz   r   r   r   inputr   r   r   r      s   






zConvNeXtBlock.forwardr   r.   r    r   r   r   r   rk      s    rk   c                       s.   e Zd ZdZ		d fdd	Zdd Z  ZS )	rQ   aF   LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
    with shape (batch_size, channels, height, width).
    r.   channels_lastc                    sT   t    tt|| _tt|| _|| _	|| _
| j
dvr$t|f| _d S )N)r   r/   )r   r   rL   rx   r   ry   weightzerosbiasr1   r2   NotImplementedErrornormalized_shape)r   r   r1   r2   r   r   r   r      s   

zLayerNorm.__init__c                 C   s   | j dkrt|| j| j| j| jS | j dkrP|jddd}|| djddd}|| t	
|| j  }| jd d d d d f | | jd d d d d f  }|S d S )Nr   r/   r   T)keepdimr(   )r2   F
layer_normr   r   r   r1   rb   powr   sqrt)r   r   usr   r   r   r      s   

&zLayerNorm.forward)r.   r   r    r   r   r   r   rQ      s    rQ   c                       rj   )	r9   a   ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_fi rst) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch

    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    r   r.   c                    s0  t    t|}t||dd|dd| _|jjjj}|dkr4t	||jjjj
|jjjj| jjd ud| _n!|dkrNt||jjjj
|jjjj| jjd ud| _ntd|t|d	d
| _t|d| | _t | _td| || _|dkrtj|t| ddnd | _|dkrt|| _d S t | _d S )Nrl   rn   cout)r+   r-   ro   cal_dimnormal)c_inratiokernelswith_bias_calnormal_lngeluzUnknown route_func_type: {}r.   rB   r*   r   Trp   r   ) r   r   float
TAdaConv2drr   rC   rD   BRANCHROUTE_FUNC_TYPERouteFuncMLPROUTE_FUNC_RROUTE_FUNC_Kr   	dwconv_rfRouteFuncMLPLnGelu
ValueErrorformatrQ   rW   rL   rs   rt   ru   rv   rw   rx   r   ry   rz   r   r{   r   )r   r<   r7   r   r8   route_func_typer   r   r   r      s`   










zTAdaConvNeXtBlock.__init__c                 C   s   |}|  || |}|ddddd}| |}| |}| |}| |}| jd ur2| j| }|ddddd}|| | }|S r|   )	rr   r   r}   rW   rt   rv   rw   rz   r   r~   r   r   r   r     s   





zTAdaConvNeXtBlock.forwardr   r    r   r   r   r   r9      s    &r9   c                       s0   e Zd ZdZ			d	 fdd	Zdd Z  ZS )
r   zF
    The routing function for generating the calibration weights.
    Fh㈵>皙?c                    s@  t t|   || _|| _td| _td| _tj	||ddd| _
tj	|t|| |d ddg|d d ddgd| _tt|| ddd| _t | _tj	t|| ||d ddg|d d ddgd	d
| _d| j_| jjj  |rtj	t|| ||d ddg|d d ddgd	d
| _d| j_| jjj  dS dS )z
        Args:
            c_in (int): number of input channels.
            ratio (int): reduction ratio for the routing function.
            kernels (list): temporal kernel size of the stacked 1D convolutions
        )Nr   r   r   r   )in_channelsout_channelsr+   r-   r(   r.   r/   r0   F)r   r   r+   r-   r   TN)r   r   r   r   r   rL   AdaptiveAvgPool3davgpool
globalpoolrP   gintarQ   lnru   gelub	skip_initr   datazero_b_bias)r   r   r   r   r   bn_epsbn_mmtr   r   r   r   !  sR   



zRouteFuncMLPLnGelu.__init__c                 C   sl   |  |}| |}| || | }| |}| |}| jr/| |d | |d gS | |d S )Nr   )	r   r   r   r   r   r   r   r   r   )r   r   r   r   r   r   r   X  s   



zRouteFuncMLPLnGelu.forward)Fr   r   r    r   r   r   r   r     s    7r   c                       s>   e Zd ZdZ						d fdd	Zdd	 Zd
d Z  ZS )r   z
    Performs temporally adaptive 2D convolution.
    Currently, only application on 5D tensors is supported, which makes TAdaConv2d
        essentially a 3D convolution with temporal kernel size of 1.
    r   r   Tcinc
                    sZ  t t|   	 t|}t|}t|}t|}|d dks J |d dks(J |d dks0J |d dks8J |	dv s>J || _|| _|| _|| _|| _|| _	|| _
|	| _ttdd||| |d |d | _|rxttdd|| _n| dd  tjj| jtdd | jd urtj| j\}
}dt|
 }tj| j| | d S d S )Nr   r   )r   r   r(   r      )r   )r   r   r   r   r   r   r+   r,   r-   dilationro   r   rL   rx   r   Tensorr   r   register_parameterinitkaiming_uniform_mathr   _calculate_fan_in_and_fan_outuniform_)r   r   r   r+   r,   r-   r   ro   r   r   fan_in_boundr   r   r   r   m  sB   

zTAdaConv2d.__init__c              	   C   s  t |tr|d |d }}n|}d}| j \}}}}}}	| \}
}}}}|ddddddd||}| jdkrT|dddddd| j d|| j ||	}n| jd	krq|dddddd| j d|| j ||	}d}| j	dur|dur|ddddd
 | j	 d}n| j	|
|dd}tj|||| jdd | jdd | jdd | j|
 | d
}||
|||d|dddddd}|S )z
        Args:
            x (tensor): feature to perform convolution on.
            alpha (tensor): calibration weight for the base weights.
                W_t = alpha_t * W_b
        r   r   Nr(   r3   r*   rA   r   r   )r   r   r,   r-   r   ro   ra   )re   listr   sizer}   reshaper   	unsqueezero   r   squeezerepeatr   conv2dr,   r-   r   view)r   r   alphaw_alphab_alphar   c_outr   khkwr   thwr   r   r   r   r   r   r     sV   



	zTAdaConv2d.forwardc              
   C   sH   d| j  d| j d| j dd| j d| j d| jd u d| j d	 S )	NzTAdaConv2d(z, z, kernel_size=zstride=z
, padding=z, bias=z, cal_dim="z"))r   r   r+   r,   r-   r   r   rh   r   r   r   __repr__  s   (zTAdaConv2d.__repr__)r   r   r   r   Tr   )r!   r"   r#   r$   r   r   r   r%   r   r   r   r   r   f  s    
<1r   )r   F)r   r   torch.nnrL   torch.nn.functional
functionalr   torch.nn.modules.utilsr   r   r   boolr   Moduler   r&   rk   rQ   r9   r   r   r   r   r   r   <module>   s   [-!BJ