o
    iS                     @   sp  d dl Z d dlmZmZ d dlmZ d dlmZ d dl	m
Z
 d dlmZ d dlmZ d dlmZ d d	lmZ d d
lmZ d dlmZ d dlmZ d dlmZ d dlmZmZ G dd de jjZe ddG dd deZ!G dd de jjZ"e ddG dd de jjZ#e ddG dd de jjZ$e ddG dd deZ%e ddG dd de jjZ&dS )    N)ListTuple)tables)utils)repeat)DecoderLayer)	LayerNorm)PositionalEncoding)MultiHeadedAttention)make_pad_mask)BaseTransformerDecoder)PositionwiseFeedForward)"PositionwiseFeedForwardDecoderSANM)MultiHeadedAttentionSANMDecoderMultiHeadedAttentionCrossAttc                       sP   e Zd ZdZ		d fdd	ZdddZdd	d
ZdddZ	dddZ  Z	S )DecoderLayerSANMa  Single decoder layer module.

    Args:
        size (int): Input dimension.
        self_attn (torch.nn.Module): Self-attention module instance.
            `MultiHeadedAttention` instance can be used as the argument.
        src_attn (torch.nn.Module): Self-attention module instance.
            `MultiHeadedAttention` instance can be used as the argument.
        feed_forward (torch.nn.Module): Feed-forward module instance.
            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
            can be used as the argument.
        dropout_rate (float): Dropout rate.
        normalize_before (bool): Whether to use layer_norm before the first block.
        concat_after (bool): Whether to concat attention layer's input and output.
            if True, additional linear will be applied.
            i.e. x -> x + linear(concat(x, att(x)))
            if False, no additional linear will be applied. i.e. x -> x + att(x)


    TFc                    s   t t|   || _|| _|| _|| _t|| _|dur!t|| _	|dur*t|| _
tj|| _|| _|| _| jrNtj|| || _tj|| || _d| _g | _dS )z!Construct an DecoderLayer object.NF)superr   __init__size	self_attnsrc_attnfeed_forwardr   norm1norm2norm3torchnnDropoutdropoutnormalize_beforeconcat_afterLinearconcat_linear1concat_linear2reserve_attnattn_mat)selfr   r   r   r   dropout_rater   r    	__class__ T/home/ubuntu/.local/lib/python3.10/site-packages/funasr/models/paraformer/decoder.pyr   0   s$   



zDecoderLayerSANM.__init__Nc                 C   s   |}| j r
| |}| |}|}| jr+| j r| |}| ||\}}|| | }| jdur_|}| j r:| |}| jrO| j|||dd\}	}
| j	
|
 n	| j|||dd}	|| |	 }|||||fS )"  Compute decoded features.

        Args:
            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
            cache (List[torch.Tensor]): List of cached tensors.
                Each tensor shape should be (#batch, maxlen_out - 1, size).

        Returns:
            torch.Tensor: Output tensor(#batch, maxlen_out, size).
            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
            torch.Tensor: Encoded memory mask (#batch, maxlen_in).

        NTret_attnF)r   r   r   r   r   r   r   r   r$   r%   append)r&   tgttgt_maskmemorymemory_maskcacheresidualx_
x_src_attnr%   r*   r*   r+   forwardN   s(   




zDecoderLayerSANM.forwardc           
      C   t   |}|  |}| |}|}| jd ur&| |}| j|||d\}}|| }|}| |}| j|||dd\}}	|	S Nr4   Tr-   r   r   r   r   r   r   
r&   r0   r1   r2   r3   r4   r5   r6   r8   r%   r*   r*   r+   get_attn_mat{      




zDecoderLayerSANM.get_attn_matc                 C   s   |}| j r
| |}| |}|}| jr2| j r| |}| jr!d}| j|||d\}}|| | }| jdurM|}| j rA| |}|| | ||| }|||||fS )r,   Nr<   )	r   r   r   r   r   trainingr   r   r   r&   r0   r1   r2   r3   r4   r5   r6   r*   r*   r+   forward_one_step   s$   




z!DecoderLayerSANM.forward_one_stepr   c           	      C   s   |}| j r
| |}| |}|}| jr,| j r| |}| |d|\}}|| | }| jdurK|}| j r;| |}| j|||||\}}|| }||||fS )r,   N)	r   r   r   r   r   r   r   r   forward_chunk)	r&   r0   r2   
fsmn_cache	opt_cache
chunk_size	look_backr5   r6   r*   r*   r+   rD      s"   




