o
    }oio2                    @   s2  d dl Z d dlZd dlZd dlmZ d dlmZmZmZ d dl	m
Z
mZmZmZmZmZ d dlZd dlZd dlmZmZ d dlmZmZ d dlmZmZ d dlmZmZ d d	lm Z  d d
l!m"Z" d dl#m$Z$m%Z% dd Z&G dd deZ'G dd de'Z(G dd de'Z)eG dd dZ*eG dd de*Z+dS )    N)abstractmethod)	dataclassfieldis_dataclass)CallableDictListOptionalSetUnion)
DictConfig	OmegaConf)ctc_beam_decodingctc_greedy_decoding)ConfidenceConfigConfidenceMixin)
HypothesisNBestHypotheses)DummyTokenizer)TokenizerSpec)logginglogging_modec                 C   s6   t t| j}| j|g|d |  ||d d    S )N   )listrangendimpermute)tensor	dim_indexall_dims r    f/home/ubuntu/.local/lib/python3.10/site-packages/nemo/collections/asr/parts/submodules/ctc_decoding.pymove_dimension_to_the_front"   s   (r"   c                       s  e Zd ZdZd:dedee f fddZ			d;d	ej	d
ej	de
de
deee eee  f f
ddZdee de
deeeef  fddZdee dee fddZedeeeeeef f  deeeeeef f  dedee deeeeeef f  f
ddZedee defddZedee dee fddZdee defdd Zd<d"ed#efd$d%Zed"ed&ee d'edeeeeeef f  fd(d)Ze	d:deeeeeef f  deeeeeef f  dee deeeeeef f  fd*d+Ze		d=d,eeeeeef f  d-ee dee d.ee deeeeeef f  f
d/d0Zed1d2 Z e j!d3d2 Z ed4d5 Z"e"j!d6d5 Z"ed7d8 Z#e#j!d9d8 Z#  Z$S )>AbstractCTCDecodingu"  
    Used for performing CTC auto-regressive / non-auto-regressive decoding of the logprobs.

    Args:
        decoding_cfg: A dict-like object which contains the following key-value pairs.
            strategy:
                str value which represents the type of decoding that can occur.
                Possible values are :

                    greedy (for greedy decoding).

                    beam (for DeepSpeed KenLM based decoding).

            compute_timestamps:
                A bool flag, which determines whether to compute the character/subword, or
                word based timestamp mapping the output log-probabilities to discrite intervals of timestamps.
                The timestamps will be available in the returned Hypothesis.timestep as a dictionary.

            ctc_timestamp_type:
                A str value, which represents the types of timestamps that should be calculated.
                Can take the following values - "char" for character/subword time stamps, "word" for word level
                time stamps and "all" (default), for both character level and word level time stamps.

            word_seperator:
                Str token representing the seperator between words.

            segment_seperators:
                List containing tokens representing the seperator(s) between segments.

            segment_gap_threshold:
                The threshold (in frames) that caps the gap between two words necessary for forming the segments.

            preserve_alignments:
                Bool flag which preserves the history of logprobs generated during
                decoding (sample / batched). When set to true, the Hypothesis will contain
                the non-null value for `logprobs` in it. Here, `logprobs` is a torch.Tensors.

            confidence_cfg:
                A dict-like object which contains the following key-value pairs related to confidence
                scores. In order to obtain hypotheses with confidence scores, please utilize
                `ctc_decoder_predictions_tensor` function with the `preserve_frame_confidence` flag set to True.

                preserve_frame_confidence:
                    Bool flag which preserves the history of per-frame confidence scores
                    generated during decoding. When set to true, the Hypothesis will contain
                    the non-null value for `frame_confidence` in it. Here, `frame_confidence` is a List of floats.

                preserve_token_confidence:
                    Bool flag which preserves the history of per-token confidence scores
                    generated during greedy decoding (sample / batched). When set to true, the Hypothesis will contain
                    the non-null value for `token_confidence` in it. Here, `token_confidence` is a List of floats.

                    The length of the list corresponds to the number of recognized tokens.

                preserve_word_confidence:
                    Bool flag which preserves the history of per-word confidence scores
                    generated during greedy decoding (sample / batched). When set to true, the Hypothesis will contain
                    the non-null value for `word_confidence` in it. Here, `word_confidence` is a List of floats.

                    The length of the list corresponds to the number of recognized words.

                exclude_blank:
                    Bool flag indicating that blank token confidence scores are to be excluded
                    from the `token_confidence`.

                aggregation:
                    Which aggregation type to use for collapsing per-token confidence into per-word confidence.
                    Valid options are `mean`, `min`, `max`, `prod`.

                tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and
                    attached to the regular frame confidence,
                    making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`).

                method_cfg:
                    A dict-like object which contains the method name and settings to compute per-frame
                    confidence scores.

                    name:
                        The method name (str).
                        Supported values:

                            'max_prob' for using the maximum token probability as a confidence.

                            'entropy' for using a normalized entropy of a log-likelihood vector.

                    entropy_type:
                        Which type of entropy to use (str).
                        Used if confidence_method_cfg.name is set to `entropy`.
                        Supported values:

                            - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided,
                                the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)).
                                Note that for this entropy, the alpha should comply the following inequality:
                                (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1)
                                where V is the model vocabulary size.

                            - 'tsallis' for the Tsallis entropy with the Boltzmann constant one.
                                Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)),
                                where α is a parameter. When α == 1, it works like the Gibbs entropy.
                                More: https://en.wikipedia.org/wiki/Tsallis_entropy

                            - 'renyi' for the Rényi entropy.
                                Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)),
                                where α is a parameter. When α == 1, it works like the Gibbs entropy.
                                More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy

                    alpha:
                        Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0.
                        When the alpha equals one, scaling is not applied to 'max_prob',
                        and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i))

                    entropy_norm:
                        A mapping of the entropy value to the interval [0,1].
                        Supported values:

                            - 'lin' for using the linear mapping.

                            - 'exp' for using exponential mapping with linear shift.

            batch_dim_index:
                Index of the batch dimension of ``targets`` and ``predictions`` parameters of
                ``ctc_decoder_predictions_tensor`` methods. Can be either 0 or 1.

            The config may further contain the following sub-dictionaries:

                "greedy":
                    preserve_alignments: Same as above, overrides above value.
                    compute_timestamps: Same as above, overrides above value.
                    preserve_frame_confidence: Same as above, overrides above value.
                    confidence_method_cfg: Same as above, overrides confidence_cfg.method_cfg.

                "beam":
                    beam_size:
                        int, defining the beam size for beam search. Must be >= 1.
                        If beam_size == 1, will perform cached greedy search. This might be slightly different
                        results compared to the greedy search above.

                    return_best_hypothesis:
                        optional bool, whether to return just the best hypothesis or all of the
                        hypotheses after beam search has concluded. This flag is set by default.

                    ngram_lm_alpha:
                        float, the strength of the Language model on the final score of a token.
                        final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length.

                    beam_beta:
                        float, the strength of the sequence length penalty on the final score of a token.
                        final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length.

                    ngram_lm_model:
                        str, path to a KenLM ARPA or .binary file (depending on the strategy chosen).
                        If the path is invalid (file is not found at path), will raise a deferred error at the moment
                        of calculation of beam search, so that users may update / change the decoding strategy
                        to point to the correct file.

        blank_id:
            The id of the RNNT blank token.
        supported_punctuation:
            Set of punctuation marks in the vocabulary.
    Nblank_idsupported_punctuationc                    s  t    t|rt|}t|tst|}t|d dg}|D ]}||vr0ti ||< q#|| _	|| _
|| _| j	dd | _| j	dd | _| j	dd| _| j	dd| _| j	d	g d
| _| j	dd | _g d}| j	j|vrtd| d| j	j | jd u r| j	jdv r| j	jdd| _n	| j	jdd| _| jd u r| j	jdv r| j	jdd| _n| j	jdv r| j	jdd| _| jrddd | jD }td| d | _| | j	dd  | js| j	jdvr| j	jddrtd| j	j d| jd ur|  j| jO  _| j	jdv rC| j	jjd ur.t d | j	jj| j	j_!| j	jj"d urCt d | j	jj"| j	j_#| j	jdkr\t$j%| j
| j| j| j| j&d| _'d S | j	jdkrt$j(| j
| j| j| j| j&| j	jdd | j	jd d!| j	jd"d#d$| _'d S | j	jd%krt)j*|| j	jd&d'd(| j	jd)d#| j| j| j	jd d*| j	jd+d!| j	jdd d,	| _'d| j'_+d S | j	jd-kr	t)j*|| j	jd&d'd-| j	jd)d#| j| j| j	jd d*| j	jd+d!| j	jdd | j	jd.d d/
| _'d| j'_+d S | j	jd0krLt)j*|| j	jd&d'd0| j	jd)d#| j| j| j	jd d*| j	jd+d!| j	jdd | j	jd1d d2
| _'d| j'_+d S | j	jd3krt)j,|| j	j-d&d'| j	j-d4d5| j	j-d)d#| j| j| j	j-d6d7| j	j-d8d| j	j-d9d:| j	j-d;d*| j	j-d<d=| j	j-d>d | j	j-d?d | j	j-d@d | j	j-dAd dB| _'d| j'_+d S | j	jdCkrt)j.|| j	jd&d'| j	jd)d#| j| j| j	jd d*| j	jd+d!| j	jdDdE| j	jdd | j	jd"d#dF
| _'d| j'_+d S tdG| dH| j	j )INFgreedypreserve_alignmentscompute_timestampsbatch_dim_indexr   word_seperator segment_seperators).?!segment_gap_threshold)r&   greedy_batchbeampyctcdecode
flashlightwfst
beam_batchz!Decoding strategy must be one of z. Given )r&   r1   )r2   |c                 S   s   g | ]}t |qS r    )reescape).0pr    r    r!   
<listcomp>   s    z0AbstractCTCDecoding.__init__.<locals>.<listcomp>z(\s)()confidence_cfgpreserve_frame_confidencez6Confidence calculation is not supported for strategy ``)r4   r5   r6   r3   r2   zh`beam_alpha` is deprecated and will be removed in a future release. Please use `ngram_lm_alpha` instead.zh`kenlm_path` is deprecated and will be removed in a future release. Please use `ngram_lm_model` instead.)r$   r'   r(   r?   confidence_method_cfgr1   ngram_lm_modelngram_lm_alphag        allow_cuda_graphsT)r$   r'   r(   r?   rA   rB   rC   rD   r2   	beam_sizer   defaultreturn_best_hypothesis      ?	beam_beta)	r$   rE   search_typerG   r'   r(   rC   rI   rB   r3   pyctcdecode_cfg)
r$   rE   rJ   rG   r'   r(   rC   rI   rB   rK   r4   flashlight_cfg)
r$   rE   rJ   rG   r'   r(   rC   rI   rB   rL   r5   rJ   rivadecoding_modenbestopen_vocabulary_decoding
beam_widthg      $@	lm_weightdevicecudaarpa_lm_pathwfst_lm_pathriva_decoding_cfgk2_decoding_cfg)r$   rE   rJ   rG   r'   r(   rN   rP   rQ   rR   rS   rU   rV   rW   rX   r6   beam_thresholdg      4@)
blank_indexrE   rG   r'   r(   rC   rI   rY   rB   rD   z5Incorrect decoding strategy supplied. Must be one of z
but was provided )/super__init__r   r   
structured
isinstancer   create
set_structcfgr$   r%   getr'   r(   r)   r*   r,   r0   strategy
ValueErrorr&   r2   joinr8   compilespace_before_punct_pattern_init_confidencer?   NotImplementedError
beam_alphar   warningrC   
kenlm_pathrB   r   GreedyCTCInferrA   decodingGreedyBatchedCTCInferr   BeamCTCInferoverride_fold_consecutive_valueWfstCTCInferr5   BeamBatchedCTCInfer)selfdecoding_cfgr$   r%   minimal_cfgitempossible_strategiespunct_pattern	__class__r    r!   r\      s2  





zAbstractCTCDecoding.__init__TFdecoder_outputsdecoder_lengthsfold_consecutivereturn_hypothesesreturnc                 C   s  t |tjrt|| j}t| jdr,| jjdur,tj	d| d| jj t
jd | jj}t  | j||d}|d }W d   n1 sFw   Y  t |d tr| jjdkr`d	d
 |D }n4g }|D ]/}|j}| ||}	| jdu r| jdd}
