o
    %ݫi                     @   s  d Z ddlZddlmZ ddlZddlmZ ddlm  mZ	 ddl
mZ ddlmZ d1dd	Zd
d Zdd Zd2ddZG dd dejZG dd dejZG dd dejeZG dd dejZG dd dejZG dd deZG dd dejZG dd  d ejZd3d!d"ZG d#d$ d$ejZG d%d& d&ejZG d'd( d(ejZG d)d* d*ejZ d+d,gZ!G d-d. d.ejZ"G d/d0 d0eZ#dS )4a  A UNet model implementation for use with diffusion models

Adapted from OpenAI guided diffusion, with slight modifications
and additional features
https://github.com/openai/guided-diffusion

MIT License

Copyright (c) 2021 OpenAI

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

Authors
 * Artem Ploujnikov 2022
    N)abstractmethod)pad_divisible   )NormalizingAutoencoderTc                 C   s"   |r|   D ]}|   q| S )a*  
    Zero out the parameters of a module and return it.

    Arguments
    ---------
    module: torch.nn.Module
        a module
    use_fixup_init: bool
        whether to zero out the parameters. If set to
        false, the function is a no-op

    Returns
    -------
    The fixed module
    )
parametersdetachzero_)moduleuse_fixup_initp r   I/home/ubuntu/.local/lib/python3.10/site-packages/speechbrain/nnet/unet.pyfixup-   s   r   c                 O   V   | dkrt j|i |S | dkrt j|i |S | dkr$t j|i |S td|  )a  
    Create a 1D, 2D, or 3D convolution module.

    Arguments
    ---------
    dims: int
        The number of dimensions
    *args: tuple
    **kwargs: dict
        Any remaining arguments are passed to the constructor

    Returns
    -------
    The constructed Conv layer
    r         unsupported dimensions: )nnConv1dConv2dConv3d
ValueErrordimsargskwargsr   r   r   conv_ndC   s   r   c                 O   r   )z8
    Create a 1D, 2D, or 3D average pooling module.
    r   r   r   r   )r   	AvgPool1d	AvgPool2d	AvgPool3dr   r   r   r   r   avg_pool_nd\   s   r    '  c                 C   s   |d }t t| t jd|t jd | j| jd}| dddf  |d  }t j	t 
|t |gdd}|d rRt j	|t |ddddf gdd}|S )	a  
    Create sinusoidal timestep embeddings.

    Arguments
    ---------
    timesteps: torch.Tensor
        a 1-D Tensor of N indices, one per batch element. These may be fractional.
    dim: int
        the dimension of the output.
    max_period: int
        controls the minimum frequency of the embeddings.

    Returns
    -------
    result: torch.Tensor
         an [N x dim] Tensor of positional embeddings.
    r   r   )startenddtype)deviceNdimr   )torchexpmathlogarangefloat32tor%   floatcatcossin
zeros_like)	timestepsr(   
max_periodhalffreqsr   	embeddingr   r   r   timestep_embeddingi   s"   
r:   c                	       s>   e Zd ZdZ	ddedededef fddZd	d
 Z  ZS )AttentionPool2da  Two-dimensional attentional pooling

    Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py

    Arguments
    ---------
    spatial_dim: int
        the size of the spatial dimension
    embed_dim: int
        the embedding dimension
    num_heads_channels: int
        the number of attention heads
    output_dim: int
        the output dimension

    Example
    -------
    >>> attn_pool = AttentionPool2d(
    ...     spatial_dim=64,
    ...     embed_dim=16,
    ...     num_heads_channels=2,
    ...     output_dim=4
    ... )
    >>> x = torch.randn(4, 1, 64, 64)
    >>> x_pool = attn_pool(x)
    >>> x_pool.shape
    torch.Size([4, 4])
    Nspatial_dim	embed_dimnum_heads_channels
output_dimc                    sp   t    tt||d d |d  | _td|d| d| _td||p'|d| _	|| | _
t| j
| _d S )Nr   r   g      ?r   )super__init__r   	Parameterr)   randnpositional_embeddingr   qkv_projc_proj	num_headsQKVAttention	attention)selfr<   r=   r>   r?   	__class__r   r   rA      s   

