o
    7wih                     @   s   d Z ddlmZ ddlmZ ddlmZmZmZ ddl	Z
ddlZddlmZmZ ddlmZ zddlmZmZmZmZmZ W n eyM   ed	 ed
w eeZG dd deZG dd dZdS )u   Perform CTC segmentation to align utterances within audio files.

This uses the ctc-segmentation Python package.
Install it with pip or see the installing instructions in
https://github.com/lumaku/ctc-segmentation

Authors
 * Ludwig Kürzinger 2021
    )Path)SimpleNamespace)ListOptionalUnionN)
EncoderASREncoderDecoderASR)
get_logger)CtcSegmentationParametersctc_segmentationdetermine_utterance_segmentsprepare_textprepare_token_listzMImportError: Is the ctc_segmentation module installed and in your PYTHONPATH?z'The ctc_segmentation module is missing.c                   @   sX   e Zd ZdZdZdZdZdZdZdZ	dZ
dZdZdZdZdZdZdZdd Zdd	 ZdS )
CTCSegmentationTasku  Task object for CTC segmentation.

    This object is automatically generated and acts as
    a container for results of a CTCSegmentation object.

    When formatted with str(·), this object returns
    results in a kaldi-style segments file formatting.
    The human-readable output can be configured with
    the printing options.

    Attributes
    ----------
    text : list
        Utterance texts, separated by line. But without the utterance
            name at the beginning of the line (as in kaldi-style text).
    ground_truth_mat : array
        Ground truth matrix (CTC segmentation).
    utt_begin_indices : np.ndarray
        Utterance separator for the Ground truth matrix.
    timings : np.ndarray
        Time marks of the corresponding chars.
    state_list : list
        Estimated alignment of chars/tokens.
    segments : list
        Calculated segments as: (start, end, confidence score).
    config : CtcSegmentationParameters
        CTC Segmentation configuration object.
    name : str
        Name of aligned audio file (Optional). If given, name is
        considered when generating the text.
        Default: "utt".
    utt_ids : list
        The list of utterance names (Optional). This list should
        have the same length as the number of utterances.
    lpz : np.ndarray
        CTC posterior log probabilities (Optional).
    print_confidence_score : bool
        Include the confidence score.
        Default: True.
    print_utterance_text : bool
        Include utterance text.
        Default: True.

    NFuttTc                 K   s   | j | dS )zUpdate object attributes.N)__dict__update)selfkwargs r   c/home/ubuntu/sommelier/.venv/lib/python3.10/site-packages/speechbrain/alignment/ctc_segmentation.pysetj   s   zCTCSegmentationTask.setc                    s   d}t  j} jdu r fddt|D }n|t  jks!J  j}t jD ]>\}}||  d j d}||d dd|d d7 } jrT|d|d	 d
7 } jra|d j|  7 }||d 7 }q)|S )z0Return a kaldi-style ``segments`` file (string). Nc                    s   g | ]} j  d |dqS )_04)name.0ir   r   r   
<listcomp>s   s    z/CTCSegmentationTask.__str__.<locals>.<listcomp> r   z.2f      z3.4f
)	lensegmentsutt_idsrange	enumerater   print_confidence_scoreprint_utterance_texttext)r   outputnum_utts	utt_namesr   boundary	utt_entryr   r   r   __str__n   s   