zDecoderLayerSANM.forward_chunk)TFNN)NNNr   )
__name__
__module____qualname____doc__r   r9   r?   rC   rD   __classcell__r*   r*   r(   r+   r      s    

-
+r   decoder_classesParaformerSANMDecoderc                1       s  e Zd ZdZdddddddddd	edd	dd
ddddddddfdedededededededededededededed ed!ed"ed#e	e d$ed%ed&ed'e
d(ed)ef0 fd*d+Z					dDd,ejd-ejd.ejd/ejd0ejd1ed2ed3eejejf fd4d5Zd6d7 Zd,ejd-ejd.ejd/ejfd8d9Zd,ejd-ejd.ejd/ejfd:d;Z	dEd<ejd=ejd>ed3eejejf fd?d@Z	dEd=ejdAejd<ejd>e	ej d3eeje	ej f f
dBdCZ  ZS )FrP   
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2006.01713
             皙?        embedTF   r   N      )   decoderzseq2seq/decoder
vocab_sizeencoder_output_sizeattention_headslinear_units
num_blocksr'   positional_dropout_rateself_attention_dropout_ratesrc_attention_dropout_rateinput_layeruse_output_layerwo_input_layerr   r    att_layer_numkernel_size
sanm_shfit	lora_list	lora_rank
lora_alphalora_dropoutchunk_multiply_factor!tf2torch_tensor_name_prefix_torchtf2torch_tensor_name_prefix_tfc                    s  t  j||||
||
d | |rd | _n;|
dkr'tjtj| | _n*|
dkrJtjtj| tj tj	tj
 | || _ntd|
 
| _| jr\t | _|rgtj || _nd | _|| _|| _d u rzd d t| 	
fdd| _|| d	krd | _nt||  
fd
d| _td 
fdd| _|| _|| _|| _d S )Nr]   r^   r'   rb   re   rf   pos_enc_classr   rW   linearz'only 'embed' or 'linear' is supported: r[      c                    s8   t  t dt 	t 
S )Nrj   )r   r   r   r   lnumattention_dimr_   r    r'   ri   r`   rm   rn   rk   rl   r   rj   rc   rd   r*   r+   <lambda>1  s&    
	z0ParaformerSANMDecoder.__init__.<locals>.<lambda>r   c                    s(   t  t ddd t S )Nr   rv   )r   r   r   rw   )rz   r    r'   ri   r`   r   rc   r*   r+   r{   J  s    
