o
    ib                     @   s  d dl Z d dlZd dlmZmZmZmZ d dlZd dlmZ d dl	m
Z G dd dejjZG dd dejjZG d	d
 d
ejjZG dd dejjZG dd dejjZG dd dejjZG dd dejZG dd dejZdejdededejfddZdd Zd/dejded ed!edejf
d"d#Zd0d$ejd ed%ed!edejf
d&d'Zd(ee defd)d*Zd(ee defd+d,Zd(ee defd-d.ZdS )1    N)AnyDictListOptional)nn)
functionalc                	       s^   e Zd ZdZddedededef fdd	Zed
e	j
fddZde	j
d
e	j
fddZ  ZS )_ScaledEmbeddingaF  Make continuous embeddings and boost learning rate

    Args:
        num_embeddings (int): number of embeddings
        embedding_dim (int): embedding dimensions
        scale (float, optional): amount to scale learning rate (Default: 10.0)
        smooth (bool, optional): choose to apply smoothing (Default: ``False``)
          $@Fnum_embeddingsembedding_dimscalesmoothc                    s   t    t||| _|r3tj| jjjdd}|t	d|d 
 d d d f  }|| jjjd d < | jj j|  _|| _d S )Nr   dim   )super__init__r   	Embedding	embeddingtorchcumsumweightdataarangesqrtr   )selfr
   r   r   r   r   	__class__ N/home/ubuntu/.local/lib/python3.10/site-packages/torchaudio/models/_hdemucs.pyr   -   s   
$
z_ScaledEmbedding.__init__returnc                 C   s   | j j| j S N)r   r   r   )r   r   r   r   r   8   s   z_ScaledEmbedding.weightxc                 C   s   |  || j }|S )zForward pass for embedding with scale.
        Args:
            x (torch.Tensor): input tensor of shape `(num_embeddings)`

        Returns:
            (Tensor):
                Embedding output of shape `(num_embeddings, embedding_dim)`
        )r   r   )r   r"   outr   r   r   forward<   s   	z_ScaledEmbedding.forward)r	   F)__name__
__module____qualname____doc__intfloatboolr   propertyr   Tensorr   r$   __classcell__r   r   r   r   r   #   s     	r   c                       s   e Zd ZdZ									dd	ed
ededededededededeeee	f  def fddZ
ddejdeej dejfddZ  ZS )
_HEncLayerat  Encoder layer. This used both by the time and the frequency branch.
    Args:
        chin (int): number of input channels.
        chout (int): number of output channels.
        kernel_size (int, optional): Kernel size for encoder (Default: 8)
        stride (int, optional): Stride for encoder layer (Default: 4)
        norm_groups (int, optional): number of groups for group norm. (Default: 4)
        empty (bool, optional): used to make a layer with just the first conv. this is used
            before merging the time and freq. branches. (Default: ``False``)
        freq (bool, optional): boolean for whether conv layer is for frequency domain (Default: ``True``)
        norm_type (string, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
        context (int, optional): context size for the 1x1 conv. (Default: 0)
        dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``)
        pad (bool, optional): true to pad the input. Padding is done so that the output size is
            always the input size / stride. (Default: ``True``)
          FT
group_normr   Nchinchoutkernel_sizestridenorm_groupsemptyfreq	norm_typecontextdconv_kwpadc                    s  t    |
d u ri }
dd }|dkr fdd}|r|d nd}tj}|| _|| _|| _|| _|| _|rD|dg}|dg}|dg}tj	}||||||| _
||| _| jrft | _t | _t | _d S ||d| dd|	  d|	| _|d| | _t|fi |
| _d S )	Nc                 S      t  S r!   r   Identitydr   r   r   <lambda>m       z%_HEncLayer.__init__.<locals>.<lambda>r2   c                       t  | S r!   r   	GroupNormrA   r7   r   r   rC   o       r1   r   r      )r   r   r   Conv1dr9   r5   r6   r8   r=   Conv2dconvnorm1r@   rewritenorm2dconv_DConv)r   r3   r4   r5   r6   r7   r8   r9   r:   r;   r<   r=   norm_fnpad_valklassr   rH   r   r   \   s6   