zAttentionPool2d.__init__c                 C   s   |j ^}}}|||d}tj|jddd|gdd}|| jdddddf |j }| |}| 	|}| 
|}|dddddf S )zComputes the attention forward pass

        Arguments
        ---------
        x: torch.Tensor
            the tensor to be attended to

        Returns
        -------
        result: torch.Tensor
            the attention output
        r&   T)r(   keepdimr'   Nr   )shapereshaper)   r1   meanrD   r/   r$   rE   rI   rF   )rJ   xbc_spatialr   r   r   forward   s   $


zAttentionPool2d.forwardN)__name__
__module____qualname____doc__intrA   rU   __classcell__r   r   rK   r   r;      s    "r;   c                   @   s   e Zd ZdZedddZdS )TimestepBlockzT
    Any module where forward() takes timestep embeddings as a second argument.
    Nc                 C   s   dS )z
        Apply the module to `x` given `emb` timestep embeddings.

        Arguments
        ---------
        x: torch.Tensor
            the data tensor
        emb: torch.Tensor
            the embedding tensor
        Nr   )rJ   rQ   embr   r   r   rU      s    zTimestepBlock.forwardrV   )rW   rX   rY   rZ   r   rU   r   r   r   r   r]      s    r]   c                   @   s   e Zd ZdZdddZdS )TimestepEmbedSequentiala   A sequential module that passes timestep embeddings to the children that
    support it as an extra input.

    Example
    -------
    >>> from speechbrain.nnet.linear import Linear
    >>> class MyBlock(TimestepBlock):
    ...     def __init__(self, input_size, output_size, emb_size):
    ...         super().__init__()
    ...         self.lin = Linear(
    ...             n_neurons=output_size,
    ...             input_size=input_size
    ...         )
    ...         self.emb_proj = Linear(
    ...             n_neurons=output_size,
    ...             input_size=emb_size,
    ...         )
    ...     def forward(self, x, emb):
    ...         return self.lin(x) + self.emb_proj(emb)
    >>> tes = TimestepEmbedSequential(
    ...     MyBlock(128, 64, 16),
    ...     Linear(
    ...         n_neurons=32,
    ...         input_size=64
    ...     )
    ... )
    >>> x = torch.randn(4, 10, 128)
    >>> emb = torch.randn(4, 10, 16)
    >>> out = tes(x, emb)
    >>> out.shape
    torch.Size([4, 10, 32])
    Nc                 C   s,   | D ]}t |tr|||}q||}q|S )a  Computes a sequential pass with sequential embeddings where applicable

        Arguments
        ---------
        x: torch.Tensor
            the data tensor
        emb: torch.Tensor
            timestep embeddings

        Returns
        -------
        The processed input
        )
isinstancer]   )rJ   rQ   r^   layerr   r   r   rU     s
   

zTimestepEmbedSequential.forwardrV   )rW   rX   rY   rZ   rU   r   r   r   r   r_      s    !r_   c                       *   e Zd ZdZd fdd	Zdd Z  ZS )	Upsamplea  
    An upsampling layer with an optional convolution.

    Arguments
    ---------
    channels: torch.Tensor
        channels in the inputs and outputs.
    use_conv: bool
        a bool determining if a convolution is applied.
    dims: int
        determines if the signal is 1D, 2D, or 3D. If 3D, then
        upsampling occurs in the inner-two dimensions.
    out_channels: int
        Number of output channels. If None, same as input channels.

    Example
    -------
    >>> ups = Upsample(channels=4, use_conv=True, dims=2, out_channels=8)
    >>> x = torch.randn(8, 4, 32, 32)
    >>> x_up = ups(x)
    >>> x_up.shape
    torch.Size([8, 8, 64, 64])
    r   Nc                    sJ   t    || _|p|| _|| _|| _|r#t|| j| jddd| _d S d S )Nr   r   padding)r@   rA   channelsout_channelsuse_convr   r   conv)rJ   rf   rh   r   rg   rK   r   r   rA   3  s   

zUpsample.__init__c                 C   st   |j d | jks
J | jdkr(tj||j d |j d d |j d d fdd}ntj|ddd}| jr8| |}|S )zComputes the upsampling pass

        Arguments
        ---------
        x: torch.Tensor
            layer inputs

        Returns
        -------
        result: torch.Tensor
            upsampled outputsr   r   r      nearest)mode)scale_factorrl   )rN   rf   r   Finterpolaterh   ri   rJ   rQ   r   r   r   rU   >  s   
&
zUpsample.forwardr   NrW   rX   rY   rZ   rA   rU   r\   r   r   rK   r   rc     s    rc   c                       rb   )	
Downsamplea  
    A downsampling layer with an optional convolution.

    Arguments
    ---------
    channels: int
        channels in the inputs and outputs.
    use_conv: bool
         a bool determining if a convolution is applied.
    dims: int
        determines if the signal is 1D, 2D, or 3D. If 3D, then
        downsampling occurs in the inner-two dimensions.
    out_channels: int
        Number of output channels. If None, same as input channels.

    Example
    -------
    >>> ups = Downsample(channels=4, use_conv=True, dims=2, out_channels=8)
    >>> x = torch.randn(8, 4, 32, 32)
    >>> x_up = ups(x)
    >>> x_up.shape
    torch.Size([8, 8, 16, 16])
    r   Nc                    s|   t    || _|p|| _|| _|| _|dkrdnd}|r,t|| j| jd|dd| _d S | j| jks4J t|||d| _d S )Nr   r   )r   r   r   r   )stridere   )kernel_sizert   )	r@   rA   rf   rg   rh   r   r   opr    )rJ   rf   rh   r   rg   rt   rK   r   r   rA   o  s"   

	zDownsample.__init__c                 C   s   |j d | jks