c                    s   t  d d t S N)r   r   rw   )rz   r    r'   r`   r   r*   r+   r{   Y  s    
)r   r   rW   r   r   
Sequential	Embeddingr!   r   r   ReLU
ValueErrorr   
after_normoutput_layerrh   ra   r   decoders	decoders2	decoders3rp   rq   ro   )r&   r]   r^   r_   r`   ra   r'   rb   rc   rd   re   rf   rg   rs   r   r    rh   ri   rj   rk   rl   rm   rn   ro   rp   rq   r(   ry   r+   r      sj   


$
zParaformerSANMDecoder.__init__hs_padhlens	ys_in_pad
ys_in_lens
chunk_maskreturn_hiddenreturn_bothreturnc                 C   sR  |}t j||jddddddf }	|}
t j||
jddddddf }|durL|| }|	d|dkrLtj||ddddddf fdd}|}| ||	|
|\}}	}
}}| jdurm| ||	|
|\}}	}
}}| ||	|
|\}}	}
}}| j	r| 
|}|	d}| jdur|du r| |}||fS |r| |}|||fS ||fS )@  Forward decoder.

        Args:
            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
            hlens: (batch)
            ys_in_pad:
                input token ids, int64 (batch, maxlen_out)
                if input_layer == "embed"
                input tensor (batch, maxlen_out, #mels) in the other cases
            ys_in_lens: (batch)
        Returns:
            (tuple): tuple containing:

            x: decoded token score before softmax (batch, maxlen_out, token)
                if use_output_layer is True,
            olens: (batch, )
        deviceNr[   dimF)myutilssequence_maskr   r   r   catr   r   r   r   r   sumr   )r&   r   r   r   r   r   r   r   r0   r1   r2   r3   r6   r7   hiddenolensr*   r*   r+   r9   g  s.   ""(





zParaformerSANMDecoder.forwardc                 C   sd   t jtjt|gtjd|jddddddf }| j|d||d|d\}}|	d|fS )zScore.dtyper   Nr   r<   )
r   r   r   tensorlenint32r   rC   	unsqueezesqueeze)r&   ysstater6   ys_masklogpr*   r*   r+   score  s   "zParaformerSANMDecoder.scorec                 C   s   |}t j||jdd d d d d f }|}t j||jdd d d d d f }| jd ||||\}}}}}	| jjd ||||}
|
S Nr   r   r[   )r   r   r   r   modelr?   r&   r   r   r   r   r0   r1   r2   r3   r7   r%   r*   r*   r+   forward_asf2  s   ""z"ParaformerSANMDecoder.forward_asf2c                 C   s   |}t j||jdd d d d d f }|}t j||jdd d d d d f }| jd ||||\}}}}}	| jd ||||\}}}}}	| jd ||||\}}}}}	| jd ||||\}}}}}	| jd ||||\}}}}}	| jd ||||}
|
S Nr   r   r[   ru      rR      )r   r   r   r   r?   r   r*   r*   r+   forward_asf6  s   ""z"ParaformerSANMDecoder.forward_asf6r2   r0   r4   c              	   C   s  |}|d du rt | j}| jdur|t | j7 }dg| }n|d }|d du r4t | j}dg| }n|d }t| jD ]"}| j| }	|	j|||| || |d |d d\}}||< ||< q=| j| j dkrt| j| j D ]}|| j }
| j| }	|	j||||
 d\}}||
< }qp| jD ]}	|	||\}}}}q| jr| 	|}| j
dur| 
|}||d< |d d	ks|d d
kr||d< |S )r   decode_fsmnNoptrG   decoder_chunk_look_back)rE   rF   rG   rH   r[   )rE   r   r   )r   r   r   rangerh   rD   ra   r   r   r   r   )r&   r2   r0   r4   r6   cache_layer_numrE   rF   ir\   jr7   r*   r*   r+   rD     sL   



	






z#ParaformerSANMDecoder.forward_chunkr1   c                 C   st  |  |}|du rt| j}| jdur|t| j7 }dg| }g }t| jD ]}| j| }	|| }
|	j|||d|
d\}}}}}|| q&| j| j dkr{t| j| j D ]$}|| j }| j| }	|| }
|	j|||d|
d\}}}}}|| qV| j	D ]}	|	j|||ddd\}}}}}q~| j
r| |dddf }n|dddf }| jdurtj| |dd}||fS )a5  Forward one step.

        Args:
            tgt: input token ids, int64 (batch, maxlen_out)
            tgt_mask: input token mask,  (batch, maxlen_out)
                      dtype=torch.uint8 in PyTorch 1.2-
                      dtype=torch.bool in PyTorch 1.2+ (include 1.2)
            memory: encoded memory, float32  (batch, maxlen_in, feat)
            cache: cached output list of (batch, max_time_out-1, size)
        Returns:
            y, cache: NN output value and cache per `self.decoders`.
            y.shape` is (batch, maxlen_out, token)
        Nr<   r[   r   r   )rW   r   r   r   r   rh   rC   r/   ra   r   r   r   r   r   log_softmax)r&   r0   r1   r2   r4   r6   r   	new_cacher   r\   cr3   c_retr   r7   yr*   r*   r+   rC     sB   











z&ParaformerSANMDecoder.forward_one_step)NFFr|   )rJ   rK   rL   rM   r	   intfloatstrboolr   tupler   r   Tensorr   r9   r   r   r   dictrD   rC   rN   r*   r*   r(   r+   rP      s   
	
 	
6


Jc                       s0   e Zd Z fddZdddZdddZ  ZS )	DecoderLayerSANMExportc                    sb   t    |j| _|j| _|j| _|j| _t|dr|jnd | _t|dr(|jnd | _|j	| _	d S )Nr   r   )
r   r   r   r   r   r   hasattrr   r   r   )r&   r   r(   r*   r+   r   R  s   
zDecoderLayerSANMExport.__init__Nc                 C   s   |}|  |}| |}|}| jd ur&| |}| j|||d\}}|| }| jd ur;|}| |}|| ||| }|||||fS Nr<   )r   r   r   r   r   r   rB   r*   r*   r+   r9   \  s   