z_HEncLayer.__init__r"   injectr    c           
      C   sh  | j s| dkr|j\}}}}||d|}| j s4|jd }|| j dks4t|d| j|| j  f}| |}| jr>|S |durk|jd |jd krPt	d| dkrg| dkrg|dddddf }|| }t
| |}| j r|j\}}}}|ddddd||}| |}|||||dddd}n| |}| | |}	tj|	dd	}	|	S )
a]  Forward pass for encoding layer.

        Size depends on whether frequency or time

        Args:
            x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape
                `(B, C, T)` for time
            inject (torch.Tensor, optional): on last layer, combine frequency and time branches through inject param,
                same shape as x (default: ``None``)

        Returns:
            Tensor
                output tensor after encoder layer of shape `(B, C, F / stride, T)` for frequency
                    and shape `(B, C, ceil(T / stride))` for time
        r1   r   NzInjection shapes do not align   rJ   r   r   )r9   r   shapeviewr6   Fr=   rM   r8   
ValueErrorgelurN   permutereshaperQ   rP   rO   glu)
r   r"   rV   BCFrTleyzr   r   r   r$      s4   



z_HEncLayer.forward)	r0   r1   r1   FTr2   r   NTr!   r%   r&   r'   r(   r)   r+   strr   r   r   r   r   r-   r$   r.   r   r   r   r   r/   I   sF    	
*,r/   c                       s   e Zd ZdZ										dd	ed
edededededededededeeee	f  def fddZ
dejdeej fddZ  ZS )
_HDecLayera  Decoder layer. This used both by the time and the frequency branches.
    Args:
        chin (int): number of input channels.
        chout (int): number of output channels.
        last (bool, optional): whether current layer is final layer (Default: ``False``)
        kernel_size (int, optional): Kernel size for encoder (Default: 8)
        stride (int): Stride for encoder layer (Default: 4)
        norm_groups (int, optional): number of groups for group norm. (Default: 1)
        empty (bool, optional): used to make a layer with just the first conv. this is used
            before merging the time and freq. branches. (Default: ``False``)
        freq (bool, optional): boolean for whether conv layer is for frequency (Default: ``True``)
        norm_type (str, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
        context (int, optional): context size for the 1x1 conv. (Default: 1)
        dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``)
        pad (bool, optional): true to pad the input. Padding is done so that the output size is
            always the input size / stride. (Default: ``True``)
    Fr0   r1   r   Tr2   Nr3   r4   lastr5   r6   r7   r8   r9   r:   r;   r<   r=   c                    s  t    |d u ri }dd }|	dkr fdd}|r.|| d dkr'td|| d }nd}|| _|| _|| _|| _|| _|| _|| _	t
j}t
j}|r[|dg}|dg}t
j}t
j}|||||| _||| _| jrwt
 | _t
 | _d S ||d| dd|
  d|
| _|d| | _d S )	Nc                 S   r>   r!   r?   rA   r   r   r   rC      rD   z%_HDecLayer.__init__.<locals>.<lambda>r2   c                    rE   r!   rF   rA   rH   r   r   rC      rI   rJ   r   z#Kernel size and stride do not alignr   )r   r   r\   r=   rk   r9   r3   r8   r6   r5   r   rK   ConvTranspose1drL   ConvTranspose2dconv_trrP   r@   rO   rN   )r   r3   r4   rk   r5   r6   r7   r8   r9   r:   r;   r<   r=   rS   rU   klass_trr   rH   r   r      s@   


z_HDecLayer.__init__r"   skipc           	      C   s   | j r| dkr|j\}}}||| jd|}| js-|| }tj| | 	|dd}n
|}|dur7t
d| | |}| j rT| jrS|d| j| j ddf }n|d| j| j| f }|jd |krkt
d| jsst|}||fS )	a,  Forward pass for decoding layer.

        Size depends on whether frequency or time

        Args:
            x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape
                `(B, C, T)` for time
            skip (torch.Tensor, optional): on first layer, separate frequency and time branches using param
                (default: ``None``)
            length (int): Size of tensor for output

        Returns:
            (Tensor, Tensor):
                Tensor
                    output tensor after decoder layer of shape `(B, C, F * stride, T)` for frequency domain except last
                        frequency layer shape is `(B, C, kernel_size, T)`. Shape is `(B, C, stride * T)`
                        for time domain.
                Tensor
                    contains the output just before final transposed convolution, which is used when the
                        freq. and time branch separate. Otherwise, does not matter. Shape is
                        `(B, C, F, T)` for frequency and `(B, C, T)` for time.
        rX   rW   r   r   Nz%Skip must be none when empty is true..z'Last index of z must be equal to length)r9   r   rY   rZ   r3   r8   r[   r`   rN   rO   r\   rP   rn   r=   rk   r]   )	r   r"   rp   lengthra   rb   rd   rf   rg   r   r   r   r$      s(   
