o
    ei                  	   @   s  d Z ddlZddlmZ ddlmZmZ ddlZddl	Z	ddl
mZ ddlm  mZ ddlmZ zddlmZ W n eyH   ddlmZ Y nw eeZdgZdd	d
ddZdddddZdd Zdd Zdd Z				d4dedededefddZe	jj d d! Z!G d"d# d#ej"Z#G d$d% d%ej"Z$G d&d' d'ej"Z%G d(d) d)ej"Z&G d*d+ d+ej"Z'G d,d- d-ej"Z(G d.d/ d/ej"Z)G d0d1 d1ej"Z*G d2d3 d3ej"Z+dS )5a,  
This lobe enables the integration of pretrained discrete DAC model.
Reference: http://arxiv.org/abs/2306.06546
Reference: https://descript.notion.site/Descript-Audio-Codec-11389fce0ce2419891d6591a68f814d5
Reference: https://github.com/descriptinc/descript-audio-codec

Author
 * Shubham Gupta 2023

    N)Path)ListUnion)
get_logger)weight_norm1.0.00.0.10.0.40.0.5))44khz8kbps)24khzr   )16khzr   )r   16kbpszWhttps://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pthz]https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pthz]https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pthzdhttps://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth))r   r   r   )r   r	   r   )r   r
   r   )r   r   r   c                  O      t tj| i |S )aH  
    Apply weight normalization to a 1D convolutional layer.

    Arguments
    ---------
    *args : tuple
        Variable length argument list for nn.Conv1d.
    **kwargs : dict
        Arbitrary keyword arguments for nn.Conv1d.

    Returns
    -------
    torch.nn.Module
        The weight-normalized nn.Conv1d layer.
    )r   nnConv1dargskwargs r   c/home/ubuntu/transcripts/venv/lib/python3.10/site-packages/speechbrain/lobes/models/discrete/dac.pyWNConv1dH      r   c                  O   r   )an  
    Apply weight normalization to a 1D transposed convolutional layer.

    Arguments
    ---------
    *args : tuple
        Variable length argument list for nn.ConvTranspose1d.
    **kwargs : dict
        Arbitrary keyword arguments for nn.ConvTranspose1d.

    Returns
    -------
    torch.nn.Module
        The weight-normalized nn.ConvTranspose1d layer.
    )r   r   ConvTranspose1dr   r   r   r   WNConvTranspose1d[   r   r   c                 C   s6   t | tjrtjj| jdd tj| jd dS dS )z=
    Initialize the weights of a 1D convolutional layer.
    g{Gz?)stdr   N)
isinstancer   r   inittrunc_normal_weight	constant_bias)mr   r   r   init_weightsn   s   r$   r   r   latest
model_typemodel_bitratetag
local_pathc                 C   s   |   } |  }| dv sJ d|dv sJ d|dkr"t| |f }t| ||fd}td|  |du rAtd| d	|  |du rUt d
|  d| d| d }|	 s}|j
jddd ddl}||}|jdkrwtd|j ||j |S )a?  
    Downloads a specified model file based on model type, bitrate, and tag, saving it to a local path.

    Arguments
    ---------
    model_type : str, optional
        The type of model to download. Can be '44khz', '24khz', or '16khz'. Default is '44khz'.
    model_bitrate : str, optional
        The bitrate of the model. Can be '8kbps' or '16kbps'. Default is '8kbps'.
    tag : str, optional
        A specific version tag for the model. Default is 'latest'.
    local_path : Path, optional
        The local file path where the model will be saved. If not provided, a default path will be used.

    Returns
    -------
    Path
        The local path where the model is saved.

    Raises
    ------
    ValueError
        If the model type or bitrate is not supported, or if the model cannot be found or downloaded.
    )r   r   r   z6model_type must be one of '44khz', '24khz', or '16khz')r   r   z1model_bitrate must be one of '8kbps', or '16kbps'r%   NzDownload link: zCould not find model with tag z and model type z.cache/descript/dac/weights__z.pthT)parentsexist_okr      z1Could not download model. Received response code )lower__MODEL_LATEST_TAGS____MODEL_URLS__getloggerinfo
