o
    ϯi6                     @   s  d dl Z d dlmZ d dlm  mZ d dlmZ d dlZ	d dl
Z
d dlZd dlmZ d dlm  mZ d dlZd dlmZmZ d dlmZ d dlmZ ddlmZmZ ddlmZmZmZ d	d
 ZedZedZedZ edZ!eZ"d8de#de$fddZ%G dd dej&Z'G dd dej&Z(G dd dej&Z)dd Z*d9dd Z+d:d#d$Z,d%d& Z-d'd( Z.d)d* Z/G d+d, d,ej&Z0G d-d. d.ej&Z1G d/d0 d0ej&Z2G d1d2 d2ej&Z3G d3d4 d4ej&Z4d;d6d7Z5dS )<    N)repeat)_calculate_fan_in_and_fan_out)SpectrogramLogmelFilterBank)SpecAugmentation   )do_mixupinterpolate)iAFFAFFDAFc                    s    fdd}|S )Nc                    s    t | tjjr	| S tt|  S N)
isinstancecollectionsabcIterabletupler   xn P/home/ubuntu/.local/lib/python3.10/site-packages/laion_clap/clap_module/htsat.pyparse   s   z_ntuple.<locals>.parser   )r   r   r   r   r   _ntuple   s   r                    F	drop_probtrainingc                 C   sd   |dks|s| S d| }| j d fd| jd   }|tj|| j| jd }|  | || }|S )a&  Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    r   r   r   r   )dtypedevice)shapendimtorchrandr"   r#   floor_div)r   r   r    	keep_probr$   random_tensoroutputr   r   r   	drop_path+   s   r-   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )DropPathz^Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    Nc                    s   t t|   || _d S r   )superr.   __init__r   )selfr   	__class__r   r   r0   @   s   
zDropPath.__init__c                 C   s   t || j| jS r   )r-   r   r    r1   r   r   r   r   forwardD   s   zDropPath.forwardr   __name__
__module____qualname____doc__r0   r5   __classcell__r   r   r2   r   r.   =   s    r.   c                       s0   e Zd ZdZ			d fd
d	ZdddZ  ZS )
PatchEmbedz! 2D Image to Patch Embedding
          r      NTFNonec
                    s  t    t|}t|}t|}|| _|| _|| _|d |d  |d |d  f| _| jd | jd  | _|| _|| _	|| _
|| _|	| _|d |d  d |d |d  d f}
| jro| jdkrotj|d ||||
d| _ntj|||||
d| _|r||nt | _| jr| jdv rtj|||d |d d f|d |d d f|
d| _| jd	krt | _d S | jd
krt|dd| _d S | jdkrt|dd| _d S d S d S d S )Nr   r   r   channel_mapr   kernel_sizestridepaddingdaf_2daff_2diaff_2dr   rG   rH   2DchannelstyperI   )r/   r0   	to_2tupleimg_size
patch_sizepatch_stride	grid_sizenum_patchesflattenin_chans	embed_dimenable_fusionfusion_typennConv2dprojIdentitynorm
mel_conv2dr   fusion_modelr   r
   )r1   rO   rP   rU   rV   
norm_layerrT   rQ   rW   rX   rE   r2   r   r   r0   J   s:   
"(6


zPatchEmbed.__init__c              
   C   sB  | j r| jdv r|d d ddd d d d f }|j\}}}}|| jd kr-|| jd ksDJ d| d| d| jd  d| jd  d	| |}|d}t|dkr||dd d d d d f  }	|	j\}}}}|	|| d||}	| 	|	}	|	|||	d|	d	|	d
}	|	
d d
}	|	 \}
}}}|	d|k rtj|	tj|
||||	d f|jdgdd}	n|	d d d d d d d |f }	| || |	||< |}n2|j\}}}}|| jd kr|| jd ksJ d| d| d| jd  d| jd  d	| |}| jr|d	dd	}| |}|S )NrF   r   r   zInput image size (*z) doesn't match model (z).r   r   )r   r   r   r   r   r#   dim)rW   rX   r$   rO   r[   sizelen
contiguousviewr^   permuterT   r&   catzerosr#   r_   	transposer]   )r1   r   
longer_idxglobal_xBCHWTWlocal_xTBTCTH_r   r   r   r5   l   s<    (

 
$2  (

zPatchEmbed.forward)	r=   r>   r   r?   NTr>   Fr@   r   r6   r   r   r2   r   r<   G   s    "r<   c                       s4   e Zd ZdZddejdf fdd	Zdd Z  ZS )MlpzG MLP as used in Vision Transformer, MLP-Mixer and related networks
    Nr   c                    sN   t    |p|}|p|}t||| _| | _t||| _t|| _d S r   )	r/   r0   rY   Linearfc1actfc2Dropoutdrop)r1   in_featureshidden_featuresout_features	act_layerr   r2   r   r   r0      s   
