o
    piNr                     @   s  d dl Z d dlmZ d dlmZ d dlmZmZmZ d dl	Z
d dlZd dlm  mZ d dlm  mZ d dlmZ d dlmZ d dlmZ d dlmZmZmZ d d	lmZ d d
lm Z  d dl!m"Z"m#Z# z
d dl$m%Z& dZ'W n e(yy   dZ'Y nw z
d dl)m*Z+ dZ,W n e(y   dZ,Y nw zd dl-Z.dZ/W n e(y   dZ/Y nw G dd deZ0G dd deZ1G dd deZ2G dd deZ3			d)de"deej4 deedf deeedf fddZ5G dd deZ6	 	!	"	d*d#e7d$e7de7d%ee7 fd&d'Z8e9d(krd dl:Z:e:;e8 dS dS )+    N)cached_property)Path)OptionalTextUnion)hf_hub_download)RepositoryNotFoundError)pad_sequence)	InferenceModelPipeline)BaseInference)	AudioFile)PipelineModel	get_model)EncoderClassifierTF)EncDecSpeakerLabelModelc                       s   e Zd Z		ddedeej f fddZdejfddZe	d	e
fd
dZe	d	e
fddZe	d	efddZe	d	e
fddZ	ddejdeej d	ejfddZ  ZS )NeMoPretrainedSpeakerEmbedding+nvidia/speakerverification_en_titanet_largeN	embeddingdevicec                    s^   t s
td| dt   || _|ptd| _t| j| _	| j	
  | j	| j d S )Nz!'NeMo' must be installed to use 'zQ' embeddings. Visit https://nvidia.github.io/NeMo/ for installation instructions.cpu)NEMO_IS_AVAILABLEImportErrorsuper__init__r   torchr   NeMo_EncDecSpeakerLabelModelfrom_pretrainedmodel_freezeto)selfr   r   	__class__ a/home/ubuntu/.local/lib/python3.10/site-packages/pyannote/audio/pipelines/speaker_verification.pyr   B   s   


z'NeMoPretrainedSpeakerEmbedding.__init__c                 C   8   t |tjstdt|j d| j| || _| S N5`device` must be an instance of `torch.device`, got ``
isinstancer   r   	TypeErrortype__name__r   r!   r"   r   r%   r%   r&   r!   U      z!NeMoPretrainedSpeakerEmbedding.toreturnc                 C   s   | j jjddS )Nsample_rate>  )r   _cfgtrain_dsgetr"   r%   r%   r&   r3   _   s   z*NeMoPretrainedSpeakerEmbedding.sample_ratec                 C   sL   t d| j| j}t | jg| j}| j||d\}}|j\}}|S )N   input_signalinput_signal_length)r   randr3   r!   r   tensorr   shape)r"   r;   r<   _
embeddings	dimensionr%   r%   r&   rB   c   s   

z(NeMoPretrainedSpeakerEmbedding.dimensionc                 C      dS Ncosiner%   r8   r%   r%   r&   metricm      z%NeMoPretrainedSpeakerEmbedding.metricc                 C   s   dt d| j }}|| d }|d |k rMztd|| j}t|g| j}| j||d}|}W n ty@   |}Y nw || d }|d |k s|S )N         ?r9   r:   )	roundr3   r   r=   r!   r   r>   r   RuntimeError)r"   loweruppermiddler;   r<   r@   r%   r%   r&   min_num_samplesq   s    z.NeMoPretrainedSpeakerEmbedding.min_num_samples	waveformsmasksc                 C   s2  |j \}}}|dksJ |jdd}|du r'|jdd}|j d t| }n3|j \}}	||ks2J tj|jdd|ddjdd}
|
dk}
tdd t||
D d	d
}|
j	dd}|
 }|| jk rntjt|| jf S || jk }|||< | j|| j|| jd\}	}|  }tj||  < |S )   

        Parameters
        ----------
        waveforms : (batch_size, num_channels, num_samples)
            Only num_channels == 1 is supported.
        masks : (batch_size, num_samples), optional

        Returns
        -------
        embeddings : (batch_size, dimension)

        r9   dimNnearestsizemoderI   c                 S   s   g | ]\}}|| qS r%   r%   .0waveformimaskr%   r%   r&   