ValueErrorr   homeexistsparentmkdirrequestsstatus_codewrite_bytescontent)r&   r'   r(   r)   download_linkr9   responser   r   r   downloadw   s>   




r?   c                 C   sN   | j }| |d |d d} | |d  t||  d  } | |} | S )a1  
    Applies the 'snake' activation function on the input tensor.

    This function reshapes the input tensor, applies a modified sine function to it, and then reshapes it back
    to its original shape.

    Arguments
    ---------
    x : torch.Tensor
        The input tensor to which the snake activation function will be applied.
    alpha : float
        A scalar value that modifies the sine function within the snake activation.

    Returns
    -------
    torch.Tensor
        The transformed tensor after applying the snake activation function.
    r      g&.>   )shapereshape
reciprocaltorchsinpow)xalpharC   r   r   r   snake   s
   $
rK   c                       sn   e Zd ZdZdededef fddZdejfdd	Zd
ejfddZ	d
ejfddZ
dejfddZ  ZS )VectorQuantizea  
    An implementation for Vector Quantization

    Implementation of VQ similar to Karpathy's repo:
    https://github.com/karpathy/deep-vector-quantization
    Additionally uses following tricks from Improved VQGAN
    (https://arxiv.org/pdf/2110.04627.pdf):
        1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
            for improved codebook usage
        2. l2-normalized codes: Converts euclidean distance to cosine similarity which
            improves training stability

    Arguments
    ---------
    input_dim : int
        Dimensionality of input
    codebook_size : int
        Size of codebook
    codebook_dim : int
        Dimensionality of codebook
    	input_dimcodebook_sizecodebook_dimc                    sH   t    || _|| _t||dd| _t||dd| _t||| _	d S )Nr@   kernel_size)
super__init__rN   rO   r   in_projout_projr   	Embeddingcodebook)selfrM   rN   rO   	__class__r   r   rS      s   
zVectorQuantize.__init__zc                 C   s|   |  |}| |\}}tj|| ddddg}tj|| ddddg}|||   }| |}|||||fS )a  Quantized the input tensor using a fixed codebook and returns
        the corresponding codebook vectors

        Arguments
        ---------
        z : torch.Tensor[B x D x T]

        Returns
        -------
        torch.Tensor[B x D x T]
            Quantized continuous representation of input
        torch.Tensor[1]
            Commitment loss to train encoder to predict vectors closer to codebook
            entries
        torch.Tensor[1]
            Codebook loss to update the codebook
        torch.Tensor[B x T]
            Codebook indices (quantized discrete representation of input)
        torch.Tensor[B x D x T]
            Projected latents (continuous representation of input before quantization)
        none)	reductionr@   rB   )rT   decode_latentsFmse_lossdetachmeanrU   )rX   r[   z_ez_qindicescommitment_losscodebook_lossr   r   r   forward  s   

zVectorQuantize.forwardembed_idc                 C   s   t || jjS )a  
        Embeds an ID using the codebook weights.

        This method utilizes the codebook weights to embed the given ID.

        Arguments
        ---------
        embed_id : torch.Tensor
            The tensor containing IDs that need to be embedded.

        Returns
        -------
        torch.Tensor
            The embedded output tensor after applying the codebook weights.
        )r_   	embeddingrW   r    rX   ri   r   r   r   
embed_code-  s   zVectorQuantize.embed_codec                 C   s   |  |ddS )a  
        Decodes the embedded ID by transposing the dimensions.

        This method decodes the embedded ID by applying a transpose operation to the dimensions of the
        output tensor from the `embed_code` method.

        Arguments
        ---------
        embed_id : torch.Tensor
            The tensor containing embedded IDs.

        Returns
        -------
        torch.Tensor
            The decoded tensor
        r@   rB   )rl   	transposerk   r   r   r   decode_code?  s   zVectorQuantize.decode_codelatentsc           
      C   s   | dddd|d}| jj}t|}t|}|djdddd| |	   |djddd	  }| j