J | |S )zComputes the downsampling pass

        Arguments
        ---------
        x: torch.Tensor
            layer inputs

        Returns
        -------
        result: torch.Tensor
            downsampled outputs
        r   )rN   rf   rv   rp   r   r   r   rU     s   
zDownsample.forwardrq   rr   r   r   rK   r   rs   V  s    rs   c                       s:   e Zd ZdZ							d fdd	Zdd	d
Z  ZS )ResBlocka  
    A residual block that can optionally change the number of channels.

    Arguments
    ---------
    channels: int
        the number of input channels.
    emb_channels: int
        the number of timestep embedding channels.
    dropout: float
        the rate of dropout.
    out_channels: int
        if specified, the number of out channels.
    use_conv: bool
        if True and out_channels is specified, use a spatial
        convolution instead of a smaller 1x1 convolution to change the
        channels in the skip connection.
    dims: int
        determines if the signal is 1D, 2D, or 3D.
    up: bool
        if True, use this block for upsampling.
    down: bool
        if True, use this block for downsampling.
    norm_num_groups: int
        the number of groups for group normalization
    use_fixup_init: bool
        whether to use FixUp initialization

    Example
    -------
    >>> res = ResBlock(
    ...     channels=4,
    ...     emb_channels=8,
    ...     dropout=0.1,
    ...     norm_num_groups=2,
    ...     use_conv=True,
    ... )
    >>> x = torch.randn(2, 4, 32, 32)
    >>> emb = torch.randn(2, 8)
    >>> res_out = res(x, emb)
    >>> res_out.shape
    torch.Size([2, 4, 32, 32])
    NFr       Tc                    sp  t    || _|| _|| _|p|| _|| _tt	|	|t
 t||| jddd| _|p/|| _|rBt|d|| _t|d|| _n|rSt|d|| _t|d|| _nt  | _| _|d urntt
 t|| j| _nd | _tt	|	| jt
 tj|dtt|| j| jddd|
d| _| j|krt | _d S |rt||| jddd| _d S t||| jd| _d S )Nr   r   rd   F)r   r
   )r@   rA   rf   emb_channelsdropoutrg   rh   r   
Sequential	GroupNormSiLUr   	in_layersupdownrc   h_updx_updrs   IdentityLinear
emb_layersDropoutr   
out_layersskip_connection)rJ   rf   rz   r{   rg   rh   r   updownnorm_num_groupsr
   rK   r   r   rA     s\   





zResBlock.__init__c                 C   s   | j r#| jdd | jd }}||}| |}| |}||}n| |}|durN| ||j}t|jt|jk rM|d }t|jt|jk s?nt	
|}|| }| |}| || S )a  
        Apply the block to a torch.Tensor, conditioned on a timestep embedding.

        Arguments
        ---------
        x: torch.Tensor
            an [N x C x ...] Tensor of features.
        emb: torch.Tensor
            an [N x emb_channels] Tensor of timestep embeddings.

        Returns
        -------
        result: torch.Tensor
            an [N x C x ...] Tensor of outputs.
        Nr&   ).N)r   r   r   r   r   typer$   lenrN   r)   r4   r   r   )rJ   rQ   r^   in_restin_convhemb_outr   r   r   rU     s"   





zResBlock.forward)NFr   FFrx   TrV   rr   r   r   rK   r   rw     s    1Drw   c                       s2   e Zd ZdZ				d
 fdd	Zdd	 Z  ZS )AttentionBlocka  
    An attention block that allows spatial positions to attend to each other.
    Originally ported from here, but adapted to the N-d case.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.


    Arguments
    ---------
    channels: int
        the number of channels
    num_heads: int
        the number of attention heads
    num_head_channels: int
        the number of channels in each attention head
    norm_num_groups: int
        the number of groups used for group normalization
    use_fixup_init: bool
        whether to use FixUp initialization

    Example
    -------
    >>> attn = AttentionBlock(
    ...     channels=8,
    ...     num_heads=4,
    ...     num_head_channels=4,
    ...     norm_num_groups=2
    ... )
    >>> x = torch.randn(4, 8, 16, 16)
    >>> out = attn(x)
    >>> out.shape
    torch.Size([4, 8, 16, 16])
    r   r&   rx   Tc                    s   t    || _|dkr|| _n|| dks J d| d| || | _t||| _td||d d| _t	| j| _
