o
    %ݫiI                     @   s   d Z ddlZddlZdZe dddZdd Ze efd	d
Ze d ddZd ddZ	e 					d!ddZ
d"ddZ	d#ddZd$ddZdS )%a5  
Functions for analyzing vocal characteristics: jitter, shimmer, HNR, and GNE.

These are typically used for analysis of dysarthric voices using more traditional approaches
(i.e. not deep learning). Often useful as a baseline for e.g. pathology detection. Inspired by PRAAT.

Authors
 * Peter Plantinga, 2024
    N      c           	      C   sj   t | }|dddd||f jdd\}}tjjj|dd}|d|djdd\}}|| }||fS )aB  Compute features based on autocorrelation

    Arguments
    ---------
    frames: torch.Tensor
        The audio frames to be evaluated for autocorrelation, shape [batch, frame, sample]
    min_lag: int
        The minimum number of samples to consider for potential period length.
    max_lag: int
        The maximum number of samples to consider for potential period length.
    neighbors: int
        The number of neighbors to use for rolling median -- to avoid octave errors.

    Returns
    -------
    harmonicity: torch.Tensor
        The highest autocorrelation score relative to the 0-lag score. Used to compute HNR
    best_lags: torch.Tensor
        The lag corresponding to the highest autocorrelation score, an estimate of period length.

    Example
    -------
    >>> audio = torch.rand(1, 16000)
    >>> frames = audio.unfold(-1, 800, 200)
    >>> frames.shape
    torch.Size([1, 77, 800])
    >>> harmonicity, best_lags = compute_autocorr_features(frames, 100, 200)
    >>> harmonicity.shape
    torch.Size([1, 77])
    >>> best_lags.shape
    torch.Size([1, 77])
    Ndim)   r   )pad   )autocorrelatemaxtorchnn
functionalr   unfoldmedian)	framesmin_lagmax_lag	neighborsautocorrelationharmonicitylags	best_lags_ r   Y/home/ubuntu/.local/lib/python3.10/site-packages/speechbrain/processing/vocal_features.pycompute_autocorr_features   s   "&r   c                 C   sP   |  d}tj|| jdddd}t| | | | }t||jdd}|| S )a  Generate autocorrelation scores using circular convolution.

    Arguments
    ---------
    frames: torch.Tensor
        The audio frames to be evaluated for autocorrelation, shape [batch, frame, sample]

    Returns
    -------
    autocorrelation: torch.Tensor
        The ratio of the best candidate lag's autocorrelation score against
        the theoretical maximum autocorrelation score at lag 0.
        Normalized by the autocorrelation_score of the window.

    Example
    -------
    >>> audio = torch.rand(1, 16000)
    >>> frames = audio.unfold(-1, 800, 200)
    >>> frames.shape
    torch.Size([1, 77, 800])
    >>> autocorrelation = autocorrelate(frames)
    >>> autocorrelation.shape
    torch.Size([1, 77, 401])
    r   devicer	   绽|=min)sizer   hann_windowr   viewcompute_cross_correlationclamp)r   window_sizehannr   
norm_scorer   r   r   r
   B   s
   
r
   c                 C   s  t |  }t j| d| jd}|ddd| j}|	d}|
|}|d }t j|ddd\}}	|	
|}
||
| k ||
| | k@ ||
| k||
| | k @ B }d||< g g }}t|D ]0}t j|ddd\}}	||	|d  k||	|d  k @ }d||< ||d ||	d qct j|dd	}t j|dd	}|
|}t ||| }|| jddd
  }|jdd	| }|jddd
}||  }|jdd	|djdd }||fS )a  Function to compute periodic features: jitter, shimmer

    Arguments
    ---------
    frames: torch.Tensor
        The framed audio to use for feature computation, dims [batch, frame, sample].
    best_lags: torch.Tensor
        The estimated period length for each frame, dims [batch, frame].
    neighbors: int
        Number of neighbors to use in comparison.

    Returns
    -------
    jitter: torch.Tensor
        The average absolute deviation in period over the frame.
    shimmer: torch.Tensor
        The average absolute deviation in amplitude over the frame.

    Example
    -------
    >>> audio = torch.rand(1, 16000)
    >>> frames = audio.unfold(-1, 800, 200)
    >>> frames.shape
    torch.Size([1, 77, 800])
    >>> harmonicity, best_lags = compute_autocorr_features(frames, 100, 200)
    >>> jitter, shimmer = compute_periodic_features(frames, best_lags)
    >>> jitter.shape
    torch.Size([1, 77])
    >>> shimmer.shape
    torch.Size([1, 77])
    r   r   r	   r   T)r   keepdimr   r   r   )r   keepdimsr   r    )r   clonedetacharanger"   r   r$   expandshape	unsqueeze	remainderr   rangeappendsqueezestackminimumfloatmeanabsr&   )r   r   r   masked_framesmask_indicesperiodsperiod_indicesjitter_rangepeaklaglag_indicesmaskpeaksr   ijitter_framesjitteravg_ampsamp_diffshimmerr   r   r   compute_periodic_featuresg   sF   "






