o
    %ݫi                    @   s  d Z ddlZddlZddlZddlmZmZmZ ddlZ	ddl
Z
ddlm  mZ ddlm  mZ ddl
mZmZ ddlmZmZ ddlmZ eeZG dd dejZd	d
 Zde
jde
jfddZde fddZ!G dd dejZ"G dd dejZ#G dd dejZ$G dd de
j%j&Z'dd Z(G dd dejZ)G dd dejZ*G d d! d!ejZ+d"ejddfd#d$Z,G d%d& d&Z-dS )'a]  This lobe enables the integration of pretrained BEATs: Audio Pre-Training with Acoustic Tokenizers.

Reference: https://arxiv.org/abs/2212.09058
Based on Github source: https://github.com/microsoft/unilm/tree/master/beats

You could download the checkpoints from: https://github.com/microsoft/unilm/tree/master/beats

Author
 * Pooneh Mousavi 2024

    N)DictOptionalTuple)Tensornn)	LayerNorm	Parameter)length_to_maskc                       s   e Zd ZdZ			ddedededdf fd	d
ZdejdejdejfddZ			ddejde
de
dejfddZ			ddejdeej de
de
fddZ			ddejdeej de
de
dejf
ddZ  ZS )BEATsai  
    BEATs: Audio Pre-Training with Acoustic Tokenizers.

    This class implements the BEATs model, which processes audio signals for feature extraction
    or downstream tasks. The model supports loading from a checkpoint, applying normalization,
    and optionally freezing parameters.

    Arguments
    ---------
    ckp_path : str, optional
        Path to the checkpoint file. If None, the model initializes without pre-trained weights.
        You could download the checkpoints from : https://github.com/microsoft/unilm/tree/master/beats
    freeze : bool, optional (default: False)
        If True, the model parameters are frozen and the model is set to evaluation mode.
    output_all_hiddens : bool, optional (default: False)
        If True, the forward function outputs hidden states from all transformer layers.
        For example BEATs_iter3 has 12 transformer layers and the output is of shape (13, B, T, C),
        where a projection of the CNN output is added to the beginning.
        If False, the forward function outputs the hidden states only from the last transformer layer.

    Example
    -------
    >>> audio = torch.randn(4, 10000)  # Batch of 4 audio signals
    >>> length = torch.tensor([1.0, 0.5, 0.75, 1.0])
    >>> model = BEATs()
    >>> outputs = model.extract_features(audio, length)[0]
    >>> outputs.shape
    torch.Size([4, 24, 768])
    NTFckp_pathfreezeoutput_all_hiddensreturnc                    sj  t    d\}}|r$tj|std| dt|}|dd }t	|| _
td| j
j  || _|| _| j
j| _| j| j
jkrNt| j| j
jnd | _| j
j| _tjd| j| j| j| j
jd| _t| j
j| _| j
jrz| j
jrzJ dt| j
| _t | j| _!| j
j"rt| j
j#| _#t| j
j| j
j$| _%nd | _%|r| &|d	  | jr| '  d S d S )
NNNzCheckpoint file 'z' does not exist.cfgzBEATs Config:    )kernel_sizestridebiaszLConfiguration error: 'deep_norm' and 'layer_norm_first' cannot both be True.model)(super__init__ospathexistsFileNotFoundErrortorchloadgetBEATsConfigr   loggerinfo__dict__r   r   	embed_dimembedencoder_embed_dimr   Linearpost_extract_projinput_patch_sizeConv2d	conv_biaspatch_embeddingDropoutdropout_input	deep_normlayer_norm_firstTransformerEncoderencoderr   
layer_normfinetuned_modelpredictor_dropoutpredictor_class	predictorload_state_dicteval)selfr   r   r   r   
checkpoint	__class__ R/home/ubuntu/.local/lib/python3.10/site-packages/speechbrain/lobes/models/beats.pyr   =   s\   





zBEATs.__init__featurespadding_maskc                 C   sV   | d| d }|dkr|ddd| f }|| d| dd}|dS )ak  
        Adjusts the padding mask for the given features.

        Arguments
        ---------
        features : torch.Tensor
            Input features after patch embedding.
        padding_mask : torch.Tensor
            Original padding mask for input signals.

        Returns
        -------
        torch.Tensor
            Adjusted padding mask.
        r   r   N)sizeviewall)r9   r?   r@   extrar=   r=   r>   forward_padding_mask   s   
zBEATs.forward_padding_maskP.@(9@source
fbank_mean	fbank_stdc                 C   sX   g }|D ]}| dd }tj|ddddd}|| qtj|dd}|| d	|  S )
a  
        Preprocesses the input waveform by extracting filter banks and applying normalization.

        Arguments
        ---------
        source : torch.Tensor
            Input waveform signals.
        fbank_mean : float, optional
            Mean value for filter bank normalization (default: 15.41663).
        fbank_std : float, optional
            Standard deviation for filter bank normalization (default: 6.55582).

        Returns
        -------
        torch.Tensor
            Normalized filter banks.
        r   i      i>     
   )num_mel_binssample_frequencyframe_lengthframe_shiftdim   )	unsqueezeta_kaldifbankappendr   stack)r9   rI   rJ   rK   fbankswaveformrX   r=   r=   r>   
preprocess   s   zBEATs.preprocesswavwav_lensc                 C   sN   | j rt  | ||||W  d   S 1 sw   Y  | ||||S )aY  Takes an input waveform and return its corresponding beats encoding.

        Arguments
        ---------
        wav : torch.Tensor
            A batch of audio signals to transform to features.
        wav_lens : torch.Tensor
            The relative length of the wav given in SpeechBrain format.
        fbank_mean : float, optional
            Mean value for filter bank normalization (default: 15.41663).
        fbank_std : float, optional
            Standard deviation for filter bank normalization (default: 6.55582).

        Returns
        -------
        BEATs encoded features.
        N)r   r   no_gradextract_features)r9   r^   r_   rJ   rK   r=   r=   r>   forward   s   
 zBEATs.forwardc                 C   s  |  |||}|dur|d}t|| ||jd  }|dur'| ||}|d}| |}||j	d |j	d d