<listcomp>   s    z;NeMoPretrainedSpeakerEmbedding.__call__.<locals>.<listcomp>Tbatch_firstr:   )r?   squeezer   onesFinterpolate	unsqueezer	   zipsummaxrO   npnanzerosrB   r   r!   r   r   numpyr"   rP   rQ   
batch_sizenum_channelsnum_samplessignalswav_lensbatch_size_masksr@   imasksmax_len	too_shortrA   r%   r%   r&   __call__   s>   





z'NeMoPretrainedSpeakerEmbedding.__call__)r   NN)r/   
__module____qualname__r   r   r   r   r   r!   r   intr3   rB   strrF   rO   Tensorrh   ndarrayrv   __classcell__r%   r%   r#   r&   r   A   s2    
	r   c                       s   e Zd ZdZ				ddedeej deedf dee	edf f fdd	Z
dejfd
dZedefddZedefddZedefddZedefddZ	ddejdeej dejfddZ  ZS )%SpeechBrainPretrainedSpeakerEmbeddinga  Pretrained SpeechBrain speaker embedding

    Parameters
    ----------
    embedding : str
        Name of SpeechBrain model
    device : torch.device, optional
        Device
    token : str or bool, optional
        Huggingface token to be used for downloading from Huggingface hub.
    cache_dir: Path or str, optional
        Path to the folder where files downloaded from Huggingface hub are stored.

    Usage
    -----
    >>> get_embedding = SpeechBrainPretrainedSpeakerEmbedding("speechbrain/spkrec-ecapa-voxceleb")
    >>> assert waveforms.ndim == 3
    >>> batch_size, num_channels, num_samples = waveforms.shape
    >>> assert num_channels == 1
    >>> embeddings = get_embedding(waveforms)
    >>> assert embeddings.ndim == 2
    >>> assert embeddings.shape[0] == batch_size

    >>> assert binary_masks.ndim == 1
    >>> assert binary_masks.shape[0] == batch_size
    >>> embeddings = get_embedding(waveforms, masks=binary_masks)
    !speechbrain/spkrec-ecapa-voxcelebNr   r   token	cache_dirc                    s   t s
td| dt   d|v r$|dd | _|dd | _n|| _d | _|p0td| _|| _	|| _
tj| j| j
 dd| ji| j	| j
| jd	| _d S )
Nz('speechbrain' must be installed to use 'zP' embeddings. Visit https://speechbrain.github.io for installation instructions.@r   r9   r   /speechbrainr   sourcesavedirrun_optsr   huggingface_cache_dirrevision)SPEECHBRAIN_IS_AVAILABLEr   r   r   splitr   r   r   r   r   r   SpeechBrain_EncoderClassifierfrom_hparamsclassifier_r"   r   r   r   r   r#   r%   r&   r      s*   


z.SpeechBrainPretrainedSpeakerEmbedding.__init__c                 C   sX   t |tjstdt|j dtj| j| j	 dd|i| j
| j	| jd| _|| _| S )Nr)   r*   r   r   r   )r,   r   r   r-   r.   r/   r   r   r   r   r   r   r   r0   r%   r%   r&   r!     s   
z(SpeechBrainPretrainedSpeakerEmbedding.tor2   c                 C   
   | j jjS rw   )r   audio_normalizerr3   r8   r%   r%   r&   r3        
z1SpeechBrainPretrainedSpeakerEmbedding.sample_ratec                 C   s,   t dd| j}| j|j^ }}|S )Nr9   r4   )r   r=   r!   r   r   encode_batchr?   )r"   dummy_waveformsr@   rB   r%   r%   r&   rB     s   z/SpeechBrainPretrainedSpeakerEmbedding.dimensionc                 C   rC   rD   r%   r8   r%   r%   r&   rF   #  rG   z,SpeechBrainPretrainedSpeakerEmbedding.metricc              	   C   s   t  Q dtd| j }}|| d }|d |k rMz| jt d|| j}|}W n t	y8   |}Y nw || d }|d |k sW d    |S W d    |S 1 sXw   Y  |S NrH   rI   r9   )
r   inference_moderJ   r3   r   r   randnr!   r   rK   r"   rL   rM   rN   r@   r%   r%   r&   rO   '  s*   