ttd||d|| _d S )Nr&   r   zq,k,v channels z' is not divisible by num_head_channels r   r   )r@   rA   rf   rG   r   r}   normr   qkvrH   rI   r   proj_out)rJ   rf   rG   num_head_channelsr   r
   rK   r   r   rA   K  s   

zAttentionBlock.__init__c                 C   sV   |j ^}}}|||d}| | |}| |}| |}|| j||g|R  S )zCompletes the forward pass

        Arguments
        ---------
        x: torch.Tensor
            the data to be attended to

        Returns
        -------
        result: torch.Tensor
            The data, with attention applied
        r&   )rN   rO   r   r   rI   r   )rJ   rQ   rR   rS   spatialr   r   r   r   r   rU   b  s   

zAttentionBlock.forward)r   r&   rx   Trr   r   r   rK   r   r   )  s    $r   c                       (   e Zd ZdZ fddZdd Z  ZS )rH   a  
    A module which performs QKV attention and splits in a different order.

    Arguments
    ---------
    n_heads : int
        Number of attention heads.

    Example
    -------
    >>> attn = QKVAttention(4)
    >>> n = 4
    >>> c = 8
    >>> h = 64
    >>> w = 16
    >>> qkv = torch.randn(4, (3 * h * c), w)
    >>> out = attn(qkv)
    >>> out.shape
    torch.Size([4, 512, 16])
    c                    s   t    || _d S rV   )r@   rA   n_heads)rJ   r   rK   r   r   rA     s   

zQKVAttention.__init__c              	   C   s   |j \}}}|d| j  dksJ |d| j  }|jddd\}}}dtt| }	td||	 || j ||||	 || j ||}
tj|
	 dd
|
j}
td|
||| j ||}||d|S )a  Apply QKV attention.

        Arguments
        ---------
        qkv: torch.Tensor
            an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.

        Returns
        -------
        result: torch.Tensor
            an [N x (H * C) x T] tensor after attention.
        r   r   r   r'   zbct,bcs->btsr&   zbts,bcs->bct)rN   r   chunkr+   sqrtr)   einsumviewsoftmaxr0   r   r$   rO   )rJ   r   bswidthlengthchqkvscaleweightar   r   r   rU     s   zQKVAttention.forwardrr   r   r   rK   r   rH   w  s    rH   c                 C   s`   i }| dur+|   D ] \}}|du s||r*d|v r |||< q
t|d |d||< q
t|S )a  Builds a dictionary of embedding modules for embedding
    projections

    Arguments
    ---------
    emb_config: dict
        a configuration dictionary
    proj_dim: int
        the target projection dimension
    use_emb: dict
        an optional dictionary of "switches" to turn
        embeddings on and off

    Returns
    -------
    result: torch.nn.ModuleDict
        a ModuleDict with a module for each embedding
    Nemb_projemb_dim)r   proj_dim)itemsgetEmbeddingProjectionr   
ModuleDict)
emb_configr   use_embr   keyitem_configr   r   r   build_emb_proj  s   


r   c                       sX   e Zd ZdZ													
	d fdd	ZdddZ				dddZ  ZS )	UNetModela	  
    The full UNet model with attention and timestep embedding.

    Arguments
    ---------
    in_channels: int
        channels in the input torch.Tensor.
    model_channels: int
        base channel count for the model.
    out_channels: int
        channels in the output torch.Tensor.
    num_res_blocks: int
        number of residual blocks per downsample.
    attention_resolutions: int
        a collection of downsample rates at which
        attention will take place. May be a set, list, or tuple.
        For example, if this contains 4, then at 4x downsampling, attention
        will be used.
    dropout: float
        the dropout probability.
    channel_mult: int
        channel multiplier for each level of the UNet.
    conv_resample: bool
        if True, use learned convolutions for upsampling and
        downsampling
    dims: int
        determines if the signal is 1D, 2D, or 3D.
    emb_dim: int
        time embedding dimension (defaults to model_channels * 4)
    cond_emb: dict
        embeddings on which the model will be conditioned

        Example:
        {
            "speaker": {
                "emb_dim": 256
            },
            "label": {
                "emb_dim": 12
            }
        }
    use_cond_emb: dict
        a dictionary with keys corresponding to keys in cond_emb
        and values corresponding to Booleans that turn embeddings
        on and off. This is useful in combination with hparams files
        to turn embeddings on and off with simple switches

        Example:
        {"speaker": False, "label": True}
    num_heads: int
        the number of attention heads in each attention layer.
    num_head_channels: int
        if specified, ignore num_heads and instead use
        a fixed channel width per attention head.
    num_heads_upsample: int
        works with num_heads to set a different number
        of heads for upsampling. Deprecated.
    norm_num_groups: int
        Number of groups in the norm, default 32
    resblock_updown: bool
        use residual blocks for up/downsampling.
    use_fixup_init: bool
        whether to use FixUp initialization

    Example
    -------
    >>> model = UNetModel(
    ...    in_channels=3,
    ...    model_channels=32,
    ...    out_channels=1,
    ...    num_res_blocks=1,
    ...    attention_resolutions=[1]
    ... )
    >>> x = torch.randn(4, 3, 16, 32)
    >>> ts = torch.tensor([10, 100, 50, 25])
    >>> out = model(x, ts)
    >>> out.shape
    torch.Size([4, 1, 16, 32])
    r   r   r   rj      Tr   Nr   r&   rx   Fc                    sj  t    |dkr|}|| _|| _|| _|| _|| _|| _|| _|| _	t
