o
    }oie                     @   s   d dl Z d dlZd dlmZ d dlmZ d dlZd dlmZ d dl	m
Z
 d dlmZmZ d dlmZ d dlmZmZmZmZ d d	lmZmZ d
gZdd ZG dd
 d
e
ZdS )    N)deepcopy)Dict)
DictConfig)ClusteringDiarizer)get_scale_interpolated_embssplit_input_data)OnlineSpeakerClustering)OnlineSegmentoraudio_rttm_mapgenerate_cluster_labelsget_embs_and_timestamps)loggingmodel_utilsOnlineClusteringDiarizerc                    s    fdd}|S )z
    Monitor elapsed time of the corresponding function displaying the method name.

    Args:
        method: function that is being measured

    Return:
        `timed` function for measuring the elapsed time
    c                     st   t   } | i |}t   }d|v r*|d j }t|| d |d |< |S td|| d  jf  |S )Nlog_timelog_namei  z
%2.2fms %r)timeget__name__upperintr   info)argskwargstsresulttenamemethod _/home/ubuntu/.local/lib/python3.10/site-packages/nemo/collections/asr/models/online_diarizer.pytimed0   s   ztimeit.<locals>.timedr    )r   r"   r    r   r!   timeit%   s   r#   c                       s>  e Zd ZdZdef fddZdd Zdd Zd	d
 Zdd Z	dd Z
dd Zdd Zdd ZdefddZedejfddZdedejdedejfddZee dejdejfd d!Zedejd"ejdejfd#d$Ze	%d/d&eeejf dejfd'd(Zdejfd)d*Zed+ejd,ejdejfd-d.Z  ZS )0r   a  
    A class that enables online (streaming) clustering based diarization.

    - The instance created from `OnlineClusteringDiarizer` sets aside a certain amount of memory
      to provide the upcoming inference with history information

    - There are two major modules involved: `OnlineSegmentor` and `OnlineSpeakerClustering`.
        OnlineSegmentor: Take the VAD-timestamps and generate segments for each scale
        OnlineSpeakerClustering: Update the entire speaker labels of the given online session
                                 while updating the speaker labels of the streaming inputs.

    - The overall diarization process is done by calling `diarize_step` function.
      `diarize_step` function goes through the following steps:
        (1) Segmentation (`OnlineSegmentor` class)
        (2) Embedding extraction (`_extract_online_embeddings` function call)
        (3) Online speaker counting and speaker clustering (`OnlineClusteringDiarizer` class)
        (4) Label generation (`generate_cluster_labels` function call)
    cfgc                    s   t  | t|| _| jj| _t| jd 	 | _
| jdd | _| jdd| _t| jjj| _| jj| _td | jj| _tj| jsOt| j tj r^d| _td| _n	d| _td	| _|   | j  d S )
N
scale_dictuniq_iddecimals   r   TcudaFcpu) super__init__r   #convert_model_config_to_dict_configr$   diarizer_cfg_diarizermaxmultiscale_args_dictkeysbase_scale_indexr   r&   r'   r
   manifest_filepathAUDIO_RTTM_MAPsample_ratetorchmanual_seedout_dir_out_dirospathexistsmkdirr)   is_availabledevicereset_speaker_modeleval)selfr$   	__class__r    r!   r,   R   s&   




z!OnlineClusteringDiarizer.__init__c                 C   s@   t |j|j|j|j|j| jd| _|j| _|j| _	| jj| _dS )az  
        Initialize online speaker clustering module

        Attributes:
            online_clus (OnlineSpeakerClustering):
                Online clustering diarizer class instance
            history_n (int):
                History buffer size for saving history of speaker label inference
                Total number of embedding vectors saved in the buffer that is kept till the end of the session
            current_n (int):
                Current buffer (FIFO queue) size for calculating the speaker label inference
                Total number of embedding vectors saved in the FIFO queue for clustering inference
        )max_num_speakersmax_rp_thresholdsparse_search_volumehistory_buffer_sizecurrent_buffer_sizer)   N)
r   rG   rH   rI   rJ   rK   r)   online_clus	history_n	current_nrD   clustering_paramsr    r    r!   _init_online_clustering_modulen   s   z7OnlineClusteringDiarizer._init_online_clustering_modulec                 C   s   t || _dS )z
        Initialize an online segmentor module

        Attributes:
            online_segmentor (OnlineSegmentor):
                online segmentation module that generates short speech segments from the VAD input
        N)r	   online_segmentor)rD   r6   r    r    r!   _init_online_segmentor_module   s   z6OnlineClusteringDiarizer._init_online_segmentor_modulec                 C   sJ   d| _ dd | jd  D | _dd | jd  D | _tg | _dS )a  
        Variables are kept in memory for future updates

        Attributes:
            memory_margin (int):
                The number of embeddings saved in the memory buffer.
                This memory margin is dependent on the base scale length: margin = (buffer_length)/(base scale shift)
                memory margin is automatically calculated to have minimal memory usage
            memory_segment_ranges (dict):
                The segment range information kept in the memory buffer
            memory_segment_indexes (dict):
                The segment indexes kept in the memory buffer
            memory_cluster_labels (Tensor):
                The cluster labels inferred in the previous diarization steps
        r   c                 S      i | ]}|g qS r    r    .0keyr    r    r!   
