o
    }oi                     @   s  d dl mZmZ d dlmZmZmZmZmZm	Z	 d dl
Z
eG dd dZeG dd dZeG dd	 d	Zd
ee dee defddZdee de
jde
jdededeeeef  fddZG dd dZG dd dZ	ddedee dee fddZdS )    )	dataclassfield)AnyDictListOptionalTupleUnionNc                   @   s   e Zd ZU dZeed< eee e	j
f ed< dZee ed< dZeee	j
  ed< dZeeeee	j
  ee	j
 f  ed< eedZeee e	j
f ed	< dZeeee eee  f  ed
< dZeeee eee  f  ed< dZeee  ed< dZeee  ed< dZeee	j
f ed< dZee	j ed< dZeeeeef ee f  ed< dZee	j
 ed< dZeeeeef ee f  ed< dZeeee e	j
f  ed< dZ ee	j
 ed< dZ!ee	j
 ed< dZ"ee ed< e#dee fddZ$e#dee fddZ%d"ddZ&d d! Z'dS )#
Hypothesisa=  Hypothesis class for beam search algorithms.

    score: A float score obtained from an AbstractRNNTDecoder module's score_hypothesis method.

    y_sequence: Either a sequence of integer ids pointing to some vocabulary, or a packed torch.Tensor
        behaving in the same manner. dtype must be torch.Long in the latter case.

    dec_state: A list (or list of list) of LSTM-RNN decoder states. Can be None.

    text: (Optional) A decoded string after processing via CTC / RNN-T decoding (removing the CTC/RNNT
        `blank` tokens, and optionally merging word-pieces). Should be used as decoded string for
        Word Error Rate calculation.

    timestamp: (Optional) A list of integer indices representing at which index in the decoding
        process did the token appear. Should be of same length as the number of non-blank tokens.

    alignments: (Optional) Represents the CTC / RNNT token alignments as integer tokens along an axis of
        time T (for CTC) or Time x Target (TxU).
        For CTC, represented as a single list of integer indices.
        For RNNT, represented as a dangling list of list of integer indices.
        Outer list represents Time dimension (T), inner list represents Target dimension (U).
        The set of valid indices **includes** the CTC / RNNT blank token in order to represent alignments.

    frame_confidence: (Optional) Represents the CTC / RNNT per-frame confidence scores as token probabilities
        along an axis of time T (for CTC) or Time x Target (TxU).
        For CTC, represented as a single list of float indices.
        For RNNT, represented as a dangling list of list of float indices.
        Outer list represents Time dimension (T), inner list represents Target dimension (U).

    token_confidence: (Optional) Represents the CTC / RNNT per-token confidence scores as token probabilities
        along an axis of Target U.
        Represented as a single list of float indices.

    word_confidence: (Optional) Represents the CTC / RNNT per-word confidence scores as token probabilities
        along an axis of Target U.
        Represented as a single list of float indices.

    length: Represents the length of the sequence (the original length without padding), otherwise
        defaults to 0.

    y: (Unused) A list of torch.Tensors representing the list of hypotheses.

    lm_state: (Unused) A dictionary state cache used by an external Language Model.

    lm_scores: (Unused) Score of the external Language Model.

    ngram_lm_state: (Optional) State of the external n-gram Language Model.

    tokens: (Optional) A list of decoded tokens (can be characters or word-pieces.

    last_token (Optional): A token or batch of tokens which was predicted in the last step.

    last_frame (Optional): Index of the last decoding step hypothesis was updated including blank token prediction.
    score
y_sequenceNtextdec_out	dec_state)default_factory	timestamp
alignmentsframe_confidencetoken_confidenceword_confidencer   lengthylm_state	lm_scoresngram_lm_statetokens
last_tokentoken_duration
last_framereturnc                    s   g }t  jtr jd n j}t|dkrP jdurPtdd  jD rGd}d}|D ]}||kr6|}d}n|d7 }| j| |  q+|S  fdd	|D }|S )
zGet per-frame confidence for non-blank tokens according to self.timestamp

        Returns:
            List with confidence scores. The length of the list is the same as `timestamp`.
        timestepr   Nc                 s   s    | ]}t |tV  qd S N)