tt|	D ]}| |	| |
|	|< q||	 qd|r|S dd
 |D }|S | jjdkr|}n@| ||}| jdu r|r| js| jr| |}n|D ]
}|jdd |_q| jdd}
tt|D ]}| || |
||< q|r|S dd
 |D S )a  
        Decodes a sequence of labels to words

        Args:
            decoder_outputs: An integer torch.Tensor of shape [Batch, Time, {Vocabulary}] (if ``batch_index_dim == 0``) or [Time, Batch]
                (if ``batch_index_dim == 1``) of integer indices that correspond to the index of some character in the
                label set.
            decoder_lengths: Optional tensor of length `Batch` which contains the integer lengths
                of the sequence in the padded `predictions` tensor.
            fold_consecutive: Bool, determine whether to perform "ctc collapse", folding consecutive tokens
                into a single token.
            return_hypotheses: Bool flag whether to return just the decoding predictions of the model
                or a Hypothesis object that holds information such as the decoded `text`,
                the `alignment` of emited by the CTC Model, and the `length` of the sequence (if available).
                May also contain the log-probabilities of the decoder (if this method is called via
                transcribe())

        Returns:
            A list of Hypothesis objects containing additional information.
        rq   NztBeam search requires that consecutive ctc tokens are not folded. 