zMlp.__init__c                 C   s6   |  |}| |}| |}| |}| |}|S r   )r|   r}   r   r~   r4   r   r   r   r5      s   




zMlp.forward)	r7   r8   r9   r:   rY   GELUr0   r5   r;   r   r   r2   r   rz      s    	rz   c                 C   s   dd }||d|  k s||d|  krt jddd t B ||| | }||| | }| d| d d| d  |   | |td  | 	| | j
||d | W  d    S 1 sdw   Y  d S )	Nc                 S   s   dt | t d  d S )N      ?       @)matherfsqrtr   r   r   r   norm_cdf   s   z(_no_grad_trunc_normal_.<locals>.norm_cdfr   zjmean is more than 2 std from [a, b] in nn.init.trunc_normal_. The distribution of values may be incorrect.)
stacklevelr   r   )minmax)warningswarnr&   no_graduniform_erfinv_mul_r   r   add_clamp_)tensormeanstdabr   lur   r   r   _no_grad_trunc_normal_   s    

$r   r          r   c                 C   s   t | ||||S )a  Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.
    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    )r   )r   r   r   r   r   r   r   r   trunc_normal_   s   r   fan_innormalc           	      C   s   t | \}}|dkr|}n|dkr|}n
|dkr|| d }|| }|dkr3t| t|d d d S |dkrB| jt|d d S |d	krVtd
| }| | | d S td| )Nr   fan_outfan_avgr   truncated_normalg۶%?r   r   uniformr   zinvalid distribution )r   r   r   r   normal_r   
ValueError)	r   scalemodedistributionr   r   denomvarianceboundr   r   r   variance_scaling_   s    r   c                 C   s   t | ddd d S )Nr   r   )r   r   )r   )r   r   r   r   lecun_normal_   s   r   c                 C   sR   | j \}}}}| ||| ||| ||} | dddddd d|||}|S )z
    Args:
        x: (B, H, W, C)
        window_size (int): window size
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    r   r   r   r   r      rb   )r$   ri   rj   rh   )r   window_sizerp   rr   rs   rq   windowsr   r   r   window_partition   s   $r   c                 C   sb   t | jd || | |  }| ||| || ||d}|dddddd |||d}|S )z
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    r   rb   r   r   r   r   r   )intr$   ri   rj   rh   )r   r   rr   rs   rp   r   r   r   r   window_reverse  s   
$r   c                       s4   e Zd ZdZd fdd	ZdddZd	d
 Z  ZS )WindowAttentiona   Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.
    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    TNr   c                    s  t    || _|| _|| _|| }|p|d | _tt	d|d  d d|d  d  || _
t| jd }	t| jd }
tt|	|
g}t|d}|d d d d d f |d d d d d f  }|ddd }|d d d d df  | jd d 7  < |d d d d df  | jd d 7  < |d d d d df  d| jd  d 9  < |d}| d| tj||d |d| _t|| _t||| _t|| _t| j
d	d
 tjdd| _d S )Ng      r   r   r   rb   relative_position_indexr   bias{Gz?r   rd   )r/   r0   re   r   	num_headsr   rY   	Parameterr&   rl   relative_position_bias_tablearangestackmeshgridrT   rj   rh   sumregister_bufferr{   qkvr   	attn_dropr[   	proj_dropr   Softmaxsoftmax)r1   re   r   r   qkv_biasqk_scaler   r   head_dimcoords_hcoords_wcoordscoords_flattenrelative_coordsr   r2   r   r   r0   $  s4   
&,((,
zWindowAttention.__init__c                 C   sn  |j \}}}| |||d| j|| j ddddd}|d |d |d }}}	|| j }||dd }
| j| j	d 	| j
d | j
d  | j
d | j
d  d}|ddd }|
|d }
|dur|j d }|
	|| || j|||dd }