ddd }|d}| | }|||}| |}	|	|fS )a  
        Decodes latent representations into discrete codes by comparing with the codebook.

        Arguments
        ---------
        latents : torch.Tensor
            The latent tensor representations to be decoded.

        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor]
            A tuple containing the decoded latent tensor (`z_q`) and the indices of the codes.
        r   rB   r@   rA   T)keepdimdim)permuterD   sizerW   r    r_   	normalizerH   sumtmaxnumelviewrn   )
rX   ro   	encodingsrW   distmax_indicesbrw   re   rd   r   r   r   r^   R  s    



zVectorQuantize.decode_latents)__name__
__module____qualname____doc__intrS   rF   Tensorrh   rl   rn   r^   __classcell__r   r   rY   r   rL      s    	*rL   c                       sz   e Zd ZdZ					ddeded	ed
eeef def
 fddZddefddZ	de
jfddZde
jfddZ  ZS )ResidualVectorQuantizea  
    Introduced in SoundStream: An end2end neural audio codec
    https://arxiv.org/abs/2107.03312

    Arguments
    ---------
    input_dim : int, optional, by default 512
    n_codebooks : int, optional, by default 9
    codebook_size : int, optional, by default 1024
    codebook_dim : Union[int, list], optional,  by default 8
    quantizer_dropout : float, optional, by default 0.0

    Example
    -------
    Using a pretrained RVQ unit.

    >>> dac = DAC(load_pretrained=True, model_type="44KHz", model_bitrate="8kbps", tag="latest")
    >>> quantizer = dac.quantizer
    >>> continuous_embeddings = torch.randn(1, 1024, 100) # Example shape: [Batch, Channels, Time]
    >>> discrete_embeddings, codes, _, _, _ = quantizer(continuous_embeddings)
       	                 rM   n_codebooksrN   rO   quantizer_dropoutc                    sh   t    t tr fddt|D  || _ | _| _t	 fddt|D | _
|| _d S )Nc                    s   g | ]} qS r   r   ).0r*   rO   r   r   
<listcomp>  s    z3ResidualVectorQuantize.__init__.<locals>.<listcomp>c                    s   g | ]
}t  | qS r   )rL   )r   irO   rN   rM   r   r   r     s    )rR   rS   r   r   ranger   rO   rN   r   
ModuleList
quantizersr   )rX   rM   r   rN   rO   r   rY   r   r   rS     s   


zResidualVectorQuantize.__init__Nn_quantizersc                 C   sr  d}|}d}d}g }g }|du r| j }| jrLt|jd f| j  d }td| j d |jd f}	t|jd | j }
|	d|
 |d|
< ||j	}t
| jD ]R\}}| jdu r`||kr` nD||\}}}}}tj|jd f||j	d|k }|||ddddf   }|| }|||  7 }|||  7 }|| || qQtj|dd}tj|dd}|||||fS )a  Quantized the input tensor using a fixed set of `n` codebooks and returns
        the corresponding codebook vectors

        Arguments
        ---------
        z : torch.Tensor
            Shape [B x D x T]
        n_quantizers : int, optional
            No. of quantizers to use
            (n_quantizers < self.n_codebooks ex: for quantizer dropout)
            Note: if `self.quantizer_dropout` is True, this argument is ignored
                when in training mode, and a random number of quantizers is used.
        Returns
        -------
        z : torch.Tensor[B x D x T]
            Quantized continuous representation of input
        codes : torch.Tensor[B x N x T]
            Codebook indices for each codebook
            (quantized discrete representation of input)
        latents : torch.Tensor[B x N*D x T]
            Projected latents (continuous representation of input before quantization)
        vq/commitment_loss : torch.Tensor[1]
            Commitment loss to train encoder to predict vectors closer to codebook
            entries
        vq/codebook_loss : torch.Tensor[1]
            Codebook loss to update the codebook
        r   Nr@   F)