z5SpeechBrainPretrainedSpeakerEmbedding.min_num_samplesrP   rQ   c                 C   s,  |j \}}}|dksJ |jdd}|du r'|jdd}|j d t| }n3|j \}}	||ks2J tj|jdd|ddjdd}
|
dk}
tdd t||
D d	d
}|
j	dd}|
 }|| jk rntjt|| jf S || jk }|| }d||< | jj||djdd  }tj||  < |S )rR   r9   rS   NrU   rV   rI   c                 S   s   g | ]
\}}||   qS r%   )
contiguousrY   r%   r%   r&   r]   c  s    
zBSpeechBrainPretrainedSpeakerEmbedding.__call__.<locals>.<listcomp>Tr^   g      ?)rq   )r?   r`   r   ra   rb   rc   rd   r	   re   rf   rg   rO   rh   ri   rj   rB   r   r   r   rk   rl   r%   r%   r&   rv   9  sF   


z.SpeechBrainPretrainedSpeakerEmbedding.__call__)r   NNNrw   )r/   rx   ry   __doc__r   r   r   r   r   r   r   r!   r   rz   r3   rB   r{   rF   rO   r|   rh   r}   rv   r~   r%   r%   r#   r&   r      s@    
!r   c                       s  e Zd ZdZ				d%dedeej deedf dee	edf f fdd	Z
dejfd
dZedefddZedefddZedefddZedefddZedefddZ				d&dejdededededejfd d!Z	d'dejd"eej dejfd#d$Z  ZS )('ONNXWeSpeakerPretrainedSpeakerEmbeddinga  Pretrained WeSpeaker speaker embedding

    Parameters
    ----------
    embedding : str
        Path to WeSpeaker pretrained speaker embedding
    device : torch.device, optional
        Device
    token : str or bool, optional
        Huggingface token to be used for downloading from Huggingface hub.
    cache_dir: Path or str, optional
        Path to the folder where files downloaded from Huggingface hub are stored.

    Usage
    -----
    >>> get_embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding("hbredin/wespeaker-voxceleb-resnet34-LM")
    >>> assert waveforms.ndim == 3
    >>> batch_size, num_channels, num_samples = waveforms.shape
    >>> assert num_channels == 1
    >>> embeddings = get_embedding(waveforms)
    >>> assert embeddings.ndim == 2
    >>> assert embeddings.shape[0] == batch_size

    >>> assert binary_masks.ndim == 1
    >>> assert binary_masks.shape[0] == batch_size
    >>> embeddings = get_embedding(waveforms, masks=binary_masks)
    &hbredin/wespeaker-voxceleb-resnet34-LMNr   r   r   r   c                    s|   t s
td| dt   t| s/z
t|d||d}W n ty.   td| dw || _	| 
|p:td d S )Nz('onnxruntime' must be installed to use 'z' embeddings.zspeaker-embedding.onnx)repo_idfilenamer   r   zCould not find 'z&' on huggingface.co nor on local disk.r   )ONNX_IS_AVAILABLEr   r   r   r   existsr   r   
ValueErrorr   r!   r   r   r   r#   r%   r&   r     s(   



z0ONNXWeSpeakerPretrainedSpeakerEmbedding.__init__c                 C   s   t |tjstdt|j d|jdkrdg}n|jdkr'dddifg}ntd	|j d
 td}dg}t	 }d|_
d|_tj| j||d| _|| _| S )Nr)   r*   r   CPUExecutionProvidercudaCUDAExecutionProvidercudnn_conv_algo_searchDEFAULTzUnsupported device type: z, falling back to CPUr9   )sess_options	providers)r,   r   r   r-   r.   r/   warningswarnortSessionOptionsinter_op_num_threadsintra_op_num_threadsInferenceSessionr   session_)r"   r   r   r   r%   r%   r&   r!     s2   

	
z*ONNXWeSpeakerPretrainedSpeakerEmbedding.tor2   c                 C   rC   )Nr4   r%   r8   r%   r%   r&   r3     rG   z3ONNXWeSpeakerPretrainedSpeakerEmbedding.sample_ratec                 C   sD   t ddd}| |}| jjdgd| idd }|j\}}|S )Nr9   r4   embsfeatsoutput_names
input_feedr   )r   r=   compute_fbankr   runrk   r?   )r"   r   featuresrA   r@   rB   r%   r%   r&   rB     s   

z1ONNXWeSpeakerPretrainedSpeakerEmbedding.dimensionc                 C   rC   rD   r%   r8   r%   r%   r&   rF     rG   z.ONNXWeSpeakerPretrainedSpeakerEmbedding.metricc                 C   s   dt d| j }}|| d }|d |k r\z| tdd|}W n ty3   |}|| d }Y qw | jjdgd| idd }t	
t	|rN|}n|}|| d }|d |k s|S )NrH   rI   r9   r   r   r   r   )rJ   r3   r   r   r   AssertionErrorr   r   rk   rh   anyisnan)r"   rL   rM   rN   r   rA   r%   r%   r&   rO     s*   z7ONNXWeSpeakerPretrainedSpeakerEmbedding.min_num_samplesc                 C   s   |  tdd| jjd S )Nr9   )r   r   r   rO   r?   r8   r%   r%   r&   min_num_frames  s   z6ONNXWeSpeakerPretrainedSpeakerEmbedding.min_num_framesP      
           rP   num_mel_binsframe_lengthframe_shiftditherc                    s<   |d }t  fdd|D }|t j|ddd S )af  Extract fbank features

        Parameters
        ----------
        waveforms : (batch_size, num_channels, num_samples)

        Returns
        -------
        fbank : (batch_size, num_frames, num_mel_bins)

        Source: https://github.com/wenet-e2e/wespeaker/blob/45941e7cba2c3ea99e232d02bedf617fc71b0dad/wespeaker/bin/infer_onnx.py#L30C1-L50
        i   c                    s(   g | ]}t j| jd ddqS )hammingF)r   r   r   r   sample_frequencywindow_type
use_energy)kaldifbankr3   )rZ   r[   r   r   r   r   r"   r%   r&   r]   '  s    zIONNXWeSpeakerPretrainedSpeakerEmbedding.compute_fbank.<locals>.<listcomp>r9   T)rT   keepdim)r   stackmean)r"   rP   r   r   r   r   r   r%   r   r&   r     s   z5ONNXWeSpeakerPretrainedSpeakerEmbedding.compute_fbankrQ   c                 C   s   |j \}}}|dksJ | || j}|j \}}}|du r2| jjdgd|jddidd }	|	S |j \}
}||
ks=J tj|j	dd	|d
dj
dd	}|dk}tjt|| jf }	tt||D ]*\}\}}|| }|j d | jk rvqc| jjdgd|jddd idd d |	|< qc|	S )rR   r9   Nr   r   T)forcer   r   rS   rU   rV   rI   )r?   r   r!   r   r   r   rk   rb   rc   rd   r`   rh   ri   rj   rB   	enumeratere   r   )r"   rP   rQ   rm   rn   ro   r   r@   
num_framesrA   rr   rs   ffeaturer\   masked_featurer%   r%   r&   rv   8  sB   