dd}| |}|durQ| ||}| jdur[| |}| |}| j||| jd\}	}
| jdur| |	}| |}|dur| rd||< |jdd}|| jddd| }n|jdd}t|}| jrtj|
dd}	|	||fS | jrtj|
dd}	|	fS )	ar  
        Extracts features from the input waveform.

        Arguments
        ---------
        wav : torch.Tensor
            A batch of audio signals to transform to features.
        wav_lens : torch.Tensor
            The relative length of the wav given in SpeechBrain format.
        fbank_mean : float, optional
            Mean value for filter bank normalization (default: 15.41663).
        fbank_std : float, optional
            Standard deviation for filter bank normalization (default: 6.55582).

        Returns
        -------
        torch.Tensor
            Extracted features from the BEATs model.
        NrA   devicer   r   rU   )r@   r   rS   )r]   rB   r	   rd   boolrF   rV   r+   reshapeshape	transposer2   r'   r-   r1   r   r6   r4   anysum	expand_asmeanr   sigmoidrZ   )r9   r^   r_   rJ   rK   rX   max_lenr@   r?   xlayer_resultsx_dlogitslprobsr=   r=   r>   ra      s^   












zBEATs.extract_features)NTF)rG   rH   )NrG   rH   )__name__
__module____qualname____doc__strre   r   r   r   rF   floatr]   r   rb   ra   __classcell__r=   r=   r;   r>   r
      st     E

(
%r
   c                 C   sH   t tdstdtj t_d|  dttj| dt| d     S )aB  
    Applies the Gaussian Error Linear Unit (GELU) activation function
    using an accurate approximation.

    Arguments
    ---------
    x: torch.Tensor
        Input tensor on which to apply the GELU activation.

    Returns
    -------
    torch.Tensor:
        Tensor with GELU activation applied element-wise.
    _arU   g      ?r   gHm?   )	hasattrgelu_accuratemathsqrtpir{   r   tanhpowro   r=   r=   r>   r~   6  s   
"r~   ro   r   c                 C   s   t jj|  | S )a  
    Applies the Gaussian Error Linear Unit (GELU) activation function.

    Arguments
    ---------
    x: torch.Tensor
        Input tensor to apply the GELU activation.

    Returns
    -------
    torch.Tensor
        Tensor with GELU activation applied element-wise.
    )r   r   
functionalgelury   type_asr   r=   r=   r>   r   N  s   r   
activationc                 C   sx   | dkrt jS | dkrtS | dkrtd tS | dkrtS | dkr%tjS | dkr-dd	 S | d
kr5dd	 S td	| )aE  
    Returns the activation function corresponding to the provided activation name.

    Arguments
    ---------
    activation : str
        Name of the activation function. Supported values:
        - "relu": Applies ReLU activation.
        - "gelu": Applies the GELU activation.
        - "gelu_fast": Alias for `gelu_accurate` with a deprecation warning.
        - "gelu_accurate": Applies the accurate GELU activation.
        - "tanh": Applies the Tanh activation.
        - "linear": Applies the identity function.
        - "glu": Applies the identity function (GLU placeholder).

    Returns
    -------
    Callable[[torch.Tensor], torch.Tensor]
        The corresponding activation function to apply to input tensors.

    Raises
    ------
    RuntimeError
        If the specified activation function is not supported.
    relur   	gelu_fastz;--activation-fn=gelu_fast has been renamed to gelu_accurater~   r   linearc                 S      | S Nr=   r   r=   r=   r>   <lambda>      z#get_activation_fn.<locals>.<lambda>gluc                 S   r   r   r=   r   r=   r=   r>   r     r   z --activation-fn {} not supported)
Fr   r   r    warnr~   r   r   RuntimeErrorformat)r   r=   r=   r>   get_activation_fn_  s(   r   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS )SamePada  
    Implements a module that adjusts the padding of a tensor after convolution
    to maintain its original size, with an option for causal padding.

    This is particularly useful for handling padding in convolutional layers
    where the kernel size or causality affects the output size.

    Arguments
    ---------
    kernel_size : int
        The size of the convolutional kernel.
    causal : bool, optional (default=False)
        If True, applies causal padding by removing `(kernel_size - 1)`
        elements from the end of the tensor. If False, removes elements
        to center-align the padding, ensuring the output size matches
        the input size.
    Fc                    s6   t    |r|d | _d S |d dkrdnd| _d S )Nr   rU   r   )r   r   remove)r9   r   causalr;   r=   r>   r     s   
zSamePad.__init__c                 C   s,   | j dkr|ddddd| j  f }|S )a  
        Adjusts the padding of the input tensor `x`.

        If `self.remove > 0`, the method slices the tensor along the last dimension
        to remove excess padding based on the `kernel_size` and `causal` settings.

        Arguments
        ---------
        x : torch.Tensor
            The input tensor to adjust padding for.

        Returns
        -------
        torch.Tensor
            The tensor with adjusted padding.
        r   N)r   r9   ro   r=   r=   r>   rb     s   