z_HDecLayer.forward)
Fr0   r1   r   FTr2   r   NTrh   r   r   r   r   rj      sL    	
"2rj   c                +       s   e Zd ZdZ												
			
	
		
	
	
	d:dee dededededededededededededededededed ed!ed"ef* fd#d$Z	d%d& Z
d;d(d)Zd<d,ejd-ed.ed/ed0ef
d1d2Zd3d4 Zd5d6 Zd7ejfd8d9Z  ZS )=HDemucsa#
  Hybrid Demucs model from
    *Hybrid Spectrogram and Waveform Source Separation* :cite:`defossez2021hybrid`.

    See Also:
        * :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models.

    Args:
        sources (List[str]): list of source names. List can contain the following source
            options: [``"bass"``, ``"drums"``, ``"other"``, ``"mixture"``, ``"vocals"``].
        audio_channels (int, optional): input/output audio channels. (Default: 2)
        channels (int, optional): initial number of hidden channels. (Default: 48)
        growth (int, optional): increase the number of hidden channels by this factor at each layer. (Default: 2)
        nfft (int, optional): number of fft bins. Note that changing this requires careful computation of
            various shape parameters and will not work out of the box for hybrid models. (Default: 4096)
        depth (int, optional): number of layers in encoder and decoder (Default: 6)
        freq_emb (float, optional): add frequency embedding after the first frequency layer if > 0,
            the actual value controls the weight of the embedding. (Default: 0.2)
        emb_scale (int, optional): equivalent to scaling the embedding learning rate (Default: 10)
        emb_smooth (bool, optional): initialize the embedding with a smooth one (with respect to frequencies).
            (Default: ``True``)
        kernel_size (int, optional): kernel_size for encoder and decoder layers. (Default: 8)
        time_stride (int, optional): stride for the final time layer, after the merge. (Default: 2)
        stride (int, optional): stride for encoder and decoder layers. (Default: 4)
        context (int, optional): context for 1x1 conv in the decoder. (Default: 4)
        context_enc (int, optional): context for 1x1 conv in the encoder. (Default: 0)
        norm_starts (int, optional): layer at which group norm starts being used.
            decoder layers are numbered in reverse order. (Default: 4)
        norm_groups (int, optional): number of groups for group norm. (Default: 4)
        dconv_depth (int, optional): depth of residual DConv branch. (Default: 2)
        dconv_comp (int, optional): compression of DConv branch. (Default: 4)
        dconv_attn (int, optional): adds attention layers in DConv branch starting at this layer. (Default: 4)
        dconv_lstm (int, optional): adds a LSTM layer in DConv branch starting at this layer. (Default: 4)
        dconv_init (float, optional): initial scale for the DConv branch LayerScale. (Default: 1e-4)
    rJ   0         皙?
   Tr0   r1   r   r   -C6?sourcesaudio_channelschannelsgrowthnfftdepthfreq_emb	emb_scale