zCTCSegmentationTask.__str__)__name__
__module____qualname____doc__r,   ground_truth_matutt_begin_indicestimings
char_probs
state_listr&   configdoner   r'   lpzr*   r+   r   r2   r   r   r   r   r   +   s$    -r   c                   @   sz  e Zd ZdZdZdZdZdZddgZdZ	ddgZ
d	Ze Z			d.d
eeef dededefddZ											d/dee dee dee dee dee dee dee dee dee dee dee fddZd0ddZd1ddZe deejejf fd d!Zd"d# Z d0d$d%Z!e"d&e#fd'd(Z$	d2deejejee%f d)ee&e ef d*ee d+e#fd,d-Z'dS )3CTCSegmentationu  Align text to audio using CTC segmentation.

    Usage
    -----
    Initialize with given ASR model and parameters.
    If needed, parameters for CTC segmentation can be set with ``set_config(·)``.
    Then call the instance as function to align text within an audio file.

    Arguments
    ---------
    asr_model : EncoderDecoderASR
        Speechbrain ASR interface. This requires a model that has a
        trained CTC layer for inference. It is better to use a model with
        single-character tokens to get a better time resolution.
        Please note that the inference complexity with Transformer models
        usually increases quadratically with audio length.
        It is therefore recommended to use RNN-based models, if available.
    kaldi_style_text : bool
        A kaldi-style text file includes the name of the
        utterance at the start of the line. If True, the utterance name
        is expected as first word at each line. If False, utterance
        names are automatically generated. Set this option according to
        your input data. Default: True.
    text_converter : str
        How CTC segmentation handles text.
        "tokenize": Use the ASR model tokenizer to tokenize the text.
        "classic": The text is preprocessed as text pieces which takes
        token length into account. If the ASR model has longer tokens,
        this option may yield better results. Default: "tokenize".
    time_stamps : str
        Choose the method how the time stamps are
        calculated. While "fixed" and "auto" use both the sample rate,
        the ratio of samples to one frame is either automatically
        determined for each inference or fixed at a certain ratio that
        is initially determined by the module, but can be changed via
        the parameter ``samples_to_frames_ratio``. Recommended for
        longer audio files: "auto".
    **ctc_segmentation_args
        Parameters for CTC segmentation.
        The full list of parameters is found in ``set_config``.

    Example
    -------
        >>> # using example file included in the SpeechBrain repository
        >>> from speechbrain.inference.ASR import EncoderDecoderASR
        >>> from speechbrain.alignment.ctc_segmentation import CTCSegmentation
        >>> # load an ASR model
        >>> pre_trained = "speechbrain/asr-transformer-transformerlm-librispeech"
        >>> asr_model = EncoderDecoderASR.from_hparams(source=pre_trained)  # doctest: +SKIP
        >>> aligner = CTCSegmentation(asr_model, kaldi_style_text=False)  # doctest: +SKIP
        >>> # load data
        >>> audio_path = "tests/samples/single-mic/example1.wav"
        >>> text = ["THE BIRCH CANOE", "SLID ON THE", "SMOOTH PLANKS"]
        >>> segments = aligner(audio_path, text, name="example1")  # doctest: +SKIP

    On multiprocessing
    ------------------
    To parallelize the computation with multiprocessing, these three steps
    can be separated:
    (1) ``get_lpz``: obtain the lpz,
    (2) ``prepare_segmentation_task``: prepare the task, and
    (3) ``get_segments``: perform CTC segmentation.
    Note that the function `get_segments` is a staticmethod and therefore
    independent of an already initialized CTCSegmentation object.

    References
    ----------
    CTC-Segmentation of Large Corpora for German End-to-end Speech Recognition
    2020, Kürzinger, Winkelbauer, Li, Watzel, Rigoll
    https://arxiv.org/abs/2007.09127

    More parameters are described in https://github.com/lumaku/ctc-segmentation
    i>  TNautofixedtokenizeclassicF	asr_modelkaldi_style_texttext_convertertime_stampsc           	         s  t  trt dr.t jdr.t jjdr.t  tr2t dr.t jdr.t jjds2tdt ds;td _jj	_
t  trptjjd	sStd
djjjjvr_tddtjdtjffdd}|_njjj_jj_jdjjj|||d|  fddt j D }|j_tdd |D }t|dkr|dkrtdt| d| d d S d S d S )Nmodsdecoder
ctc_weightencoderctc_linz&The given asr_model has no CTC module!	tokenizerz<The given asr_model has no tokenizer in asr_model.tokenizer!scorerz:``ScorerBuilder`` module is required for CTC segmentation.ctcz6``CTCScorer`` module is required for CTC segmentation.xreturnc                    s(    j jjjd }|| }||}|S )zForward step for CTC module.rO   )rD   hparamsrN   full_scorersctc_fcsoftmax)rP   modulelogits	log_probsr   r   r   ctc_forward_step  s   