z0ONNXWeSpeakerPretrainedSpeakerEmbedding.__call__)r   NNN)r   r   r   r   rw   )r/   rx   ry   r   r   r   r   r   r   r   r   r!   r   rz   r3   rB   r{   rF   rO   r   r|   floatr   rh   r}   rv   r~   r%   r%   r#   r&   r     sf    
 "	
)r   c                       s   e Zd ZdZ				ddedeej dee	df dee
e	df f fdd	Zdejfd
dZedefddZedefddZedefddZedefddZ	ddejdeej dejfddZ  ZS )'PyannoteAudioPretrainedSpeakerEmbeddinga  Pretrained pyannote.audio speaker embedding

    Parameters
    ----------
    embedding : PipelineModel
        pyannote.audio model
    device : torch.device, optional
        Device
    token : str or bool, optional
        Huggingface token to be used for downloading from Huggingface hub.
    cache_dir: Path or str, optional
        Path to the folder where files downloaded from Huggingface hub are stored.

    Usage
    -----
    >>> get_embedding = PyannoteAudioPretrainedSpeakerEmbedding("pyannote/embedding")
    >>> assert waveforms.ndim == 3
    >>> batch_size, num_channels, num_samples = waveforms.shape
    >>> assert num_channels == 1
    >>> embeddings = get_embedding(waveforms)
    >>> assert embeddings.ndim == 2
    >>> assert embeddings.shape[0] == batch_size

    >>> assert masks.ndim == 1
    >>> assert masks.shape[0] == batch_size
    >>> embeddings = get_embedding(waveforms, masks=masks)
    pyannote/embeddingNr   r   r   r   c                    sN   t    || _|ptd| _t| j||d| _| j  | j| j d S )Nr   r   r   )	r   r   r   r   r   r   r   evalr!   r   r#   r%   r&   r     s   