j| _|| _|| _|| _|| _|| _|
d u r>|d }
t||
| _t||
|d| _t|d |  }}ttt|	||dddg| _|| _|g}d}t|D ]}\}}t|D ]<}t||
|t|| |	||dg}t|| }||v r| t!|||||d	 | j t|  |  j|7  _| | qz|t"|d kr|}| j t|rt||
|||	d
||dnt#|||	|d |}| | |d9 }|  j|7  _qrtt||
||	||dt!|||||d	t||
||	||d| _$|  j|7  _tg | _%t&t|d d d D ]s\}}t|d D ]g}|' }t|| |
|t|| |	||dg}t|| }||v r_| t!|||||d	 |r||kr|}| |rzt||
|||	d
||dnt(|||	|d |d }| j% t|  |  j|7  _q0q&t)t*||t+ t,t|	||ddd|d| _-d S )Nr&   rj   )r   r   r   r   r   r   rd   rg   r   r   r
   rG   r   r   r
   Trg   r   r   r   r
   r   rg   r   r   r   r
   rg   r   r   r   r
   ry   ).r@   rA   in_channelsmodel_channelsrg   num_res_blocksattention_resolutionsr{   channel_multconv_resampler)   r.   r$   rG   r   num_heads_upsamplecond_embuse_cond_embr   
time_embedr   cond_emb_projr[   r   
ModuleListr_   r   input_blocks_feature_size	enumeraterangerw   appendr   r   rs   middle_blockoutput_blockslistpoprc   r|   r}   r~   r   out)rJ   r   r   rg   r   r   r{   r   r   r   r   r   r   rG   r   r   r   resblock_updownr
   r   input_chinput_block_chansdslevelmult_layersout_chiichrK   r   r   rA      sJ  

	


	.

zUNetModel.__init__c                 C   s   g }|  t|| j}|dur#| D ]\}}| j| |}||7 }q|| j}	| jD ]}
|
|	|}	||	 q,| 	|	|}	| j
D ]}
tj|	| gdd}	|
|	|}	qB|	|j}	| |	S )a  Apply the model to an input batch.

        Arguments
        ---------
        x: torch.Tensor
            an [N x C x ...] Tensor of inputs.
        timesteps: torch.Tensor
            a 1-D batch of timesteps.
        cond_emb: dict
            a string -> tensor dictionary of conditional
            embeddings (multiple embeddings are supported)

        Returns
        -------
        result: torch.Tensor
            an [N x C x ...] Tensor of outputs.
        Nr   r'   )r   r:   r   r   r   r   r$   r   r   r   r   r)   r1   r   r   )rJ   rQ   r5   r   hsr^   r   valuer   r   r	   r   r   r   rU     s$   





zUNetModel.forwardc                 C   s   | |||dS )zForward function suitable for wrapping by diffusion.
        For this model, `length`/`out_mask_value`/`latent_mask_value` are unused
        and discarded.
        See :meth:`~UNetModel.forward` for details.)r   r   )rJ   rQ   r5   r   r   out_mask_valuelatent_mask_valuer   r   r   diffusion_forward  s   zUNetModel.diffusion_forward)r   r   Tr   NNNr   r&   r&   rx   FTrV   )NNNN)rW   rX   rY   rZ   rA   rU   r   r\   r   r   rK   r   r     s.    W 