isinstancelist.0i r'   _/home/ubuntu/.local/lib/python3.10/site-packages/nemo/collections/asr/parts/utils/rnnt_utils.py	<genexpr>{   s    z8Hypothesis.non_blank_frame_confidence.<locals>.<genexpr>   c                    s   g | ]} j | qS r'   )r   )r%   tselfr'   r(   
<listcomp>   s    z9Hypothesis.non_blank_frame_confidence.<locals>.<listcomp>)r"   r   dictlenr   anyappend)r.   non_blank_frame_confidencer   t_prevoffsetr,   r'   r-   r(   r4   p   s   z%Hypothesis.non_blank_frame_confidencec                 C   s   | j du rg S | j  S )zVGet words from self.text

        Returns:
            List with words (str).
        N)r   splitr-   r'   r'   r(   words   s   zHypothesis.wordsotherc                 C   s  |  j |j 7  _ | jdu r|j| _nt| jtjr&tj| j|jfdd| _n| j|j |j| _| jdu r;|j| _nt| jtjrOtj| j|jfdd| _n| j|j |  j	|j	7  _	|j
| _
| jdu rl|j| _n| j|j | jdu r}|j| _n| j|j d| _| S )z4Merge (inplace) current hypothesis with another one.Nr   dim)r   r   r"   torchTensorcatextendr   r   r   r   r   r   r   )r.   r9   r'   r'   r(   merge_   s,   







zHypothesis.merge_c                 C   s
   d| _ dS )z(Clean the decoding state to save memory.N)r   r-   r'   r'   r(   clean_decoding_state_   s   
z Hypothesis.clean_decoding_state_)r9   r
   r   r
   )(__name__
__module____qualname____doc__float__annotations__r	   r   intr<   r=   r   r   strr   r   r   r#   r   r   r   r   r   r   r   tensorr   r   r   r   r   r   r   r   r   propertyr4   r8   r@   rA   r'   r'   r'   r(   r
   #   s6   
 7( $$$$
r
   c                   @   s"   e Zd ZU dZeee  ed< dS )NBestHypotheseszList of N best hypothesesn_best_hypothesesN)rB   rC   rD   rE   r   r   r
   rG   r'   r'   r'   r(   rL      s   
 rL   c                   @   s6   e Zd ZU dZdZeej ed< dZ	eej ed< dS )HATJointOutputzHATJoint outputs for beam search decoding

    hat_logprobs: standard HATJoint outputs as for RNNTJoint

    ilm_logprobs: internal language model probabilities (for ILM subtraction)
    Nhat_logprobsilm_logprobs)
rB   rC   rD   rE   rO   r   r<   r=   rG   rP   r'   r'   r'   r(   rN      s   
 rN   xprefr   c                 C   s@   t |t | kr
dS tt |D ]}|| | | kr dS qdS )z
    Obtained from https://github.com/espnet/espnet.

    Check if pref is a prefix of x.

    Args:
        x: Label ID sequence.
        pref: Prefix label ID sequence.

    Returns:
        : Whether pref is a prefix of x.
    FT)r1   range)rQ   rR   r&   r'   r'   r(   	is_prefix   s   rT   hyps	topk_idxs
topk_logpsgammabetac                    s   g }t | D ]H\}fddt|| || D }t|dd d}|d }	|d tt fdd|d	d d}
t|
dkrF||
 q||	fg q|S )