z2CTCSegmentation.__init__.<locals>.ctc_forward_step)fsrG   rE   rF   c                    s   g | ]} j |qS r   )rM   id_to_piecer   )rD   r   r   r      s    
z,CTCSegmentation.__init__.<locals>.<listcomp>c                 S   s   g | ]}t |qS r   r%   )r   cr   r   r   r    %      i     zThe dictionary has z tokens with a max length of z>. This may lead to low alignment performance and low accuracy.r   )
isinstancer   hasattrrH   rI   r   rK   AttributeErrorrD   encode_batch_encoderR   rN   rS   torchTensor_ctclog_softmaxrM   
_tokenizer
set_configsample_rater(   
vocab_sizer<   	char_listmaxr%   loggerwarning)	r   rD   rE   rF   rG   ctc_segmentation_argsrY   rm   max_char_lenr   )rD   r   r   __init__   sp   








	zCTCSegmentation.__init__rZ   samples_to_frames_ratio	set_blankreplace_spaces_with_blanksgratis_blankmin_window_sizemax_window_sizescoring_lengthc                 C   s4  |dur|| j vrtdt| j  || _|durt|| _|dur(t|| _|dur2t|| j_	|dur<t
|| j_|durEt
|| _|dur[|| jvrXtdt| j || _|	duret|	| j_|
durot|
| j_|duryt
|| j_| jjr| jjr| jstd d| _|durt|| j_dS dS )a	  Set CTC segmentation parameters.

        Parameters for timing
        ---------------------
        time_stamps : str
            Select method how CTC index duration is estimated, and
            thus how the time stamps are calculated.
        fs : int
            Sample rate. Usually derived from ASR model; use this parameter
            to overwrite the setting.
        samples_to_frames_ratio : float
            If you want to directly determine the
            ratio of samples to CTC frames, set this parameter, and
            set ``time_stamps`` to "fixed".
            Note: If you want to calculate the time stamps from a model
            with fixed subsampling, set this parameter to:
            ``subsampling_factor * frame_duration / 1000``.

        Parameters for text preparation
        -------------------------------
        set_blank : int
            Index of blank in token list. Default: 0.
        replace_spaces_with_blanks : bool
            Inserts blanks between words, which is
            useful for handling long pauses between words. Only used in
            ``text_converter="classic"`` preprocessing mode. Default: False.
        kaldi_style_text : bool
            Determines whether the utterance name is expected
            as fist word of the utterance. Set at module initialization.
        text_converter : str
            How CTC segmentation handles text.
            Set at module initialization.

        Parameters for alignment
        ------------------------
        min_window_size : int
            Minimum number of frames considered for a single
            utterance. The current default value of 8000 corresponds to
            roughly 4 minutes (depending on ASR model) and should be OK in
            most cases. If your utterances are further apart, increase
            this value, or decrease it for smaller audio files.
        max_window_size : int
            Maximum window size. It should not be necessary
            to change this value.
        gratis_blank : bool
            If True, the transition cost of blank is set to zero.
            Useful for long preambles or if there are large unrelated segments
            between utterances. Default: False.

        Parameters for calculation of confidence score
        ----------------------------------------------
        scoring_length : int
            Block length to calculate confidence score. The
            default value of 30 should be OK in most cases.
            30 corresponds to roughly 1-2s of audio.
        Nu+   Parameter ´time_stamps´ has to be one of u.   Parameter ´text_converter´ has to be one of zBlanks are inserted between words, and also the transition cost of blank is zero. This configuration may lead to misalignments!T)choices_time_stampsNotImplementedErrorlistrG   floatrZ   rt   intr<   blankboolrv   rE   choices_text_converterrF   rx   ry   blank_transition_cost_zerowarned_about_misconfigurationro   errorscore_min_mean_over_L)r   rG   rZ   rt   ru   rv   rE   rF   rw   rx   ry   rz   r   r   r   rj   -  s^   G