<dictcomp>       z@OnlineClusteringDiarizer._init_memory_buffer.<locals>.<dictcomp>r%   c                 S   rT   r    r    rU   r    r    r!   rX      rY   N)memory_marginr1   r2   memory_segment_rangesmemory_segment_indexesr7   tensormemory_cluster_labelsrD   r    r    r!   _init_memory_buffer   s   z,OnlineClusteringDiarizer._init_memory_bufferc                 C   s&   | dd| _| dd| _i | _dS )a  
        Variables needed for taking majority votes for speaker labels

        Attributes:
            use_temporal_label_major_vote (bool):
                Boolean for whether to use temporal majority voting
            temporal_label_major_vote_buffer_size (int):
                buffer size for majority voting
            base_scale_label_dict (dict):
                Dictionary containing multiple speaker labels for major voting
                Speaker labels from multiple steps are saved for each segment index.
        use_temporal_label_major_voteF%temporal_label_major_vote_buffer_size   N)r   ra   rb   base_scale_label_dictrO   r    r    r!   "_init_temporal_major_voting_module   s   
z;OnlineClusteringDiarizer._init_temporal_major_voting_modulec                 C   s|   i | _ i | _i | _i | _i | _| jd  D ]%}ddg| j|< t	g | j |< g | j|< g | j|< g | j|< g | j|< qdS )z
        Initialize segment variables for each scale.
        Note that we have `uniq_id` variable in case where multiple sessions are handled.
        r%   N)
emb_vectorstime_stampssegment_range_tssegment_raw_audiosegment_indexesr1   r2   $multiscale_embeddings_and_timestampsr7   r]   )rD   	scale_idxr    r    r!   _init_segment_variables   s   


z0OnlineClusteringDiarizer._init_segment_variablesc                 C   s   d| _ d| _d| _d| _dS )a  
        Timing variables transferred from OnlineDiarWithASR class.
        Buffer is window region where input signal is kept for ASR.
        Frame is window region where the actual inference ASR decoded results are updated

        Example:
            buffer_len = 5.0
            frame_len = 1.0

            |___Buffer___[___________]____________|
            |____________[   Frame   ]____________|

            | <- buffer_start
            |____________| <- frame_start
            |_____________________________________| <- buffer_end

            buffer_start = 12.0
            buffer_end = 17.0
            frame_start = 14.0

        These timestamps and index variables are updated by OnlineDiarWithASR.

        Attributes:
            frame_index (int):
                Integer index of frame window
            frame_start (float):
                The start of the frame window
            buffer_start (float):
                The start of the buffer window
            buffer_end (float):
                The end of the buffer
        r           N)frame_indexframe_startbuffer_start
buffer_endr_   r    r    r!   _init_buffer_frame_timestamps   s   !
z6OnlineClusteringDiarizer._init_buffer_frame_timestampsc                 C   s"   | j | j_ | j| j_| j| j_dS )zI
        Pass the timing information from streaming ASR buffers.
        N)rp   rR   rq   rr   r_   r    r    r!   !_transfer_timestamps_to_segmentor   s   

z:OnlineClusteringDiarizer._transfer_timestamps_to_segmentorc                 C   sj   t | j| jd | j d  | _|   | | jjj	 | 
| jj |   | | jjj	 |   dS )z
        Reset all the necessary variables and initialize classes.

        Attributes:
            n_embed_seg_len (int):
                Number of segments needed for 1 second of input time-series signal
        r%   r   N)r   r6   r1   r3   n_embed_seg_lenrm   rQ   r/   