a  
    Obtained from https://github.com/espnet/espnet

    Return K hypotheses candidates for expansion from a list of hypothesis.
    K candidates are selected according to the extended hypotheses probabilities
    and a prune-by-value method. Where K is equal to beam_size + beta.

    Args:
        hyps: Hypotheses.
        topk_idxs: Indices of candidates hypothesis. Shape = [B, num_candidates]
        topk_logps: Log-probabilities for hypotheses expansions. Shape = [B, V + 1]
        gamma: Allowed logp difference for prune-by-value method.
        beta: Number of additional candidates to store.

    Return:
        k_expansions: Best K expansion hypotheses candidates.
    c                    s&   g | ]\}}t | jt| fqS r'   )rH   r   rF   )r%   kv)hypr'   r(   r/      s   & z'select_k_expansions.<locals>.<listcomp>c                 S      | d S Nr+   r'   rQ   r'   r'   r(   <lambda>       z%select_k_expansions.<locals>.<lambda>)keyr   r+   c                    s     | d kS r^   r'   r_   )rX   
k_best_expr'   r(   r`     s    c                 S   r]   r^   r'   r_   r'   r'   r(   r`     ra   )	enumeratezipmaxsortedfilterr1   r3   )rU   rV   rW   rX   rY   k_expansionsr&   hyp_ik_best_exp_valk_best_exp_idx
expansionsr'   )rX   r\   rc   r(   select_k_expansions   s    rn   c                   @   s0  e Zd ZdZ		d%dededeej deej fddZ	d	d
 Z
dd Z	d&dejdejdejdejdeej f
ddZ	d&dejdejdejdejdeej f
ddZ	d&dejdejdejdejdeej f
ddZ	d&dejdejdejdejdeej f
ddZd'defddZd(d d!Zd)d#d$ZdS )*BatchedHypsz\Class to store batched hypotheses (labels, time_indices, scores) for efficient RNNT decodingN
batch_sizeinit_lengthdevicefloat_dtypec                 C   s  |dkrt d| |dkrt d| || _|| _|| _|| _tj||tjd| _tj|| jf|tjd| _	tj|| jf|tjd| _
tj|| jf|tjd| _tj|||d| _tj|fd|tjd| _tj||tjd| _tj||d| _t| j| _dS )a>  

        Args:
            batch_size: batch size for hypotheses
            init_length: initial estimate for the length of hypotheses (if the real length is higher,
                tensors will be reallocated)
            device: device for storing hypotheses
            float_dtype: float type for scores
        r   init_length must be > 0, got batch_size must be > 0, got rr   dtyper*   rr   N)
ValueError_max_lengthrp   rr   rs   r<   zeroslongcurrent_lengths
transcript
timestampstoken_durationsscoresfulllast_timestamplast_timestamp_lastsarange_batch_indices	ones_like_ones_batch)r.   rp   rq   rr   rs   r'   r'   r(   __init__  s"   zBatchedHyps.__init__c                 C   sX   | j d | jd | jd | jd | jd | jd | jd dS )2
        Clears batched hypotheses state.
        r           r*   N)r}   fill_r~   r   r   r   r   r   r-   r'   r'   r(   clear_=  s   zBatchedHyps.clear_c                 C   sl   t j| jt | jfdd| _t j| jt | jfdd| _t j| jt | jfdd| _|  jd9  _dS )
        Allocate 2x space for tensors, similar to common C++ std::vector implementations
        to maintain O(1) insertion time complexity
        r*   r:      N)r<   r>   r~   
zeros_liker   r   rz   r-   r'   r'   r(   _allocate_moreI  s   zBatchedHyps._allocate_moreactive_indiceslabelstime_indicesr   r   c                 C   sF   |j d dkr	dS | j  | jkr|   | j|||||d dS )a  
        Add results (inplace) from a decoding step to the batched hypotheses.
        We assume that all tensors have the same first dimension, and labels are non-blanks.
        Args:
            active_indices: tensor with indices of active hypotheses (indices should be within the original batch_size)
            labels: non-blank labels to add
            time_indices: tensor of time index for each label
            scores: label scores
        r   N)r   r   r   r   r   )shaper}   rf   itemrz   r   add_results_no_checks_)r.   r   r   r   r   r   r'   r'   r(   add_results_S  s   
zBatchedHyps.add_results_c                 C   s   | j |  |7  < | j| }|| j||f< || j||f< |dur'|| j||f< t| j| |k| j| d d| j|< || j|< | j|  d7  < dS )a  
        Add results (inplace) from a decoding step to the batched hypotheses without checks.
        We assume that all tensors have the same first dimension, and labels are non-blanks.
        Useful if all the memory is pre-allocated, especially with cuda graphs
        (otherwise prefer a more safe `add_results_`)
        Args:
            active_indices: tensor with indices of active hypotheses (indices should be within the original batch_size)
            labels: non-blank labels to add
            time_indices: tensor of time index for each label
            scores: label scores
            token_durations: predicted durations for each token by TDT head
        Nr+   )	r   r}   r~   r   r   r<   wherer   r   )r.   r   r   r   r   r   active_lengthsr'   r'   r(   r   r  s   


z"BatchedHyps.add_results_no_checks_active_maskc                 C   4   | j |  | jkr|   | j|||||d dS )a  
        Add results (inplace) from a decoding step to the batched hypotheses.
        We assume that all tensors have the same first dimension, and labels are non-blanks.
        Args:
            active_mask: tensor with mask for active hypotheses (of batch_size)
            labels: non-blank labels to add
            time_indices: tensor of time index for each label
            scores: label scores
            token_durations: token durations for TDT
        )r   r   r   r   r   Nr}   rf   rz   r   add_results_masked_no_checks_r.   r   r   r   r   r   r'   r'   r(   add_results_masked_  s   
zBatchedHyps.add_results_masked_c                 C   s   t j|| j| | j| jd || j| j| jf< || j| j| jf< |dur-|| j| j| jf< t jt || j	|k| j
d | j
| j
d t jt || j	|k| j| j
| j
d t j||| j	| j	d |  j|7  _dS )af  
        Add results (inplace) from a decoding step to the batched hypotheses without checks.
        We assume that all tensors have the same first dimension, and labels are non-blanks.
        Useful if all the memory is pre-allocated, especially with cuda graphs
        (otherwise prefer a more safe `add_results_`)
        Args:
            active_mask: tensor with mask for active hypotheses (of batch_size)
            labels: non-blank labels to add
            time_indices: tensor of time index for each label
            scores: label scores
            token_durations: token durations for TDT
        )outNr+   )r<   r   r   r~   r   r}   r   r   logical_andr   r   r   r   r'   r'   r(   r     s&   z)BatchedHyps.add_results_masked_no_checks_r*   pad_idc                 C   s&   t | jdk| j| j| jd f |S )z7Get last labels. For elements without labels use pad_idr   r+   )r<   r   r}   r~   r   )r.   r   r'   r'   r(   get_last_labels  s   zBatchedHyps.get_last_labelsr   c                 C   s~   t | j| j| j| jd}|j| j |j| j |j| j |j	| j	 |j
| j
 |j| j |j| j |S )Return a copy of self)rp   rq   rr   rs   )ro   rp   rz   rr   rs   r}   copy_r~   r   r   r   r   r   )r.   batched_hypsr'   r'   r(   clone  s   zBatchedHyps.cloner9   c                 C   s"  t j| jt |jfdd| _t j| jt |jfdd| _t j| jt |jfdd| _|  j|j7  _t j|jjd | j	j
d}| j	dddf |dddf  }| jjd||jd | jjd||jd | jjd||jd |  j	|j	7  _	|  j|j7  _| j|j | j|j | S )z
        Merge two batched hypotheses structures.
        NB: this will reallocate memory

        Args:
            other: BatchedHyps
        r*   r:   r+   rx   N)r;   indexsrc)r<   r>   r~   r   r   r   rz   r   r   r}   rr   scatter_r   r   r   r   )r.   r9   indicesshifted_indicesr'   r'   r(   r@     s   "zBatchedHyps.merge_NNr!   )r*   )r   ro   )r9   ro   r   ro   )rB   rC   rD   rE   rH   r   r<   rr   rw   r   r   r   r=   r   r   r   r   r   r   r@   r'   r'   r'   r(   ro     s    
,
%
+
"
0
ro   c                   @   s   e Zd ZdZ					d"dedededeej d	eej d
e	de	de	fddZ
dd Zdd Z			d#dejdejdeej deej deej f
ddZ			d#dejdejdeej deej deej f
ddZ			d#dejdejdeej deej deej f
ddZd$d d!ZdS )%BatchedAlignmentsz
    Class to store batched alignments (logits, labels, frame_confidence).
    Size is different from hypotheses, since blank outputs are preserved
    NTFrp   
logits_dimrq   rr   rs   store_alignmentsstore_frame_confidencewith_duration_confidencec	           	      C   s@  |dkrt d| |dkrt d| || _|| _|| _|| _|| _|| _|| _|| _t	j
|| jf|t	jd| _t	j
||t	jd| _t	j
d||d| _t	j
d|t	jd| _| jrut	j
|| j|f||d| _t	j
|| jf|t	jd| _t	j
d||d| _| jrt	j
| jr|| jdgn|| jg||d| _t	j||d| _dS )a  

        Args:
            batch_size: batch size for hypotheses
            logits_dim: dimension for logits
            init_length: initial estimate for the lengths of flatten alignments
            device: device for storing data
            float_dtype: expected logits/confidence data type
            store_alignments: if alignments should be stored
            store_frame_confidence: if frame confidence should be stored
        r   rt   ru   rv   r   rx   N)ry   rp   r   rr   rs   with_frame_confidencer   with_alignmentsrz   r<   r{   r|   r   r}   logitsr   r   r   r   )	r.   rp   r   rq   rr   rs   r   r   r   r'   r'   r(   r     s6   zBatchedAlignments.__init__c                 C   s@   | j d | jd | jd | jd | jd dS )r   r   r   N)r}   r   r   r   r   r   r-   r'   r'   r(   r   V  s
   zBatchedAlignments.clear_c                 C   s   t j| jt | jfdd| _| jr0t j| jt | jfdd| _t j| jt | jfdd| _| jrBt j| jt | jfdd| _|  j	d9  _	dS )r   r*   r:   r+   r   N)
r<   r>   r   r   r   r   r   r   r   rz   r-   r'   r'   r(   r   `  s   z BatchedAlignments._allocate_morer   r   r   r   
confidencec                 C   s   |j d dkr	dS | j  | jkr|   | j| }|| j||f< | jr<|dur<|dur<|| j||f< || j	||f< | j
rJ|durJ|| j||f< | j|  d7  < dS )a  
        Add results (inplace) from a decoding step to the batched hypotheses.
        All tensors must use the same fixed batch dimension.
        Args:
            active_mask: tensor with mask for active hypotheses (of batch_size)
            logits: tensor with raw network outputs
            labels: tensor with decoded labels (can contain blank)
            time_indices: tensor of time index for each label
            confidence: optional tensor with confidence for each item in batch
        r   Nr+   )r   r}   rf   r   rz   r   r   r   r   r   r   r   )r.   r   r   r   r   r   r   r'   r'   r(   r   m  s   
zBatchedAlignments.add_results_r   c                 C   r   )a  
        Add results (inplace) from a decoding step to the batched hypotheses.
        All tensors must use the same fixed batch dimension.
        Args:
            active_mask: tensor with indices of active hypotheses (indices should be within the original batch_size)
            time_indices: tensor of time index for each label
            logits: tensor with raw network outputs
            labels: tensor with decoded labels (can contain blank)
            confidence: optional tensor with confidence for each item in batch
        )r   r   r   r   r   Nr   r.   r   r   r   r   r   r'   r'   r(   r     s
   

z%BatchedAlignments.add_results_masked_c                 C   s   || j | j| jf< | jr/|dur/|dur/|| j | j| jf< || j| j| jf< || j| j| jf< | jr?|dur?|| j| j| jf< |  j|7  _dS )a  
        Add results (inplace) from a decoding step to the batched hypotheses.
        All tensors must use the same fixed batch dimension.
        Useful if all the memory is pre-allocated, especially with cuda graphs
        (otherwise prefer a more safe `add_results_masked_`)
        Args:
            active_mask: tensor with indices of active hypotheses (indices should be within the original batch_size)
            time_indices: tensor of time index for each label
            logits: tensor with raw network outputs
            labels: tensor with decoded labels (can contain blank)
            confidence: optional tensor with confidence for each item in batch
        N)r   r   r}   r   r   r   r   r   r   r'   r'   r(   r     s   z/BatchedAlignments.add_results_masked_no_checks_r   c              
   C   sr   t | j| j| j| j| j| j| j| jd}|j	