Overriding provided value of `fold_consecutive` = z to mode)decoder_outputr}   r   r5   c                 S   s   g | ]}|j qS r    )n_best_hypotheses)r:   hypr    r    r!   r<     s    zFAbstractCTCDecoding.ctc_decoder_predictions_tensor.<locals>.<listcomp>Tctc_timestamp_typeallc                 S   s   g | ]	}d d |D qS )c                 S      g | ]}t |j|j|jqS r    r   score
y_sequencetextr:   hr    r    r!   r<         zQAbstractCTCDecoding.ctc_decoder_predictions_tensor.<locals>.<listcomp>.<listcomp>r    )r:   hhr    r    r!   r<     s       c                 S   r   r    r   r   r    r    r!   r<     r   )r^   torchTensorr"   r)   hasattrrn   rq   r   infor   ONCEinference_moder   ra   rc   r   decode_hypothesisr(   rb   r   lencompute_ctc_timestampsappendpreserve_word_confidencepreserve_token_confidencecompute_confidencer   )rt   r|   r}   r~   r   hypotheses_listall_hypotheses	nbest_hypn_hypsdecoded_hypstimestamp_typehyp_idxall_hyp
hypothesesr   r    r    r!   ctc_decoder_predictions_tensor  sl   


	

z2AbstractCTCDecoding.ctc_decoder_predictions_tensorr   c                 C   s  t t|D ]}|| }|j}|jdkr|jnd}|rt|tkr'|  }|dur1|d| }g }g }g }	| j}
d}d}t	|D ]6\}}||
ksO|
| jkri|| jkri|
| |
||  |}|	
| d}||
krv|
| jkrv|d7 }|}
qBt|	dkr|	dd |g }	n!|dur|d| }||| jk  }dgt| }dgt| }	| jdu r|||	f}n| |}||| _q|S )a  
        Decode a list of hypotheses into a list of strings.

        Args:
            hypotheses_list: List of Hypothesis.
            fold_consecutive: Whether to collapse the ctc blank tokens or not.

        Returns:
            A list of strings.
        r   Nr   T)r   r   r   lengthtyper   numpytolistr$   	enumerater   r(   +decode_tokens_to_str_with_strip_punctuationr   )rt   r   r~   indr   
predictionpredictions_lendecoded_predictiontoken_lengthstoken_repetitionspreviouslast_lengthlast_repetitionpidxr;   
hypothesisr    r    r!   r     sL   



z%AbstractCTCDecoding.decode_hypothesisc              
   C   s  |D ]v}t |jtrt|jdkrtd|jd }|jdd |_g }| jrD|j}d}|D ]}|| }|| |||  |}q.n1|jd }	t|	dkru|	d }
|	dd t|j	g D ]}|| |j	|
|
|   |
|7 }
q_||_
q| jr|D ]}| ||_q~|S )aI  
        Computes high-level (per-token and/or per-word) confidence scores for a list of hypotheses.
        Assumes that `frame_confidence` is present in the hypotheses.

        Args:
            hypotheses_list: List of Hypothesis.

        Returns:
            A list of hypotheses with high-level confidence scores.
           zWrong format of the `text` attribute of a hypothesis.

                    Expected: (decoded_prediction, token_repetitions)

                    The method invocation is expected between .decode_hypothesis() and .compute_ctc_timestamps()r   Nr   r   )r^   r   tupler   rd   exclude_blank_from_confidencenon_blank_frame_confidencer   _aggregate_confidenceframe_confidencetoken_confidencer   _aggregate_token_confidenceword_confidence)rt   r   r   r   r   r   itrjr   tstlr    r    r!   r   A  s6   