zSamePad.forward)Frt   ru   rv   rw   r   rb   rz   r=   r=   r;   r>   r     s    r   c                       s(   e Zd ZdZ fddZdd Z  ZS )Swisha%  
    Implements the Swish activation function as a PyTorch module.

    Swish is a smooth, non-monotonic activation function defined as:
        Swish(x) = x * sigmoid(x)

    It is often used in deep learning for its ability to improve training
    performance in certain architectures.

    c                    s   t t|   tj | _d S r   )r   r   r   r   r   Sigmoidactr9   r;   r=   r>   r     s   zSwish.__init__c                 C   s   ||  | S )aI  
        Applies the Swish activation function to the input tensor.

        Arguments
        ---------
        x : torch.Tensor
            The input tensor to which the Swish activation is applied.

        Returns
        -------
        torch.Tensor
            The input tensor after applying the Swish activation.
        )r   r   r=   r=   r>   rb     s   zSwish.forwardr   r=   r=   r;   r>   r     s    r   c                       s$   e Zd ZdZ	d fdd	Z  ZS )
GLU_Lineara  
    Implements a Gated Linear Unit (GLU) combined with a linear transformation.

    Arguments
    ---------
    input_dim : int
        The dimensionality of the input features.
    output_dim : int
        The dimensionality of the output features.
    glu_type : str, optional (default="sigmoid")
        The type of activation function used for gating. Supported values are:
        - "sigmoid": Uses the sigmoid activation function.
        - "swish": Uses the Swish activation function.
        - "relu": Uses the ReLU activation function.
        - "gelu": Uses the GELU activation function.
    bias_in_glu : bool, optional (default=True)
        Whether to include a bias term in the linear transformation.

    rm   Tc                    s   t t|   || _|| _|dkrtj | _n|dkr!t	 | _n|dkr,tj
 | _n
|dkr6tj | _|rDt||d d| _d S t||d d| _d S )Nrm   swishr   r   rU   TF)r   r   r   glu_type
output_dimr   r   r   glu_actr   ReLUGELUr&   r   )r9   	input_dimr   r   bias_in_glur;   r=   r>   r     s   
zGLU_Linear.__init__)rm   T)rt   ru   rv   rw   r   rz   r=   r=   r;   r>   r     s    r   c                   @   s(   e Zd ZdZedd Zedd ZdS )GradMultiplya;  
    A custom autograd function that scales gradients during the backward pass.

    This is useful for scenarios where gradient scaling is required without
    affecting the forward pass output. The forward pass returns the input as-is,
    while the backward pass scales the gradients by a specified factor.

    c                 C   s   || _ ||}|S )a  
        Performs the forward pass of the GradMultiply function.

        Arguments
        ---------
        ctx : torch.autograd.Function
            The context object to store information for the backward computation.
        x : torch.Tensor
            The input tensor to be forwarded unchanged.
        scale : float
            The factor by which the gradients will be scaled during the backward pass.

        Returns
        -------
        torch.Tensor
            A new tensor identical to the input tensor.
        )scalenew)ctxro   r   resr=   r=   r>   rb     s   
zGradMultiply.forwardc                 C   s   || j  dfS )a  
        Performs the backward pass, scaling the gradients by the stored factor.

        Arguments
        ---------
        ctx : torch.autograd.Function
            The context object containing the stored scaling factor.
        grad : torch.Tensor
            The gradient tensor from the subsequent layer.

        Returns
        -------
        Tuple[torch.Tensor, None]
            The scaled gradient tensor and None (for the scale input, which has no gradient).
        N)r   )r   gradr=   r=   r>   backward/  s   zGradMultiply.backwardN)rt   ru   rv   rw   staticmethodrb   r   r=   r=   r=   r>   r     s    	
r   c                 C   s   |dkr| S t | tjtjtjfsJ | jjdk}|s+| jd| dks)J ddS | jdkr=| j	| dks;J ddS | jd | jd  }|| dksQJ ddS )	a   
    Wraps modules and applies quantization noise to their weights for
    subsequent quantization using Iterative Product Quantization (iPQ).

    This approach is described in the paper:
    "Training with Quantization Noise for Extreme Model Compression." It
    introduces quantization noise during training to improve model robustness
    for extreme weight compression scenarios.

    Arguments
    ---------
    module : nn.Module
        The module to which quantization noise will be applied. Supported modules
        are Linear, Embedding, and Conv2d.
    p : float
        The amount of quantization noise to apply. Typically a probability or scaling factor.
    block_size : int
        The size of the blocks for subsequent quantization with iPQ.

    Returns
    -------
    None

    r      r   z0Input features must be a multiple of block sizes)r   r   z0Input channels must be a multiple of block sizesz,Kernel size must be a multiple of block sizeN)

isinstancer   r&   	Embeddingr)   weightndimrB   r   in_channels)modulep
block_sizeis_convkr=   r=   r>   quant_noiseC  s    
r   c                       s4   e Zd ZdZ fddZd	ddZd	ddZ  ZS )
r0   z
    Implements the Transformer Encoder module.

    Arguments
    ---------
    args : Namespace or dict
        A collection of model hyperparameters and configurations.

    c                    sT  t     j_ j_tjjj j jd  jd_	d}t