emb_smoothr5   time_strider6   r;   context_encnorm_startsr7   dconv_depth
dconv_comp
dconv_attn
dconv_lstm
dconv_initc           +         s  t    || _|| _|| _|| _|
| _|| _|| _|| _	| jd | _
d | _t | _t | _t | _t | _|}|d }|}|}| jd }t| jD ]}||k}||k}||kr_dnd}|dk}|} |
}!|sy|dkrstd|d }!|} d}"d}#|r||
kr|}!d}"d}#|!| ||"|||||||d	d
}$t|$}%d|%d< |
|%d< ||%d< d|%d< t|$}&|#rt||}|}t||fd|i|$}'|r|#du r|dkrd|%d< d|%d< t||f||#d|%}(| j|( | j|' |dkr| jt| j }|d }t||f|dk|d|&})|r&t||f|#|dk|d|%}*| jd|* | jd|) |}|}t|| }t|| }|rL||
krHd}n|| }|dkr`|r`t|||	|d| _|| _qNt|  d S )Nr1   rJ   r2   noner   z$When freq is false, freqs must be 1.TF)lstmattnr~   compressinit)r5   r6   r9   r=   r:   r7   r<   r   r9   r5   r6   r=   r;      )r;   r8   )rk   r;   )r8   rk   r;   )r   r   )r   r   r~   r}   rz   ry   r5   r;   r6   r{   
hop_lengthr   r   
ModuleListfreq_encoderfreq_decodertime_encodertime_decoderranger\   dictmaxr/   appendlenrj   insertr)   r   freq_emb_scale_rescale_module)+r   ry   rz   r{   r|   r}   r~   r   r   r   r5   r   r6   r;   r   r   r7   r   r   r   r   r   r3   chin_zr4   chout_zfreqsindexr   r   r:   r9   strikerr=   	last_freqkwkwtkw_decenctencdectdecr   r   r   r   Q  s   







zHDemucs.__init__c                 C   s   | j }| j}|}||d krtdtt|jd | }|d d }| j|||||  |jd  dd}t|||dd dd d f }|jd |d krRtd	|ddd| f }|S )
Nr1   zHop length must be nfft // 4rW   rJ   rX   reflect)mode.zESpectrogram's last dimension must be 4 + input size divided by stride)	r   r}   r\   r)   mathceilrY   _pad1d_spectro)r   r"   hlr}   x0re   r=   rg   r   r   r   _spec  s   	$zHDemucs._specNc                 C   sv   | j }t|g d}t|ddg}|d d }|tt||  d|  }t|||d}|d||| f }|S )N)r   r   r   r   rJ   rX   )rq   .)r   r[   r=   r)   r   r   	_ispectro)r   rg   rq   r   r=   re   r"   r   r   r   _ispec  s   zHDemucs._ispeczero        r"   padding_leftpadding_rightr   valuec                 C   sP   |j d }|dkrt||}||krt|d|| d f}t|||f||S )zWrapper around F.pad, in order for reflect padding when num_frames is shorter than max_pad.
        Add extra zero padding around in order for padding to not break.rW   r   r   r   )rY   r   r[   r=   )r   r"   r   r   r   r   rq   max_padr   r   r   r     s   

zHDemucs._pad1dc                 C   s>   |j \}}}}t|ddddd}|||d ||}|S )Nr   r   r1   rJ   rX   )rY   r   view_as_realr^   r_   )r   rg   ra   rb   rc   rd   mr   r   r   
_magnitude  s   zHDemucs._magnitudec                 C   sF   |j \}}}}}|||dd||dddddd}t| }|S )NrW   rJ   r   r   r1      rX   )rY   rZ   r^   r   view_as_complex
contiguous)r   r   ra   Srb   rc   rd   r#   r   r   r   _mask  s   $zHDemucs._maskinputc           "      C   s|  |j dkrtd|j |jd | jkr td|jd  d|}|jd }| |}| |}|}|j\}}}}	|jddd	}
|jddd	}||
 d
|  }|}|jddd	}|jddd	}|| d
|  }g }g }g }g }t| j	D ]j\}}|
|jd  d}|t| jk r|
|jd  | j| }||}|js|
| n|}|||}|dkr| jdurtj|jd |jd}| | ddddddf |}|| j|  }|
| qut|}t|}t| jD ]b\}}|d}||||d\}}| jt| j }||krQ| j||  }|d}|jrD|jd dkr0td|j |dddddf }||d|\}}q|d}||||\}}qt|dkr]tdt|dkrhtdt|dkrstdt| j} ||| d||	}||dddf  |
dddf  }| |}!| |!|}||| d|}||dddf  |dddf  }|| }|S )a  HDemucs forward call

        Args:
            input (torch.Tensor): input mixed tensor of shape `(batch_size, channel, num_frames)`

        Returns:
            Tensor
                output tensor split into sources of shape `(batch_size, num_sources, channel, num_frames)`
        rX   zDExpected 3D tensor with dimensions (batch, channel, frames). Found: r   zZThe channel dimension of input Tensor must match `audio_channels` of HDemucs model. Found:.rW   )r   rJ   rX   T)r   keepdimgh㈵>)r   rJ   Nr   )devicerJ   z0If tdec empty is True, pre shape does not match zsaved is not emptyzlengths_t is not emptyzsaved_t is not empty)ndimr\   rY   rz   r   r   meanstd	enumerater   r   r   r   r8   r   r   r   r   t	expand_asr   
