o
    i                     @   s   d Z ddlmZ ddlZddlmZ ddlmZ ddlmZ ddlmZ ddlm	Z	 dd	lm
Z
 ddlZdd
lmZ ddlmZ ddlmZ G dd deZG dd dejjZG dd dejjZdS )zBeam search module.    )chainN)Any)Dict)List)
NamedTuple)Tuple)Union)
end_detect)PartialScorerInterface)ScorerInterfacec                   @   sv   e Zd ZU dZejed< dZee	ejf ed< e
 Zeeee	ejf f ed< e
 Zeeef ed< de
fdd	Zd
S )
HypothesiszHypothesis data type.yseqr   scorescoresstatesreturnc                 C   s0   | j | j t| jdd | j D d S )z#Convert data to JSON-friendly dict.c                 S   s   i | ]	\}}|t |qS  )float).0kvr   r   S/home/ubuntu/.local/lib/python3.10/site-packages/funasr/models/scama/beam_search.py
<dictcomp>    s    z%Hypothesis.asdict.<locals>.<dictcomp>)r   r   r   )_replacer   tolistr   r   r   items_asdictselfr   r   r   asdict   s   zHypothesis.asdictN)__name__
__module____qualname____doc__torchTensor__annotations__r   r   r   dictr   r   strr   r   r   r   r   r   r   r      s   
 
 r   c                       sB  e Zd ZdZ			d<deeef deeef dededed	ed
e	e dedef fddZ
dejde	e fddZedejdedejfddZ		d=dedejdejdejdeeeejf eeef f f
ddZdedejdejdeeeejf eeef f fddZdejdejdeejejf fdd Zed!eeef d"eeejf d#ed$eeejf d%edeeejf fd&d'Zd(ed)ed%edefd*d+Z		d=d,e	e dejdejdejde	e f
d-d.Z			/	/		0d>dejd1ejdejd2ed3ed4ed5ede	e fd6d7Zd8ed4ed2ed,e	e d9e	e de	e fd:d;Z  ZS )?BeamSearchScamaBeam search implementation.N      ?scorersweights	beam_size
vocab_sizesoseos
token_listpre_beam_ratiopre_beam_score_keyc
                    R  t    || _t | _t | _t | _tj	 | _
| D ]E\}
}||
d}|dks0|du r1qt|tsBJ |
 dt| d|| j|
< t|trR|| j|
< n|| j|
< t|tjjrc|| j
|
< q|| _|| _|| _t|| | _|| _|| _|	dur|	dkr|	| jvrt|	 d| j |	| _| jduo| j| jk ot| jdk| _dS aT  Initialize beam search.

        Args:
            scorers (dict[str, ScorerInterface]): Dict of decoder modules
                e.g., Decoder, CTCPrefixScorer, LM
                The scorer will be ignored if it is `None`
            weights (dict[str, float]): Dict of weights for each scorers
                The scorer will be ignored if its weight is 0
            beam_size (int): The number of hypotheses kept during search
            vocab_size (int): The number of vocabulary
            sos (int): Start of sequence id
            eos (int): End of sequence id
            token_list (list[str]): List of tokens for debug log
            pre_beam_score_key (str): key of scores to perform pre-beam search
            pre_beam_ratio (float): beam size in the pre-beam search
                will be `int(pre_beam_ratio * beam_size)`

        r   Nz (z$) does not implement ScorerInterfacefullz is not found in super__init__r-   r'   r,   full_scorerspart_scorersr$   nn
ModuleDictnn_dictr   get
isinstancer   typer
   Moduler0   r1   r2   intpre_beam_sizer.   n_vocabKeyErrorr4   lendo_pre_beamr   r,   r-   r.   r/   r0   r1   r2   r3   r4   r   r   w	__class__r   r   r:   '   L   







zBeamSearchScama.__init__xr   c                 C   X   t  }t  }| j D ]\}}||||< d||< qtd||tj| jg|jddgS zGet an initial hypothesis data.

        Args:
            x (torch.Tensor): The encoder output feature

        Returns:
            Hypothesis: The initial hypothesis.

                device)r   r   r   r   	r'   r,   r   
init_stater   r$   tensorr0   rT   r   rO   init_statesinit_scoresr   dr   r   r   init_hypq      

zBeamSearchScama.init_hypxsc                 C   $   t j|g| j| jd}t | |fS zAppend new token to prefix tokens.

        Args:
            xs (torch.Tensor): The prefix token
            x (int): The new token to append

        Returns:
            torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device

        dtyperT   r$   rW   rb   rT   catr^   rO   r   r   r   append_token      zBeamSearchScama.append_tokenhypx_maskpre_acoustic_embedsc           	      C   sP   t  }t  }| j D ]\}}|j|j|j| |||d\||< ||< q||fS )  Score new hypothesis by `self.full_scorers`.

        Args:
            hyp (Hypothesis): Hypothesis with prefix tokens to score
            x (torch.Tensor): Corresponding input feature

        Returns:
            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
                score dict of `hyp` that has string keys of `self.full_scorers`
                and tensor score values of shape: `(self.n_vocab,)`,
                and state dict that has string keys
                and state values of `self.full_scorers`

        ri   rj   r'   r;   r   r   r   r   )	r   rh   rO   ri   rj   r   r   r   r[   r   r   r   