fill_valuedevicerq   )r   trainingrF   onesrC   randintr   r   tor   	enumerater   fullrb   appendstackcat)rX   r[   r   rd   residualrf   rg   codebook_indicesro   dropout	n_dropoutr   	quantizerz_q_icommitment_loss_icodebook_loss_i	indices_iz_e_imaskcodesr   r   r   rh     sJ   
zResidualVectorQuantize.forwardr   c                 C   sx   d}g }|j d }t|D ]$}| j| |dd|ddf }|| | j| |}|| }q|tj|dd|fS )aK  Given the quantized codes, reconstruct the continuous representation

        Arguments
        ---------
        codes : torch.Tensor[B x N x T]
            Quantized discrete representation of input

        Returns
        -------
        torch.Tensor[B x D x T]
            Quantized continuous representation of input
        r   r@   Nrq   )rC   r   r   rn   r   rU   rF   r   )rX   r   rd   z_pr   r   z_p_ir   r   r   r   
from_codes  s   
"

z!ResidualVectorQuantize.from_codesro   c                 C   s   d}g }g }t dgdd | jD  }t ||jd kd jdddd }t|D ]8}|| ||d  }}	| j| |dd||	ddf \}
}||
 || | j| 	|
}|| }q+|t
j|ddt
j|ddfS )	a  Given the unquantized latents, reconstruct the
        continuous representation after quantization.

        Arguments
        ---------
        latents : torch.Tensor[B x N x T]
            Continuous representation of input after projection

        Returns
        -------
        torch.Tensor[B x D x T]
            Quantized representation of full-projected space
        torch.Tensor[B x D x T]
            Quantized representation of latent space
        r   c                 S   s   g | ]}|j qS r   r   )r   qr   r   r   r     s    z7ResidualVectorQuantize.from_latents.<locals>.<listcomp>r@   T)axiskeepdimsNrq   )npcumsumr   whererC   rx   r   r^   r   rU   rF   r   r   )rX   ro   rd   r   r   dimsr   r   jkr   codes_ir   r   r   r   from_latents  s&   



z#ResidualVectorQuantize.from_latents)r   r   r   r   r   N)r   r   r   r   r   r   listfloatrS   rh   rF   r   r   r   r   r   r   rY   r   r   x  s*    
Mr   c                       s(   e Zd ZdZ fddZdd Z  ZS )Snake1dz
    A PyTorch module implementing the Snake activation function in 1D.

    Arguments
    ---------
    channels : int
        The number of channels in the input tensor.
    c                    s$   t    ttd|d| _d S )Nr@   )rR   rS   r   	ParameterrF   r   rJ   )rX   channelsrY   r   r   rS   <  s   
zSnake1d.__init__c                 C   s   t || jS z}

        Arguments
        ---------
        x : torch.Tensor

        Returns
        -------
        torch.Tensor
        )rK   rJ   rX   rI   r   r   r   rh   @  s   zSnake1d.forward)r   r   r   r   rS   rh   r   r   r   rY   r   r   2  s    	r   c                       sB   e Zd ZdZddedef fddZdejd	ejfd
dZ	  Z
S )ResidualUnita  
    A residual unit module for convolutional neural networks.

    Arguments
    ---------
    dim : int, optional
        The number of channels in the input tensor. Default is 16.
    dilation : int, optional
        The dilation rate for the convolutional layers. Default is 1.

       r@   rr   dilationc              
      sL   t    d| d }tt|t||d||dt|t||dd| _d S )N   rB      )rQ   r   paddingr@   rP   )rR   rS   r   
Sequentialr   r   block)rX   rr   r   padrY   r   r   rS   [  s   

zResidualUnit.__init__rI   returnc                 C   sD   |  |}|jd |jd  d }|dkr|d|| f }|| S )|
        Arguments
        ---------
        x : torch.Tensor

        Returns
        -------
        torch.Tensor
        rA   rB   r   .)r   rC   )rX   rI   yr   r   r   r   rh   e  s
   