dd|   jj  }tjjj	jd|d tjj	jd tjjj	ddd_	tj	t jt _	t d	ro j_ j_ j_n	d
_d_d_t fddt jD _jrtd jD ]}j| j`jd jjj| j_q j _ t!j_" j#_$%t&  j'r!t
(d j d}t jD ]T}tjj)j| jj*jdd tjj)j| jj+j|d tjj)j| jj,jdd tjj)j| jj-j|d tjj)j| j.j|d tjj)j| j/j|d qt0 dd_1d S )NrU   )r   paddinggroupsr   r         ?rl   stdr   )namerT   relative_position_embeddingFc                    sL   g | ]"}t j j jj j j j j j	j
jj j jd qS ))embedding_dimffn_embedding_dimnum_attention_headsdropoutattention_dropoutactivation_dropoutactivation_fnr/   r.   has_relative_attention_biasnum_bucketsmax_distancegru_rel_posencoder_layers)TransformerSentenceEncoderLayerr   encoder_ffn_embed_dimencoder_attention_headsr   r   r   r   r/   r.   r   r   r   r   r   ).0iargsr9   r=   r>   
<listcomp>  s&    z/TransformerEncoder.__init__.<locals>.<listcomp>r      g      пgainlayer_wise_gradient_decay_ratio)2r   r   r   r%   r   r   Conv1dconv_posconv_pos_groupspos_convr   r   initnormal_r   	constant_r   utilsweight_norm
Sequentialr   r   r}   r   r   r   
ModuleListranger   layers	self_attnrelative_attention_biasr/   r   r2   encoder_layerdrop	layerdropapplyinit_bert_paramsr.   r   xavier_normal_k_projv_projq_projout_projfc1fc2getattrr   )r9   r   r   r   r   deep_norm_betar;   r   r>   r     s   




zTransformerEncoder.__init__Nc                 C   s.   |  |||\}}| jr|r| |}||fS )aN  
        Processes the input sequence through the Transformer Encoder layers.


        Arguments
        ---------
        x : torch.Tensor
            The input tensor of shape `(seq_len, batch_size, embed_dim)` containing
            the input embeddings.
        padding_mask : torch.Tensor, optional
            A binary mask of shape `(batch_size, seq_len)` indicating which positions
            are padding and should be ignored in attention computations.
            Default is `None`.
        output_all_hiddens : bool, optional
            If True, returns the hidden states from all encoder layers in addition
            to the final output. Default is `None`.

        Returns
        -------
        Tuple[torch.Tensor, List[torch.Tensor]]
            - The final output tensor of shape `(seq_len, batch_size, embed_dim)`.
        )ra   r/   r2   )r9   ro   r@   r   rp   r=   r=   r>   rb     s   

zTransformerEncoder.forwardc                 C   s
  |durd||< |  |dd}|dd}|| }| js#| |}tj|| j| jd}|dd}g }d}|r>|| d}d}t| j	D ]-\}	}
| j
dkrWt|| j
}tj }| jrd|| jkro|
||d|d\}}}|| qG|dur{|}|dd}||fS )	a  
        Extracts features from the input sequence using positional convolution,
        layer normalization, dropout, and a series of Transformer Encoder layers.


        Arguments
        ---------
        x : torch.Tensor
            The input tensor of shape `(batch_size, seq_len, embed_dim)` containing
            the input embeddings.
        padding_mask : torch.Tensor, optional
            A binary mask of shape `(batch_size, seq_len)` indicating which positions
            are padding and should be ignored in computations. Default is `None`.
        output_all_hiddens : bool, optional
            If True, collects and returns the hidden states from all encoder layers
            in addition to the final output. Default is `None`.

        Returns
        -------
        Tuple[torch.Tensor, List[torch.Tensor]]
            - The final output tensor of shape `(batch_size, seq_len, embed_dim)`.
        Nr   r   rU   )r   trainingr   F)self_attn_padding_maskneed_weightspos_bias)r   rh   r/   r2   r   r   r   rY   	enumerater   r   r   r   nprandomr   )r9   ro   r@   r   x_convrp   zrr   r   layerdropout_probabilityr=   r=   r>   ra     s>   



z#TransformerEncoder.extract_featuresr   )rt   ru   rv   rw   r   rb   ra   rz   r=   r=   r;   r>   r0   |  s
    

a r0   c                !       s   e Zd ZdZ															d"d	ed
ededededededededededededededdf  fddZ				d#de	j
de	j
de	j
defd d!Z  ZS )$r   a  
    Implements a single Transformer Sentence Encoder layer.

    Arguments
    ---------
    embedding_dim : float, optional (default=768)
        The dimensionality of input embeddings.
    ffn_embedding_dim : float, optional (default=3072)
        The dimensionality of the feed-forward network's hidden layer.
    num_attention_heads : float, optional (default=8)
        The number of attention heads for self-attention.
    dropout : float, optional (default=0.1)
        The dropout rate applied to the output of the feed-forward network and attention layers.
    attention_dropout : float, optional (default=0.1)
        The dropout rate applied within the attention mechanism.
    activation_dropout : float, optional (default=0.1)
        The dropout rate applied after the activation function in the feed-forward network.
    activation_fn : str, optional (default="relu")
        The activation function used in the feed-forward network. Supported values include "relu" and "gelu".
    layer_norm_first : bool, optional (default=False)
        If True, applies layer normalization before attention and feed-forward layers; otherwise, applies it afterward.
    deep_norm : bool, optional (default=False)
        If True, uses deep normalization scaling for residual connections.
    has_relative_attention_bias : bool, optional (default=False)
        If True, includes relative position bias in the attention mechanism.
    num_buckets : int, optional (default=0)
        The number of buckets used for relative attention bias (if enabled).
    max_distance : int, optional (default=0)
        The maximum distance for relative attention bias (if enabled).
    rescale_init : bool, optional (default=False)
        If True, rescales parameter initialization for improved stability.
    gru_rel_pos : bool, optional (default=False)
        If True, incorporates GRU-style relative position encoding.
    encoder_layers : int, optional (default=0)
        The number of encoder layers in the Transformer.
          r   皙?r   Fr   r   r   r   r   r   r   r   r/   r.   r   r   r   rescale_initr   r   r   Nc                    s   t    || _|| _|| _|| _t|| _t| j||d|
||||d	| _	t
|| _t
| j| _t
|| _|| _t| j| _| jdkrOt| j|d| _nt
| j|| _t
|| j| _t| j| _|	| _| jrvtd| d| _d S d| _d S )NT)r   self_attentionr   r   r   r  r   r   r   rU   g      ?r   )r   r   r   r   r   activation_namer   r   MultiheadAttentionr   r   r,   dropout1dropout2dropout3r/   r   self_attn_layer_normr   r   r&   r   final_layer_normr.   r   r   deep_norm_alpha)r9   r   r   r   r   r   r   r   r/   r.   r   r   r   r  r   r   r;   r=   r>   r   q  sB   



z(TransformerSentenceEncoderLayer.__init__ro   self_attn_maskr   r   c              	   C   sJ  |}| j rP| |}| j||||d||d\}}}| |}|| }|}| |}| jdkr4| |}n| | |}| |}| 	|}| 
|}|| }nP| j|||||||d\}}}| |}|| j | }| |}|}| jdkr}| |}n| | |}| |}| 	|}| 
|}|| j | }| |}|||fS )ai  
        Processes the input tensor through the Transformer sentence encoder layer.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor of shape `(seq_len, batch_size, embed_dim)`.
        self_attn_mask : torch.Tensor, optional
            Mask for the self-attention mechanism, typically used for causal or
            padding masking. Default is `None`.
        self_attn_padding_mask : torch.Tensor, optional
            Padding mask of shape `(batch_size, seq_len)`, indicating which tokens
            should be ignored in attention computations. Default is `None`.
        need_weights : bool, optional (default=False)
            Whether to return attention weights. If `True`, attention weights are
            included in the output.
        pos_bias : optional
            Positional bias for relative attention, if applicable. Default is `None`.

        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor, optional]
            - `x` (torch.Tensor): The output tensor of shape `(seq_len, batch_size, embed_dim)`
            after applying the encoder layer.

        F)querykeyvaluekey_padding_maskr   	attn_maskposition_biasr   )r/   r  r   r
  r  r  r   r   r  r   r  r  )r9   ro   r  r   r   r   residualattnr=   r=   r>   rb     sZ   "

	














z'TransformerSentenceEncoderLayer.forward)r  r  r   r  r  r  r   FFFr   r   FFr   )NNFN)rt   ru   rv   rw   ry   rx   re   intr   r   r   rb   rz   r=   r=   r;   r>   r   K  s~    '	
Ar   c                       s  e Zd ZdZ															d7 fd	d
	Zdd Z	d8ddZdededej	fddZ
								d9de	dee	 dee	 dee	 deeeeeee	 f f  dededee	 dededee	 dee	ee	 ee	 f fdd Zd!d" Zd#d$ Zd:d%d&Z			d;d'd(Zedee	 d)ee	 d*ed+ededee	 fd,d-Zdeeeeeee	 f f  deeee	 f fd.d/Zdeeeeee	 f f d0eeee	 f fd1d2Zd3ed+ed4efd5d6Z  ZS )<r	  ad  
    Implements multi-headed attention with support for advanced features like relative position
    embeddings and gated relative position embedding (GRU-based).

    Arguments
    ---------
    embed_dim : int
        Total number of dimensions for input embeddings.
    num_heads : int
        Number of attention heads.
    kdim : int, optional
        Dimensionality of key embeddings. Defaults to `embed_dim`.
    vdim : int, optional
        Dimensionality of value embeddings. Defaults to `embed_dim`.
    dropout : float, optional
        Dropout probability for attention weights. Defaults to 0.0.
    bias : bool, optional
        Whether to include a bias term in projections. Defaults to True.
    add_bias_kv : bool, optional
        Whether to include bias for key and value projections. Defaults to False.
    add_zero_attn : bool, optional
        Whether to include zero attention vectors. Defaults to False.
    self_attention : bool, optional
        Whether the layer is for self-attention. Defaults to False.
    encoder_decoder_attention : bool, optional
        Whether the layer is for encoder-decoder attention. Defaults to False.
    q_noise : float, optional
        Noise level for quantization. Defaults to 0.0.
    qn_block_size : int, optional
        Block size for quantization. Defaults to 8.
    has_relative_attention_bias : bool, optional
        Whether to use relative position embeddings. Defaults to False.
    num_buckets : int, optional
        Number of buckets for relative position embeddings. Defaults to 32.
    max_distance : int, optional
        Maximum distance for relative position embeddings. Defaults to 128.
    gru_rel_pos : bool, optional
        Whether to use gated relative position embeddings. Defaults to False.
    rescale_init : bool, optional
        Whether to rescale the initialization of weights. Defaults to False.
    N        TFr       rL   c                    s  t    || _|d ur|n|| _|d ur|n|| _| j|ko#| j|k| _|| _t|| _	|| _
|| _|| _| j
rAt||| _|| | _| j| _| j| _| j| | jksZJ d| jd | _|	| _|
| _| jrp| jspJ dttj| j|| d||| _ttj| j||d||| _ttj|||d||| _ttj|||d||| _|rttdd|| _ttdd|| _nd  | _| _|| _ || _!| j!rt| jd| _"tt#d|dd| _$| %  d S )Nz(embed_dim must be divisible by num_headsg      zESelf-attention requires query, key, and value to be of the same size.)r   r   r   )&r   r   r#   kdimvdimqkv_same_dim	num_headsr   r,   dropout_moduler   r   r   r   r   head_dim
q_head_dim
k_head_dimscalingr  encoder_decoder_attentionr   r&   r   r   r   r   r   r   r   bias_kbias_vadd_zero_attnr   grep_linearonesgrep_areset_parameters)r9   r#   r  r  r  r   r   add_bias_kvr(  r  r%  q_noiseqn_block_sizer   r   r   r   r  r;   r=   r>   r   5  sf   

zMultiheadAttention.__init__c                 C   s  | j r1tjj| jjdtd d tjj| jjdtd d tjj| j	jdtd d ntj| jj tj| jj tj| j	j tj| j
j | j
jdur`tj| j
jd | jdurltj| j | jdurxtj| j | jrtj| jj dS dS )ze
        Initializes the weights for the projection layers and relative position embeddings.
        r   rU   r   Nr  )r  r   r   xavier_uniform_r   r   r   r   r   r   r   r   r   r&  r   r'  r   r   r   r=   r=   r>   r,    s"    