z0PyannoteAudioPretrainedSpeakerEmbedding.__init__c                 C   r'   r(   r+   r0   r%   r%   r&   r!     r1   z*PyannoteAudioPretrainedSpeakerEmbedding.tor2   c                 C   r   rw   )r   audior3   r8   r%   r%   r&   r3     r   z3PyannoteAudioPretrainedSpeakerEmbedding.sample_ratec                 C   s   | j jS rw   )r   rB   r8   r%   r%   r&   rB     s   z1PyannoteAudioPretrainedSpeakerEmbedding.dimensionc                 C   rC   rD   r%   r8   r%   r%   r&   rF     rG   z.PyannoteAudioPretrainedSpeakerEmbedding.metricc              	   C   s   t  Q dtd| j }}|| d }|d |k rMz| t dd|| j}|}W n ty8   |}Y nw || d }|d |k sW d    |S W d    |S 1 sXw   Y  |S r   )	r   r   rJ   r3   r   r   r!   r   	Exceptionr   r%   r%   r&   rO     s&   


z7PyannoteAudioPretrainedSpeakerEmbedding.min_num_samplesrP   rQ   c              	   C   s   t  > |d u r| || j}n(t  td | j|| j|| jd}W d    n1 s6w   Y  W d    n1 sEw   Y  | 	 S )Nignoreweights)
r   r   r   r!   r   r   catch_warningssimplefilterr   rk   )r"   rP   rQ   rA   r%   r%   r&   rv     s   


	z0PyannoteAudioPretrainedSpeakerEmbedding.__call__r   NNNrw   )r/   rx   ry   r   r   r   r   r   r   r   r   r   r!   r   rz   r3   rB   r{   rF   rO   r|   rh   r}   rv   r~   r%   r%   r#   r&   r   n  s@    

r   r   r   r   r   c                 C   s   t | trd| v rt| |||dS t | tr"d| v r"t| |||dS t | tr1d| v r1t| |dS t | trBd| v rBt| |||dS t| |||dS )a~  Pretrained speaker embedding

    Parameters
    ----------
    embedding : Text
        Can be a SpeechBrain (e.g. "speechbrain/spkrec-ecapa-voxceleb")
        or a pyannote.audio model.
    device : torch.device, optional
        Device
    token : str or bool, optional
        Huggingface token to be used for downloading from Huggingface hub.
    cache_dir: Path or str, optional
        Path to the folder where files downloaded from Huggingface hub are stored.

    Usage
    -----
    >>> get_embedding = PretrainedSpeakerEmbedding("pyannote/embedding")
    >>> get_embedding = PretrainedSpeakerEmbedding("speechbrain/spkrec-ecapa-voxceleb")
    >>> get_embedding = PretrainedSpeakerEmbedding("nvidia/speakerverification_en_titanet_large")
    >>> assert waveforms.ndim == 3
    >>> batch_size, num_channels, num_samples = waveforms.shape
    >>> assert num_channels == 1
    >>> embeddings = get_embedding(waveforms)
    >>> assert embeddings.ndim == 2
    >>> assert embeddings.shape[0] == batch_size

    >>> assert masks.ndim == 1
    >>> assert masks.shape[0] == batch_size
    >>> embeddings = get_embedding(waveforms, masks=masks)
    pyannote)r   r   r   speechbrainnvidia)r   	wespeaker)r,   r{   r   r   r   r   )r   r   r   r   r%   r%   r&   PretrainedSpeakerEmbedding  s"   %r   c                       sf   e Zd ZdZ				ddedee deedf deeedf f fdd	Z	d