score_full   s   zBeamSearchScama.score_fullidsc                 C   L   t  }t  }| j D ]\}}||j||j| |\||< ||< q||fS aa  Score new hypothesis by `self.part_scorers`.

        Args:
            hyp (Hypothesis): Hypothesis with prefix tokens to score
            ids (torch.Tensor): 1D tensor of new partial tokens to score
            x (torch.Tensor): Corresponding input feature

        Returns:
            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
                score dict of `hyp` that has string keys of `self.part_scorers`
                and tensor score values of shape: `(len(ids),)`,
                and state dict that has string keys
                and state values of `self.part_scorers`

        r'   r<   r   score_partialr   r   r   rh   ro   rO   r   r   r   r[   r   r   r   rs      
   &zBeamSearchScama.score_partialweighted_scoresc                 C   z   | d| dkr|| jd }||fS || }td |dd< |||< || jd }|| | jd }||fS a  Compute topk full token ids and partial token ids.

        Args:
            weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
            Its shape is `(self.n_vocab,)`.
            ids (torch.Tensor): The partial token ids to compute topk

        Returns:
            Tuple[torch.Tensor, torch.Tensor]:
                The topk full token ids and partial token ids.
                Their shapes are `(self.beam_size,)`

        r      infNsizetopkr.   r   r   rv   ro   top_idstmp	local_idsr   r   r   beam      zBeamSearchScama.beamprev_scoresnext_full_scoresfull_idxnext_part_scorespart_idxc                 C   V   t  }| D ]\}}| | ||  ||< q| D ]\}}| | ||  ||< q|S a  Merge scores for new hypothesis.

        Args:
            prev_scores (Dict[str, float]):
                The previous hypothesis scores by `self.scorers`
            next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
            full_idx (int): The next token id for `next_full_scores`
            next_part_scores (Dict[str, torch.Tensor]):
                scores of partial tokens by `self.part_scorers`
            part_idx (int): The new token id for `next_part_scores`

        Returns:
            Dict[str, torch.Tensor]: The new score dict.
                Its keys are names of `self.full_scorers` and `self.part_scorers`.
                Its values are scalar tensors by the scorers.

        r'   r   r   r   r   r   r   
new_scoresr   r   r   r   r   merge_scores      zBeamSearchScama.merge_scoresr   part_statesc                 C   L   t  }| D ]\}}|||< q| j D ]\}}||| |||< q|S a  Merge states for new hypothesis.

        Args:
            states: states of `self.full_scorers`
            part_states: states of `self.part_scorers`
            part_idx (int): The new token id for `part_scores`

        Returns:
            Dict[str, torch.Tensor]: The new score dict.
                Its keys are names of `self.full_scorers` and `self.part_scorers`.
                Its values are states of the scorers.

        r'   r   r<   select_stater   r   r   r   
new_statesr   r   r[   r   r   r   merge_states
     
zBeamSearchScama.merge_statesrunning_hypsc                 C   sf  g }t j| j|jd}|D ]}t j| j|j|jd}| j||||d\}	}
| jD ]}|| j| |	|  7 }q(| j	rN| j
dkr@|n|	| j
 }t || jd }| |||\}}| jD ]}||  | j| ||  7  < qZ||j7 }t| || D ]#\}}|t|| | |j|| |j|	|||| |
||d qyt|dd d	d
dtt|| j }q|S )"  Search new tokens for running hypotheses and encoded speech x.

        Args:
            running_hyps (List[Hypothesis]): Running hypotheses on beam
            x (torch.Tensor): Encoded speech feature (T, D)

        Returns:
            List[Hypotheses]: Best sorted hypotheses

        rS   ra   rl   r7   ry   r   r   r   r   c                 S      | j S Nr   rO   r   r   r   <lambda>U      z(BeamSearchScama.search.<locals>.<lambda>TkeyreverseNr$   arangerF   rT   zerosrb   rn   r;   r-   rI   r4   r}   rE   rs   r<   r   zipr   appendr   rf   r   r   r   r   sortedminrH   r.   )r   r   rO   ri   rj   	best_hypspart_idsrh   rv   r   r   r   pre_beam_scorespart_scoresr   jpart_jr   r   r   search  s@   



 