rK   r   c              	   C   sN  |  d}tjdd|| jdddd}t|| d}t|| d |  }t|| d | |d |  }t|| d | |d |  }|d}| | | 	  j
dd }| | 	 
d }	|	| j
dd|  }
| jdd| jdd|  }| d	d	ddd	d	f }tj| d|d
dj
dd }tj||||||
||fddS )a  Compute statistical measures on spectral frames
    such as flux, skew, spread, flatness.

    Reference page for computing values:
    https://www.mathworks.com/help/audio/ug/spectral-descriptors.html

    Arguments
    ---------
    spectrum: torch.Tensor
        The spectrum to use for feature computation, dims [batch, frame, freq].
    eps: float
        A small value to avoid division by 0.

    Returns
    -------
    features: torch.Tensor
        A [batch, frame, 8] tensor of spectral features for each frame:
         * centroid: The mean of the spectrum.
         * spread: The stdev of the spectrum.
         * skew: The spectral balance.
         * kurtosis: The spectral tailedness.
         * entropy: The peakiness of the spectrum.
         * flatness: The ratio of geometric mean to arithmetic mean.
         * crest: The ratio of spectral maximum to arithmetic mean.
         * flux: The average delta-squared between one spectral value and it's successor.

    Example
    -------
    >>> audio = torch.rand(1, 16000)
    >>> window_size = 800
    >>> frames = audio.unfold(-1, window_size, 200)
    >>> frames.shape
    torch.Size([1, 77, 800])
    >>> hann = torch.hann_window(window_size).view(1, 1, -1)
    >>> windowed_frames = frames * hann
    >>> spectrum = torch.abs(torch.fft.rfft(windowed_frames))
    >>> spectral_features = compute_spectral_features(spectrum)
    >>> spectral_features.shape
    torch.Size([1, 77, 8])
    r   r   r	   r   r      r   r   N)r   prepend)r"   r   linspacer   r$   	spec_normr1   sqrtr5   logr9   expamaxsumdiffpowr6   )spectrumepsnfreqfreqscentroidspreadskewkurtentropygeomeanflatnesscrestr   fluxr   r   r   compute_spectral_features   s    
+
"rd   c                 C   s    | | j dd|j dd|  S )z*Normalize the given value by the spectrum.r   r   )rT   )valuerW   rX   r   r   r   rO     s    rO   >    ,  Q?{Gz?c                    s   |   dks
J dd}tj| |} t| }t| }tj|| jdddd}	| j	d||d|	 }
t
|
dd	 d d  d  }}t||| fd
dD  fddD }tj|ddjddS )aj  An algorithm for GNE computation from the original paper:

    "Glottal-to-Noise Excitation Ratio - a New Measure for Describing
    Pathological Voices" by D. Michaelis, T. Oramss, and H. W. Strube.

    This algorithm divides the signal into frequency bands, and compares
    the correlation between the bands. High correlation indicates a
    relatively low amount of noise in the signal, whereas lower correlation
    could be a sign of pathology in the vocal signal.

    Godino-Llorente et al. in "The Effectiveness of the Glottal to Noise
    Excitation Ratio for the Screening of Voice Disorders." explore the
    goodness of the bandwidth and frequency shift parameters, the defaults
    here are the ones recommended in that work.

    Arguments
    ---------
    audio : torch.Tensor
        The batched audio signal to use for GNE computation, [batch, sample]
    sample_rate : float
        The sample rate of the input audio.
    bandwidth : float
        The width of the frequency bands used for computing correlation.
    fshift : float
        The shift between frequency bands used for computing correlation.
    frame_len : float
        Length of each analysis frame, in seconds.
    hop_len : float
        Length of time between the start of each analysis frame, in seconds.

    Returns
    -------
    gne : torch.Tensor
        The glottal-to-noise-excitation ratio for each frame of the audio signal.

    Example
    -------
    >>> sample_rate = 16000
    >>> audio = torch.rand(1, sample_rate) # 1s of audio
    >>> gne = compute_gne(audio, sample_rate=sample_rate)
    >>> gne.shape
    torch.Size([1, 98])
    r   z3Expected audio to be 2-dimensional, [batch, sample]'  r   r	   r   )	dimensionr"   step   )	lpc_orderc                    s   i | ]
}|t | qS r   )compute_hilbert_envelopes).0center_freq)	bandwidthexcitation_framessample_rater   r   
<dictcomp>Q  s    zcompute_gne.<locals>.<dictcomp>c                    s<   g | ]}D ]}||  d  krt | | ddqqS )r   rL   width)r%   )rq   freq_ifreq_j)rs   center_freqs	envelopesr   r   
<listcomp>Y  s    zcompute_gne.<locals>.<listcomp>r   )r   rL   )r   
torchaudior   resampleintr   r#   r   r$   r   inverse_filterr3   r6   rS   )audioru   rs   fshift	frame_lenhop_lenold_sample_rate
frame_sizehop_sizewindowr   min_freqmax_freqcorrelationsr   )rs   r{   r|   rt   ru   r   compute_gne  s&   6
r   rn   c                 C   s   t | | |d}|j\}}}||| d}| || d}d|dd|f< |ddddf d|djdd}|dd|d df }tj||}	tjj	j
|	 ddd	}
t|
}d|ddd
f< tj	j|||
dd}|||dS )a  Perform inverse filtering on frames to estimate glottal pulse train.

    Uses autocorrelation method and Linear Predictive Coding (LPC).
    Algorithm from https://course.ece.cmu.edu/~ece792/handouts/RS_Chap_LPC.pdf

    Arguments
    ---------
    frames : torch.Tensor
        The audio frames to filter using inverse filter.
    lpc_order : int
        The size of the filter to compute and use on the frames.

    Returns
    -------
    filtered_frames : torch.Tensor
        The frames after the inverse filter is applied

    Example
    -------
    >>> audio = torch.rand(1, 10000)
    >>> frames = audio.unfold(-1, 300, 100)
    >>> frames.shape
    torch.Size([1, 98, 300])
    >>> filtered_frames = inverse_filter(frames)
    >>> filtered_frames.shape
    torch.Size([1, 98, 300])
    rw   r   g      ?Nr	   )r	   )dims)r	   r   )re   r   F)r&   )r%   r0   r$   r   flipr   linalgsolver   r   r   