|
	d| j||}
| |
}
n| |
}
| |
}
|
|	 dd|||}| |}| |}||
fS )	z
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        r   r   r   r   r   rb   N)r$   r   reshaper   rj   r   rm   r   r   ri   r   rh   	unsqueezer   r   r[   r   )r1   r   maskB_Nrq   r   qkvattnrelative_position_biasnWr   r   r   r5   F  s*   .
&
(



zWindowAttention.forwardc                 C      d| j  d| j d| j S )Ndim=, window_size=, num_heads=)re   r   r   r1   r   r   r   
extra_reprg     zWindowAttention.extra_repr)TNr   r   r   )r7   r8   r9   r:   r0   r5   r   r;   r   r   r2   r   r     s
    
"!r   c                       sL   e Zd ZdZddddddddejejdf fd	d
	Zdd Zdd Z	  Z
S )SwinTransformerBlocka   Swin Transformer Block.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
       r         @TNr   lnc              	      s:  t     | _|| _|| _|| _|| _|| _|| _t	| j| jkr+d| _t	| j| _d| j  kr:| jk s?J d J d| | _
t t| j||||
|	d| _|dkr[t|nt | _| jdkrlt | _n| jdkry fdd| _ntt | }t |||	d	| _| jdkr| j\}}td
||d
f}td| j t| j | j t| j d f}td| j t| j | j t| j d f}d}|D ]}|D ]}||d d ||d d f< |d
7 }qqt|| j}|d| j| j }|d
|d }||dktd|dktd}nd }|  d| d S )Nr   z shift_size must in 0-window_size)r   r   r   r   r   r   r   r   bnc                    s   t  | ddddS )Nr   r   )rY   BatchNorm1drm   r   rd   r   r   <lambda>  s    z/SwinTransformerBlock.__init__.<locals>.<lambda>)r   r   r   r   r   rb   r   g      Y	attn_mask)!r/   r0   re   input_resolutionr   r   
shift_size	mlp_rationorm_before_mlpr   norm1r   rN   r   r.   rY   r\   r-   	LayerNormnorm2NotImplementedErrorr   rz   mlpr&   rl   slicer   ri   r   masked_fillfloatr   )r1   re   r   r   r   r   r   r   r   r   r   r-   r   r`   r   mlp_hidden_dimrr   rs   img_maskh_slicesw_slicescnthwmask_windowsr   r2   rd   r   r0   ~  s`   
(




&zSwinTransformerBlock.__init__c                 C   s$  | j \}}|j\}}}|}| |}|||||}| jdkr.tj|| j | j fdd}n|}t|| j}	|	d| j| j |}	| j	|	| j
d\}
}|
d| j| j|}
t|
| j||}| jdkrotj|| j| jfdd}n|}|||| |}|| | }|| | | | }||fS )Nr   )r   r   )shiftsdimsrb   )r   )r   r$   r   ri   r   r&   rollr   r   r   r   r   r-   r   r   )r1   r   rr   rs   rp   Lrq   shortcut	shifted_x	x_windowsattn_windowsr   r   r   r   r5     s(   



zSwinTransformerBlock.forwardc                 C   s4   d| j  d| j d| j d| j d| j d| j S )Nr   , input_resolution=r   r   z, shift_size=z, mlp_ratio=)re   r   r   r   r   r   r   r   r   r   r     s   zSwinTransformerBlock.extra_repr)r7   r8   r9   r:   rY   r   r   r0   r5   r   r;   r   r   r2   r   r   l  s    
9+r   c                       s6   e Zd ZdZejf fdd	Zdd Zdd Z  Z	S )PatchMergingz Patch Merging Layer.
    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    c                    sB   t    || _|| _tjd| d| dd| _|d| | _d S )Nr   r   Fr   )r/   r0   r   re   rY   r{   	reductionr]   )r1   r   re   r`   r2   r   r   r0     s
   
zPatchMerging.__init__c                 C   s6  | j \}}|j\}}}||| ksJ d|d dkr!|d dks,J d| d| d|||||}|ddddddddddf }|ddddddddddf }|ddddddddddf }	|ddddddddddf }
t|||	|
gd	}||d	d
| }| |}| |}|S )z
        x: B, H*W, C
        zinput feature has wrong sizer   r   zx size (ra   z) are not even.Nr   rb   r   )r   r$   ri   r&   rk   r]   r  )r1   r   rr   rs   rp   r  rq   x0x1x2x3r   r   r   r5     s   
.$$$$

zPatchMerging.forwardc                 C   s   d| j  d| j S )Nzinput_resolution=z, dim=)r   re   r   r   r   r   r     s   zPatchMerging.extra_repr
r7   r8   r9   r:   rY   r   r0   r5   r   r;   r   r   r2   r   r
    s
    r
  c                
       sH   e Zd ZdZddddddejdddf
 fdd		Zd
d Zdd Z  Z	S )
BasicLayera.   A basic Swin Transformer layer for one stage.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    r   TNr   Fr   c                    s|   t    | _| _|| _|| _t 	
fddt|D | _	|d ur9|d| _
d S d | _
d S )Nc                    sT   g | ]&}t |d  dkrdnd  
	 ttr"| ndqS )r   r   )re   r   r   r   r   r   r   r   r   r   r-   r`   r   )r   r   list).0ir   re   r   r-   r   r   r   r`   r   r   r   r   r   r   
<listcomp>1  s    	z'BasicLayer.__init__.<locals>.<listcomp>)re   r`   )r/   r0   re   r   depthuse_checkpointrY   
ModuleListrangeblocks
downsample)r1   re   r   r  r   r   r   r   r   r   r   r-   r`   r  r  r   r2   r  r   r0   %  s   
$	

zBasicLayer.__init__c                 C   s   g }| j D ]}| jrt||}q||\}}| js"||d q| jd ur-| |}| js>tj|dd}tj	|dd}||fS )Nr   rd   )
r  r  
checkpointr    appendr   r  r&   rk   r   )r1   r   attnsblkr   r   r   r   r5   B  s   


zBasicLayer.forwardc                 C   r   )Nr   r	  z, depth=)re   r   r  r   r   r   r   r   R  r   zBasicLayer.extra_reprr  r   r   r2   r   r    s    

r  c                       s   e Zd ZdZddddddg dg d	d
ddddddejdddddddf fdd	Zdd Zej	j
dd Zej	j
dd Zd&ddZd&ddZdd  Zd!d" Zd'd#ejfd$d%Z  ZS )(HTSAT_Swin_Transformera*  HTSAT based on the Swin Transformer
    Args:
        spec_size (int | tuple(int)): Input Spectrogram size. Default 256
        patch_size (int | tuple(int)): Patch size. Default: 4
        path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
        in_chans (int): Number of input image channels. Default: 1 (mono)
        num_classes (int): Number of classes for classification head. Default: 527
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 8
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
        config (module): The configuration Module from config.py
       r   r   r   r   i  `   r   r      r   r      r>       r(  r   TNr   g?Fr   r@   c           %         s  t t|   || _|| _|| _|| _|	| _|| _|| _	|| _
|| _|| _|| _t| j	| _t| jd| jd   | _|| _|| _|| _|| _d | _|| _| jrR|nd | _|| _|
| _|| _|| _|| _| j| jj | _d}d}d}d}d}d }d| _ t!|j|j"|j|||dd	| _#t$|j%|j|j|j&|j'|||dd
	| _(t)ddddd| _*t+,| jj| _-t.| j| j| j| j| j|| j| jd| _/| j/j0}| j/j1} | | _2| j