z#MultiheadAttention.reset_parametersc           	      C   s   | j }| j}d}|r |d }||dktj| 7 }t|}n
t|t| }|d }||k }|t|	 | t
||  ||  tj }t|t||d }|t|||7 }|S )a  Computes bucket indices for relative positions for relative attention bias.

        Arguments
        ---------
        relative_positions : torch.Tensor
            A tensor of relative positions, where negative values indicate positions to the
            left and positive values indicate positions to the right.
        bidirectional : bool, optional, (default: True)
            If True, separate buckets are used for positive and negative positions.

        Returns
        -------
        torch.Tensor
            A tensor of the same shape as `relative_positions`, where each value is the
            bucket index corresponding to the relative position.
        r   rU   r   )r   r   tor   longabsmin
zeros_likelogry   r   	full_likewhere)	r9   relative_positionsbidirectionalr   r   relative_buckets	max_exactis_smallrelative_position_if_larger=   r=   r>   _relative_positions_bucket  s@   

z-MultiheadAttention._relative_positions_bucketquery_length
key_lengthr   c                 C   sz   t j|t jddddf }t j|t jddddf }|| }| j|dd}|| jjj}| |}|g d}|S )a  
        Computes relative position bias for attention scores.


        Arguments
        ---------
        query_length : int
            The length of the query sequence.
        key_length : int
            The length of the key sequence.

        Returns
        -------
        torch.Tensor
            A tensor of shape `(num_heads, query_length, key_length)` containing
            the relative position bias values for each attention head.
        )dtypeNT)r:  )rU   r   r   )	r   aranger2  r?  r1  r   r   rd   permute)r9   r@  rA  context_positionmemory_positionrelative_positionrelative_position_bucketvaluesr=   r=   r>   compute_bias  s   
