o
    iV                     @   s   d dl Z d dlmZmZ d dlZd dlmZ d dlmZ g Z					dd	ejd
ejdejde
dedededejfddZG dd dejjZG dd dejjZG dd dejjZG dd dejjZdS )    N)OptionalUnion)Tensor)
functionalref_channelTHz>:0yE>psd_spsd_nreference_vectorsolutiondiagonal_loadingdiag_epsepsreturnc           	      C   s`   |dkrt | |||||}|S |dkrt | }n
t j| ||||d}t ||||||}|S )a  Compute the MVDR beamforming weights with ``solution`` argument.

    Args:
        psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
            Tensor with dimensions `(..., freq, channel, channel)`.
        psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
            Tensor with dimensions `(..., freq, channel, channel)`.
        reference_vector (torch.Tensor): one-hot reference channel matrix.
        solution (str, optional): Solution to compute the MVDR beamforming weights.
            Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
        diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
            (Default: ``True``)
        diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
            It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
        eps (float, optional): Value to add to the denominator in the beamforming weight formula.
            (Default: ``1e-8``)

    Returns:
        torch.Tensor: the mvdr beamforming weight matrix
    r   stv_evd)r   r   )Fmvdr_weights_soudenrtf_evd	rtf_powermvdr_weights_rtf)	r	   r
   r   r   r   r   r   beamform_vectorstv r   X/home/ubuntu/.local/lib/python3.10/site-packages/torchaudio/transforms/_multi_channel.py_get_mvdr_vector   s   r   c                       sL   e Zd ZdZddededef fdd	Zddejde	ej fddZ
  ZS )PSDa  Compute cross-channel power spectral density (PSD) matrix.

    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

    Args:
        multi_mask (bool, optional): If ``True``, only accepts multi-channel Time-Frequency masks. (Default: ``False``)
        normalize (bool, optional): If ``True``, normalize the mask along the time dimension. (Default: ``True``)
        eps (float, optional): Value to add to the denominator in mask normalization. (Default: ``1e-15``)
    FTV瞯<
multi_mask	normalizer   c                    s    t    || _|| _|| _d S N)super__init__r   r   r   )selfr   r   r   	__class__r   r   r"   D   s   

zPSD.__init__Nspecgrammaskc                 C   s2   |dur| j r|jdd}t||| j| j}|S )a  
        Args:
            specgram (torch.Tensor): Multi-channel complex-valued spectrum.
                Tensor with dimensions `(..., channel, freq, time)`.
            mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
                Tensor with dimensions `(..., freq, time)` if multi_mask is ``False`` or
                with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
                (Default: ``None``)

        Returns:
            torch.Tensor: The complex-valued PSD matrix of the input spectrum.
                Tensor with dimensions `(..., freq, channel, channel)`
        Ndim)r   meanr   psdr   r   )r#   r&   r'   r,   r   r   r   forwardJ   s
   zPSD.forward)FTr   r    )__name__
__module____qualname____doc__boolfloatr"   torchr   r   r-   __classcell__r   r   r$   r   r   7   s    $r   c                       s   e Zd ZdZ						d!dededed	ed
edef fddZ				d"de	j
de	j
de	j
de	j
de	j
deded
edede	j
fddZde	j
de	j
de	j
fddZde	j
de	j
de	j
fddZ	d#de	j
de	j
dee	j
 de	j
fdd Z  ZS )$MVDRa  Minimum Variance Distortionless Response (MVDR) module that performs MVDR beamforming with Time-Frequency masks.

    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

    Based on https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/beamformer.py

    We provide three solutions of MVDR beamforming. One is based on *reference channel selection*
    :cite:`souden2009optimal` (``solution=ref_channel``).

    .. math::
        \textbf{w}_{\text{MVDR}}(f) =        \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)}        {\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u}

    where :math:`\bf{\Phi}_{\textbf{SS}}` and :math:`\bf{\Phi}_{\textbf{NN}}` are the covariance        matrices of speech and noise, respectively. :math:`\bf{u}` is an one-hot vector to determine the         reference channel.

    The other two solutions are based on the steering vector (``solution=stv_evd`` or ``solution=stv_power``).

    .. math::
        \textbf{w}_{\text{MVDR}}(f) =        \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}}        {{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}

    where :math:`\bm{v}` is the acoustic transfer function or the steering vector.        :math:`.^{\mathsf{H}}` denotes the Hermitian Conjugate operation.

    We apply either *eigenvalue decomposition*
    :cite:`higuchi2016robust` or the *power method* :cite:`mises1929praktische` to get the
    steering vector from the PSD matrix of speech.

    After estimating the beamforming weight, the enhanced Short-time Fourier Transform (STFT) is obtained by

    .. math::
        \hat{\bf{S}} = {\bf{w}^\mathsf{H}}{\bf{Y}}, {\bf{w}} \in \mathbb{C}^{M \times F}

    where :math:`\bf{Y}` and :math:`\hat{\bf{S}}` are the STFT of the multi-channel noisy speech and        the single-channel enhanced speech, respectively.

    For online streaming audio, we provide a *recursive method* :cite:`higuchi2017online` to update the
    PSD matrices of speech and noise, respectively.

    Args:
        ref_channel (int, optional): Reference channel for beamforming. (Default: ``0``)
        solution (str, optional): Solution to compute the MVDR beamforming weights.
            Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
        multi_mask (bool, optional): If ``True``, only accepts multi-channel Time-Frequency masks. (Default: ``False``)
        diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to the covariance matrix
            of the noise. (Default: ``True``)
        diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
            It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
        online (bool, optional): If ``True``, updates the MVDR beamforming weights based on
            the previous covarience matrices. (Default: ``False``)

    Note:
        To improve the numerical stability, the input spectrogram will be converted to double precision
        (``torch.complex128`` or ``torch.cdouble``) dtype for internal computation. The output spectrogram
        is converted to the dtype of the input spectrogram to be compatible with other modules.

    Note:
        If you use ``stv_evd`` solution, the gradient of the same input may not be identical if the
        eigenvalues of the PSD matrix are not distinct (i.e. some eigenvalues are close or identical).
    r   r   FTr   r   r   diag_loadingr   onlinec                    s   t    |dvrtd||| _|| _|| _|| _|| _|| _	t
|| _td}td}td}	td}
| d| | d| | d|	 | d|
 d S )N)r   r   	stv_powerzK`solution` must be one of ["ref_channel", "stv_evd", "stv_power"]. Given {}   r	   r
   
mask_sum_s
mask_sum_n)r!   r"   
ValueErrorformatr   r   r   r7   r   r8   r   r,   r4   zerosregister_buffer)r#   r   r   r   r7   r   r8   r	   r
   r;   r<   r$   r   r   r"      s(   
	




zMVDR.__init__r   r	   r
   mask_smask_nr   r   r   r   c
           
      C   s   | j r|jdd}|jdd}| jjdkr3|| _|| _|jdd| _|jdd| _t|||||||	S | 	||}| 
||}|| _|| _| j|jdd | _| j|jdd | _t|||||||	S )a  Recursively update the MVDR beamforming vector.

        Args:
            psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
                Tensor with dimensions `(..., freq, channel, channel)`.
            psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
                Tensor with dimensions `(..., freq, channel, channel)`.
            mask_s (torch.Tensor): Time-Frequency mask of the target speech.
                Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
                or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
            mask_n (torch.Tensor or None, optional): Time-Frequency mask of the noise.
                Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
                or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
            reference_vector (torch.Tensor): One-hot reference channel matrix.
            solution (str, optional): Solution to compute the MVDR beamforming weights.
                Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
            diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
                (Default: ``True``)
            diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
                It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
            eps (float, optional): Value to add to the denominator in the beamforming weight formula.
                (Default: ``1e-8``)

        Returns:
            torch.Tensor: The MVDR beamforming weight matrix.
        r(   r)   r:   )r   r+   r	   ndimr
   sumr;   r<   r   _get_updated_psd_speech_get_updated_psd_noise)
r#   r	   r
   rA   rB   r   r   r   r   r   r   r   r   _get_updated_mvdr_vector   s    &zMVDR._get_updated_mvdr_vectorc                 C   L   | j | j |jdd  }d| j |jdd  }| j|d  ||d   }|S )a  Update psd of speech recursively.

        Args:
            psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
                Tensor with dimensions `(..., freq, channel, channel)`.
            mask_s (torch.Tensor): Time-Frequency mask of the target speech.
                Tensor with dimensions `(..., freq, time)`.

        Returns:
            torch.Tensor: The updated PSD matrix of target speech.
        rC   r)   r:   .NN)r;   rE   r	   )r#   r	   rA   	numeratordenominatorr   r   r   rF        zMVDR._get_updated_psd_speechc                 C   rI   )a  Update psd of noise recursively.

        Args:
            psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
                Tensor with dimensions `(..., freq, channel, channel)`.
            mask_n (torch.Tensor or None, optional): Time-Frequency mask of the noise.
                Tensor with dimensions `(..., freq, time)`.

        Returns:
            torch.Tensor:  The updated PSD matrix of noise.
        rC   r)   r:   rJ   )r<   rE   r
   )r#   r
   rB   rK   rL   r   r   r   rG     rM   zMVDR._get_updated_psd_noiseNr&   c           
   
   C   s  |j }|jdk rtd|j | std|j  |j tjkr&| }|du r3t	d d| }| 
||}| 
||}tj| dd |jtjd}|d	| jf d | jrl| |||||| j| j| j}nt|||| j| j| j}t||}	|	|S )
a`  Perform MVDR beamforming.

        Args:
            specgram (torch.Tensor): Multi-channel complex-valued spectrum.
                Tensor with dimensions `(..., channel, freq, time)`
            mask_s (torch.Tensor): Time-Frequency mask of target speech.
                Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
                or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
            mask_n (torch.Tensor or None, optional): Time-Frequency mask of noise.
                Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
                or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
                (Default: None)

        Returns:
            torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
           z?Expected at least 3D tensor (..., channel, freq, time). Found: ziThe type of ``specgram`` tensor must be ``torch.cfloat`` or ``torch.cdouble``.                    Found: Nz=``mask_n`` is not provided, use ``1 - mask_s`` as ``mask_n``.r:   )devicedtype.)rQ   rD   r=   shape