rt+3t45d|| j| _6t7| j6dd t+j8| jd| _9dd t4:d| jt;| j	D }!t+< | _=t>| jD ]]}"t?t| jd|"  | d d|"  | d d|"  f| j	|" | j|" | j| j| j| j| j| j|!t;| j	d |" t;| j	d |"d   | j|"| jd k rTt@nd || jd}#| j=A|# q| | j| _Bt+Cd| _Dt+Ed| _F| jdt| j	d   | jd  | j }$t+jG| j| j|$dfdd| _Ht+I||| _J| jr| jdv rt+Kt+jLddddddt+Md| _N| jdkrtO | _Pn| jdkrtQddd| _Pn| jd krtRddd| _P| S| jT d S )!Nr   r   hannTreflectr   g|=r)  )n_fft
hop_length
win_lengthwindowcenterpad_modefreeze_parameters)	srr,  n_melsfminfmaxrefamintop_dbr2  @   r(  )time_drop_widthtime_stripes_numfreq_drop_widthfreq_stripes_num)rO   rP   rU   rV   r`   rQ   rW   rX   r   r   )pc                 S   s   g | ]}|  qS r   )item)r  r   r   r   r   r    s    z3HTSAT_Swin_Transformer.__init__.<locals>.<listcomp>r   )re   r   r  r   r   r   r   r   r   r   r-   r`   r  r  r   r   )r   r   )in_channelsout_channelsrC   rE   daf_1daff_1diaff_1dr   rB   rD  rE  1DrK   rF  )Ur/   r!  r0   config	spec_sizerQ   rP   r   rV   depthsaperU   num_classesr   rg   
num_layersr   num_features	drop_rateattn_drop_ratedrop_path_rater   r   
patch_normr`   r   r   r  rW   rX   mel_bins
freq_ratiointerpolate_ratior   hop_sizespectrogram_extractorr   sample_rater5  r6  logmel_extractorr   spec_augmenterrY   BatchNorm2dbn0r<   patch_embedrS   rR   patches_resolutionr   r&   rl   absolute_pos_embedr   r   pos_droplinspacer   r  layersr  r  r
  r  r]   AdaptiveAvgPool1davgpoolAdaptiveMaxPool1dmaxpoolrZ   
tscam_convr{   head
SequentialConv1dr   
mel_conv1dr   r_   r   r
   apply_init_weights)%r1   rI  rP   rQ   rU   rL  rV   rJ  r   r   r   r   r   rO  rP  rQ  r`   rK  rR  r  r   rH  rW   rX   kwargsr/  r0  r1  r7  r8  r9  rS   r^  dpri_layerlayerSFr2   r   r   r0   p  s   	



 