zMultiheadAttention.compute_biasr  r  r  r  incremental_stater   	static_kvr  before_softmaxneed_head_weightsr  c                 C   s   |
rd}|  \}}}|}|| jksJ t|  |||gks!J |durJ|  \}}}tj sJ||ks7J |dus=J |sJJ ||jdd k| jrj|du rj| ||}|	d
|ddd|| j ||}|dur| |}|durd|v r|r| jr| jrJ d }}nd}d}| j|||||||dd\}}}}}|durRd|v r|d }|dusJ ||| j d	| j}|r|}n|dusJ tj||gdd
}| d}d|v r|d }|dusJ ||| j d	| j}|r|}n|dusJ tj||gdd
}d}d|v r|d }|dur|dusJ tj|||| d|d}||| jd	| j|d< ||| jd	| j|d< ||d< |dusLJ | ||}|dusYJ | d|kscJ | |||||||||	\}}|	rz|||fS | ||||||||||
|\}}|||fS )a	  
        Forward pass for multi-head attention with support for relative position embeddings,
        caching, and optional dropout.

        This method implements the core functionality of multi-head attention with
        optional features such as relative position bias, incremental decoding, and
        support for various masking options.

        Arguments
        ---------
        query : torch.Tensor
            Query tensor of shape `(target_length, batch_size, embed_dim)`.
        key : torch.Tensor, optional
            Key tensor of shape `(source_length, batch_size, embed_dim)`. Defaults to `None`.
        value : torch.Tensor, optional
            Value tensor of shape `(source_length, batch_size, embed_dim)`. Defaults to `None`.
        key_padding_mask : torch.Tensor, optional
            Mask to exclude padding keys, of shape `(batch_size, source_length)`,
            where padding elements are indicated by 1s. Defaults to `None`.
        incremental_state : dict, optional
            Stores cached key and value tensors for incremental decoding. Defaults to `None`.
        need_weights : bool, optional
            If True, returns the attention weights. Defaults to `True`.
        static_kv : bool, optional
            If True, the key and value tensors remain static for incremental decoding.
            Defaults to `False`.
        attn_mask : torch.Tensor, optional
            Attention mask to prevent certain positions from attending, typically for
            causal attention. Shape: `(target_length, source_length)`. Defaults to `None`.
        before_softmax : bool, optional
            If True, returns raw attention scores before softmax. Defaults to `False`.
        need_head_weights : bool, optional
            If True, returns attention weights for each head. Implies `need_weights=True`.
            Defaults to `False`.
        position_bias : torch.Tensor, optional
            Precomputed position bias tensor. If `None`, it is computed during the forward pass.

        Returns
        -------
        attn : torch.Tensor
            Attention output of shape `(target_length, batch_size, embed_dim)`.
        attn_weights : torch.Tensor, optional
            Attention weights of shape `(batch_size, num_heads, target_length, source_length)`,
            averaged across heads if `need_head_weights=False`.
        position_bias : torch.Tensor, optional
            Computed or passed relative position bias of shape `(num_heads, target_length, source_length)`.
        TNrU   r   r   prev_keyr  )alpharA   rS   
prev_valueprev_key_padding_mask)r  rR  
batch_sizesrc_lenrL  )rB   r#   listr   jitis_scriptingrg   r   rJ  rV   repeatrC   r  _get_input_bufferr%  r  _prepare_attention_inputsr!  catr	  _append_prev_key_padding_mask_set_input_buffer_process_attention_weights_compute_attention_output)r9   r  r  r  r  rK  r   rL  r  rM  rN  r  tgt_lenbszr#   rT  key_bsz_saved_staterP  qr   v	_prev_keyrO  _prev_valuerQ  rR  attn_weightsr  r=   r=   r>   rb     s   @







zMultiheadAttention.forwardc              	   C   sx  |dur_|}| j dkrT||| j|| j| | j }| \}}}}t| ||||ddj	dddj
ddd\}}||| j d	  d
 }||| j |d| }|| }|| }tj|dd}||}| |}|dusvJ t||}t| || j || jgksJ |dd |||}| |}d}|	r||| j||dd}|