E,r   c                       sF   e Zd ZdZ											
	
		d fdd	ZdddZ  ZS )EncoderUNetModelam  
    The half UNet model with attention and timestep embedding.
    For usage, see UNetModel.

    Arguments
    ---------
    in_channels: int
        channels in the input torch.Tensor.
    model_channels: int
        base channel count for the model.
    out_channels: int
        channels in the output torch.Tensor.
    num_res_blocks: int
        number of residual blocks per downsample.
    attention_resolutions: int
        a collection of downsample rates at which
        attention will take place. May be a set, list, or tuple.
        For example, if this contains 4, then at 4x downsampling, attention
        will be used.
    dropout: float
        the dropout probability.
    channel_mult: int
        channel multiplier for each level of the UNet.
    conv_resample: bool
        if True, use learned convolutions for upsampling and
        downsampling
    dims: int
        determines if the signal is 1D, 2D, or 3D.
    num_heads: int
        the number of attention heads in each attention layer.
    num_head_channels: int
        if specified, ignore num_heads and instead use
        a fixed channel width per attention head.
    num_heads_upsample: int
        works with num_heads to set a different number
        of heads for upsampling. Deprecated.
    norm_num_groups: int
        Number of groups in the norm, default 32.
    resblock_updown: bool
        use residual blocks for up/downsampling.
    pool: str
        Type of pooling to use, one of:
        ["adaptive", "attention", "spatial", "spatial_v2"].
    attention_pool_dim: int
        The dimension on which to apply attention pooling.
    out_kernel_size: int
        the kernel size of the output convolution
    use_fixup_init: bool
        whether to use FixUp initialization


    Example
    -------
    >>> model = EncoderUNetModel(
    ...    in_channels=3,
    ...    model_channels=32,
    ...    out_channels=1,
    ...    num_res_blocks=1,
    ...    attention_resolutions=[1]
    ... )
    >>> x = torch.randn(4, 3, 16, 32)
    >>> ts = torch.tensor([10, 100, 50, 25])
    >>> out = model(x, ts)
    >>> out.shape
    torch.Size([4, 1, 2, 4])

    r   r   Tr   r   r&   rx   FNr   c                    s  t    |dkr|
}|| _|| _|| _|| _|| _|| _|| _|| _	t
j| _|
| _|| _|| _|| _|d }tt||t t||| _t|d | }ttt|	||dddg| _|| _|g}d}t|D ]}\}}t|D ]<}t|||t|| |	||dg}t|| }||v r|t ||
|||d | jt|  |  j|7  _|| qu|t!|d kr|}| jt|rt|||||	d	||d
nt"|||	|d |}|| |d9 }|  j|7  _qmtt||||	|dt ||
|||dt||||	|d| _#|  j|7  _|| _$d| _%|d u r2ttj&||ddt t|	|||dd| _'d S |dkrWtt&||t t(dt)t|	||d|dt* | _'d S |dkrz|dkscJ tt&||t t+|| |||| _'d S |dkrtt| jdt, td| j| _'d	| _%d S |dkrtt| jdt&|dt td| j| _'d	| _%d S t-d| d)Nr&   rj   r   r   r   rd   r   r   Tr   r   r   )r   r
   Fư>num_channels
num_groupsepssameru   re   adaptive)r   r   ry   rI   r   i   
spatial_v2zUnexpected z pooling).r@   rA   r   r   rg   r   r   r{   r   r   r)   r.   r$   rG   r   r   out_kernel_sizer   r|   r   r~   r   r[   r   r_   r   r   r   r   r   rw   r   r   r   rs   r   poolspatial_poolingr}   r   AdaptiveAvgPool2dr   Flattenr;   ReLUNotImplementedError)rJ   r   r   rg   r   r   r{   r   r   r   rG   r   r   r   r   r   attention_pool_dimr   r
   r   r   r   r   r   r   r   r   r   rK   r   r   rA   b  sD  



	















zEncoderUNetModel.__init__c                 C   s   d}|dur|  t|| j}g }|| j}| jD ]}|||}| jr1|||jjdd q| 	||}| jrT|||jjdd t
j|dd}| |S ||j}| |S )O  
        Apply the model to an input batch.

        Arguments
        ---------
        x:  torch.Tensor
            an [N x C x ...] Tensor of inputs.
        timesteps: torch.Tensor
            a 1-D batch of timesteps.

        Returns
        -------
        result: torch.Tensor
            an [N x K] Tensor of outputs.
        N)r   r   r'   r&   )axis)r   r:   r   r   r$   r   r   r   rP   r   r)   r1   r   )rJ   rQ   r5   r^   resultsr   r	   r   r   r   rU   !  s&   




zEncoderUNetModel.forward)r   r   Tr   r   r&   r&   rx   FNNr   TrV   rr   r   r   rK   r   r     s$    K @r   c                       r   )r   a  A simple module that computes the projection of an
    embedding vector onto the specified number of dimensions

    Arguments
    ---------
    emb_dim: int
        the original embedding dimensionality

    proj_dim: int
        the dimensionality of the target projection
        space

    Example
    -------
    >>> mod_emb_proj = EmbeddingProjection(emb_dim=16, proj_dim=64)
    >>> emb = torch.randn(4, 16)
    >>> emb_proj = mod_emb_proj(emb)
    >>> emb_proj.shape
    torch.Size([4, 64])
    c                    s@   t    || _|| _t||| _t | _t||| _	d S rV   )
r@   rA   r   r   r   r   inputr~   actoutput)rJ   r   r   rK   r   r   rA   ]  s   