zCTCSegmentation.set_configc                 C   sh   d| j ji}| jdkr| jdu r|  }|| _| j| j }n| jdks%J || }|| j }||d< |S )z+Obtain parameters to determine time stamps.index_durationrA   Nr@   )r<   r   rG   rt    estimate_samples_to_frames_ratiorZ   )r   
speech_lenlpz_len
timing_cfgratior   rt   r   r   r   get_timing_config  s   


z!CTCSegmentation.get_timing_config H c                 C   s*   t |}| |}|jd }|| }|S )a_  Determine the ratio of encoded frames to sample points.

        This method helps to determine the time a single encoded frame occupies.
        As the sample rate already gave the number of samples, only the ratio
        of samples per encoded CTC frame are needed. This function estimates them by
        doing one inference, which is only needed once.

        Arguments
        ---------
        speech_len : int
            Length of randomly generated speech vector for single
            inference. Default: 215040.

        Returns
        -------
        int
            Estimated ratio.
        r   )re   randget_lpzshape)r   r   random_inputr>   r   rt   r   r   r   r     s
   


z0CTCSegmentation.estimate_samples_to_frames_ratiospeechc                 C   sp   t |tjrt|}|d| jj}tdg| jj}| 	||}| 
| }|d  }|S )a/  Obtain CTC posterior log probabilities for given speech data.

        Arguments
        ---------
        speech : Union[torch.Tensor, np.ndarray]
            Speech audio input.

        Returns
        -------
        np.ndarray
            Numpy vector with CTC log posterior probabilities.
        r   g      ?)r`   npndarrayre   tensor	unsqueezetorD   devicerd   rg   detachsqueezecpunumpy)r   r   wav_lensencr>   r   r   r   r     s   
zCTCSegmentation.get_lpzc                 C   sr   d}t |tr| }ttt|}| jr5dd |D }tdd |}t|}dd |D }dd |D }||fS )z/Convert text to list and extract utterance IDs.Nc                 S      g | ]}| d dqS )r!   r"   )splitr   r   r   r   r   r          z/CTCSegmentation._split_text.<locals>.<listcomp>c                 S   s   t | dkS )Nr#   r\   )uir   r   r   <lambda>  s    z-CTCSegmentation._split_text.<locals>.<lambda>c                 S      g | ]}|d  qS )r   r   r   r   r   r   r      r^   c                 S   r   )r"   r   r   r   r   r   r      r^   )r`   str
splitlinesr}   filterr%   rE   )r   r,   r'   utt_ids_and_textr   r   r   _split_text  s   
zCTCSegmentation._split_textc              	      s    j }|dur|jd } ||}|jdi |  |\}} jdkrF fdd|D }	|jdfdd|	D }	t||	\}
}n jdksMJ  fd	d|D }d
d |D }t	||\}
}t
||||
|||d}|S )u	  Preprocess text, and gather text and lpz into a task object.

        Text is pre-processed and tokenized depending on configuration.
        If ``speech_len`` is given, the timing configuration is updated.
        Text, lpz, and configuration is collected in a CTCSegmentationTask
        object. The resulting object can be serialized and passed in a
        multiprocessing computation.

        It is recommended that you normalize the text beforehand, e.g.,
        change numbers into their spoken equivalent word, remove special
        characters, and convert UTF-8 characters to chars corresponding to
        your ASR model dictionary.

        The text is tokenized based on the ``text_converter`` setting:

        The "tokenize" method is more efficient and the easiest for models
        based on latin or cyrillic script that only contain the main chars,
        ["a", "b", ...] or for Japanese or Chinese ASR models with ~3000
        short Kanji / Hanzi tokens.

        The "classic" method improves the the accuracy of the alignments
        for models that contain longer tokens, but with a greater complexity
        for computation. The function scans for partial tokens which may
        improve time resolution.
        For example, the word "▁really" will be broken down into
        ``['▁', '▁r', '▁re', '▁real', '▁really']``. The alignment will be
        based on the most probable activation sequence given by the network.

        Arguments
        ---------
        text : list
            List or multiline-string with utterance ground truths.
        lpz : np.ndarray
            Log CTC posterior probabilities obtained from the CTC-network;
            numpy array shaped as ( <time steps>, <classes> ).
        name : str
            Audio file name that will be included in the segments output.
            Choose a unique name, or the original audio
            file name, to distinguish multiple audio files. Default: None.
        speech_len : int
            Number of sample points. If given, the timing
            configuration is automatically derived from length of fs, length
            of speech and length of lpz. If None is given, make sure the
            timing parameters are correct, see time_stamps for reference!
            Default: None.

        Returns
        -------
        CTCSegmentationTask
            Task object that can be passed to
            ``CTCSegmentation.get_segments()`` in order to obtain alignments.
        Nr   rB   c                    s   g | ]}t  j|qS r   )r   arrayri   encode_as_idsr   r   r   r   r    =      z=CTCSegmentation.prepare_segmentation_task.<locals>.<listcomp><unk>c                    s   g | ]}|| k qS r   r   r   )unkr   r   r    B  r   rC   c                    s   g | ]}d   j|qS )r   )joinri   encode_as_piecesr   r   r   r   r    H  r   c                 S   r   )r   r   )replacer   r   r   r   r    L  r   )r<   r   r,   r7   r8   r'   r>   r   )r<   r   r   r   r   rF   rm   indexr   r   r   )r   r,   r>   r   r   r<   r   r   r'   