z&AbstractCTCDecoding.compute_confidencechar_offsetsencoded_char_offsetsword_delimiter_charc                 C      t  )zL
        Implemented by subclass in order to get the words offsets.
        ri   )rt   r   r   r   r%   r    r    r!   get_words_offsetsm     z%AbstractCTCDecoding.get_words_offsetstokensc                 C   r   )z
        Implemented by subclass in order to decoder a token id list into a string.

        Args:
            tokens: List of int representing the token ids.

        Returns:
            A decoded string.
        r   rt   r   r    r    r!   decode_tokens_to_strz  r   z(AbstractCTCDecoding.decode_tokens_to_strc                 C   r   &  
        Implemented by subclass in order to decode a token id list into a token list.
        A token list is the string representation of each token id.

        Args:
            tokens: List of int representing the token ids.

        Returns:
            A list of decoded tokens.
        r   r   r    r    r!   decode_ids_to_tokens  s   z(AbstractCTCDecoding.decode_ids_to_tokensc                 C   s"   |  |}| jr| jd|}|S )zn
        Decodes a list of tokens to a string and removes a space before supported punctuation marks.
        z\2)r   r%   rg   sub)rt   r   r   r    r    r!   r     s   
z?AbstractCTCDecoding.decode_tokens_to_str_with_strip_punctuationr   r   r   c              
   C   s  |dv sJ |j \}}||_ d }}| ||| j}t|t|j kr9td| d|j  dt| dt|j  t|}t|j D ]\}}	| |	g|| d< qC| j	||| j
d\}}d}|d	v rn| j||| j| j
d
}d}
|dv r| j|| j| j
| jd }
}
t|jdkr|j}ng }d|i|_|dur|dv r||jd< |dur|dv r||jd< |
dur|dv r|
|jd< | |j |_ |S )a  
        Method to compute time stamps at char/subword, and word level given some hypothesis.
        Requires the input hypothesis to contain a `text` field that is the tuple. The tuple contains -
        the ctc collapsed integer ids, and the number of repetitions of each token.

        Args:
            hypothesis: A Hypothesis object, with a wrapped `text` field.
                The `text` field must contain a tuple with two values -
                The ctc collapsed integer ids
                A list of integers that represents the number of repetitions per token.
            timestamp_type: A str value that represents the type of time stamp calculated.
                Can be one of "char", "word" "segment" or "all"

        Returns:
            A Hypothesis object with a modified `timestep` value, which is now a dictionary containing
            the time stamp information.
        )charwordsegmentr   Nz`char_offsets`: z and `processed_tokens`: z9 have to be of the same length, but are: `len(offsets)`: z and `len(processed_tokens)`: r   )r   r   r%   )r   r   r   )r   r   r   r%   )r   r   )segment_delimiter_tokensr%   r0   r   timestep)r   r   )r   r   r   r   )r   _compute_offsetsr$   r   rd   copydeepcopyr   r   _refine_timestampsr%   r   r*   _get_segment_offsetsr,   r0   	timestampr   )rt   r   r   r   r   r   word_offsetsr   r   r   segment_offsetstimestep_infor    r    r!   r     sb   







z*AbstractCTCDecoding.compute_ctc_timestampsr   	ctc_tokenc                    s   d}| j durt| j dkrtd| j d d }t| }t|g|dd f}dd t| j||D }t	t
 fdd|}|S )	a  
        Utility method that calculates the indidual time indices where a token starts and ends.

        Args:
            hypothesis: A Hypothesis object that contains `text` field that holds the character / subword token
                emitted at every time step after ctc collapse.
            token_lengths: A list of ints representing the lengths of each emitted token.
            ctc_token: The integer of the ctc blank token used during ctc collapse.

        Returns:

        r   Nr   c                 S   s   g | ]\}}}|||d qS ))r   start_offset
end_offsetr    )r:   tser    r    r!   r<     s    
z8AbstractCTCDecoding._compute_offsets.<locals>.<listcomp>c                    s   | d  kS )Nr   r    )offsetsr   r    r!   <lambda>      z6AbstractCTCDecoding._compute_offsets.<locals>.<lambda>)r   r   maxnpasarraycumsumconcatenatezipr   r   filter)r   r   r   start_indexend_indicesstart_indicesr   r    r   r!   r     s   z$AbstractCTCDecoding._compute_offsetsc                 C   s^   |s| |fS t |D ] \}}|d r*|d d |v r*|dkr*|d  | | d< |d< q
| |fS )Nr   r   r   r   )r   )r   r   r%   r   offsetr    r    r!   r   "  s    z&AbstractCTCDecoding._refine_timestampsr   r   r0   c                 C   sT  |rt ||s|stjd| dtjd g }g }d}t| D ]j\}}|d }	|rZ|rZ|d | |d  d  }
|
|krY|d	|| | d | |d  d d
 |	g}|}q!n,|	r|	d |v sf|	|v r||	 |r|d	|| | d |d d
 g }|d }q!||	 q!|r| | d }|d	||| d d d
 |	  |S )a  
        Utility method which constructs segment time stamps out of word time stamps.

        Args:
            offsets: A list of dictionaries, each containing "word", "start_offset" and "end_offset".
            segments_delimiter_tokens: List containing tokens representing the seperator(s) between segments.
            supported_punctuation: Set containing punctuation marks in the vocabulary.
            segment_gap_threshold: Number of frames between 2 consecutive words necessary to form segments out of plain text.

        Returns:
            A list of dictionaries containing the segment offsets. Each item contains "segment", "start_offset" and
            "end_offset".
        z>Specified segment seperators are not in supported punctuation z. If the seperators are not punctuation marks, ignore this warning. Otherwise, specify 'segment_gap_threshold' parameter in decoding config to form segments.r   r   r   r   r   r   r+   )r   r   r   r   )