s|jdd}||fS )a-  
        Computes the final attention output, including relative position bias adjustments,
        attention weight computation, and attention projection.

        Arguments
        ---------
        q : torch.Tensor
            Query tensor.
        v : torch.Tensor
            Value tensor.
        attn_weights : torch.Tensor
            Attention weights tensor.
        position_bias : Optional[torch.Tensor]
            Relative position bias tensor.
        bsz : int
            Batch size.
        tgt_len : int
            Target sequence length.
        src_len : int
            Source sequence length.
        embed_dim : int
            Embedding dimension.
        need_weights : bool
            Whether to return attention weights.
        need_head_weights : bool
            Whether to return head-specific weights.
        alpha : float
            Scaling factor for relative position.

        Returns
        -------
        Tuple[torch.Tensor, Optional[torch.Tensor]]
            Final attention output and optional attention weights.
        Nr   rU   r   rA   F)keepdimrS   r   g       @r   )r   rC   r  r"  r$  rB   r   rm   r)  rj   chunkr+  r   softmaxr   r   bmmrU  r!  rh   
contiguousr   rl   )r9   re  rf  ri  r  ra  r`  rT  r#   r   rN  rP  attn_mask_rel_posquery_layer_B_H_L__gate_agate_bgate_a_1attn_weights_float
attn_probsr  attn_weights_outr=   r=   r>   r_    s^   1







z,MultiheadAttention._compute_attention_outputc
                 C   s&  |j jdk}
|dur| dkrd}|dur(|d|ksJ |d|ks(J | jr|dus1J |d7 }tj|||ddf| dd  gdd}tj|||ddf| dd  gdd}|dur{tj|||ddgdd}|durtj|t|dd	|gdd}t
||dd}||jddd	d  |	 }| ||||}t| || j ||gksJ |dur|d}||7 }|dur||| j||}|
s||ddtjtd
}n|dd}||td
}|dd}||| j ||}||fS )ao  
        Processes attention weights, including handling key padding masks, adding zero attention if required,
        and computing the attention weights with masking.

        Arguments
        ---------
        q : torch.Tensor
            Query tensor.
        k : torch.Tensor
            Key tensor.
        v : torch.Tensor
            Value tensor.
        attn_mask : torch.Tensor
           Attention mask
        key_padding_mask : torch.Tensor
           Key padding mask.
        bsz : int
            Batch size.
        tgt_len : int
            Target sequence length.
        src_len : int
            Source sequence length.
        alpha : float
            Scaling factor for relative position.

        Returns
        -------
        Tuple[torch.Tensor, Optional[torch.Tensor]]
            Computed attention weights and the updated attention mask.
        xlaNr   r   rU   rS   rA   T)rT   rj  z-inf)rd   typerT   rB   r(  r   r[  	new_zeroszerosr   rm  rh   maxapply_sparse_maskrU  r  rV   rC   masked_fillr1  re   ry   )r9   re  r   rf  r  r  ra  r`  rT  rP  is_tpuri  r=   r=   r>   r^  %  s|   !((




z-MultiheadAttention._process_attention_weightsc                 C   s   | j durT| jdusJ dtj|| j d|dgdd}tj|| jd|dgdd}|dur@tj|||ddgdd}|durTtj|||ddgdd}||||fS )a  
        Applies bias_k and bias_v to the key and value tensors, updating
        the attention mask and key padding mask accordingly.

        Arguments
        ---------
        k : torch.Tensor
            Key tensor.
        v : torch.Tensor
            Value tensor.
        bsz : int
            Batch size.
        attn_mask : torch.Tensor
            Attention mask
        key_padding_mask : torch.Tensor
           Key padding mask.

        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: Updated key, value,
            attention mask, and key padding mask.
        Nz(bias_k and bias_v must both be provided.r   r   rS   )r&  r'  r   r[  rX  r}  rB   )r9   r   rf  ra  r  r  r=   r=   r>   
apply_bias  s&   
zMultiheadAttention.apply_biasc	                 C   s6  | j r| |}	| |}
| |}n;| jr5| |}	|du r*|du s%J d }
}n$| |}
| |}n|dur=|dus?J | |}	| |}
| |}|	| j9 }	|	d| 9 }	|	 ||| j | j	
dd}	|
dur|
 d|| j | j
dd}
|dur| d|| j | j
dd}|	|
|||fS )a  
        Prepares and scales the projections, applies biases, and reshapes the query, key, and value tensors
        for multi-head attention.

        Arguments
        ---------
        query : torch.Tensor
            Query tensor.
        key : torch.Tensor
            Key tensor.
        value : torch.Tensor
            Value tensor.
        bsz : int
            Batch size.
        tgt_len : int
            Target sequence length.
        key_padding_mask : torch.Tensor
           Key padding mask.
        attn_mask : torch.Tensor
           Attention mask
        alpha : float, optional
            Scaling factor for relative position. Default is 32.

        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
            Scaled and reshaped query, key, and value tensors, along with updated attention and key padding masks.
        Nr   r   rA   )r  r   r   r   r%  r$  rn  rC   r  r"  rh   r#  r!  )r9   r  r  r  ra  r`  r  r  rP  re  r   rf  r=   r=   r>   rZ    s@   (