zeros_liker   popr~   r   AssertionErrorry   rZ   r   r   )"r   r   r"   rq   rg   magra   rb   Fqrd   r   r   xtmeantstdtsavedsaved_tlengths	lengths_tidxencoderV   r   frsembdecoderp   preoffsetr   length_t_r   zoutr   r   r   r$     s   





(






$
$zHDemucs.forward)rJ   rs   rJ   rt   ru   rv   rw   Tr0   rJ   r1   r   r   r1   r1   rJ   r1   r1   r1   rx   r!   )r   r   )r%   r&   r'   r(   r   ri   r)   r*   r+   r   r   r   r   r-   r   r   r   r$   r.   r   r   r   r   rr   -  s    &	
 
"

rr   c                       sf   e Zd ZdZ									dded	ed
edededededededef fddZdd Z	  Z
S )rR   a  
    New residual branches in each encoder layer.
    This alternates dilated convolutions, potentially with LSTMs and attention.
    Also before entering each residual branch, dimension is projected on a smaller subspace,
    e.g. of dim `channels // compress`.

    Args:
        channels (int): input/output channels for residual branch.
        compress (float, optional): amount of channel compression inside the branch. (default: 4)
        depth (int, optional): number of layers in the residual branch. Each layer has its own
            projection, and potentially LSTM and attention.(default: 2)
        init (float, optional): initial scale for LayerNorm. (default: 1e-4)
        norm_type (bool, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
        attn (bool, optional): use LocalAttention. (Default: ``False``)
        heads (int, optional): number of heads for the LocalAttention.  (default: 4)
        ndecay (int, optional): number of decay controls in the LocalAttention. (default: 4)
        lstm (bool, optional): use LSTM. (Default: ``False``)
        kernel_size (int, optional): kernel size for the (dilated) convolutions. (default: 3)
    r1   rJ   rx   r2   FrX   r{   r   r~   r   r:   r   headsndecayr   r5   c              
      s8  t    |
d dkrtd|| _|| _t|| _|dk}dd }|dkr*dd }t|| }tj	}t
g | _t| jD ][}|rGtd|nd}||
d  }tj|||
||d	||| t|d| d|d| tdt||g}|r|d
t|||d |	r|d
t|ddd tj| }| j| q>d S )NrJ   r   z(Kernel size should not be divisible by 2c                 S   r>   r!   r?   rA   r   r   r   rC     rD   z!_DConv.__init__.<locals>.<lambda>r2   c                 S   s   t d| S )Nr   rF   rA   r   r   r   rC     rI   r   )dilationpaddingrX   )r   r   T)layersrp   )r   r   r\   r{   r   absr~   r)   r   GELUr   r   r   powrK   GLU_LayerScaler   _LocalState_BLSTM
Sequentialr   )r   r{   r   r~   r   r:   r   r   r   r   r5   dilaterS   hiddenactrB   r   r   modslayerr   r   r   r     s>   


	
z_DConv.__init__c                 C   s   | j D ]}||| }q|S )zDConv forward call

        Args:
            x (torch.Tensor): input tensor for convolution

        Returns:
            Tensor
                Output after being run through layers.
        )r   )r   r"   r   r   r   r   r$     s   

z_DConv.forward)	r1   rJ   rx   r2   Fr1   r1   FrX   )r%   r&   r'   r(   r)   r*   ri   r+   r   r$   r.   r   r   r   r   rR   }  sB    	
3rR   c                       sB   e Zd ZdZddedef fddZdejd	ejfd
dZ	  Z
S )r   ae  
    BiLSTM with same hidden units as input dim.
    If `max_steps` is not None, input will be splitting in overlapping
    chunks and the LSTM applied separately on each chunk.
    Args:
        dim (int): dimensions at LSTM layer.
        layers (int, optional): number of LSTM layers. (default: 1)
        skip (bool, optional): (default: ``False``)
    r   Fr   rp   c                    s@   t    d| _tjd|||d| _td| || _|| _d S )N   T)bidirectional