zEmbeddingProjection.__init__c                 C   s"   |  |}| |}| |}|S )zComputes the forward pass

        Arguments
        ---------
        emb: torch.Tensor
            the original embedding tensor

        Returns
        -------
        result: torch.Tensor
            the target embedding space
        )r   r   r   )rJ   r^   rQ   r   r   r   rU   e  s   


zEmbeddingProjection.forwardrr   r   r   rK   r   r   G  s    r   c                       sB   e Zd ZdZ											
	d fdd	ZdddZ  ZS )DecoderUNetModela  
    The half UNet model with attention and timestep embedding.
    For usage, see UNet.

    Arguments
    ---------
    in_channels: int
        channels in the input torch.Tensor.
    model_channels: int
        base channel count for the model.
    out_channels: int
        channels in the output torch.Tensor.
    num_res_blocks: int
        number of residual blocks per downsample.
    attention_resolutions: int
        a collection of downsample rates at which
        attention will take place. May be a set, list, or tuple.
        For example, if this contains 4, then at 4x downsampling, attention
        will be used.
    dropout: float
        the dropout probability.
    channel_mult: int
        channel multiplier for each level of the UNet.
    conv_resample: bool
        if True, use learned convolutions for upsampling and
        downsampling
    dims: int
        determines if the signal is 1D, 2D, or 3D.
    num_heads: int
        the number of attention heads in each attention layer.
    num_head_channels: int
        if specified, ignore num_heads and instead use
                               a fixed channel width per attention head.
    num_heads_upsample: int
        works with num_heads to set a different number
                               of heads for upsampling. Deprecated.
    resblock_updown: bool
        use residual blocks for up/downsampling.
    norm_num_groups: int
        Number of groups to use in norm, default 32
    out_kernel_size: int
        Output kernel size, default 3
    use_fixup_init: bool
        whether to use FixUp initialization

    Example
    -------
    >>> model = DecoderUNetModel(
    ...    in_channels=1,
    ...    model_channels=32,
    ...    out_channels=3,
    ...    num_res_blocks=1,
    ...    attention_resolutions=[1]
    ... )
    >>> x = torch.randn(4, 1, 2, 4)
    >>> ts = torch.tensor([10, 100, 50, 25])
    >>> out = model(x, ts)
    >>> out.shape
    torch.Size([4, 3, 16, 32])
    r   r   Tr   r   r&   Frx   r   c                    s8  t    |dkr|
}|| _|| _|| _|| _|| _|| _|| _|| _	t
j| _|
| _|| _|| _|d }tt||t t||| _t|d | }tt|	||ddd| _tt||||	||dt||
|||dt||||	||d| _t | _|| _d}tt |D ]s\}}t!|D ]7}t|||t|| |	||d	g}t|| }||v r|"t||
|||d | j"t|  |  j|7  _q|t#|d kr|}| j"t|rt|||||	d
||dnt$|||	|d |}|d9 }|  j|7  _qttj%||ddt t|	|||dd| _&|  j|7  _d S )Nr&   rj   r   r   r   rd   r   r   r   Tr   r   r   r   r   r   r   )'r@   rA   r   r   rg   r   r   r{   r   r   r)   r.   r$   rG   r   r   r   r|   r   r~   r   r[   r_   r   input_blockrw   r   r   r   upsample_blocksr   r   reversedr   r   r   rc   r}   r   )rJ   r   r   rg   r   r   r{   r   r   r   rG   r   r   r   r   r   r
   r   r   r   r   r   r   r   r   rK   r   r   rA     s   




	zDecoderUNetModel.__init__Nc                 C   sf   d}|dur|  t|| j}|| j}| ||}| ||}| jD ]}|||}q$| |}|S )r   N)	r   r:   r   r   r$   r   r   r   r   )rJ   rQ   r5   r^   r   r	   r   r   r   rU   E  s   


zDecoderUNetModel.forward)r   r   Tr   r   r&   r&   Frx   r   TrV   rr   r   r   rK   r   r   x  s     D r   r   r   c                       s,   e Zd ZdZd fdd	Zd	ddZ  ZS )
DownsamplingPaddinga  A wrapper module that applies the necessary padding for
    the downsampling factor

    Arguments
    ---------
    factor: int
        the downsampling / divisibility factor
    len_dim: int
        the index of the dimension in which the length will vary
    dims: list
        the list of dimensions to be included in padding

    Example
    -------
    >>> padding = DownsamplingPadding(factor=4, dims=[1, 2], len_dim=1)
    >>> x = torch.randn(4, 7, 14)
    >>> length = torch.tensor([1., 0.8, 1., 0.7])
    >>> x, length_new = padding(x, length)
    >>> x.shape
    torch.Size([4, 8, 16])
    >>> length_new
    tensor([0.8750, 0.7000, 0.8750, 0.6125])
    r   Nc                    s,   t    || _|| _|d u rt}|| _d S rV   )r@   rA   factorlen_dimDEFAULT_PADDING_DIMSr   )rJ   r  r  r   rK   r   r   rA     s   