zDecoderLayerSANMExport.forwardc           
      C   r:   r;   r=   r>   r*   r*   r+   r?   o  r@   z#DecoderLayerSANMExport.get_attn_matrI   )rJ   rK   rL   r   r9   r?   rN   r*   r*   r(   r+   r   P  s    

r   ParaformerSANMDecoderExportc                       s   e Zd Zddef fddZdd Z				dd
ejdejdejdejdedefddZd
ejdejdejdejfddZ	d
ejdejdejdejfddZ
  ZS )r      r\   Tonnxc                    s"  t    ddlm} || _||dd| _ddlm} ddlm} t	| jj
D ]$\}	}
t|
jtr7||
j|
_t|
jtrC||
j|
_t|
| jj
|	< q'| jjd urqt	| jjD ]\}	}
t|
jtrh||
j|
_t|
| jj|	< qXt	| jjD ]\}	}
t|
| jj|	< qw|j| _|j| _|| _d S Nr   r   Fflip)%MultiHeadedAttentionSANMDecoderExport)"MultiHeadedAttentionCrossAttExport)r   r   funasr.utils.torch_functionr   r   r   funasr.models.sanm.attentionr   r   	enumerater   
isinstancer   r   r   r   r   r   r   r   r   
model_namer&   r   max_seq_lenr   r   kwargsr   r   r   r   dr(   r*   r+   r     s,   

z$ParaformerSANMDecoderExport.__init__c                 C   z   |d d d d d f }t |jdkr!d|d d d d d d f  }nt |jdkr5d|d d d d d f  }|d }||fS Nru   r[   r   g     r   shaper&   maskmask_3d_btdmask_4d_bhltr*   r*   r+   prepare_mask     z(ParaformerSANMDecoderExport.prepare_maskFr   r   r   r   r   r   c                 C   s   |}|  |}| |\}}	|}
|  |}| |\}	}|}| j|||
|\}}}
}}	| jjd ur@| j|||
|\}}}
}}	| j|||
|\}}}
}}	| |}| jd ure|du re| |}||fS |rq| |}|||fS ||fS )NF)r   r   r   r   r   r   r   r   )r&   r   r   r   r   r   r   r0   r1   r7   r2   r3   r6   r   r*   r*   r+   r9     s*   






z#ParaformerSANMDecoderExport.forwardc                 C   s   |}t j||jdd d d d d f }|}t j||jdd d d d d f }| |\}	}| jjd ||||\}}}}}	| jjd ||||}
|
S r   r   r   r   r   r   r   r?   r   r*   r*   r+   r     s   ""
z(ParaformerSANMDecoderExport.forward_asf2c                 C   s  |}t j||jdd d d d d f }|}t j||jdd d d d d f }| |\}	}| jjd ||||\}}}}}	| jjd ||||\}}}}}	| jjd ||||\}}}}}	| jjd ||||\}}}}}	| jjd ||||\}}}}}	| jjd ||||}
|
S r   r   r   r*   r*   r+   r     s,   ""




z(ParaformerSANMDecoderExport.forward_asf6r   r\   T)FF)rJ   rK   rL   r   r   r   r   r   r9   r   r   rN   r*   r*   r(   r+   r     sJ    !
&

!!ParaformerSANMDecoderOnlineExportc                       sn   e Zd Zddef fddZdd Zd	ejd
ejdejdejfddZdd Z	dd Z
dd Zdd Z  ZS )r   r   r\   Tr   c                    s(  t    || _ddlm} || _||dd| _ddlm} ddlm} t	| jj
D ]$\}	}
t|
jtr:||
j|
_t|
jtrF||
j|
_t|
| jj
|	< q*| jjd urtt	| jjD ]\}	}
t|
jtrk||
j|
_t|
| jj|	< q[t	| jjD ]\}	}
t|
| jj|	< qz|j| _|j| _|| _d S r   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r(   r*   r+   r   ?  s.   

z*ParaformerSANMDecoderOnlineExport.__init__c                 C   r   r   r   r   r*   r*   r+   r   a  r   z.ParaformerSANMDecoderOnlineExport.prepare_maskr   r   r   r   c                 G   s  |}|  |}| |\}}|}	|  |}
| |
\}}
|}t }t| jjD ]\}}|| }||||	|
|d\}}}	}
}|| q'| jjd urqt| jjD ]!\}}||t| jj  }||||	|
|d\}}}	}
}|| qO| j	|||	|
\}}}	}
}| 
|}| |}||fS r   )r   r   listr   r   r   r/   r   r   r   r   r   )r&   r   r   r   r   argsr0   r1   r7   r2   r3   r6   
out_cachesr   r\   in_cache	out_cacher*   r*   r+   r9   k  s2   	