zBeamSearchScama.searchrR   r   
scama_maskmaxlenratiominlenratiomaxlenminlenc              
      sf  |du r.|dkr|j d }n|dk rdt| }ntdt||d }t||d }tdt|j d   tdt|  tdt|   |}g }	t|D ]}
t	dt|
  d}|dur|d}t
|
|d }|dd||d ddf }d}|dur| \}}}tj|d|f|jd	j|jd
}tj||fdd}t
|
|}|dd||d ddf } j||||d} |
||||	}|dkrtdd |	D |
rtd|
   nt|dkrtd  nt	dt|  qWt|	dd dd}t|dkr'td |dk rg S  ||td|d S |D ]}d fdd|jD }t	d|j||j q)|d }|j D ] \}}t|dd j| dd| j|  dd |  qOtd!|jd" td#|jt|j d" td$t|   jdurtd%d fd&d|jdd D  d'  |S )(W  Perform beam search.

        Args:
            x (torch.Tensor): Encoded speech feature (T, D)
            maxlenratio (float): Input length ratio to obtain max output length.
                If maxlenratio=0.0 (default), it uses a end-detect function
                to automatically find maximum hypothesis lengths
                If maxlenratio<0.0, its absolute value is interpreted
                as a constant max output length.
            minlenratio (float): Input length ratio to obtain min output length.

        Returns:
            list[Hypothesis]: N-best decoding results

        Nr   ry   decoder input length: max output length: min output length: 	position rb   rS   dimrl   rR   c                 S      g | ]}|  qS r   r   r   hr   r   r   
<listcomp>      z+BeamSearchScama.forward.<locals>.<listcomp>end detected at no hypothesis. Finish decoding.remained hypotheses: c                 S   r   r   r   r   r   r   r   r     r   z)BeamSearchScama.forward.<locals>.<lambda>Tr   Othere is no N-best results, perform recognition again with smaller minlenratio.皙? c                       g | ]} j | qS r   r2   r   rO   r   r   r   r         !nbest: y: {}, yseq: {}, score: {}6.2f * 3 =  for total log probability: .2fnormalized log probability: "total number of ended hypotheses: best hypo: c                    r   r   r   r   r   r   r   r     r   
) shaperD   maxr|   logginginfor(   r\   rangedebugr   r$   r   rb   torT   rd   r   post_processr	   rH   r   warningforwardjoinr   formatr   r   r   r-   r2   )r   rO   r   rj   r   r   r   r   r   
ended_hypsimask_enctoken_num_predictortoken_id_slicepre_acoustic_embeds_curbtr[   padbest
nbest_hypsr   r   r   r   r   r   r   Z  s   




:(zBeamSearchScama.forwardr   r   c              	        t dt|   jdur't dd fdd|d jdd D   ||d kr;t d	  fd
d|D }g }|D ]D}|jd  jkr~t j	
  j
 D ]#\}}	|	|j| }
|j|  |
7  < |j|j j| |
  d}qT|| q?|| q?|S )   Perform post-processing of beam search iterations.

        Args:
            i (int): The length of hypothesis tokens.
            maxlen (int): The maximum length of tokens in beam search.
            maxlenratio (int): The maximum length ratio in beam search.
            running_hyps (List[Hypothesis]): The running hypotheses in beam search.
            ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.

        Returns:
            List[Hypothesis]: The new running hypotheses.

        "the number of running hypotheses: Nr   r   c                    r   r   r   r   r   r   r   r     r   z0BeamSearchScama.post_process.<locals>.<listcomp>r   ry   -adding <eos> in the last position in the loopc                    $   g | ]}|j  |j jd qS )r   r   rf   r   r1   r   r   r   r   r         r   r   r   r   rH   r2   r   r   r   r1   r   r;   r   r<   final_scorer   r   r   r   r-   r   r   r   r   r   r   r   remained_hypsrh   r   r[   sr   r   r   r     (   
(

zBeamSearchScama.post_processNr+   N)NN)NNrR   rR   Nr   )r    r!   r"   r#   r   r(   r   r   rD   r   r:   r$   r%   r   r\   staticmethodrf   r   r   rn   rs   r   r   r   r   r   r   __classcell__r   r   rL   r   r)   $   s    


	
J