clustering
parametersrS   r$   r`   re   rs   r_   r    r    r!   rA      s   zOnlineClusteringDiarizer.resetrl   c                 C   s   | j d | j d }t| j| j | | _ttt| j| 	 tt| j| j 	  | j
| j  }|| j }| j| | d | j|< | j| | d | j|< | j| | d | j|< | j| | d | j|< dS )a  
        Calculate how many segments should be removed from memory (`memory_margin`) and
        save the necessary information.
        `keep_range` determines how many segments and their corresponding embedding, raw audio,
        timestamps in the memory of the online diarizer instance.

        Args:
            scale_idx (int):
                Scale index in integer type
        r%   rc   N)r1   r3   r   rr   rq   rZ   lensetscale_mapping_dicttolistrM   rN   rf   ri   rh   rj   )rD   rl   base_scale_shiftscale_buffer_size
keep_ranger    r    r!   _clear_memory  s   

z&OnlineClusteringDiarizer._clear_memoryreturnc              	   C   s   g }| j | j D ]J}|| jvr| j| g| j|< n't| j| | jkr5| j| d t| j| | jks#| j| | j|  |t	t
| j| d   q|S )aN  
        Take a majority voting for every segment on temporal steps. This feature significantly reduces the error coming
        from unstable speaker counting in the beginning of sessions.

        Returns:
            maj_vote_labels (list):
                List containing the major-voted speaker labels on temporal domain
        r   )r\   r3   rd   r^   rx   rb   popappendr7   moder]   item)rD   maj_vote_labelsseg_idxr    r    r!   _temporal_label_major_vote%  s   

&z3OnlineClusteringDiarizer._temporal_label_major_votetotal_cluster_labels	is_onlinec                 C   s  |  }|s%t| j| | j|< t| j| | j|< || jkr$t|| _n| j| d | j| d krtt| j| | j	 d}t
| j| }t
||kd d }t| j| |d | j| |d< t| j| |d | j| |d< || jkrt||d | j|d< t| jt| j| krtdt| j dt| j|  d| | t| j| t| j|   krt| j|   krt| j| ksn tdt| j|  dt| j|  dt| j|  d	t| j|  d
	| jr|  }|S | j}|S )a   
        Save the temporary input to the class memory buffer.

        - Clustering is done for (hist_N + curr_N) number of embeddings.
        - Thus, we need to remove the clustering results on the embedding memory.
        - If self.diar.history_buffer_seg_end is not None, that indicates streaming diarization system
          is starting to save embeddings to its memory. Thus, the new incoming clustering label should be separated.
        - If `is_online = True`, old embeddings outside the window are removed to save GPU memory.

        Args:
            scale_idx (int):
                Scale index in integer
            total_cluster_labels (Tensor):
                The speaker labels from the beginning of the session to the current position
            is_online (bool)
                Boolean variable that indicates whether the system is currently in online mode or not

        Returns:
            cluster_label_hyp (Tensor):
                Majority voted speaker labels over multiple inferences
        r   Nzlself.memory_cluster_labels and self.memory_segment_ranges should always have the same length, but they have z and .zself.emb_vectors, self.segment_raw_audio, self.segment_indexes, and self.segment_range_ts should always have the same length, but they have z, z, and z, respectively.)r{   r   rh   r[   rj   r\   r3   r^   r0   rZ   r7   r]   whererx   
ValueErrorr   rf   ri   ra   r   )rD   rl   r   r   global_stt_idxsegment_indexes_matbuffer_stt_idxcluster_label_hypr    r    r!   save_history_data;  sf   



z*OnlineClusteringDiarizer.save_history_dataaudio_signalc                    sX   t |  j}t  fddt|jd D  j} jj	||d\}}|S )a  
        Call `forward` function of the speaker embedding model.

        Args:
            audio_signal (Tensor):
                Torch tensor containing time-series signal

        Returns:
            Speaker embedding vectors for the given time-series input `audio_signal`.
        c                    s   g | ]} j qS r    )ru   )rV   kr_   r    r!   
<listcomp>  rY   zEOnlineClusteringDiarizer._run_embedding_extractor.<locals>.<listcomp>r   )input_signalinput_signal_length)
r7   stackfloattor@   r]   rangeshaperB   forward)rD   r   audio_signal_lens_
torch_embsr    r_   r!   _run_embedding_extractor  s   *z1OnlineClusteringDiarizer._run_embedding_extractorsegment_rangesc                 C   s   |du rdn|j d }t|}||kr:| ||| }|du s'|j d dkr*|}nt|d|ddf |f}n||k rF|dt| }t||j d krStd|S )aa  
        Incrementally extract speaker embeddings based on `audio_signal` and `segment_ranges` variables.
        Unlike offline speaker diarization, speaker embedding and subsegment ranges are not saved to disk.
        Measures the mismatch between `segment_ranges` and `embeddings` then extract the necessary amount of
        speaker embeddings.

        Args:
            audio_signal (Tensor):
                Torch tensor containing time-series audio signal
            embeddings (Tensor):
                Previously existing Torch tensor containing speaker embedding vector
            segment_ranges(Tensor):
                Torch tensor containing the start and end of each segment

        Returns:
            embeddings (Tensor):
                Concatenated speaker embedding vectors that match segment range information in `segment_ranges`.
        Nr   z2Segment ranges and embeddings shapes do not match.)r   rx   r   r7   vstackr   )rD   r   r   
embeddingsstt_idxend_idxr   r    r    r!   _extract_online_embeddings  s    z3OnlineClusteringDiarizer._extract_online_embeddingsFuniq_embs_and_timestampsc                 C   s   |rt dnt d}t|d |d |d d\}}t|d |||d\}| _t | j| j |j}| j	j
||| j|d	}| jd
  D ]\}	\}
}| |	|| j	j}qG|S )a  
        Launch online clustering for `uniq_embs_and_timestamps` input variable.

        Args:
            uniq_embs_and_timestamps (dict):
                Dictionary containing embeddings, timestamps and multiscale weights.
                If uniq_embs_and_timestamps contains only one scale, single scale diarization
                is performed.
            cuda (bool):
                Boolean indicator for cuda usages
        r)   r*   r   
timestampsmultiscale_segment_counts)embeddings_in_scalestimestamps_in_scalesr   multiscale_weights)r   r   r   r@   )curr_embbase_segment_indexesro   r)   r%   )r7   r@   r   r   rz   r]   rj   r3   r   rL   forward_inferro   r1   itemsr   r   )rD   r   r)   r@   r   r   r   r   merged_clus_labelsrl   windowshiftr   r    r    r!   _perform_online_clustering  s&   

z3OnlineClusteringDiarizer._perform_online_clusteringc                 C   sP   t | jdks| jdk rtd| jggdg\}}|S t| j| j | j\}}|S )z
        In case buffer is not filled or there is no speech activity in the input, generate temporary output.

        Returns:
            diar_hyp (Tensor): Speaker labels based on the previously saved segments and speaker labels
        r   rn   )rx   r^   rq   r   total_buffer_in_secsr[   r3   )rD   diar_hypr   r    r    r!   _get_interim_output  s   z,OnlineClusteringDiarizer._get_interim_outputaudio_buffervad_timestampsc              
   C   s  |    | jdk st|dkr|  S | jd  D ]Q\}\}}| jj||| j| | j	| | j
| ||d\}}}|| j|< || j	|< || j
|< | j| j| | j	| | j| d}	|	| j|< | j|	i| j|ig| j|< qt| j| j}
| j|
| j | jd}t| j| j |\}}|S )a  
        A function for a unit diarization step. Each diarization step goes through the following steps:

        1. Segmentation:
            Using `OnlineSegmentor` class, call `run_online_segmentation` method to get the segments.
        2. Embedding Extraction:
            Extract multiscale embeddings from the extracted speech segments.
        3. Online Clustering & Counting
            Perform online speaker clustering by using `OnlineSpeakerClustering` class.
        4. Generate speaker labels:
            Generate start and end timestamps of speaker labels based on the diarization results.

        c.f.) Also see method `diarize` in `ClusteringDiarizer` class.

        Args:
            audio_buffer (Tensor):
                Tensor variable containing the time series signal at the current frame
                Dimensions: (Number of audio time-series samples) x 1
            vad_timestamps (Tensor):
                List containing VAD timestamps.
                Dimensions: (Number of segments) x 2
                Example:
                    >>> vad_timestamps = torch.Tensor([[0.05, 2.52], [3.12, 6.85]])

        Returns:
            diar_hyp (Tensor):
                Speaker label hypothesis from the start of the session to the current position
        r   r%   )r   r   ri   rh   rj   r   r   )r   r   r   )r)   )rt   rq   rx   r   r1   r   rR   run_online_segmentationri   rh   rj   r   rf   r&   rk   r   r   r)   r   r[   r3   )rD   r   r   rl   r   r   
audio_sigsr   
range_indsr   embs_and_timestampsr   r   r   r    r    r!   diarize_step  s>   
	


z%OnlineClusteringDiarizer.diarize_step)F)r   
__module____qualname____doc__r   r,   rQ   rS   r`   re   rm   rs   rt   rA   r   r   r#   r7   Tensorr   boolr   no_gradr   r   r   strr   r   r   __classcell__r    r    rE   r!   r   >   sH    
&J&(&)r;   r   copyr   typingr   r7   	omegaconfr   nemo.collections.asr.modelsr   3nemo.collections.asr.parts.utils.offline_clusteringr   r   2nemo.collections.asr.parts.utils.online_clusteringr   .nemo.collections.asr.parts.utils.speaker_utilsr	   r
   r   r   
nemo.utilsr   r   __all__r#   r   r    r    r    r!   <module>   s   