setintersectionr   rk   r   r   r   r   re   clear)r   r   r%   r0   r   segment_wordsprevious_word_indexr   r   r   gap_between_wordsr   r    r    r!   r   6  sf   




z(AbstractCTCDecoding._get_segment_offsetsc                 C      | j S N)_preserve_alignmentsrt   r    r    r!   r'        z'AbstractCTCDecoding.preserve_alignmentsc                 C       || _ t| dr|| j_d S d S Nrn   )r  r   rn   r'   rt   valuer    r    r!   r'        
c                 C   r   r  )_compute_timestampsr  r    r    r!   r(     r  z&AbstractCTCDecoding.compute_timestampsc                 C   r  r  )r
  r   rn   r(   r  r    r    r!   r(     r	  c                 C   r   r  )_preserve_frame_confidencer  r    r    r!   r?     r  z-AbstractCTCDecoding.preserve_frame_confidencec                 C   r  r  )r  r   rn   r?   r  r    r    r!   r?     r	  r  )NTF)r   )NN)%__name__
__module____qualname____doc__intr	   r
   r\   r   r   boolr   r   r   r   r   r   r   r   r   strfloatr   r   r   r   r   staticmethodr   r   r   propertyr'   setterr(   r?   __classcell__r    r    rz   r!   r#   '   s     " J
i
I,
^$V




r#   c                       s   e Zd ZdZ fddZdedee fddZdee	 de
fd	d
Zdee	 dee
 fddZe		ddeee