>	
kr)   c                       sN  e Zd ZdZ			d=deeef deeef dededed	ed
e	e dedef fddZ
de	e fddZedejdedejfddZddi fdedejdejdejdedeeeejf eeef f fddZdedejdejdeeeejf eeef f fddZdejdejdeejejf fd d!Zed"eeef d#eeejf d$ed%eeejf d&edeeejf fd'd(Zd)ed*ed&edefd+d,Zddi fd-e	e dejdejdejdede	e fd.d/Zddd0d0dd1i fdejd2ejdejd3ed4ed5ed6edede	e fd7d8Zd9ed5ed3ed-e	e d:e	e de	e fd;d<Z  ZS )>BeamSearchScamaStreamingr*   Nr+   r,   r-   r.   r/   r0   r1   r2   r3   r4   c
                    r5   r6   r8   rJ   rL   r   r   r:     rN   z!BeamSearchScamaStreaming.__init__r   c                 C   rP   rQ   rU   rX   r   r   r   r\   C  r]   z!BeamSearchScamaStreaming.init_hypr^   rO   c                 C   r_   r`   rc   re   r   r   r   rf   [  rg   z%BeamSearchScamaStreaming.append_tokenrh   ri   rj   cachec           
   	   C   sR   t  }t  }| j D ]\}}	|	j|j|j| ||||d\||< ||< q||fS )rk   ri   rj   r  rm   )
r   rh   rO   ri   rj   r  r   r   r   r[   r   r   r   rn   j  s   z#BeamSearchScamaStreaming.score_fullro   c                 C   rp   rq   rr   rt   r   r   r   rs     ru   z&BeamSearchScamaStreaming.score_partialrv   c                 C   rw   rx   r{   r~   r   r   r   r     r   zBeamSearchScamaStreaming.beamr   r   r   r   r   c                 C   r   r   r   r   r   r   r   r     r   z%BeamSearchScamaStreaming.merge_scoresr   r   c                 C   r   r   r   r   r   r   r   r     r   z%BeamSearchScamaStreaming.merge_statesr   c                 C   sh  g }t j| j|jd}|D ]}t j| j|j|jd}	| j|||||d\}
}| jD ]}|	| j| |
|  7 }	q)| j	rO| j
dkrA|	n|
| j
 }t || jd }| |||\}}| jD ]}|	|  | j| ||  7  < q[|	|j7 }	t| |	| D ]#\}}|t|	| | |j|| |j|
|||| |||d qzt|dd d	d
dtt|| j }q|S )r   rS   ra   r  r7   ry   r   c                 S   r   r   r   r   r   r   r   r   .  r   z1BeamSearchScamaStreaming.search.<locals>.<lambda>Tr   Nr   )r   r   rO   ri   rj   r  r   r   rh   rv   r   r   r   r   r   r   r   r   r   r   r   r     s@   




 

zBeamSearchScamaStreaming.searchrR   r   r   r   r   r   r   c	              
      s*  |du r.|dkr|j d }n|dk rdt| }ntdt||d }t||d }tdt|j d   tdt|  tdt|  |d }	g }
t|D ]}td	t|  d}d}|dur| \}}}t	j
|d|f|jd
j|jd}t	j||fdd}t||}|dd||d ddf } j|	||||d d} |||||
}	|dkrtdd |
D |rtd|   nt|	dkrtd  ntdt|	  qVt|
dd dd}t|dkr	td |dk rg S  ||td|d S |D ]}d fdd|jD }td|j||j q|d }|j D ] \}}t|dd j| d d!| j|  dd"|  q1td#|jd$ td%|jt|j d$ td&t|   jdurtd'd fd(d|jdd D  d)  |S )*r   Nr   r   ry   r   r   r   r   r   r   rS   r   decoderr  rR   c                 S   r   r   r   r   r   r   r   r     r   z4BeamSearchScamaStreaming.forward.<locals>.<listcomp>r   r   r   c                 S   r   r   r   r   r   r   r   r     r   z2BeamSearchScamaStreaming.forward.<locals>.<lambda>Tr   r   r   r   c                    r   r   r   r   r   r   r   r     r   r   r   r   r   r   r   r   r   r   r   r   c                    r   r   r   r   r   r   r   r     r   r   )r   rD   r   r|   r   r   r(   r   r   r$   r   rb   r   rT   rd   r   r   r   r	   rH   r   r   r   r   r   r   r   r   r   r-   r2   )r   rO   r   rj   r   r   r   r   r  r   r   r   r   r   r   r   r[   r   r   r   r   r   r   r   r   r   r   r   3  s   

:(z BeamSearchScamaStreaming.forwardr   r   c              	      r   )r   r   Nr   r   c                    r   r   r   r   r   r   r   r     r   z9BeamSearchScamaStreaming.post_process.<locals>.<listcomp>r   ry   r   c                    r   r   r   r   r   r   r   r     r   r   r   r   r   r   r   r   r     r  z%BeamSearchScamaStreaming.post_processr  )r    r!   r"   r#   r   r(   r   r   rD   r   r:   r   r\   r  r$   r%   rf   r'   r   r   rn   rs   r   r   r   r   r   r   r  r   r   rL   r   r    s   


	
J
#



?	

qr  )r#   	itertoolsr   r   typingr   r   r   r   r   r   r$   funasr.metrics.commonr	   2funasr.models.transformer.scorers.scorer_interfacer
   r   r   r=   rC   r)   r  r   r   r   r   <module>   s&       U