is_complexr4   cfloatcdoublewarningswarnr,   r?   sizerP   r   fill_r8   rH   r   r7   r   r   r   apply_beamformingto)
r#   r&   rA   rB   rQ   r	   r
   uw_mvdrspecgram_enhancedr   r   r   r-   #  s2   

 
zMVDR.forward)r   r   FTr   Fr   Tr   r   r    )r.   r/   r0   r1   intstrr2   r3   r"   r4   r   rH   rF   rG   r   r-   r5   r   r   r$   r   r6   a   st    E*	

9r6   c                   @   sJ   e Zd ZdZ			ddedededeeef d	ed
ededefddZ	dS )RTFMVDRa  Minimum Variance Distortionless Response (*MVDR* :cite:`capon1969high`) module
    based on the relative transfer function (RTF) and power spectral density (PSD) matrix of noise.

    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

    Given the multi-channel complex-valued spectrum :math:`\textbf{Y}`, the relative transfer function (RTF) matrix
    or the steering vector of target speech :math:`\bm{v}`, the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and
    a one-hot vector that represents the reference channel :math:`\bf{u}`, the module computes the single-channel
    complex-valued spectrum of the enhanced speech :math:`\hat{\textbf{S}}`. The formula is defined as:

    .. math::
        \hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)

    where :math:`\textbf{w}_{\text{bf}}(f)` is the MVDR beamforming weight for the :math:`f`-th frequency bin,
    :math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation.

    The beamforming weight is computed by:

    .. math::
        \textbf{w}_{\text{MVDR}}(f) =
        \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}}
        {{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}
    Tr   r   r&   rtfr
   reference_channelr   r   r   r   c           
      C   $   t ||||||}t ||}	|	S )a  
        Args:
            specgram (torch.Tensor): Multi-channel complex-valued spectrum.
                Tensor with dimensions `(..., channel, freq, time)`
            rtf (torch.Tensor): The complex-valued RTF vector of target speech.
                Tensor with dimensions `(..., freq, channel)`.
            psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
                Tensor with dimensions `(..., freq, channel, channel)`.
            reference_channel (int or torch.Tensor): Specifies the reference channel.
                If the dtype is ``int``, it represents the reference channel index.
                If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
                is one-hot.
            diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
                (Default: ``True``)
            diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
                It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
            eps (float, optional): Value to add to the denominator in the beamforming weight formula.
                (Default: ``1e-8``)

        Returns:
            torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
        )r   r   rZ   )
