o
    }oiGl                     @   s2  d dl Z d dlZd dlZd dlmZmZ d dlmZ d dlm	Z	 d dl
mZmZmZmZmZmZmZ d dlZd dlmZ d dlmZ d dlmZmZ d	Zd
d Zdedeeef fddZdd ZG dd deZ G dd deZ!G dd dZ"dee" dedee" fddZ#G dd deZ$G dd de$Z%dS )     N)ABCabstractmethod)defaultdict)Path)AnyDictList
NamedTupleOptionalTupleUnion)wer)
DictConfig)TW_BREAKkaldifst_importerzriva decoder is not installed or is installed incorrectly.
please run `bash scripts/installers/install_riva_decoder.sh` or `pip install riva-asrlib-decoder` to install.c               	   C   s:   zddl m  m  m}  W | S  ttfy   ttw )z`Import helper function that returns Riva asrlib decoder package or raises ImportError exception.r   N)"riva.asrlib.decoder.python_decoderasrlibdecoderpython_decoderImportErrorModuleNotFoundError!RIVA_DECODER_INSTALLATION_MESSAGE)riva_decoder r   f/home/ubuntu/.local/lib/python3.10/site-packages/nemo/collections/asr/parts/submodules/wfst_decoder.pyriva_decoder_importer#   s   r   confreturnc                 C   sF   i }|   D ]}|ds t| |}|jjdkr|nt|||< q|S )z
    Helper function for parsing Riva configs (namely BatchedMappedDecoderCudaConfig) into a dictionary.

    Args:
      conf:
        Inner Riva config.

    Returns:
      Dictionary corresponding to the Riva config.
    __builtins)__dir__
startswithgetattr	__class__
__module___riva_config_to_dict)r   resultname	attributer   r   r   r%   ,   s   

r%   c                 C   s>   |  D ]\}}t|trtt| || qt| || qdS )a"  
    Helper function for filling Riva configs (namely BatchedMappedDecoderCudaConfig)
    according to the corresponding NeMo config.

    Note: in-place for the first argument.

    Args:
      riva_conf:
        Inner Riva config.

      nemo_conf:
        Corresponding NeMo config.
    N)items
isinstancer   _fill_inner_riva_config_r"   setattr)	riva_conf	nemo_confnemo_knemo_vr   r   r   r+   A   s
   
r+   c                       s    e Zd ZdZ fddZ  ZS )RivaDecoderConfigz1
    NeMo config for the RivaGpuWfstDecoder.
    c                    s   z4t  }| }d|jj_d|_d|jj_d|jj_d|j_	d|j_
d|j_d|j_d|jj_t|}W n ty?   i }Y nw t | d S )	N      $@2   g      4@i'  Ti              )r   BatchedMappedDecoderCudaConfigonline_optslattice_postprocessor_optsacoustic_scalen_input_per_chunkdecoder_optsdefault_beam
max_activedeterminize_latticemax_batch_sizenum_channelsframe_shift_secondsword_ins_penaltyr%   r   super__init__)selfr   configcontentr#   r   r   rD   [   s"   



zRivaDecoderConfig.__init__)__name__r$   __qualname____doc__rD   __classcell__r   r   rH   r   r1   V   s    r1   c                   @   s>   e Zd ZU dZee ed< ee ed< ee ed< eed< dS )WfstNbestUnitzF
    Container for a single RivaGpuWfstDecoder n-best hypothesis.
    words	timesteps	alignmentscoreN)	rI   r$   rJ   rK   r   str__annotations__intfloatr   r   r   r   rM   p   s   
 rM   c                
   @   s   e Zd ZdZdeeee ee ee ef  fddZdd Z	dd Z
d	d
 Zdedeeeee ee ee ef f fddZedd Zedd Zedd Zedd ZdS )WfstNbestHypothesiszm
    Container for the RivaGpuWfstDecoder n-best results represented as a list of WfstNbestUnit objects.
    raw_hypothesesc                 C   s(  t |D ]Z\}}t|d tsJ |d  t|d ts*J |d  d|d  t|d ts8J |d  t|d tsFJ |d  t|d t|d ks^t|d dks^J dqtdd |D d	d
 d| _t| j| _dd | jD | _t| jd j	dk| _
t| jd jdk| _d S )Nr   r4   z,       zwords do not match timestepsc                 S   s   g | ]}t | qS r   )rM   ).0rhr   r   r   
<listcomp>       z0WfstNbestHypothesis.__init__.<locals>.<listcomp>c                 S      | j S N)rQ   )hypr   r   r   <lambda>   s    z.WfstNbestHypothesis.__init__.<locals>.<lambda>)keyc                 S   s   g | ]}t |jqS r   )lenrN   rZ   hr   r   r   r\          )	enumerater*   tuplerU   rc   sorted_hypotheses_shape0_shape1rO   _has_timestepsrP   _has_alignment)rE   rW   ir[   r   r   r   rD      s   (0zWfstNbestHypothesis.__init__c                 c   s    | j E d H  d S r_   rj   rE   r   r   r   __iter__   s   zWfstNbestHypothesis.__iter__c                 C   s
   | j | S r_   rp   )rE   indexr   r   r   __getitem__   s   
zWfstNbestHypothesis.__getitem__c                 C   r^   r_   )shape0rq   r   r   r   __len__   s   zWfstNbestHypothesis.__len__rs   new_unitc                 C   s  d|  kr| j k sJ  J | jrt|d t|d ks+| js)t|d dks+J |dkrBt| jdksq|d | j|d  jksq|| j d krU| j|d  j|d ksq| j|d  j|d   krn| j|d  jksqJ  J t|tszt| }|| j|< t|j| j|< dS )z
        Replaces a WfstNbestUnit by index.

        Note: in-place operation.

        Args:
          index:
            Index of the unit to be replaced.

          new_unit:
            Replacement unit.
        r   r4   rY   N)	ru   has_timestepsrc   rj   rQ   r*   rM   rN   rl   )rE   rs   rw   r   r   r   replace_unit_   s    &8

z!WfstNbestHypothesis.replace_unit_c                 C   r^   r_   )rk   rq   r   r   r   ru         zWfstNbestHypothesis.shape0c                 C   r^   r_   )rl   rq   r   r   r   shape1   rz   zWfstNbestHypothesis.shape1c                 C   r^   r_   )rm   rq   r   r   r   rx      rz   z!WfstNbestHypothesis.has_timestepsc                 C   r^   r_   )rn   rq   r   r   r   has_alignment   rz   z!WfstNbestHypothesis.has_alignmentN)rI   r$   rJ   rK   r   rR   rT   rU   rD   rr   rt   rv   r   rM   ry   propertyru   r{   rx   r|   r   r   r   r   rV   {   s&    *"
#


rV   
hypothesestokenword_disambig_strc              	   C   s~  t | }|D ]}t|D ]\}}g }t|jD ]\}}||kr%|| qt|dkrt|j}	t|j}
t|	}t|d dkrG|| g g }}d}t|ddd |ddd D ]9\}}||	|| 7 }d	|	|d | 
t | ddd }|| ||
|| |
| g 7 }|d }q]||k r||	|| 7 }||
|| 7 }||t|t||j|jf qq|S )a%  
    Searches for tokenwords in the input hypotheses and collapses them into words.

    Args:
      hypotheses:
        List of input WfstNbestHypothesis.

      tokenword_disambig_str:
        Tokenword disambiguation symbol (e.g. `#1`).

    Returns:
      List of WfstNbestHypothesis.
    r   rX   r4   N )copydeepcopyrg   rN   appendrc   listrO   zipjoinreplacer   ry   rh   rP   rQ   )r~   r   new_hypothesesr`   kh_unit	twds_listro   word	old_wordsold_timesteps	words_len	new_wordsnew_timestepsj_prevjnew_wordr   r   r   collapse_tokenword_hypotheses   s>   





&.

 r   c                   @   s  e Zd ZdZ		d4dedededee ded	efd
dZ	e
d5dee fddZe
defddZe
dd Zedd ZejdefddZe
defddZedd ZejdefddZe
defddZedd Zejdefd dZe
defd!d"Zed#d$ Zed%d& Ze
d'ejd(ejd)ee fd*d+Ze
d,ee d)ee fd-d.Ze
d'ejd(ejd/ee d)eeef fd0d1Ze
d'ejd(ejd/ee d)eeee f fd2d3ZdS )6AbstractWFSTDecodera  
    Used for performing WFST decoding of the logprobs.

    Args:
      lm_fst:
        Language model WFST.

      decoding_mode:
        Decoding mode. E.g. `nbest`.

      beam_size:
        Beam width (float) for the WFST decoding.

      config:
        Decoder config.

      tokenword_disambig_id:
        Tokenword disambiguation index. Set to -1 to disable the tokenword mode.

      lm_weight:
        Language model weight in decoding.
          ?lm_fstdecoding_mode	beam_sizerF   tokenword_disambig_id	lm_weightc                 C   sl   || _ || _|| _| jdk| _|| _d\| _| _d\| _| _d\| _	| _
| _| | | | |   d S )Nr   )NN)NNN)_lm_fst
_beam_size_tokenword_disambig_id_open_vocabulary_decoding
_lm_weight_id2word_word2id	_id2token	_token2id_decoding_mode_config_decoder_set_decoding_mode_set_decoder_config_init_decoder)rE   r   r   r   rF   r   r   r   r   r   rD     s   	

zAbstractWFSTDecoder.__init__Nc                 C      d S r_   r   rE   rF   r   r   r   r   )     z'AbstractWFSTDecoder._set_decoder_configc                 C   r   r_   r   rE   r   r   r   r   r   -  r   z&AbstractWFSTDecoder._set_decoding_modec                 C   r   r_   r   rq   r   r   r   r   1  r   z!AbstractWFSTDecoder._init_decoderc                 C   r^   r_   )r   rq   r   r   r   r   5  rz   z!AbstractWFSTDecoder.decoding_modevaluec                 C      |  | d S r_   )_decoding_mode_setterrE   r   r   r   r   r   9     c                 C   r   r_   r   r   r   r   r   r   =  r   z)AbstractWFSTDecoder._decoding_mode_setterc                 C   r^   r_   )r   rq   r   r   r   r   A  rz   zAbstractWFSTDecoder.beam_sizec                 C   r   r_   )_beam_size_setterr   r   r   r   r   E  r   c                 C   r   r_   r   r   r   r   r   r   I  r   z%AbstractWFSTDecoder._beam_size_setterc                 C   r^   r_   )r   rq   r   r   r   r   M  rz   zAbstractWFSTDecoder.lm_weightc                 C   r   r_   )_lm_weight_setterr   r   r   r   r   Q  r   c                 C   r   r_   r   r   r   r   r   r   U  r   z%AbstractWFSTDecoder._lm_weight_setterc                 C   r^   r_   )r   rq   r   r   r   r   Y  rz   z)AbstractWFSTDecoder.tokenword_disambig_idc                 C   r^   r_   )r   rq   r   r   r   open_vocabulary_decoding]  rz   z,AbstractWFSTDecoder.open_vocabulary_decoding	log_probslog_probs_lengthr   c                 C      dS )  
        Decodes logprobs into recognition hypotheses.

        Args:
          log_probs:
            A torch.Tensor of the predicted log-probabilities of shape [Batch, Time, Vocabulary].

          log_probs_length:
            A torch.Tensor of length `Batch` which contains the lengths of the log_probs elements.

        Returns:
          List of recognition hypotheses.
        Nr   )rE   r   r   r   r   r   decodea  s   zAbstractWFSTDecoder.decoder~   c                 C   r   )
        Does various post-processing of the recognition hypotheses.

        Args:
          hypotheses:
            List of recognition hypotheses.

        Returns:
          List of processed recognition hypotheses.
        Nr   rE   r~   r   r   r   _post_decoder  s   z AbstractWFSTDecoder._post_decodereference_textsc                 C   r   )  
        Calibrates LM weight to achieve the best WER for given logprob-text pairs.

        Args:
          log_probs:
            A torch.Tensor of the predicted log-probabilities of shape [Batch, Time, Vocabulary].

          log_probs_length:
            A torch.Tensor of length `Batch` which contains the lengths of the log_probs elements.

          reference_texts:
            List of reference word sequences.

        Returns:
          Pair of (best_lm_weight, best_wer).
        Nr   rE   r   r   r   r   r   r   calibrate_lm_weight     z'AbstractWFSTDecoder.calibrate_lm_weightc                 C   r   )  
        Calculates the oracle (the best possible WER for given logprob-text pairs.

        Args:
          log_probs:
            A torch.Tensor of the predicted log-probabilities of shape [Batch, Time, Vocabulary].

          log_probs_length:
            A torch.Tensor of length `Batch` which contains the lengths of the log_probs elements.

          reference_texts:
            List of reference word sequences.

        Returns:
          Pair of (oracle_wer, oracle_wer_per_utterance).
        Nr   r   r   r   r   calculate_oracle_wer  r   z(AbstractWFSTDecoder.calculate_oracle_wer)r   r   r_   ) rI   r$   rJ   rK   r   rR   rU   r
   rT   rD   r   r   r   r   r}   r   setterr   r   r   r   r   r   r   torchTensorr   r   r   r   r   r   r   r   r   r   r      s    






 
r   c                       s  e Zd ZdZ						d=ded	eef d
ededed de	dede	f fddZ
d>ded fddZdd Zd
efddZdefddZdefddZdefddZed d! Zejdefd"d!Zdefd#d$Zd%ejd&ejd'ee fd(d)Zd%ejd&ejd'ee fd*d+Zd%ejd&ejd'ed, fd-d.Zd%ejd&ejd'eee ed, f fd/d0Zd1eee ed, f d'eee ed, f fd2d3Zd%ejd&ejd4ee d'eeef fd5d6Zd%ejd&ejd4ee d'eeee f fd7d8Z d9d: Z!d;d< Z"  Z#S )?RivaGpuWfstDecoderaU  
    Used for performing WFST decoding of the logprobs with the Riva WFST decoder.

    Args:
      lm_fst:
        Kaldi-type language model WFST or its path.

      decoding_mode:
        Decoding mode. Choices: `nbest`, `mbr`, `lattice`.

      beam_size:
        Beam width (float) for the WFST decoding.

      config:
        Riva Decoder config.

      tokenword_disambig_id:
        Tokenword disambiguation index. Set to -1 to disable the tokenword mode.

      lm_weight:
        Language model weight in decoding.

      nbest_size:
        N-best size for decoding_mode == `nbest`
    mbrr2   Nr   r   r4   r   zkaldifst.StdFstr   r   rF   r1   r   r   
nbest_sizec                    s&   || _ d | _t |||||| d S r_   )_nbest_size_load_word_latticerC   rD   )rE   r   r   r   rF   r   r   r   rH   r   r   rD     s   
zRivaGpuWfstDecoder.__init__c                 C   sj   |d u s
t |dkrt }t|dst  td| j|jj_| j	|jj
j |jj
_| j|jj
_|| _d S )Nr   r7   z/Unexpected config error. Please debug manually.)rc   r1   hasattrr   RuntimeErrorr   r7   r;   lattice_beamr   r8   r9   lm_scaler   nbestr   r   r   r   r   r     s   

z&RivaGpuWfstDecoder._set_decoder_configc                    s  t  }t }ddlm} || _| j}d }d }t|ttfr$|j	
|}n$t||j	|jfr?|}tjdd}||j |j}n	tdt| |jdd }| jd u rdd	 t|j d
D | _| jtt| j }|d tfdd| _| D ]	\}	}
|
| j|	< q| jd u rdd	 t|j d
D | _| jtt| j }|d  t fdd| _| D ]	\}	}
|
| j|	< qtjdd#}|j |j |! }t"|| j# |$|||j|| _%W d    n1 sw   Y  |r|&  d S d S )Nr   )load_word_latticezw+t)modezUnsupported lm_fst type: z#0r4   c                 S   *   i | ]}t |d d |d d qS 	r4   r   rT   splitrZ   liner   r   r   
<dictcomp>      z4RivaGpuWfstDecoder._init_decoder.<locals>.<dictcomp>
z<unk>c                          S r_   r   r   )word_unk_idr   r   ra         z2RivaGpuWfstDecoder._init_decoder.<locals>.<lambda>c                 S   r   r   r   r   r   r   r   r     r   c                      r   r_   r   r   )token_unk_idr   r   ra     r   )'r   r   +nemo.collections.asr.parts.utils.wfst_utilsr   r   r   r*   r   rR   StdVectorFstreadStdConstFsttempfileNamedTemporaryFilewriter'   
ValueErrortypeinput_symbolsfindr   output_symbolsstripr   r#   mapreversedr)   r   r   r   r   
write_textr6   r+   r   BatchedMappedDecoderCudar   close)rE   kaldifstr   r   r   tmp_fsttmp_fst_filenum_tokens_with_blankword2idr   vtoken2id	words_tmprF   r   )r   r   r   r     sZ   


z RivaGpuWfstDecoder._init_decoderc                 C   sN   |dkr	| j | _n|dkr| j| _n|dkr| j| _ntd| || _d S )Nr   r   latticezUnsupported mode: )_decode_nbest_decode_decode_mbr_decode_latticer   r   r   r   r   r   r     s   



z%RivaGpuWfstDecoder._set_decoding_moder   c                 C   4   | j |kr|   || jjj_|   || _ d S d S r_   )r   _release_gpu_memoryr   r7   r;   r   r   r   r   r   r   r   *     

z$RivaGpuWfstDecoder._beam_size_setterc                 C   s@   | j |kr|   || jjjj | jjj_|   || _ d S d S r_   )r   r  r   r7   r8   r9   r   r   r   r   r   r   r   1  s   


z$RivaGpuWfstDecoder._lm_weight_setterc                 C   s   | j |kr| | d S d S r_   )r   r   r   r   r   r   r   :  s   
z(RivaGpuWfstDecoder._decoding_mode_setterc                 C   r^   r_   )r   rq   r   r   r   r   >  rz   zRivaGpuWfstDecoder.nbest_sizec                 C   r   r_   )_nbest_size_setterr   r   r   r   r   B  r   c                 C   r  r_   )r   r  r   r7   r8   r   r   r   r   r   r   r  F  r  z%RivaGpuWfstDecoder._nbest_size_setterr   r   r   c              	   C   s   | j ||}g }|D ]R}g }|D ]B}g g }}	t|j|jD ]\}
}|
dkr6|| j|
  |	t| qdd |jD }|j	}|t
t
|t
|	t
||g q|tt
| q|S )a  
        Decodes logprobs into recognition hypotheses via the N-best decoding decoding.

        Args:
          log_probs:
            A torch.Tensor of the predicted log-probabilities of shape [Batch, Time, Vocabulary].

          log_probs_length:
            A torch.Tensor of length `Batch` which contains the lengths of the log_probs elements.

        Returns:
          List of WfstNbestHypothesis with empty alignment and trivial score.
        r   c                 S   s   g | ]}|d  qS )r4   r   )rZ   ilabelr   r   r   r\   g  r]   z4RivaGpuWfstDecoder._decode_nbest.<locals>.<listcomp>)r   decode_nbestr   rN   word_start_times_secondsr   r   rT   ilabelsrQ   rh   rV   )rE   r   r   hypotheses_nbestr~   nhnbest_containerre   rN   rO   wtrP   rQ   r   r   r   r   M  s    
$z RivaGpuWfstDecoder._decode_nbestc           	   
   C   s~   | j ||}g }|D ]1}g g }}|D ]}||d  |t|d  q|tttt|t|t dgg q|S )a  
        Decodes logprobs into recognition hypotheses via the Minimum Bayes Risk (MBR) decoding.

        Args:
          log_probs:
            A torch.Tensor of the predicted log-probabilities of shape [Batch, Time, Vocabulary].

          log_probs_length:
            A torch.Tensor of length `Batch` which contains the lengths of the log_probs elements.

        Returns:
          List of WfstNbestHypothesis with empty alignment and trivial score.
        r   r4   r5   )r   
decode_mbrr   rT   rV   rh   )	rE   r   r   hypotheses_mbrr~   re   rN   rO   er   r   r   r  m  s   
,zRivaGpuWfstDecoder._decode_mbrKaldiWordLatticec              	      s   t  :}|j d}| j||dd tt|D d|  | || j| j  fddtt|D }W d   |S 1 sAw   Y  |S )ax  
        Decodes logprobs into kaldi-type lattices.

        Args:
          log_probs:
            A torch.Tensor of the predicted log-probabilities of shape [Batch, Time, Vocabulary].

          log_probs_length:
            A torch.Tensor of length `Batch` which contains the lengths of the log_probs elements.

        Returns:
          List of KaldiWordLattice.
        z.latsc                 S   s   g | ]}t |qS r   rR   rZ   ro   r   r   r   r\     r]   z6RivaGpuWfstDecoder._decode_lattice.<locals>.<listcomp>zark,t:c                    s   g | ]} t | qS r   r  r  hypotheses_latticer   r   r\     s    N)	r   r   r'   r   decode_write_latticerangerc   r   r   )rE   r   r   tmp_lattmp_lat_namer~   r   r  r   r    s   
 

		z"RivaGpuWfstDecoder._decode_latticec                 C   s8   |  }|tjd  }| ||}| |}|S )r   cpu)
contiguoustor   longr  r   )rE   r   r   r~   r   r   r   r     s
   
zRivaGpuWfstDecoder.decoder~   c                 C   s&   | j r| jdv rt|| j| j S |S )r   )r   r   )r   r   r   r   r   r   r   r   r   r     s   zRivaGpuWfstDecoder._post_decoder   c                 C   s   t |t |ks
J | j}| j}d| _dtd}}tddD ]&}|d | _| ||}	tdd |	D |}
t||
 |
|k rE| j|
}}q|| _|| _||fS )	r   r   g      infr4      
   c                 S   s   g | ]
}d  |d jqS ) r   )r   rN   rd   r   r   r   r\     s    z:RivaGpuWfstDecoder.calibrate_lm_weight.<locals>.<listcomp>)rc   r   r   rU   r  r   word_error_rateprint)rE   r   r   r   decoding_mode_backuplm_weight_backupbest_lm_weightbest_werr   r~   r   r   r   r   r     s    

z&RivaGpuWfstDecoder.calibrate_lm_weightc                    s    j rtt|t|ksJ  j}d _ ||}g g g }}}t||D ]/\}	}
 fdd|
  D }||rAt|nd ||		| ||d |d   q(| _t
|t
| |fS )r   r   c                    s   g | ]} j | qS r   )r   )rZ   r  rq   r   r   r\     rf   z;RivaGpuWfstDecoder.calculate_oracle_wer.<locals>.<listcomp>r4   r   )r   NotImplementedErrorrc   r   r   r   r   r   r   edit_distancesum)rE   r   r   r   r'  latticesscorescountswer_per_uttr   textword_idsr   rq   r   r     s   z'RivaGpuWfstDecoder.calculate_oracle_werc                 C   s(   z| ` W n	 ty   Y nw t  dS )zS
        Forces freeing of GPU memory by deleting the Riva decoder object.
        N)r   	Exceptiongccollectrq   r   r   r   r    s   z&RivaGpuWfstDecoder._release_gpu_memoryc                 C   s   |    d S r_   )r  rq   r   r   r   __del__  s   zRivaGpuWfstDecoder.__del__)r   r2   Nr   r   r4   r_   )$rI   r$   rJ   rK   r   r   rR   rU   r
   rT   rD   r   r   r   r   r   r   r}   r   r   r  r   r   r   rV   r   r  r  r   r   r   r   r   r  r7  rL   r   r   rH   r   r     s    9	

 



#
"r   )&r   r5  r   abcr   r   collectionsr   pathlibr   typingr   r   r   r	   r
   r   r   r   jiwerr   r%  	omegaconfr   r   r   r   r   r   rR   r%   r+   r1   rM   rV   r   r   r   r   r   r   r   <module>   s:   $	P
0 3