zeros_liker~   lfilter)r   ro   r   batchframe_countr   reshaped_framesRrlpc
lpc_coeffsa_coeffsinverse_filteredr   r   r   r   d  s   &
r   rk   c                 C   s   ||d  }||d  }t j| }t j|dd| }t j|t jd}||k ||k @ }	t j|	 |jd}
|
|dddd|	f< t j	|| }|
 S )a  Compute the hilbert envelope of the signal in a specific frequency band using FFT.

    Arguments
    ---------
    frames : torch.Tensor
        A set of frames from a signal for which to compute envelopes.
    center_freq : float
        The target frequency for the envelope.
    bandwidth : float
        The size of the band to use for the envelope.
    sample_rate : float
        The number of samples per second in the frame signals.

    Returns
    -------
    envelopes : torch.Tensor
        The computed envelopes.

    Example
    -------
    >>> audio = torch.rand(1, 10000)
    >>> frames = audio.unfold(-1, 300, 100)
    >>> frames.shape
    torch.Size([1, 98, 300])
    >>> envelope = compute_hilbert_envelopes(frames, 1000)
    >>> envelope.shape
    torch.Size([1, 98, 300])
    r   r   r	   )dtyper   N)r   fftfftfreqr"   r   r8   r#   rT   r   ifftr:   )r   rr   rs   ru   low_freq	high_freqspectrarZ   rC   window_binsr   analytic_signalr   r   r   rp     s   !rp   c                 C   s   | j \}}}|du rd|d fn||f}tjjj| |dd}|| }|d|d}	||dd}
tjjj|	|
|d}|||d}t| d jdd	|d jdd	 }||	dj
d
d }|S )a  Computes the correlation between two sets of frames.

    Arguments
    ---------
    frames_a : torch.Tensor
    frames_b : torch.Tensor
        The two sets of frames to compare using cross-correlation,
        shape [batch, frame, sample]
    width : int, default is None
        The number of samples before and after 0 lag. A width of 3 returns 7 results.
        If None, 0 lag is put at the front, and the result is 1/2 the original length + 1,
        a nice default for autocorrelation as there are no repeated values.

    Returns
    -------
    The cross-correlation between frames_a and frames_b.

    Example
    -------
    >>> frames = torch.arange(10).view(1, 1, -1).float()
    >>> compute_cross_correlation(frames, frames, width=3)
    tensor([[[0.6316, 0.7193, 0.8421, 1.0000, 0.8421, 0.7193, 0.6316]]])
    >>> compute_cross_correlation(frames, frames)
    tensor([[[1.0000, 0.8421, 0.7193, 0.6316, 0.5789, 0.5614]]])
    Nr   r   circular)moder	   r   )inputweightgroupsr   r   r    )r0   r   r   r   r   r$   conv1drP   rT   r1   r&   )frames_aframes_brx   
batch_sizer   r   r   padded_frames_amerged_size
reshaped_a
reshaped_bcross_correlationnormr   r   r   r%     s   &r%   )r   )r   )rf   rg   rh   ri   rj   )rn   )rg   rk   )N)__doc__r   r~   PERIODIC_NEIGHBORSno_gradr   r
   rK   rd   rO   r   r   rp   r%   r   r   r   r   <module>   s.    
0%R
H
[@
5