| j	 |j
| j |j
| j |j
| j |j
| j |S )r   )rp   r   rq   rr   rs   r   r   r   )r   rp   r   rz   rr   rs   r   r   r   r}   r   r   r   r   r   )r.   batched_alignmentsr'   r'   r(   r     s    
zBatchedAlignments.clone)NNTFF)NNN)r   r   )rB   rC   rD   rE   rH   r   r<   rr   rw   boolr   r   r   r=   r   r   r   r   r'   r'   r'   r(   r     s    
	
;

+

!r   r   r   c           	         s  |du s|j jd ksJ |du rj jd n|}j  j j 
j 		
fddt|D }|dur|j  }|jrV|j	 |j
  |jr^|j tt|D ][g | _|jrsg | _tj|jd| f dd\}}d| D ]3}|jr| j fddt|D  |jr| jfddt|D  |7 qqd|S )	aB  
    Convert batched hypotheses to a list of Hypothesis objects.
    Keep this function separate to allow for jit compilation for BatchedHyps class (see tests)

    Args:
        batched_hyps: BatchedHyps object
        alignments: BatchedAlignments object, optional; must correspond to BatchedHyps if present
        batch_size: Batch Size to retrieve hypotheses. When working with CUDA graphs the batch size for all tensors
            is constant, thus we need here the real batch size to return only necessary hypotheses

    Returns:
        list of Hypothesis objects
    Nr   c                    sz   g | ]9}t |  |d | f |d  j| f t j|d  j| f  dks2ntdd d dqS )Nr   )r   r   r   r   r   r   )r
   r   r}   r<   allr   emptyr$   )r   r}   	durationsr   r   r~   r'   r(   r/     s    
z.batched_hyps_to_hypotheses.<locals>.<listcomp>T)return_countsc                    s,   g | ]}| f  | f fqS r'   r'   r%   j)alignment_labelsalignment_logitsr&   startr'   r(   r/     s    c                    s   g | ]
} | f qS r'   r'   r   )r   r&   r   r'   r(   r/   %  s    )r   r   cpur}   r~   r   rS   tolistr   r   r   r   r   r1   r   r<   unique_consecutiver3   )	r   r   rp   num_hyps
hypothesesalignment_lengths_grouped_countstimestamp_cntr'   )r   r   r   r}   r   r   r&   r   r   r   r~   r(   batched_hyps_to_hypotheses  sL   












r   r   )dataclassesr   r   typingr   r   r   r   r   r	   r<   r
   rL   rN   rH   r   rT   r=   rF   rn   ro   r   r   r'   r'   r'   r(   <module>   sJ     
.  	 N