zResidualUnit.forwardr   r@   )r   r   r   r   r   rS   rF   r   tensorrh   r   r   r   rY   r   r   N  s    
r   c                       s<   e Zd ZdZddedef fddZdejfd	d
Z  Z	S )EncoderBlocka  
    An encoder block module for convolutional neural networks.

    This module constructs an encoder block consisting of a series of ResidualUnits and a final Snake1d
    activation followed by a weighted normalized 1D convolution. This block can be used as part of an
    encoder in architectures like autoencoders.

    Arguments
    ---------
    dim : int, optional
        The number of output channels. Default is 16.
    stride : int, optional
        The stride for the final convolutional layer. Default is 1.
    r   r@   rr   stridec                    sn   t    tt|d ddt|d ddt|d ddt|d t|d |d| |t|d d| _	d S )NrB   r@   r      r   rQ   r   r   )
rR   rS   r   r   r   r   r   mathceilr   )rX   rr   r   rY   r   r   rS     s   


zEncoderBlock.__init__rI   c                 C   
   |  |S r   r   r   r   r   r   rh        

zEncoderBlock.forwardr   )
r   r   r   r   r   rS   rF   r   rh   r   r   r   rY   r   r   v  s    r   c                       sB   e Zd ZdZdg ddfdededef fddZd	d
 Z  ZS )Encodera  
    A PyTorch module for the Encoder part of DAC.

    Arguments
    ---------
    d_model : int, optional
        The initial dimensionality of the model. Default is 64.
    strides : list, optional
        A list of stride values for downsampling in each EncoderBlock. Default is [2, 4, 8, 8].
    d_latent : int, optional
        The dimensionality of the output latent space. Default is 64.

    Example
    -------
    Creating an Encoder instance
    >>> encoder = Encoder()
    >>> audio_input = torch.randn(1, 1, 44100) # Example shape: [Batch, Channels, Time]
    >>> continuous_embedding = encoder(audio_input)

    Using a pretrained encoder.

    >>> dac = DAC(load_pretrained=True, model_type="44KHz", model_bitrate="8kbps", tag="latest")
    >>> encoder = dac.encoder
    >>> audio_input = torch.randn(1, 1, 44100) # Example shape: [Batch, Channels, Time]
    >>> continuous_embeddings = encoder(audio_input)
    @   rB      r   r   d_modelstridesd_latentc              	      s   t    td|dddg| _|D ]}|d9 }|  jt||dg7  _q|  jt|t||dddg7  _tj| j | _|| _d S )Nr@   r   r   rQ   r   rB   )r   )	rR   rS   r   r   r   r   r   r   enc_dim)rX   r   r   r   r   rY   r   r   rS     s   

zEncoder.__init__c                 C   r   r   r   r   r   r   r   rh     r   zEncoder.forward)	r   r   r   r   r   r   rS   rh   r   r   r   rY   r   r     s    r   c                       s:   e Zd ZdZ	ddededef fdd	Zd
d Z  ZS )DecoderBlocka  
    A PyTorch module representing a block within the Decoder architecture.

    Arguments
    ---------
    input_dim : int, optional
        The number of input channels. Default is 16.
    output_dim : int, optional
        The number of output channels. Default is 8.
    stride : int, optional
        The stride for the transposed convolution, controlling the upsampling. Default is 1.
    r   r   r@   rM   
output_dimr   c                    sZ   t    tt|t||d| |t|d dt|ddt|ddt|dd| _	d S )NrB   r   r@   r   r   r   )
rR   rS   r   r   r   r   r   r   r   r   )rX   rM   r   r   rY   r   r   rS     s   