r#   r&   rc   r
   rd   r   r   r   r]   spectrum_enhancedr   r   r   r-   r      zRTFMVDR.forwardNTr   r   )
r.   r/   r0   r1   r   r   r`   r2   r3   r-   r   r   r   r   rb   W  s,     
	rb   c                   @   sL   e Zd ZdZ			ddedededeeef d	ed
edede	jfddZ
dS )
SoudenMVDRa  Minimum Variance Distortionless Response (*MVDR* :cite:`capon1969high`) module
    based on the method proposed by *Souden et, al.* :cite:`souden2009optimal`.

    .. devices:: CPU CUDA

    .. properties:: Autograd TorchScript

    Given the multi-channel complex-valued spectrum :math:`\textbf{Y}`, the power spectral density (PSD) matrix
    of target speech :math:`\bf{\Phi}_{\textbf{SS}}`, the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and
    a one-hot vector that represents the reference channel :math:`\bf{u}`, the module computes the single-channel
    complex-valued spectrum of the enhanced speech :math:`\hat{\textbf{S}}`. The formula is defined as:

    .. math::
        \hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)

    where :math:`\textbf{w}_{\text{bf}}(f)` is the MVDR beamforming weight for the :math:`f`-th frequency bin.

    The beamforming weight is computed by:

    .. math::
        \textbf{w}_{\text{MVDR}}(f) =
        \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)}
        {\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u}
    Tr   r   r&   r	   r
   rd   r   r   r   r   c           
      C   re   )a  
        Args:
            specgram (torch.Tensor): Multi-channel complex-valued spectrum.
                Tensor with dimensions `(..., channel, freq, time)`.
            psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
                Tensor with dimensions `(..., freq, channel, channel)`.
            psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
                Tensor with dimensions `(..., freq, channel, channel)`.
            reference_channel (int or torch.Tensor): Specifies the reference channel.
                If the dtype is ``int``, it represents the reference channel index.
                If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
                is one-hot.
            diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
                (Default: ``True``)
            diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
                It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
            eps (float, optional): Value to add to the denominator in the beamforming weight formula.
                (Default: ``1e-8``)

        Returns:
            torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
        )r   r   rZ   )
r#   r&   r	   r
   rd   r   r   r   r]   rf   r   r   r   r-     rg   zSoudenMVDR.forwardNrh   )r.   r/   r0   r1   r   r   r`   r2   r3   r4   r-   r   r   r   r   ri     s,    
	ri   r_   )rV   typingr   r   r4   r   
torchaudior   r   __all__ra   r2   r3   r   nnModuler   r6   rb   ri   r   r   r   r   <module>   s@   
)* w@