ee
ef f  deee
ee
ef f  de
dee deee
ee
ef f  f
ddZ  ZS )CTCDecodingu"  
    Used for performing CTC auto-regressive / non-auto-regressive decoding of the logprobs for character
    based models.

    Args:
        decoding_cfg: A dict-like object which contains the following key-value pairs.

            strategy:
                str value which represents the type of decoding that can occur.
                Possible values are :

                    -   greedy (for greedy decoding).

                    -   beam (for DeepSpeed KenLM based decoding).

            compute_timestamps:
                A bool flag, which determines whether to compute the character/subword, or
                word based timestamp mapping the output log-probabilities to discrite intervals of timestamps.
                The timestamps will be available in the returned Hypothesis.timestep as a dictionary.

            ctc_timestamp_type:
                A str value, which represents the types of timestamps that should be calculated.
                Can take the following values - "char" for character/subword time stamps, "word" for word level
                time stamps and "all" (default), for both character level and word level time stamps.

            word_seperator:
                Str token representing the seperator between words.

            segment_seperators:
                List containing tokens representing the seperator(s) between segments.

            segment_gap_threshold:
                The threshold (in frames) that caps the gap between two words necessary for forming the segments.

            preserve_alignments:
                Bool flag which preserves the history of logprobs generated during
                decoding (sample / batched). When set to true, the Hypothesis will contain
                the non-null value for `logprobs` in it. Here, `logprobs` is a torch.Tensors.

            confidence_cfg:
                A dict-like object which contains the following key-value pairs related to confidence
                scores. In order to obtain hypotheses with confidence scores, please utilize
                `ctc_decoder_predictions_tensor` function with the `preserve_frame_confidence` flag set to True.

                preserve_frame_confidence:
                    Bool flag which preserves the history of per-frame confidence scores
                    generated during decoding. When set to true, the Hypothesis will contain
                    the non-null value for `frame_confidence` in it. Here, `frame_confidence` is a List of floats.

                preserve_token_confidence:
                    Bool flag which preserves the history of per-token confidence scores
                    generated during greedy decoding (sample / batched). When set to true, the Hypothesis will contain
                    the non-null value for `token_confidence` in it. Here, `token_confidence` is a List of floats.

                    The length of the list corresponds to the number of recognized tokens.

                preserve_word_confidence:
                    Bool flag which preserves the history of per-word confidence scores
                    generated during greedy decoding (sample / batched). When set to true, the Hypothesis will contain
                    the non-null value for `word_confidence` in it. Here, `word_confidence` is a List of floats.

                    The length of the list corresponds to the number of recognized words.

                exclude_blank:
                    Bool flag indicating that blank token confidence scores are to be excluded
                    from the `token_confidence`.

                aggregation:
                    Which aggregation type to use for collapsing per-token confidence into per-word confidence.
                    Valid options are `mean`, `min`, `max`, `prod`.

                tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and
                    attached to the regular frame confidence,
                    making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`).

                method_cfg:
                    A dict-like object which contains the method name and settings to compute per-frame
                    confidence scores.

                    name:
                        The method name (str).
                        Supported values:

                            - 'max_prob' for using the maximum token probability as a confidence.

                            - 'entropy' for using a normalized entropy of a log-likelihood vector.

                    entropy_type:
                        Which type of entropy to use (str).
                        Used if confidence_method_cfg.name is set to `entropy`.
                        Supported values:

                            - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided,
                                the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)).
                                Note that for this entropy, the alpha should comply the following inequality:
                                (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1)
                                where V is the model vocabulary size.

                            - 'tsallis' for the Tsallis entropy with the Boltzmann constant one.
                                Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)),
                                where α is a parameter. When α == 1, it works like the Gibbs entropy.
                                More: https://en.wikipedia.org/wiki/Tsallis_entropy

                            - 'renyi' for the Rényi entropy.
                                Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)),
                                where α is a parameter. When α == 1, it works like the Gibbs entropy.
                                More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy

                    alpha:
                        Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0.
                        When the alpha equals one, scaling is not applied to 'max_prob',
                        and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i))

                    entropy_norm:
                        A mapping of the entropy value to the interval [0,1].
                        Supported values:

                            - 'lin' for using the linear mapping.

                            - 'exp' for using exponential mapping with linear shift.

            batch_dim_index:
                Index of the batch dimension of ``targets`` and ``predictions`` parameters of
                ``ctc_decoder_predictions_tensor`` methods. Can be either 0 or 1.

            The config may further contain the following sub-dictionaries:

                "greedy":
                    preserve_alignments: Same as above, overrides above value.
                    compute_timestamps: Same as above, overrides above value.
                    preserve_frame_confidence: Same as above, overrides above value.
                    confidence_method_cfg: Same as above, overrides confidence_cfg.method_cfg.

                "beam":
                    beam_size:
                        int, defining the beam size for beam search. Must be >= 1.
                        If beam_size == 1, will perform cached greedy search. This might be slightly different
                        results compared to the greedy search above.

                    return_best_hypothesis:
                        optional bool, whether to return just the best hypothesis or all of the
                        hypotheses after beam search has concluded. This flag is set by default.

                    ngram_lm_alpha:
                        float, the strength of the Language model on the final score of a token.
                        final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length.

                    beam_beta:
                        float, the strength of the sequence length penalty on the final score of a token.
                        final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length.

                    ngram_lm_model:
                        str, path to a KenLM ARPA or .binary file (depending on the strategy chosen).
                        If the path is invalid (file is not found at path), will raise a deferred error at the moment
                        of calculation of beam search, so that users may update / change the decoding strategy
                        to point to the correct file.

        blank_id: The id of the RNNT blank token.
    c                    s~   t  } | _t fddtt  D | _dd  D }t j|||d t| jt	j
r=| j| j | jd d S d S )Nc                    s   g | ]}| | fqS r    r    )r:   r   
vocabularyr    r!   r<   W  s    z(CTCDecoding.__init__.<locals>.<listcomp>c                 S   *   h | ]}|D ]}t |d r|qqS Punicodedatacategory
startswithr:   tokenr   r    r    r!   	<setcomp>Y      z'CTCDecoding.__init__.<locals>.<setcomp>ru   r$   r%   r   )r   r  dictr   
labels_mapr[   r\   r^   rn   r   AbstractBeamCTCInferset_vocabularyset_decoding_type)rt   ru   r  r$   r%   rz   r  r!   r\   P  s    zCTCDecoding.__init__r   r   c                 C   s   |  | |jd  |jS )z
        Implemented by subclass in order to aggregate token confidence to a word-level confidence.

        Args:
            hypothesis: Hypothesis

        Returns:
            A list of word-level confidence scores.
        r   )!_aggregate_token_confidence_charsr   r   splitr   rt   r   r    r    r!   r   d  s   
z'CTCDecoding._aggregate_token_confidencer   c                 C   s   d | |}|S )
        Implemented by subclass in order to decoder a token list into a string.

        Args:
            tokens: List of int representing the token ids.

        Returns:
            A decoded string.
         )re   r   rt   r   r   r    r    r!   r   r  s   
z CTCDecoding.decode_tokens_to_strc                    s    fdd|D }|S )r   c                    s    g | ]}| j kr j| qS r    )r$   r(  )r:   cr  r    r!   r<     s     z4CTCDecoding.decode_ids_to_tokens.<locals>.<listcomp>r    rt   r   
token_listr    r  r!   r     s   z CTCDecoding.decode_ids_to_tokensr+   Nr   r   r   r%   c                 C   s   g }d}d}d}d}t | D ]]\}	}
|
d }||krdnd}|	t| d k r.| |	d  d nd}|r;||v o9||k}nd}|d	krD|rDq||krQ|
d
 }||7 }n|dkr_||||d n
|
d }|
d
 }|}|}q|dkry||||d |S )a  
        Utility method which constructs word time stamps out of character time stamps.

        References:
            This code is a port of the Hugging Face code for word time stamp construction.

        Args:
            char_offsets: A list of dictionaries, each containing "char", "start_offset" and "end_offset",
                        where "char" is decoded with the tokenizer.
            encoded_char_offsets: A list of dictionaries, each containing "char", "start_offset" and "end_offset",
                        where "char" is the original id/ids from the hypotheses (not decoded with the tokenizer).
                        As we are working with char-based models here, we are using the `char_offsets` to get the word offsets.
                        `encoded_char_offsets` is passed for keeping the consistency with `AbstractRNNTDecoding`'s abstract method.
            word_delimiter_char: Character token that represents the word delimiter. By default, " ".
            supported_punctuation: Set containing punctuation marks in the vocabulary.

        Returns:
            A list of dictionaries containing the word offsets. Each item contains "word", "start_offset" and
            "end_offset".
        	DELIMITERr0  r   r   WORDr   NFr+   r   r   r   r   r   )r   r   r   )r   r   r   r%   r   
last_stater   r   r   r   r   r   state	next_charnext_punctuationr    r    r!   r     s4   $
zCTCDecoding.get_words_offsetsr+   N)r  r  r  r  r\   r   r   r  r   r  r  r   r   r  r   r   r	   r
   r   r  r    r    rz   r!   r    s*     !r  c                       s  e Zd ZdZdef fddZedee defddZ	ed	ed
ede
eegef fddZdedee fddZdee defddZdee dee fddZ		ddeeeeeef f  deeeeeef f  d
edee deeeeeef f  f
ddZ  ZS )CTCBPEDecodingu!  
    Used for performing CTC auto-regressive / non-auto-regressive decoding of the logprobs for subword based
    models.

    Args:
        decoding_cfg: A dict-like object which contains the following key-value pairs.

            strategy:
                str value which represents the type of decoding that can occur.
                Possible values are :

                    -   greedy (for greedy decoding).

                    -   beam (for DeepSpeed KenLM based decoding).

            compute_timestamps:
                A bool flag, which determines whether to compute the character/subword, or
                word based timestamp mapping the output log-probabilities to discrite intervals of timestamps.
                The timestamps will be available in the returned Hypothesis.timestep as a dictionary.

            ctc_timestamp_type:
                A str value, which represents the types of timestamps that should be calculated.
                Can take the following values - "char" for character/subword time stamps, "word" for word level
                time stamps and "all" (default), for both character level and word level time stamps.

            word_seperator:
                Str token representing the seperator between words.

            preserve_alignments:
                Bool flag which preserves the history of logprobs generated during
                decoding (sample / batched). When set to true, the Hypothesis will contain
                the non-null value for `logprobs` in it. Here, `logprobs` is a torch.Tensors.

            confidence_cfg:
                A dict-like object which contains the following key-value pairs related to confidence
                scores. In order to obtain hypotheses with confidence scores, please utilize
                `ctc_decoder_predictions_tensor` function with the `preserve_frame_confidence` flag set to True.

                preserve_frame_confidence:
                    Bool flag which preserves the history of per-frame confidence scores
                    generated during decoding. When set to true, the Hypothesis will contain
                    the non-null value for `frame_confidence` in it. Here, `frame_confidence` is a List of floats.

                preserve_token_confidence:
                    Bool flag which preserves the history of per-token confidence scores
                    generated during greedy decoding (sample / batched). When set to true, the Hypothesis will contain
                    the non-null value for `token_confidence` in it. Here, `token_confidence` is a List of floats.

                    The length of the list corresponds to the number of recognized tokens.

                preserve_word_confidence:
                    Bool flag which preserves the history of per-word confidence scores
                    generated during greedy decoding (sample / batched). When set to true, the Hypothesis will contain
                    the non-null value for `word_confidence` in it. Here, `word_confidence` is a List of floats.

                    The length of the list corresponds to the number of recognized words.

                exclude_blank:
                    Bool flag indicating that blank token confidence scores are to be excluded
                    from the `token_confidence`.

                aggregation:
                    Which aggregation type to use for collapsing per-token confidence into per-word confidence.
                    Valid options are `mean`, `min`, `max`, `prod`.

                tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and
                    attached to the regular frame confidence,
                    making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`).

                method_cfg:
                    A dict-like object which contains the method name and settings to compute per-frame
                    confidence scores.

                    name:
                        The method name (str).
                        Supported values:

                            - 'max_prob' for using the maximum token probability as a confidence.

                            - 'entropy' for using a normalized entropy of a log-likelihood vector.

                    entropy_type:
                        Which type of entropy to use (str).
                        Used if confidence_method_cfg.name is set to `entropy`.
                        Supported values:

                            - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided,
                                the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)).
                                Note that for this entropy, the alpha should comply the following inequality:
                                (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1)
                                where V is the model vocabulary size.

                            - 'tsallis' for the Tsallis entropy with the Boltzmann constant one.
                                Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)),
                                where α is a parameter. When α == 1, it works like the Gibbs entropy.
                                More: https://en.wikipedia.org/wiki/Tsallis_entropy

                            - 'renyi' for the Rényi entropy.
                                Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)),
                                where α is a parameter. When α == 1, it works like the Gibbs entropy.
                                More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy

                    alpha:
                        Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0.
                        When the alpha equals one, scaling is not applied to 'max_prob',
                        and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i))

                    entropy_norm:
                        A mapping of the entropy value to the interval [0,1].
                        Supported values:

                            - 'lin' for using the linear mapping.

                            - 'exp' for using exponential mapping with linear shift.

            batch_dim_index:
                Index of the batch dimension of ``targets`` and ``predictions`` parameters of
                ``ctc_decoder_predictions_tensor`` methods. Can be either 0 or 1.

            The config may further contain the following sub-dictionaries:

                "greedy":
                    preserve_alignments: Same as above, overrides above value.
                    compute_timestamps: Same as above, overrides above value.
                    preserve_frame_confidence: Same as above, overrides above value.
                    confidence_method_cfg: Same as above, overrides confidence_cfg.method_cfg.

                "beam":
                    beam_size:
                        int, defining the beam size for beam search. Must be >= 1.
                        If beam_size == 1, will perform cached greedy search. This might be slightly different
                        results compared to the greedy search above.

                    return_best_hypothesis:
                        optional bool, whether to return just the best hypothesis or all of the
                        hypotheses after beam search has concluded. This flag is set by default.

                    ngram_lm_alpha:
                        float, the strength of the Language model on the final score of a token.
                        final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length.

                    beam_beta:
                        float, the strength of the sequence length penalty on the final score of a token.
                        final_score = acoustic_score + ngram_lm_alpha * lm_score + beam_beta * seq_length.

                    ngram_lm_model:
                        str, path to a KenLM ARPA or .binary file (depending on the strategy chosen).
                        If the path is invalid (file is not found at path), will raise a deferred error at the moment
                        of calculation of beam search, so that users may update / change the decoding strategy
                        to point to the correct file.

        tokenizer: NeMo tokenizer object, which inherits from TokenizerSpec.
    	tokenizerc                    s   |j j}|| _ | |j| _t|dr|j}ndd |jD }t j|||d t	| j