token_listr7   r8   text_piecestaskr   )r   r   r   prepare_segmentation_task  sB   5




	z)CTCSegmentation.prepare_segmentation_taskr   c                 C   st   t | tsJ | jdusJ | j}| j}| j}| j}| j}t|||\}}}t|||||}	| j	||||	dd}
|
S )a  Obtain segments for given utterance texts and CTC log posteriors.

        Arguments
        ---------
        task : CTCSegmentationTask
            Task object that contains ground truth and
            CTC posterior probabilities.

        Returns
        -------
        dict
            Dictionary with alignments. Combine this with the task
            object to obtain a human-readable segments representation.
        NT)r   r9   r:   r;   r&   r=   )
r`   r   r<   r>   r7   r8   r,   r   r   r   )r   r<   r>   r7   r8   r,   r9   r:   r;   r&   resultr   r   r   get_segments[  s*   

zCTCSegmentation.get_segmentsr,   r   rQ   c                 C   s^   t |ts
t |tr| j|}| |}| ||||jd }| |}|j	di | |S )uv  Align utterances.

        Arguments
        ---------
        speech : Union[torch.Tensor, np.ndarray, str, Path]
            Audio file that can be given as path or as array.
        text : Union[List[str], str]
            List or multiline-string with utterance ground truths.
            The required formatting depends on the setting ``kaldi_style_text``.
        name : str
            Name of the file. Utterance names are derived from it.

        Returns
        -------
        CTCSegmentationTask
            Task object with segments. Apply str(·) or print(·) on it
            to obtain the segments list.
        r   Nr   )
r`   r   r   rD   
load_audior   r   r   r   r   )r   r   r,   r   r>   r   r&   r   r   r   __call__  s   

zCTCSegmentation.__call__)TrB   r@   )NNNNNNNNNNN)NN)r   )N)(r3   r4   r5   r6   rZ   rE   rt   rG   r{   rF   r   r   r
   r<   r   r   r   r   r   rs   r   r   r~   rj   r   r   re   no_gradrf   r   r   r   r   r   staticmethodr   r   r   r   r   r   r   r   r   r?      s    J

S	


w

^-r?   )r6   pathlibr   typesr   typingr   r   r   r   r   re   speechbrain.inference.ASRr   r   speechbrain.utils.loggerr	   r   r
   r   r   r   ImportErrorprintr3   ro   r   r?   r   r   r   r   <module>   s&   
 \