num_layershidden_size
input_sizerJ   )	r   r   	max_stepsr   LSTMr   Linearlinearrp   )r   r   r   rp   r   r   r   r     s
   

z_BLSTM.__init__r"   r    c              	   C   s  |j \}}}|}d}d}d}d}	| jdur;|| jkr;| j}|d }t|||}
|
j d }	d}|
ddddd||}|ddd}| |d }| |}|ddd}|rg }||d||}
|d }t|	D ]C}|dkr||
dd|ddd| f  qi||	d kr||
dd|dd|df  qi||
dd|dd|| f  qit	
|d}|d	d|f }|}| jr|| }|S )
a  BLSTM forward call

        Args:
            x (torch.Tensor): input tensor for BLSTM shape is `(batch_size, dim, time_steps)`

        Returns:
            Tensor
                Output after being run through bidirectional LSTM. Shape is `(batch_size, dim, time_steps)`
        Fr   NrJ   Tr   rX   rW   .)rY   r   _unfoldr^   r_   r   r   r   r   r   catrp   )r   r"   ra   rb   rd   rf   framedwidthr6   nframesframesr#   limitkr   r   r   r$     sB   


&$&z_BLSTM.forward)r   F)r%   r&   r'   r(   r)   r+   r   r   r-   r$   r.   r   r   r   r   r     s    
r   c                       sF   e Zd ZdZddededef fddZdejd	ejfd
dZ  Z	S )r   a   Local state allows to have attention based only on data (no positional embedding),
    but while setting a constraint on the time window (e.g. decaying penalty term).
    Also a failed experiments with trying to provide some frequency based attention.
    r1   r{   r   r   c                    s   t t|   || dkrtd|| _|| _t||d| _t||d| _	t||d| _
t||| d| _|rW| jj jd9  _| jjdu rNtdd| jjjdd< t||d  |d| _dS )z
        Args:
            channels (int): Size of Conv1d layers.
            heads (int, optional):  (default: 4)
            ndecay (int, optional): (default: 4)
        r   z$Channels must be divisible by heads.r   g{Gz?Nzbias must not be None.r   )r   r   r   r\   r   r   r   rK   contentquerykeyquery_decayr   r   biasproj)r   r{   r   r   r   r   r   r     s   z_LocalState.__init__r"   r    c                 C   sz  |j \}}}| j}tj||j|jd}|dddf |dddf  }| |||d|}| |||d|}	t	d|	|}
|
t
|	j d  }
| jrtjd| jd |j|jd}| |||d|}t|d }|ddd |  t
| j }|
t	d||7 }
|
tj||
jtjdd tj|
dd	}| |||d|}t	d
||}||d|}|| | S )zLocalState forward call

        Args:
            x (torch.Tensor): input tensor for LocalState

        Returns:
            Tensor
                Output after being run through LocalState layer.
        )r   dtypeNrW   zbhct,bhcs->bhtsrJ   r   zfts,bhfs->bhtsir   zbhts,bhct->bhcs)rY   r   r   r   r   r  r
  rZ   r  einsumr   r   r   r  sigmoidr   masked_fill_eyer+   softmaxr	  r_   r  )r   r"   ra   rb   rd   r   indexesdeltaquerieskeysdotsdecaysdecay_qdecay_kernelweightsr	  resultr   r   r   r$   6  s(   
 $z_LocalState.forward)r1   r1   )
r%   r&   r'   r(   r)   r   r   r-   r$   r.   r   r   r   r   r     s    r   c                       sB   e Zd ZdZddedef fddZdejdejfd	d
Z	  Z
S )r   zLayer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
    This rescales diagonally residual outputs close to 0 initially, then learnt.
    r   r{   r   c                    s4   t    ttj|dd| _|| jjdd< dS )z
        Args:
            channels (int): Size of  rescaling
            init (float, optional): Scale to default to (default: 0)
        T)requires_gradN)r   r   r   	Parameterr   zerosr   r   )r   r{   r   r   r   r   r   a  s   
z_LayerScale.__init__r"   r    c                 C   s   | j dddf | S )zLayerScale forward call

        Args:
            x (torch.Tensor): input tensor for LayerScale

        Returns:
            Tensor
                Output after rescaling tensor.
        N)r   )r   r"   r   r   r   r$   k  s   