z)ParaformerSANMDecoderOnlineExport.forwardc                    s   t dd|t j}t jddgt jd}t dd|t j}t jddgt jd}t jj}t	 jdrD jj
d urD|t jj
7 } fdd	t|D }||||g|R S )
Nru   d      r   
   r   r   c                    s<   g | ]}t jd  jjd j jjd jjd ft jdqS )ru   r   r[   r   )r   zerosr   r   r   r   ri   float32.0r7   r&   r*   r+   
<listcomp>  s    "zFParaformerSANMDecoderOnlineExport.get_dummy_inputs.<locals>.<listcomp>)r   randntyper   r   r   r   r   r   r   r   r   )r&   enc_sizeencenc_lenacoustic_embedsacoustic_embeds_len	cache_numr4   r*   r   r+   get_dummy_inputs  s   
z2ParaformerSANMDecoderOnlineExport.get_dummy_inputsc                 C   sN   t | jj}t| jdr| jjd ur|t | jj7 }g ddd t|D  S )Nr   )r   r   r   r   c                 S      g | ]}d | qS )in_cache_%dr*   r   r   r*   r*   r+   r     s    zEParaformerSANMDecoderOnlineExport.get_input_names.<locals>.<listcomp>r   r   r   r   r   r   r&   r   r*   r*   r+   get_input_names  s   z1ParaformerSANMDecoderOnlineExport.get_input_namesc                 C   sN   t | jj}t| jdr| jjd ur|t | jj7 }ddgdd t|D  S )Nr   logits
sample_idsc                 S   r   out_cache_%dr*   r   r*   r*   r+   r         zFParaformerSANMDecoderOnlineExport.get_output_names.<locals>.<listcomp>r   r   r*   r*   r+   get_output_names  s   z2ParaformerSANMDecoderOnlineExport.get_output_namesc                 C   s   ddddddddiddid}t | jj}t| jdr+| jjd ur+|t | jj7 }|dd	 t|D  |d
d	 t|D  |S )N
batch_size
enc_lengthr   r[   token_lengthr   )r   r   r   r   r   c                 S      i | ]	}d | ddiqS )r   r   r  r*   r   r   r*   r*   r+   
<dictcomp>      zFParaformerSANMDecoderOnlineExport.get_dynamic_axes.<locals>.<dictcomp>c                 S   r  )r  r   r  r*   r  r*   r*   r+   r    r  )r   r   r   r   r   updater   r&   retr   r*   r*   r+   get_dynamic_axes  s*   
		z2ParaformerSANMDecoderOnlineExport.get_dynamic_axesr   )rJ   rK   rL   r   r   r   r   r   r9   r   r   r  r  rN   r*   r*   r(   r+   r   =  s     "

(ParaformerSANDecoderc                       s   e Zd ZdZdddddddddedd	d
fdedededededededededededededef fddZ	de
jde
jde
jde
jdee
je
jf f
d d!Z  ZS )"r  rQ   rR   rS   rT   rU   rV   rW   TFr   r]   r^   r_   r`   ra   r'   rb   rc   rd   re   rf   r   r    	embeds_idc              
      sR   t  j||||
||d | t| fdd| _|| _ | _d S )Nrr   c                    s,   t  t t t S r|   )r   r
   r   rw   rz   r_   r    r'   r`   r   rc   rd   r*   r+   r{     s    