tjrft| j j drY| j j  }t	| j j trF|}nt| }| j
| | j
| ntd | j
d d S d S )Nr%   c                 S   r  r  r  r"  r    r    r!   r$  y  r%  z*CTCBPEDecoding.__init__.<locals>.<setcomp>r&  	get_vocabz3Could not resolve the vocabulary of the tokenizer !subword)r>  
vocab_sizedefine_tokenizer_typevocabtokenizer_typer   r%   r[   r\   r^   rn   r   r)  r?  r   r   keysr*  set_tokenizerr   rk   r+  )rt   ru   r>  r$   r%   
vocab_dictrC  rz   r    r!   r\   q  s(   

zCTCBPEDecoding.__init__r  r   c                 C   s   t dd | D rdS dS )zD
        Define the tokenizer type based on the vocabulary.
        c                 s   s    | ]}| d V  qdS )##Nr!  )r:   r#  r    r    r!   	<genexpr>  s    z7CTCBPEDecoding.define_tokenizer_type.<locals>.<genexpr>wpebpe)anyr  r    r    r!   rB    s   z$CTCBPEDecoding.define_tokenizer_typerD  r   c                    s,    dkr| dkrdd S dd S  fddS )zk
        Define the word start condition based on the tokenizer type and word delimiter character.
        r+   rK  c                 S   s   |o| d S )NrH  rI  r#  