z_LayerScale.forward)r   )r%   r&   r'   r(   r)   r*   r   r   r-   r$   r.   r   r   r   r   r   \  s    
r   ar5   r6   r    c                    s   t  jdd }t jd }t|| }|d | | }tj d|| gd  fddt  D }|d dkrAt	d|dd |dg }|
| |
|  ||S )	zGiven input of size [*OT, T], output Tensor of size [*OT, F, K]
    with K the kernel size, by extracting frames with the given stride.
    This will pad the input so that `F = ceil(T / K)`.
    see https://github.com/pytorch/pytorch/issues/60466
    NrW   r   r   )r   r=   c                    s   g | ]}  |qS r   )r6   ).0r   r"  r   r   
<listcomp>  s    z_unfold.<locals>.<listcomp>zData should be contiguous.)listrY   r)   r   r   r[   r=   r   r   r\   r   
as_strided)r"  r5   r6   rY   rq   n_frames
tgt_lengthstridesr   r$  r   r  x  s   

r  c                 C   sp   |   D ]1}t|tjtjtjtjfr5|j 	 }|d d }|j j
|  _
|jdur5|j j
|  _
qdS )zI
    Rescales initial weight scale for all models within the module.
    g?g      ?N)modules
isinstancer   rK   rl   rL   rm   r   r   detachr   r  )modulesubr   r   r   r   r   r     s   
r      r"   n_fftr   r=   c           
      C   s   t | jd d }t| jd }| d|} tj| |d|  |t|| |ddddd	}|j\}}}	|||	g |	|S )NrW   r   Tr   )window
win_length
normalizedcenterreturn_complexpad_mode)
r&  rY   r)   r_   r   stfthann_windowtoextendrZ   )
r"   r1  r   r=   otherrq   rg   r   r   framer   r   r   r     s"   

r   rg   rq   c              
   C   s   t | jd d }t| jd }t| jd }d| d }| d||} |d|  }tj| ||t|| j|d|dd}	|	j\}
}|	| |	|S )Nr   rW   rJ   r   T)r2  r3  r4  rq   r5  )
r&  rY   r)   rZ   r   istftr9  r:  realr   )rg   r   rq   r=   r<  r   r  r1  r3  r"   r   r   r   r   r     s&   



r   ry   c                 C      t | dddS )zBuilds low nfft (1024) version of :class:`HDemucs`, suitable for sample rates around 8 kHz.

    Args:
        sources (List[str]): See :py:func:`HDemucs`.

    Returns:
        HDemucs:
            HDemucs model.
    i   r   ry   r}   r~   rr   ry   r   r   r   hdemucs_low     rD  c                 C   r@  )a  Builds medium nfft (2048) version of :class:`HDemucs`, suitable for sample rates of 16-32 kHz.

    .. note::

        Medium HDemucs has not been tested against the original Hybrid Demucs as this nfft and depth configuration is
        not compatible with the original implementation in https://github.com/facebookresearch/demucs

    Args:
        sources (List[str]): See :py:func:`HDemucs`.

    Returns:
        HDemucs:
            HDemucs model.
    r   ru   rA  rB  rC  r   r   r   hdemucs_medium  s   rF  c                 C   r@  )zBuilds medium nfft (4096) version of :class:`HDemucs`, suitable for sample rates of 44.1-48 kHz.

    Args:
        sources (List[str]): See :py:func:`HDemucs`.

    Returns:
        HDemucs:
            HDemucs model.
    rt   ru   rA  rB  rC  r   r   r   hdemucs_high  rE  rG  )r0  r   r   )r   r   r   )r   typingtpr   r   r   r   r   r   torch.nnr   r[   Moduler   r/   rj   rr   rR   r   r   r   r-   r)   r  r   r   r   ri   rD  rF  rG  r   r   r   r   <module>   s.   &nv  RWCE$$