zDecoderBlock.__init__c                 C   r   r   r   r   r   r   r   rh        
zDecoderBlock.forward)r   r   r@   )r   r   r   r   r   rS   rh   r   r   r   rY   r   r     s    r   c                	       sB   e Zd ZdZ	ddededee def fddZd	d
 Z  ZS )DecoderaP  
    A PyTorch module for the Decoder part of DAC.

    Arguments
    ---------
    input_channel : int
        The number of channels in the input tensor.
    channels : int
        The base number of channels for the convolutional layers.
    rates : list
        A list of stride rates for each decoder block
    d_out: int
        The out dimension of the final conv layer, Default is 1.

    Example
    -------
    Creating a Decoder instance

    >>> decoder = Decoder(256, 1536,  [8, 8, 4, 2])
    >>> discrete_embeddings = torch.randn(2, 256, 200) # Example shape: [Batch, Channels, Time]
    >>> recovered_audio = decoder(discrete_embeddings)

    Using a pretrained decoder. Note that the actual input should be proper discrete representation.
    Using randomly generated input here for illustration of use.

    >>> dac = DAC(load_pretrained=True, model_type="44KHz", model_bitrate="8kbps", tag="latest")
    >>> decoder = dac.decoder
    >>> discrete_embeddings = torch.randn(1, 1024, 500) # Example shape: [Batch, Channels, Time]
    >>> recovered_audio = decoder(discrete_embeddings)
    r@   input_channelr   ratesd_outc           
         s   t    t||dddg}t|D ]\}}|d|  }|d|d   }	|t||	|g7 }q|t|	t|	|dddt g7 }tj| | _	d S )Nr   r   r   rB   r@   )
rR   rS   r   r   r   r   r   Tanhr   model)
rX   r   r   r   r   layersr   r   rM   r   rY   r   r   rS   3  s   
zDecoder.__init__c                 C   r   r   )r   r   r   r   r   rh   N  r   zDecoder.forward)r@   )	r   r   r   r   r   r   rS   rh   r   r   r   rY   r   r     s    $r   c                !       s   e Zd ZdZdg dddg dddd	d
dddddd
d
fdedee dededee dededeeef dedede	de	de	de	dedef  fdd Z
	d*d!ejd"efd#d$Zd%ejfd&d'Z		d+d!ejded"efd(d)Z  ZS ),DACa	  
    Discrete Autoencoder Codec (DAC) for audio data encoding and decoding.

    This class implements an autoencoder architecture with quantization for efficient audio processing.
    It includes an encoder, quantizer, and decoder for transforming audio data into a compressed latent representation and reconstructing it back into audio.
    This implementation supports both initializing a new model and loading a pretrained model.

    Arguments
    ---------
    encoder_dim : int
        Dimensionality of the encoder.
    encoder_rates : List[int]
        Downsampling rates for each encoder layer.
    latent_dim : int, optional
        Dimensionality of the latent space, automatically calculated if None.
    decoder_dim : int
        Dimensionality of the decoder.
    decoder_rates : List[int]
        Upsampling rates for each decoder layer.
    n_codebooks : int
        Number of codebooks for vector quantization.
    codebook_size : int
        Size of each codebook.
    codebook_dim : Union[int, list]
        Dimensionality of each codebook entry.
    quantizer_dropout : bool
        Whether to use dropout in the quantizer.
    sample_rate : int
        Sample rate of the audio data.
    model_type : str
        Type of the model to load (if pretrained).
    model_bitrate : str
        Bitrate of the model to load (if pretrained).
    tag : str
        Specific tag of the model to load (if pretrained).
    load_path : str, optional
        Path to load the pretrained model from, automatically downloaded if None.
    strict : bool
        Whether to strictly enforce the state dictionary match.
    load_pretrained : bool
        Whether to load a pretrained model.

    Example
    -------
    Creating a new DAC instance:

    >>> dac = DAC()
    >>> audio_data = torch.randn(1, 1, 16000) # Example shape: [Batch, Channels, Time]
    >>> tokens, embeddings = dac(audio_data)

    Loading a pretrained DAC instance:

    >>> dac = DAC(load_pretrained=True, model_type="44KHz", model_bitrate="8kbps", tag="latest")
    >>> audio_data = torch.randn(1, 1, 16000) # Example shape: [Batch, Channels, Time]
    >>> tokens, embeddings = dac(audio_data)

    The tokens and the discrete embeddings obtained above or from other sources can be decoded:

    >>> dac = DAC(load_pretrained=True, model_type="44KHz", model_bitrate="8kbps", tag="latest")
    >>> audio_data = torch.randn(1, 1, 16000) # Example shape: [Batch, Channels, Time]
    >>> tokens, embeddings = dac(audio_data)
    >>> decoded_audio = dac.decode(embeddings)
    r   r   Ni   )r   r   r   rB   r   r   r   FiD  r   r   r%   encoder_dimencoder_rates