zDownsamplingPadding.__init__c                 C   s<   |}| j D ]}t||| j|d\}}|| jkr|}q||fS )aV  Applies the padding

        Arguments
        ---------
        x: torch.Tensor
            the sample
        length: torch.Tensor
            the length tensor

        Returns
        -------
        x_pad: torch.Tensor
            the padded tensor
        lens: torch.Tensor
            the new, adjusted lengths, if applicable
        )r  )r   r   r  r  )rJ   rQ   r   updated_lengthr(   
length_padr   r   r   rU     s   

zDownsamplingPadding.forwardrq   rV   rr   r   r   rK   r   r   g  s    r   c                       s>   e Zd ZdZ												
	
		d fdd	Z  ZS )UNetNormalizingAutoencodera	  A convenience class for a UNet-based Variational Autoencoder (VAE) -
    useful in constructing Latent Diffusion models

    Arguments
    ---------
    in_channels: int
        the number of input channels
    model_channels: int
        the number of channels in the convolutional layers of the
        UNet encoder and decoder
    encoder_out_channels: int
        the number of channels the encoder will output
    latent_channels: int
        the number of channels in the latent space
    encoder_num_res_blocks: int
        the number of residual blocks in the encoder
    encoder_attention_resolutions: list
        the resolutions at which to apply attention layers in the encoder
    decoder_num_res_blocks: int
        the number of residual blocks in the decoder
    decoder_attention_resolutions: list
        the resolutions at which to apply attention layers in the encoder
    dropout: float
        the dropout probability
    channel_mult: tuple
        channel multipliers for each layer
    dims: int
        the convolution dimension to use (1, 2 or 3)
    num_heads: int
        the number of attention heads
    num_head_channels: int
        the number of channels in attention heads
    num_heads_upsample: int
        the number of upsampling heads
    norm_num_groups: int
        Number of norm groups, default 32
    resblock_updown: bool
        whether to use residual blocks for upsampling and downsampling
    out_kernel_size: int
        the kernel size for output convolution layers (if applicable)
    len_dim: int
        Size of the output.
    out_mask_value: float
        Value to fill when masking the output.
    latent_mask_value: float
        Value to fill when masking the latent variable.
    use_fixup_norm: bool
        whether to use FixUp normalization
    downsampling_padding: int
        Amount of padding to apply in downsampling, default 2 ** len(channel_mult)

    Example
    -------
    >>> unet_ae = UNetNormalizingAutoencoder(
    ...     in_channels=1,
    ...     model_channels=4,
    ...     encoder_out_channels=16,
    ...     latent_channels=3,
    ...     encoder_num_res_blocks=1,
    ...     encoder_attention_resolutions=[],
    ...     decoder_num_res_blocks=1,
    ...     decoder_attention_resolutions=[],
    ...     norm_num_groups=2,
    ... )
    >>> x = torch.randn(4, 1, 32, 32)
    >>> x_enc = unet_ae.encode(x)
    >>> x_enc.shape
    torch.Size([4, 3, 4, 4])
    >>> x_dec = unet_ae.decode(x_enc)
    >>> x_dec.shape
    torch.Size([4, 1, 32, 32])
    r   r   r   r   r&   rx   Fr           Nc                    s   t ||||||	|
||||||||d}t|t|||dd}|d u r)dt|
 }t|}t||||||	t|
||||||||d}t j	||||||d d S )N)r   r   rg   r   r   r{   r   r   rG   r   r   r   r   r   r
   r   )r   r   rg   ru   r   )r   rg   r   r   r   r{   r   r   rG   r   r   r   r   r   r
   )encoderlatent_paddingdecoderr  r   r   )
r   r   r|   r   r   r   r   r   r@   rA   )rJ   r   r   encoder_out_channelslatent_channelsencoder_num_res_blocksencoder_attention_resolutionsdecoder_num_res_blocksdecoder_attention_resolutionsr{   r   r   rG   r   r   r   r   r   r  r   r   use_fixup_normdownsampling_paddingencoder_unetr  encoder_padr
  rK   r   r   rA     sl   	
z#UNetNormalizingAutoencoder.__init__)r   r   r   r   r&   r&   rx   Fr   r   r  r  FN)rW   rX   rY   rZ   rA   r\   r   r   rK   r   r    s"    Sr  )T)r!   )NN)$rZ   r+   abcr   r)   torch.nnr   torch.nn.functional
functionalrn   speechbrain.utils.data_utilsr   autoencodersr   r   r   r    r:   Moduler;   r]   r|   r_   rc   rs   rw   r   rH   r   r   r   r   r   r  r   r  r   r   r   r   <module>   sD     

!E8<> N
8   P  ,1 m;