*(
zHTSAT_Swin_Transformer.__init__c                 C   s   t |tjr&t|jdd t |tjr"|jd ur$tj|jd d S d S d S t |tjr>tj|jd tj|jd d S d S )Nr   r   r   r   )	r   rY   r{   r   weightr   init	constant_r   )r1   mr   r   r   rm    s   z$HTSAT_Swin_Transformer._init_weightsc                 C      dhS )Nr_  r   r   r   r   r   no_weight_decay     z&HTSAT_Swin_Transformer.no_weight_decayc                 C   rw  )Nr   r   r   r   r   r   no_weight_decay_keywords  ry  z/HTSAT_Swin_Transformer.no_weight_decay_keywordsc                 C   s  |j d }| j||d}| jr|| j }| |}t| jD ]
\}}||\}}q| |}|j \}}}	|dt| j	d   | j
d  }
|dt| j	d   | j
d  }|ddd ||	|
|}|j \}}	}}|| j }|||	|| ||}|ddddd ||	|d}tj|dd}t|ddd d	| j
d  }| t|d}t|d}| |}t|d}tt|ddd d	| j
d  }| |}t|d}|t|||d
}|S )Nr   rn   r   r   r   r   rb   rd   r(  )framewise_outputclipwise_outputfine_grained_embedding	embedding)r$   r]  rK  r_  r`  	enumeraterb  r]   rg   rJ  rQ   rj   rh   r   rT  r&   r   r	   rd  rT   rg  sigmoid)r1   r   rn   
frames_numr  rq  r   rp   r   rq   rr  STFT
c_freq_binfine_grained_latent_outputlatent_outputfpxoutput_dictr   r   r   forward_features  s>   



  
""
(
z'HTSAT_Swin_Transformer.forward_featuresc                 C   s   |j d }t|j d |j d ||j d |j}tt|D ]%}|d u r1td|| d }n|}||d||| d d f || d< q |S )Nr   r   r   r   )	r$   r&   rl   tor#   r  rg   randomrandint)r1   r   	crop_sizespe_pos
time_stepstxr  crop_posr   r   r   crop_wav5  s   
*&zHTSAT_Swin_Transformer.crop_wavc                 C   s   |j \}}}}t| j| j }| j| j }||kr||ks!J d||k r4tjj|||j d fddd}||k rGtjj||j d |fddd}|dddd }|	|j d |j d |j d | j|j d | j }|ddddd	 }|	|j d |j d |j d |j d  |j d	 }|S )
N=the wav size should less than or equal to the swin input sizer   bicubicTr   align_cornersr   r   r   r   )
r$   r   rI  rT  rY   
functionalr	   rj   rh   r   )r1   r   rp   rq   r  r  target_Ttarget_Fr   r   r   reshape_wav2imgA  s   22z&HTSAT_Swin_Transformer.reshape_wav2imgc           	      C   s   |j \}}}}t| j| j }| j| j }||kr||ks!J d||k r4tjj|||j d fddd}||k rGtjj||j d |fddd}|dddd }|d d d d d d ||| j f }|j	d	d
}|S )Nr  r   r  Tr  r   r   r   )r   r   r   r   )repeats)
r$   r   rI  rT  rY   r  r	   rj   rh   r   )	r1   r   cur_posrp   rq   r  r  r  r  r   r   r   repeat_wat2imgS  s   &z%HTSAT_Swin_Transformer.repeat_wat2imgr   c              	   C   sF  | j rG|d  dkrG| jrd|d td|d jd d< n(|d j|dd}|dd}| |}|dd}| 	|}| j
|g d	}|S | j s|d
 j|dd}| |}| |}|dd}| |}|dd}| jrv| |}| jr|d urt||}| 	|}| 
|}|S |d j|dd}|d j|dd}|dd}| |}|dd}t|d }| jdv rv|d d ddd d d d f   }t|dkrs||dd d d d d f   }	|	 \}
}}}|	|
| ||}	t|	d }	| |	}	|	|
|||	d}	t|	d d}	|	d|k r>tj|	tj|
|||	d f|dgdd}	n|	d d d d d |f }	|dd }| || |	||< |d d d d d d d d f }n|}n| jdv r~|}| jr| |}| jr|d urt||}| 	|}| j
||d	}|S )Nlongerr   Tr!   
mel_fusion)r#   non_blockingr   r   r{  waveformrC  )r   r   r   rb   )r   r   r   r   r   rc   rd   )rG   rH   rI   rA   )rW   r   r    r&   r  r$   r  rm   r\  r  r  rW  rY  rZ  r   whererX   clonerh   rg   rf   ri   rj   rk  rT   rk   rl   squeezer_   )r1   r   mixup_lambda
infer_moder#   r  longer_listlonger_list_idxnew_xfusion_x_localFBFCFTFFr   r   r   r5   b  sp   "








T
($
.(


-zHTSAT_Swin_Transformer.forwardr   )NFN)r7   r8   r9   r:   rY   r   r0   rm  r&   jitignorerx  rz  r  r  r  r  Tensorr5   r;   r   r   r2   r   r!  W  s.     	



/r!  r@   c                 C   s   z]| j dv s
J d| j dkr&tddd| jdg dg d	d
| ||d}|W S | j dkrBtddd| jdg dg d	d
| ||d}|W S | j dkr[tddd| jdg dg d	d
| ||d}|W S    td| j  d)N)tinybaselargezmodel name for HTS-AT is wrong!r  r"  r   r#  r$  r%  r'  r(  )rI  rP   rQ   rL  rV   rJ  r   r   rH  rW   rX   r     )r   r      r   r  zImport Model for z7 not found, or the audio cfg parameters are not enough.)
model_namer!  	class_numRuntimeError)	audio_cfgrW   rX   modelr   r   r   create_htsat_model  sb   
*

r  )r   F)r   r   r   r   )r   r   r   )Fr@   )6r&   torch.nnrY   torch.nn.functionalr  r  	itertoolsr   collections.abcr   r   r   torch.nn.initr   torch.utils.checkpointutilsr  r  torchlibrosa.stftr   r   torchlibrosa.augmentationr   r   r	   feature_fusionr
   r   r   r   	to_1tuplerN   	to_3tuple	to_4tuple	to_ntupler   boolr-   Moduler.   r<   rz   r   r   r   r   r   r   r   r   r
  r  r!  r  r   r   r   r   <module>   sP   
K
$
U|*E   