token_textr    r    r!   r     s    z<CTCBPEDecoding.define_word_start_condition.<locals>.<lambda>c                 S   s   | |kS r  r    rN  r    r    r!   r         c                    s   | kS r  r    rN  r   r    r!   r     rP  r    )rD  r   r    rQ  r!   define_word_start_condition  s
   z*CTCBPEDecoding.define_word_start_conditionr   c                 C   s&   |  | |jd  |j|jd S )a%  
        Implemented by subclass in order to aggregate token confidence to a word-level confidence.

        **Note**: Only supports Sentencepiece based tokenizers!

        Args:
            hypothesis: Hypothesis

        Returns:
            A list of word-level confidence scores.
        r   )2_aggregate_token_confidence_subwords_sentencepiecer   r   r-  r   r.  r    r    r!   r     s   z*CTCBPEDecoding._aggregate_token_confidencer   c                 C      | j |}|S )r/  )r>  ids_to_textr1  r    r    r!   r     s   
z#CTCBPEDecoding.decode_tokens_to_strc                 C   rT  r   )r>  ids_to_tokensr3  r    r    r!   r     s   z#CTCBPEDecoding.decode_ids_to_tokensr+   Nr   r   r%   c                 C   s  |  }g }d}g }| | j|}t|D ]\}	}
|
d }| |gd }| |g }|o2||v }|||rf|sf|rV| |}|rV|||| d ||	d  d d |  ||kre|| |	}q|r|s|d }|
d |d< |d d d	kr|d d
d |d< |d  |7  < q|s|	}|| qt	|dkr|r| |}|r|||d d |d d d |S |d d |d d< |r| |}|r|||| d |d d d |S )a  
        Utility method which constructs word time stamps out of sub-word time stamps.

        **Note**: Only supports Sentencepiece based tokenizers !

        Args:
            char_offsets: A list of dictionaries, each containing "char", "start_offset" and "end_offset",
                        where "char" is decoded with the tokenizer.
            encoded_char_offsets: A list of dictionaries, each containing "char", "start_offset" and "end_offset",
                        where "char" is the original id/ids from the hypotheses (not decoded with the tokenizer).
                        This is needed for subword tokenization models.
            word_delimiter_char: Character token that represents the word delimiter. By default, " ".
            supported_punctuation: Set containing punctuation marks in the vocabulary.

        Returns:
            A list of dictionaries containing the word offsets. Each item contains "word", "start_offset" and
            "end_offset".
        r   r   r   r   r   r7  r   r   r+   N)
r   rR  rD  r   r   r   stripr   r   r   )rt   r   r   r   r%   r   previous_token_indexbuilt_tokenscondition_for_word_startr   r   r   r#  rO  curr_punctuation
built_wordlast_built_wordr    r    r!   r     sp   

	






z CTCBPEDecoding.get_words_offsetsr<  )r  r  r  r  r   r\   r  r   r  rB  r   r  rR  r   r  r   r  r   r   r   r   r	   r
   r   r  r    r    rz   r!   r=    s0     $r=  c                   @   s   e Zd ZU dZeed< dZee ed< dZ	ee ed< dZ
eed< edd	 d
Zeee  ed< dZee ed< dZeed< dZeed< edd	 d
Zejed< edd	 d
Zejed< edd	 d
Zejed< edd	 d
Zeed< dZeed< dS )CTCDecodingConfigr1   rc   Nr'   r(   r+   r*   c                   C   s   g dS )N)r-   r/   r.   r    r    r    r    r!   r   W  rP  zCTCDecodingConfig.<lambda>)default_factoryr,   r0   r   r   r   r)   c                   C   s   t  S r  )r   GreedyCTCInferConfigr    r    r    r!   r   d  rP  r&   c                   C      t jddS N   )rE   )r   BeamCTCInferConfigr    r    r    r!   r   i  r   r2   c                   C   ra  rb  )r   WfstCTCInferConfigr    r    r    r!   r   n  r   r5   c                   C   s   t  S r  )r   r    r    r    r!   r   r  s    r>   rH   temperature)r  r  r  rc   r  __annotations__r'   r	   r  r(   r*   r   r,   r   r0   r  r   r)   r&   r   r`  r2   r   rd  r5   re  r>   r   rf  r  r    r    r    r!   r^  I  s(   
 r^  c                   @   s   e Zd ZdS )CTCBPEDecodingConfigN)r  r  r  r    r    r    r!   rh  x  s    rh  ),r   r8   r  abcr   dataclassesr   r   r   typingr   r   r   r	   r
   r   r   r   r   	omegaconfr   r   %nemo.collections.asr.parts.submodulesr   r   5nemo.collections.asr.parts.utils.asr_confidence_utilsr   r   +nemo.collections.asr.parts.utils.rnnt_utilsr   r   6nemo.collections.common.tokenizers.aggregate_tokenizerr   1nemo.collections.common.tokenizers.tokenizer_specr   
nemo.utilsr   r   r"   r#   r  r=  r^  rh  r    r    r    r!   <module>   sD             )  u.