latent_dimdecoder_dimdecoder_ratesr   rN   rO   r   sample_rater&   r'   r(   	load_pathstrictload_pretrainedc                    sD  t    || _|| _|| _|| _|
| _|| _|| _|| _	|| _
|	| _|rQ|s6t|||d}td|  t|d}|d }|d  D ]
\}}t| || qFt| j| _| j
d u rh| jdt| j  | _
t| j| j| j
| _t| j
| j| j| j	| jd| _t| j
| j| j| _| t |r| j|d |d	 || _ d S d S )
N)r&   r'   r(   zObtained load path as: cpumetadatar   rB   )rM   r   rN   rO   r   
state_dict)r   )!rR   rS   r   r   r   r   r   r   rN   rO   r   r   r?   r2   r3   rF   loaditemssetattrr   prod
hop_lengthlenr   encoderr   r   r   decoderapplyr$   load_state_dictr   )rX   r   r   r   r   r   r   rN   rO   r   r   r&   r'   r(   r   r   r   
model_dictr   keyvaluerY   r   r   rS     sX   



zDAC.__init__
audio_datar   c                 C   s.   |  |}| ||\}}}}}|||||fS )a  Encode given audio data and return quantized latent codes

        Arguments
        ---------
        audio_data : torch.Tensor[B x 1 x T]
            Audio data to encode
        n_quantizers : int, optional
            Number of quantizers to use, by default None
            If None, all quantizers are used.

        Returns
        -------
        "z" : torch.Tensor[B x D x T]
            Quantized continuous representation of input
        "codes" : torch.Tensor[B x N x T]
            Codebook indices for each codebook
            (quantized discrete representation of input)
        "latents" : torch.Tensor[B x N*D x T]
            Projected latents (continuous representation of input before quantization)
        "vq/commitment_loss" : torch.Tensor[1]
            Commitment loss to train encoder to predict vectors closer to codebook
            entries
        "vq/codebook_loss" : torch.Tensor[1]
            Codebook loss to update the codebook
        "length" : int
            Number of samples in input audio
        )r  r   )rX   r
  r   r[   r   ro   rf   rg   r   r   r   encode  s
   
 z
DAC.encoder[   c                 C   r   )a9  Decode given latent codes and return audio data

        Arguments
        ---------
        z : torch.Tensor
            Shape [B x D x T]
            Quantized continuous representation of input

        Returns
        -------
        torch.Tensor: shape B x 1 x length
            Decoded audio data.
        )r  )rX   r[   r   r   r   decode  s   
z
DAC.decodec           	      C   sT   |j d }t|| j | j | }tj|d|f}| ||\}}}}}||fS )a  Model forward pass

        Arguments
        ---------
        audio_data : torch.Tensor[B x 1 x T]
            Audio data to encode
        sample_rate : int, optional
            Sample rate of audio data in Hz, by default None
            If None, defaults to `self.sample_rate`
        n_quantizers : int, optional
            Number of quantizers to use, by default None.
            If None, all quantizers are used.

        Returns
        -------
        "tokens" : torch.Tensor[B x N x T]
            Codebook indices for each codebook
            (quantized discrete representation of input)
        "embeddings" : torch.Tensor[B x D x T]
            Quantized continuous representation of input
        rA   r   )rC   r   r   r  r   
functionalr   r  )	rX   r
  r   r   length	right_padr[   r   r*   r   r   r   rh     s   
zDAC.forwardr   )NN)r   r   r   r   r   r   r   r   boolstrrS   rF   r   r  r  rh   r   r   r   rY   r   r   \  s    B
	
F
&r   )r   r   r%   N),r   r   pathlibr   typingr   r   numpyr   rF   torch.nnr   torch.nn.functionalr  r_   speechbrain.utils.loggerr   torch.nn.utils.parametrizationsr   ImportErrortorch.nn.utilsr   r2   SUPPORTED_VERSIONSr/   r0   r   r   r$   r  r?   jitscriptrK   ModulerL   r   r   r   r   r   r   r   r   r   r   r   r   <module>   sn    

Q
  ;(-B.I