z/ParaformerSANDecoder.__init__.<locals>.<lambda>)r   r   r   r   r  rz   )r&   r]   r^   r_   r`   ra   r'   rb   rc   rd   re   rf   rs   r   r    r  r(   r  r+   r     s"   
zParaformerSANDecoder.__init__r   r   r   r   r   c                 C   s"  |}t |dddddf  |j}|}t ||dd dddddf |j}|jd |jd krM|jd |jd  }	tjj|d|	fdd}|}
d}t	| j
D ]\}}||
|||\}
}}}|| jkrl|
}qV| jru| |
}
| jdur| |
}
|d}|dur|
||fS |
|fS )r   Nr[   )maxlenr   r   constantF)r   tor   r   r   r   r   
functionalpadr   r   r  r   r   r   r   )r&   r   r   r   r   r0   r1   r2   r3   padlenr6   embeds_outputslayer_idr\   r   r*   r*   r+   r9   
  s,   $.





zParaformerSANDecoder.forward)rJ   rK   rL   rM   r	   r   r   r   r   r   r   r   r   r9   rN   r*   r*   r(   r+   r    sn    
	
-ParaformerDecoderSANExportc                       s|   e Zd Z			ddef fddZdd Zd	ejd
ejdejdejfddZdd Z	dd Z
dd Zdd Zdd Z  ZS )r  r   r\   Tr   c           
         s   t    || _ddlm} || _||dd| _ddlm} ddlm	} t
| jjD ]\}}	t|	jtr:||	j|	_||	| jj|< q*|j| _|j| _|| _d S )Nr   r   Fr   )DecoderLayerExport)MultiHeadedAttentionExport)r   r   r   r   r   r   !funasr.models.transformer.decoderr  #funasr.models.transformer.attentionr   r   r   r   r   r
   r   r   r   )
r&   r   r   r   r   r   r  r   r   r   r(   r*   r+   r   A  s   

z#ParaformerDecoderSANExport.__init__c                 C   r   r   r   r   r*   r*   r+   r   ^  r   z'ParaformerDecoderSANExport.prepare_maskr   r   r   r   c                 C   sr   |}|  |}| |\}}|}|  |}	| |	\}}	|}
| j|
|||	\}
}}}	| |
}
| |
}
|
|fS r|   )r   r   r   r   r   r   )r&   r   r   r   r   r0   r1   r7   r2   r3   r6   r*   r*   r+   r9   h  s   



z"ParaformerDecoderSANExport.forwardc                    sh   t dgd}t dd|}t dd|}t jjt jj } fddt|D }||||fS )Nr   r[   r   c                    s2   g | ]}t d  jjd j jjd jjfqS )r[   r   )r   r   r   r   r   r   ri   r   r   r*   r+   r     s    z?ParaformerDecoderSANExport.get_dummy_inputs.<locals>.<listcomp>)	r   
LongTensorr   r   r   r   r   r   r   )r&   r   r0   r2   pre_acoustic_embedsr   r4   r*   r   r+   r     s   
z+ParaformerDecoderSANExport.get_dummy_inputsc                 C   s   dS )NTr*   r   r*   r*   r+   is_optimizable  s   z)ParaformerDecoderSANExport.is_optimizablec                 C   s2   t | jjt | jj }g ddd t|D  S )Nr0   r2   r$  c                 S   r   )cache_%dr*   r   r*   r*   r+   r     r  z>ParaformerDecoderSANExport.get_input_names.<locals>.<listcomp>r   r   r   r   r   r   r*   r*   r+   r     s   z*ParaformerDecoderSANExport.get_input_namesc                 C   s0   t | jjt | jj }dgdd t|D  S )Nr   c                 S   r   r  r*   r   r*   r*   r+   r     r  z?ParaformerDecoderSANExport.get_output_names.<locals>.<listcomp>r(  r   r*   r*   r+   r    s   z+ParaformerDecoderSANExport.get_output_namesc                 C   sR   dddddddddd}t | jjt | jj }|d	d
 t|D  |S )N	tgt_batch
tgt_lengthr	  memory_batchmemory_lengthacoustic_embeds_batchacoustic_embeds_lengthr&  c                 S   s$   i | ]}d | d| d| dqS )r'  zcache_%d_batchzcache_%d_length)r   ru   r*   r  r*   r*   r+   r    s    z?ParaformerDecoderSANExport.get_dynamic_axes.<locals>.<dictcomp>)r   r   r   r   r  r   r  r*   r*   r+   r    s   z+ParaformerDecoderSANExport.get_dynamic_axesr   )rJ   rK   rL   r   r   r   r   r   r9   r   r%  r   r  r  rN   r*   r*   r(   r+   r  ?  s,    

)'r   typingr   r   funasr.registerr   funasr.models.scamar   r   &funasr.models.transformer.utils.repeatr   r!  r   $funasr.models.transformer.layer_normr   #funasr.models.transformer.embeddingr	   r"  r
   *funasr.models.transformer.utils.nets_utilsr   r   3funasr.models.transformer.positionwise_feed_forwardr   ,funasr.models.sanm.positionwise_feed_forwardr   r   r   r   r   Moduler   registerrP   r   r   r   r  r  r*   r*   r*   r+   <module>   s>    
G  q
0 
= 

i