z,MultiheadAttention._prepare_attention_inputsrR  rS  rT  c                 C   s  |dur
|r
|}|S |dur!| dur!t j| |  gdd}|S |durP||dkrJt j|||d f|jd}t j| | gdd}|S | }|S | dur|| dkryt j||| d f| jd}t j| |  gdd}|S |  }|S |}|S )a  
        Combines the previous and current key padding masks to create a unified mask.

        Arguments
        ---------
        key_padding_mask : Optional[torch.Tensor]
            The current key padding mask of shape `(batch_size, seq_len)`, or `None`.
        prev_key_padding_mask : Optional[torch.Tensor]
            The previous key padding mask of shape `(batch_size, seq_len)`, or `None`.
        batch_size : int
            The batch size of the input.
        src_len : int
            The source sequence length to which the masks need to align.
        static_kv : bool
            If `True`, indicates that the key-value pairs are static and only the
            previous key padding mask should be used.

        Returns
        -------
        Optional[torch.Tensor]
            The combined key padding mask of shape `(batch_size, src_len)`, or `None`
            if both input masks are `None`.

        Nr   rS   rc   )r   r[  ry   rB   r~  rd   )r  rR  rS  rT  rL  new_key_padding_maskfillerr=   r=   r>   r\    sD   ! z0MultiheadAttention._append_prev_key_padding_maskc                 C   s    |  |d}|dur|S i }|S )aU  
        Retrieves the input buffer for incremental decoding.

        Arguments
        ---------
        incremental_state : Optional[Dict[str, Dict[str, Optional[Tensor]]]]
            The state dictionary used for incremental decoding. It stores intermediate
            computation states, such as attention states, for efficient sequential processing.

        Returns
        -------
        Dict[str, Optional[Tensor]]
            The attention state dictionary containing keys and values for incremental
            decoding. If no state exists, an empty dictionary is returned.

        
attn_stateN)get_incremental_state)r9   rK  resultempty_resultr=   r=   r>   rY  _  s
   z$MultiheadAttention._get_input_bufferbufferc                 C   s   |  |d|S )a  
        Updates the input buffer for incremental decoding.

        Arguments
        ---------
        incremental_state : Dict[str, Dict[str, Optional[Tensor]]]
            The state dictionary used for incremental decoding. It stores intermediate
            computation states, such as attention states.
        buffer : Dict[str, Optional[Tensor]]
            The attention state dictionary containing keys and values to be stored
            for incremental decoding.
        Returns
        -------
        None
        r  )set_incremental_state)r9   rK  r  r=   r=   r>   r]  z  s   z$MultiheadAttention._set_input_bufferr`  ra  c                 C   s   |S )aI  
        Applies a sparse mask to the attention weights.

        Arguments
        ---------
        attn_weights : torch.Tensor
            The attention weights tensor of shape `(batch_size * num_heads, tgt_len, src_len)`.
        tgt_len : int
            The target sequence length.
        src_len : int
            The source sequence length.
        bsz : int
            The batch size.

        Returns
        -------
        torch.Tensor
            The (potentially modified) attention weights tensor. By default, this is
            the same as the input tensor.
        r=   )r9   ri  r`  rT  ra  r=   r=   r>   r    s   z$MultiheadAttention.apply_sparse_mask)NNr  TFFFFr  r   Fr  rL   FF)T)NNTFNFFNr   )NNr  )rt   ru   rv   rw   r   r,  r?  r  r   r   rJ  r   r   rx   re   r   rb   r_  r^  r  rZ  r   r\  rY  r]  r  rz   r=   r=   r;   r>   r	  	  s    /W
50	

 <e
p:
SC

r	  r   c                 C   s   dt jddfdd}t| tjr%|| jj | jdur#| jj  dS dS t| tj	rC|| jj | j
durA| jj| j
   dS dS t| tr_|| jjj || jjj || jjj dS dS )z
    Initializes weights and biases for modules in the BERT model.

    Arguments
    ---------
    module : nn.Module
        The module to initialize. Can be one of `nn.Linear`, `nn.Embedding`, or `MultiheadAttention`.

    datar   Nc                 S   s$   |  |  jddd| j dS )z
        Initializes a tensor with values drawn from a normal distribution.

        Arguments
        ---------
        data : torch.Tensor
            The tensor to initialize.
        r  g{Gz?r   N)copy_cpur   r1  rd   )r  r=   r=   r>   r     s   $
z!init_bert_params.<locals>.normal_)r   r   r   r   r&   r   r  r   zero_r   padding_idxr	  r   r   r   )r   r   r=   r=   r>   r     s    


r   c                   @   s(   e Zd ZdZdddZdefddZdS )	r   a  
    Configuration class for the BEATs model.

    This class defines the configuration for the BEATs model. It provides a default
    configuration that can be updated with custom settings via the `update` method.

    Arguments
    ---------
    cfg : dict, optional
        A dictionary containing custom configuration values. If provided, it will override
        the default settings.
    Nc                 C   s   d| _ d| _d| _d| _d| _d| _d| _d| _d| _d| _	d| _
d	| _d	| _d
| _d
| _d
| _d| _d| _d| _d| _d| _d| _d| _d	| _d| _|d urV| | d S d S )N   i   F   r  r  r   r   r  r  rL   i@  i   i  )r(   r#   r*   r   r%   r   r   r   r   r/   r.   r   r   r   r   r-   r   r   r   r   r   r   r3   r4   r5   updater9   r   r=   r=   r>   r     sV   zBEATsConfig.__init__r   c                 C   s   | j | dS )a  
        Updates the instance's attributes with key-value pairs from a given configuration dictionary.

        Arguments
        ---------
        cfg : dict
            A dictionary containing the configuration values to update the instance with.
        N)r"   r  r  r=   r=   r>   r  *  s   	zBEATsConfig.updater   )rt   ru   rv   rw   r   dictr  r=   r=   r=   r>   r     s    
Fr   ).rw   loggingr   r   typingr   r   r   numpyr   r   torch.nn.functionalr   r   r   torchaudio.compliance.kaldi
compliancekaldirW   r   torch.nnr   r   speechbrain.dataio.dataior	   	getLoggerrt   r    Moduler
   r~   r   rx   r   r   r   r   autogradFunctionr   r   r0   r   r	  r   r   r=   r=   r=   r>   <module>   sL    
  20!,59 P ?       **