e
dejfddZ  ZS )SpeakerEmbeddinga  Speaker embedding pipeline

    This pipeline assumes that each file contains exactly one speaker
    and extracts one single embedding from the whole file.

    Parameters
    ----------
    embedding : Model, str, or dict, optional
        Pretrained embedding model. Defaults to "pyannote/embedding".
        See pyannote.audio.pipelines.utils.get_model for supported format.
    segmentation : Model, str, or dict, optional
        Pretrained segmentation (or voice activity detection) model.
        See pyannote.audio.pipelines.utils.get_model for supported format.
        Defaults to no voice activity detection.
    token : str or bool, optional
        Huggingface token to be used for downloading from Huggingface hub.
    cache_dir: Path or str, optional
        Path to the folder where files downloaded from Huggingface hub are stored.

    Usage
    -----
    >>> from pyannote.audio.pipelines import SpeakerEmbedding
    >>> pipeline = SpeakerEmbedding()
    >>> emb1 = pipeline("speaker1.wav")
    >>> emb2 = pipeline("speaker2.wav")
    >>> from scipy.spatial.distance import cdist
    >>> distance = cdist(emb1, emb2, metric="cosine")[0,0]
    r   Nr   segmentationr   r   c                    sZ   t    || _|| _t|||d| _| jd ur+t| j||d}t|dd d| _d S d S )Nr   c                 S   s   t j| dddS )NT)axiskeepdims)rh   rg   )scoresr%   r%   r&   <lambda>A  s    z+SpeakerEmbedding.__init__.<locals>.<lambda>)pre_aggregation_hook)r   r   r   r   r   embedding_model_r
   _segmentation)r"   r   r   r   r   segmentation_modelr#   r%   r&   r   +  s   

zSpeakerEmbedding.__init__filer2   c                 C   s   | j j}| j |d d  |}| jd u rd }n| |j}d|t|< t	
|d d d d df |}t	  | j ||d  W  d    S 1 sQw   Y  d S )Nr   r      r   )r   r   r   r!   r   r   datarh   r   r   
from_numpyno_gradr   rk   )r"   r   r   r[   r   r%   r%   r&   applyF  s   
"
$zSpeakerEmbedding.applyr   )r/   rx   ry   r   r   r   r   r   r   r   r   rh   r}   r   r~   r%   r%   r#   r&   r     s     
r   &VoxCeleb.SpeakerVerification.VoxCeleb1testr   protocolsubsetr   c                 C   sT  dd l }ddlm}m} ddlm} ddlm} ddlm}	 t	||d}
|| d| id} g g }}t
 }t| | d	 }t|	|D ]=\}}|d
 d }||vrX|
|||< |d d }||vrh|
|||< |||| || ddd d  ||d  qD||t|dd\}}}}|| j d| d| d| dd| dd
 d S )Nr   )
FileFinderget_protocol)	det_curve)cdist)tqdm)r   r   r   )preprocessors_trialfile1file2rE   )rF   	referenceT)	distancesz | z	 | EER = d   z.3f%)typerpyannote.databaser   r   &pyannote.metrics.binary_classificationr   scipy.spatial.distancer   r  r   dictgetattrr   appendrh   arrayechoname)r   r   r   r   r
  r   r   r   r   r  pipeliney_truey_predembtrialsttrialaudio1audio2r@   eerr%   r%   r&   mainZ  s.   
$(r  __main__)NNN)r   r   r   N)<r   	functoolsr   pathlibr   typingr   r   r   rk   rh   r   torch.nn.functionalnn
functionalrb   torchaudio.compliance.kaldi
compliancer   huggingface_hubr   huggingface_hub.utilsr   torch.nn.utils.rnnr	   pyannote.audior
   r   r   pyannote.audio.core.inferencer   pyannote.audio.core.ior   pyannote.audio.pipelines.utilsr   r   speechbrain.inferencer   r   r   r   nemo.collections.asr.modelsr   r   r   onnxruntimer   r   r   r   r   r   r   r   r   r{   r  r/   r
  r   r%   r%   r%   r&   <module>   s    
 9 mc

>N

(