o
    i                     @   s  d Z ddlZddlmZmZmZmZ ddlZddlm	  m
Z ddlmZ ddlmZ ddlmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZ ddlmZ ddlm Z  ddl!m"Z"m#Z# ddl$m%Z% ddl&m'Z'm(Z( ddl)mZ* G dd deZ+dS )z(Fastspeech2 related modules for ESPnet2.    N)DictOptionalSequenceTuple)check_argument_types)force_gatherable)
initialize)AbsTTS)FastSpeech2Loss)VariancePredictor)StyleEncoder)Encoder)DurationPredictor)LengthRegulator)make_non_pad_maskmake_pad_mask)Postnet)PositionalEncodingScaledPositionalEncodingc                       s  e Zd ZdZ													
							
																																															
						dded ed!ed"ed#ed$ed%ed&ed'ed(ed)ed*ed+ed,ed-ed.ed/ed0ed1ed2ed3ed4ed5ed6ed7ed8ed9ed:ed;ed<ed=ed>ed?ed@edAedBedCedDedEedFedGedHedIedJedKedLedMedNedOedPedQedRedSedTedUedVedWee dXee dYee dZed[ed\ed]ed^ed_e	e d`edaedbedceddedeedfedgedhef fdidjZ
				ddkejdlejdmejdnejdoejdpejdqejdrejdsejdtejdueej dveej dweej dxedyeejeeejf ejf fdzd{Z										dd|ejd}ejd~eej deej deej deej deej dueej dveej dweej dededye	ej fddZ									ddkejdmeej doeej duejdveej dweej dqeej dseej dededyeeejf fddZdejduejdyejfddZd}ejdyejfddZddedeedfefddZ  ZS )FastSpeech2a  FastSpeech2 module.

    This is a module of FastSpeech2 described in `FastSpeech 2: Fast and
    High-Quality End-to-End Text to Speech`_. Instead of quantized pitch and
    energy, we use token-averaged value introduced in `FastPitch: Parallel
    Text-to-speech with Pitch Prediction`_.

    .. _`FastSpeech 2: Fast and High-Quality End-to-End Text to Speech`:
        https://arxiv.org/abs/2006.04558
    .. _`FastPitch: Parallel Text-to-speech with Pitch Prediction`:
        https://arxiv.org/abs/2006.06873

                           ?conv1d   TFtransformer皙?legacyrel_posrel_selfattnswish            	   Nadd
       r-   @   r.      r/   r/   xavier_uniform      ?idimodimadimaheadselayerseunitsdlayersdunitspostnet_layerspostnet_chanspostnet_filtspostnet_dropout_ratepositionwise_layer_typepositionwise_conv_kernel_sizeuse_scaled_pos_encuse_batch_normencoder_normalize_beforedecoder_normalize_beforeencoder_concat_afterdecoder_concat_afterreduction_factorencoder_typedecoder_typetransformer_enc_dropout_rate'transformer_enc_positional_dropout_rate!transformer_enc_attn_dropout_ratetransformer_dec_dropout_rate'transformer_dec_positional_dropout_rate!transformer_dec_attn_dropout_rateconformer_rel_pos_typeconformer_pos_enc_layer_typeconformer_self_attn_layer_typeconformer_activation_typeuse_macaron_style_in_conformeruse_cnn_in_conformer	zero_triuconformer_enc_kernel_sizeconformer_dec_kernel_sizeduration_predictor_layersduration_predictor_chansduration_predictor_kernel_sizeduration_predictor_dropout_rateenergy_predictor_layersenergy_predictor_chansenergy_predictor_kernel_sizeenergy_predictor_dropoutenergy_embed_kernel_sizeenergy_embed_dropout#stop_gradient_from_energy_predictorpitch_predictor_layerspitch_predictor_chanspitch_predictor_kernel_sizepitch_predictor_dropoutpitch_embed_kernel_sizepitch_embed_dropout"stop_gradient_from_pitch_predictorspkslangsspk_embed_dimspk_embed_integration_typeuse_gst
gst_tokens	gst_headsgst_conv_layersgst_conv_chans_listgst_conv_kernel_sizegst_conv_stridegst_gru_layersgst_gru_units	init_typeinit_enc_alphainit_dec_alphause_maskinguse_weighted_maskingcK           M         sp  t  sJ t   || _|| _|d | _|| _|| _|| _|8| _	|1| _
|| _|=| _d| _| jr2tnt}Kd||fv rm|dkrU|dkrId}td | dkrTd	} td
 n|dkrf|dks_J | d	kseJ ntd| tjj||| jd}L|dkrt||||||L||||K||||d| _nN|dkrtd.i d|d|d|d|d|d|Ld|d|d|d|d|d|d|d|"d|d| d |!d!|#d"|%d#|$| _nt| d$| jrt||>||?|@|A|B|C|D|Ed%
| _d&| _|9d&ur|9dkr|9| _tj|9|| _d&| _|:d&ur|:dkr|:| _tj|:|| _d&| _|;d&ur2|;dkr2|;| _|<| _ | jd&urS| j d'krHtj!| j|| _"ntj!|| j || _"t#||'|(|)|*d(| _$t%||2|3|4|5d(| _&tj'tjj(d||6|6d d) d*tj)|7| _*t%||+|,|-|.d(| _+tj'tjj(d||/|/d d) d*tj)|0| _,t- | _.|dkrtd||||d&||||K||||d| _/nL|dkrtd.i ddd|d|d|d|dd&d|d|d|d|d|d|d|d|"d|d| d |!d!|#d"|&| _/nt| d$tj!||| | _0|	dkrd&n
t1|||	|
|||d+| _2| j3|F|G|Hd, t4|I|Jd-| _5d&S )/aa  Initialize FastSpeech2 module.

        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
            elayers (int): Number of encoder layers.
            eunits (int): Number of encoder hidden units.
            dlayers (int): Number of decoder layers.
            dunits (int): Number of decoder hidden units.
            postnet_layers (int): Number of postnet layers.
            postnet_chans (int): Number of postnet channels.
            postnet_filts (int): Kernel size of postnet.
            postnet_dropout_rate (float): Dropout rate in postnet.
            use_scaled_pos_enc (bool): Whether to use trainable scaled pos encoding.
            use_batch_norm (bool): Whether to use batch normalization in encoder prenet.
            encoder_normalize_before (bool): Whether to apply layernorm layer before
                encoder block.
            decoder_normalize_before (bool): Whether to apply layernorm layer before
                decoder block.
            encoder_concat_after (bool): Whether to concatenate attention layer's input
                and output in encoder.
            decoder_concat_after (bool): Whether to concatenate attention layer's input
                and output in decoder.
            reduction_factor (int): Reduction factor.
            encoder_type (str): Encoder type ("transformer" or "conformer").
            decoder_type (str): Decoder type ("transformer" or "conformer").
            transformer_enc_dropout_rate (float): Dropout rate in encoder except
                attention and positional encoding.
            transformer_enc_positional_dropout_rate (float): Dropout rate after encoder
                positional encoding.
            transformer_enc_attn_dropout_rate (float): Dropout rate in encoder
                self-attention module.
            transformer_dec_dropout_rate (float): Dropout rate in decoder except
                attention & positional encoding.
            transformer_dec_positional_dropout_rate (float): Dropout rate after decoder
                positional encoding.
            transformer_dec_attn_dropout_rate (float): Dropout rate in decoder
                self-attention module.
            conformer_rel_pos_type (str): Relative pos encoding type in conformer.
            conformer_pos_enc_layer_type (str): Pos encoding layer type in conformer.
            conformer_self_attn_layer_type (str): Self-attention layer type in conformer
            conformer_activation_type (str): Activation function type in conformer.
            use_macaron_style_in_conformer: Whether to use macaron style FFN.
            use_cnn_in_conformer: Whether to use CNN in conformer.
            zero_triu: Whether to use zero triu in relative self-attention module.
            conformer_enc_kernel_size: Kernel size of encoder conformer.
            conformer_dec_kernel_size: Kernel size of decoder conformer.
            duration_predictor_layers (int): Number of duration predictor layers.
            duration_predictor_chans (int): Number of duration predictor channels.
            duration_predictor_kernel_size (int): Kernel size of duration predictor.
            duration_predictor_dropout_rate (float): Dropout rate in duration predictor.
            pitch_predictor_layers (int): Number of pitch predictor layers.
            pitch_predictor_chans (int): Number of pitch predictor channels.
            pitch_predictor_kernel_size (int): Kernel size of pitch predictor.
            pitch_predictor_dropout_rate (float): Dropout rate in pitch predictor.
            pitch_embed_kernel_size (float): Kernel size of pitch embedding.
            pitch_embed_dropout_rate (float): Dropout rate for pitch embedding.
            stop_gradient_from_pitch_predictor: Whether to stop gradient from pitch
                predictor to encoder.
            energy_predictor_layers (int): Number of energy predictor layers.
            energy_predictor_chans (int): Number of energy predictor channels.
            energy_predictor_kernel_size (int): Kernel size of energy predictor.
            energy_predictor_dropout_rate (float): Dropout rate in energy predictor.
            energy_embed_kernel_size (float): Kernel size of energy embedding.
            energy_embed_dropout_rate (float): Dropout rate for energy embedding.
            stop_gradient_from_energy_predictor: Whether to stop gradient from energy
                predictor to encoder.
            spks (Optional[int]): Number of speakers. If set to > 1, assume that the
                sids will be provided as the input and use sid embedding layer.
            langs (Optional[int]): Number of languages. If set to > 1, assume that the
                lids will be provided as the input and use sid embedding layer.
            spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0,
                assume that spembs will be provided as the input.
            spk_embed_integration_type: How to integrate speaker embedding.
            use_gst (str): Whether to use global style token.
            gst_tokens (int): The number of GST embeddings.
            gst_heads (int): The number of heads in GST multihead attention.
            gst_conv_layers (int): The number of conv layers in GST.
            gst_conv_chans_list: (Sequence[int]):
                List of the number of channels of conv layers in GST.
            gst_conv_kernel_size (int): Kernel size of conv layers in GST.
            gst_conv_stride (int): Stride size of conv layers in GST.
            gst_gru_layers (int): The number of GRU layers in GST.
            gst_gru_units (int): The number of GRU units in GST.
            init_type (str): How to initialize transformer parameters.
            init_enc_alpha (float): Initial value of alpha in scaled pos encoding of the
                encoder.
            init_dec_alpha (float): Initial value of alpha in scaled pos encoding of the
                decoder.
            use_masking (bool): Whether to apply masking for padded part in loss
                calculation.
            use_weighted_masking (bool): Whether to apply weighted masking in loss
                calculation.

        r   r   	conformerr!   r"   legacy_rel_poszFallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' due to the compatibility. If you want to use the new one, please use conformer_pos_enc_layer_type = 'latest'.r#   legacy_rel_selfattnzFallback to conformer_self_attn_layer_type = 'legacy_rel_selfattn' due to the compatibility. If you want to use the new one, please use conformer_pos_enc_layer_type = 'latest'.latestzUnknown rel_pos_type: )num_embeddingsembedding_dimpadding_idxr   )r2   attention_dimattention_headslinear_units
num_blocksinput_layerdropout_ratepositional_dropout_rateattention_dropout_ratepos_enc_classnormalize_beforeconcat_afterr>   r?   r2   r   r   r   r   r   r   r   r   r   r   r>   r?   macaron_stylepos_enc_layer_typeselfattention_layer_typeactivation_typeuse_cnn_modulecnn_module_kernelrU   z is not supported.)
r2   ro   gst_token_dimrp   conv_layersconv_chans_listconv_kernel_sizeconv_stride
gru_layers	gru_unitsNr*   )r2   n_layersn_chanskernel_sizer   r'   )in_channelsout_channelsr   padding)r2   r3   r   r   n_filtsrA   r   )rw   rx   ry   )rz   r{    )6r   super__init__r2   r3   eosrF   rG   rH   ri   rb   r@   rn   r   r   r   loggingwarning
ValueErrortorchnn	EmbeddingTransformerEncoderencoderConformerEncoderr   gstrj   sid_embrk   lid_embrl   rm   Linear
projectionr   duration_predictorr   pitch_predictor
SequentialConv1dDropoutpitch_embedenergy_predictorenergy_embedr   length_regulatordecoderfeat_outr   postnet_reset_parametersr
   	criterion)Mselfr2   r3   r4   r5   r6   r7   r8   r9   r:   r;   r<   r=   r>   r?   r@   rA   rB   rC   rD   rE   rF   rG   rH   rI   rJ   rK   rL   rM   rN   rO   rP   rQ   rR   rS   rT   rU   rV   rW   rX   rY   rZ   r[   r\   r]   r^   r_   r`   ra   rb   rc   rd   re   rf   rg   rh   ri   rj   rk   rl   rm   rn   ro   rp   rq   rr   rs   rt   ru   rv   rw   rx   ry   rz   r{   r   encoder_input_layer	__class__r   W/home/ubuntu/.local/lib/python3.10/site-packages/espnet2/tts/fastspeech2/fastspeech2.pyr   0   s   
4


	
	






	

zFastSpeech2.__init__texttext_lengthsfeatsfeats_lengths	durationsdurations_lengthspitchpitch_lengthsenergyenergy_lengthsspembssidslidsjoint_trainingreturnc           &         sP  |ddd|  f }|ddd|  f }|ddd|  f }|ddd|  f }|	ddd|
  f }	|d}t|ddgd j}t|D ]\}} j|||f< qP|d }||||	f\}}}}|} j||||||||||dd\}}}}} jdkr|	 fdd|D }t |}|ddd|f } j
du rd} j|||||||||||d	\}} }!}"||  |! |" }#t| |  |! |" d
}$ jdkr jr|$j jjd jj d  jdkr jr|$j jjd jj d |s|$j|# d t|#|$|f|#j\}#}$}%|#|$|%fS |#|$|dur%|fS |fS )a@  Calculate forward propagation.

        Args:
            text (LongTensor): Batch of padded token ids (B, T_text).
            text_lengths (LongTensor): Batch of lengths of each input (B,).
            feats (Tensor): Batch of padded target features (B, T_feats, odim).
            feats_lengths (LongTensor): Batch of the lengths of each target (B,).
            durations (LongTensor): Batch of padded durations (B, T_text + 1).
            durations_lengths (LongTensor): Batch of duration lengths (B, T_text + 1).
            pitch (Tensor): Batch of padded token-averaged pitch (B, T_text + 1, 1).
            pitch_lengths (LongTensor): Batch of pitch lengths (B, T_text + 1).
            energy (Tensor): Batch of padded token-averaged energy (B, T_text + 1, 1).
            energy_lengths (LongTensor): Batch of energy lengths (B, T_text + 1).
            spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim).
            sids (Optional[Tensor]): Batch of speaker IDs (B, 1).
            lids (Optional[Tensor]): Batch of language IDs (B, 1).
            joint_training (bool): Whether to perform joint training with vocoder.

        Returns:
            Tensor: Loss scalar value.
            Dict: Statistics to be monitored.
            Tensor: Weight value if not joint training else model outputs.

        Nr   r   constantF)r   r   r   is_inferencec                    s   g | ]	}|| j   qS r   rF   .0olenr   r   r   
<listcomp>-  s    z'FastSpeech2.forward.<locals>.<listcomp>)
after_outsbefore_outsd_outsp_outse_outsysdspsesilensolens)l1_lossduration_loss
pitch_lossenergy_lossr   )encoder_alpha)decoder_alpha)loss)maxsizeFpadr   	enumerater   _forwardrF   newr   r   dictitemrG   r@   updater   embedalphadatarH   r   r   device)&r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   
batch_sizexsilr   r   r   r   r   r   r   r   r   r   r   max_olenr   r   r   r   r   statsweightr   r   r   forward  s   )




zFastSpeech2.forwardr   r   r   r   r   r   r   r   r   c                    sz    |} ||\}} jr |}||d } jd ur0 |	d}||d } jd urD 	|
d}||d } j
d urO ||}t||j} jrf | |d}n	 ||d} jr~ | |d}n	 ||d}|r j||} |dddd} |dddd}|| | } |||}n, ||} |dddd} |dddd}|| | } ||}|d ur|s jdkr| fdd|D }n|}  |}nd } ||\}} ||dd j} j d u r'|}n|  |dddd }|||||fS )Nr   r   r'   c                    s   g | ]}| j  qS r   r   r   r   r   r   r     s    z(FastSpeech2._forward.<locals>.<listcomp>r   )!_source_maskr   rn   r   	unsqueezerj   r   viewrk   r   rl   _integrate_with_spk_embedr   tor   ri   r   detachrb   r   r   	inferencer   	transposer   r   rF   r   r   r   r   r3   r   )r   r   r   r   r   r   r   r   r   r   r   r   r   x_maskshs_
style_embssid_embslid_embsd_masksr   r   r   p_embse_embsolens_inh_maskszsr   r   r   r   r   r   _  sb   







zFastSpeech2._forwarduse_teacher_forcingc                 C   s  ||}}||||f\}}}}t |ddgd| j}tj|jd gtj|jd}|dd}}|dur9|d}|durB|d}|
ri|d|d|d}}}| j	|||||||||d	\}}}}}n| j	||||||d|	d\}}}}}t
|d |d |d |d d	S )
a  Generate the sequence of features given the sequences of characters.

        Args:
            text (LongTensor): Input sequence of characters (T_text,).
            feats (Optional[Tensor): Feature sequence to extract style (N, idim).
            durations (Optional[Tensor): Groundtruth of duration (T_text + 1,).
            spembs (Optional[Tensor): Speaker embedding vector (spk_embed_dim,).
            sids (Optional[Tensor]): Speaker ID (1,).
            lids (Optional[Tensor]): Language ID (1,).
            pitch (Optional[Tensor]): Groundtruth of token-avg pitch (T_text + 1, 1).
            energy (Optional[Tensor]): Groundtruth of token-avg energy (T_text + 1, 1).
            alpha (float): Alpha to control the speed.
            use_teacher_forcing (bool): Whether to use teacher forcing.
                If true, groundtruth of duration, pitch and energy will be used.

        Returns:
            Dict[str, Tensor]: Output dict including the following items:
                * feat_gen (Tensor): Output sequence of features (T_feats, odim).
                * duration (Tensor): Duration sequence (T_text + 1,).
                * pitch (Tensor): Pitch sequence (T_text + 1,).
                * energy (Tensor): Energy sequence (T_text + 1,).

        r   r   r   )dtyper   N)r   r   r   r   r   r   T)r   r   r   r   r   )feat_gendurationr   r   )r   r   r   r   tensorshapelongr   r  r   r   )r   r   r   r   r   r   r   r   r   r   r  xyspembdper   r   r   r   r   r   r  outsr   r   r   r   r   r   r    sL   
$

"zFastSpeech2.inferencer  c                 C   sz   | j dkr| t|}||d }|S | j dkr9t|dd|dd}| tj||gdd}|S t	d)aE  Integrate speaker embedding with hidden states.

        Args:
            hs (Tensor): Batch of hidden state sequences (B, T_text, adim).
            spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim).

        Returns:
            Tensor: Batch of integrated hidden state sequences (B, T_text, adim).

        r*   r   concatr   )dimzsupport only add or concat.)
rm   r   r   	normalizer  expandr   r   catNotImplementedError)r   r  r   r   r   r   r
  	  s   

 z%FastSpeech2._integrate_with_spk_embedc                 C   s"   t |t|  j}|dS )a  Make masks for self-attention.

        Args:
            ilens (LongTensor): Batch of lengths (B,).

        Returns:
            Tensor: Mask tensor for self-attention.
                dtype=torch.uint8 in PyTorch 1.2-
                dtype=torch.bool in PyTorch 1.2+ (including 1.2)

        Examples:
            >>> ilens = [5, 3]
            >>> self._source_mask(ilens)
            tensor([[[1, 1, 1, 1, 1],
                     [1, 1, 1, 0, 0]]], dtype=torch.uint8)

        )r   r  next
parametersr   r  )r   r   r  r   r   r   r  #  s   
zFastSpeech2._source_maskc                 C   sj   |dkr	t | | | jdkr| jrt|| jjd j_| j	dkr1| jr3t|| j
jd j_d S d S d S )Npytorchr   r   )r   rG   r@   r   r  r   r   r   r   rH   r   )r   rw   rx   ry   r   r   r   r   8  s   
zFastSpeech2._reset_parameters)Hr   r   r   r   r   r   r   r   r   r   r   r   TTTTFFr   r   r   r    r    r    r    r    r    r!   r"   r#   r$   TTFr%   r&   r'   r   r(   r    r'   r   r(   r   r)   r   Fr'   r   r(   r   r)   r   FNNNr*   Fr+   r   r   r,   r(   r'   r   r/   r0   r1   r1   FF)NNNF)
NNNNNNNNFr1   )	NNNNNNNr1   F)__name__
__module____qualname____doc__intfloatstrboolr   r   r   r   Tensorr   r   r  r   r  r
  r  r   __classcell__r   r   r   r   r   !   s   	
!"#$%&'()+,-.012345689:;<=>@ABCDEFGHIJKLNOPQR   A	

 	

Z	

S
r   ),r6  r   typingr   r   r   r   r   torch.nn.functionalr   
functionalr   	typeguardr    espnet2.torch_utils.device_funcsr   espnet2.torch_utils.initializer   espnet2.tts.abs_ttsr	   espnet2.tts.fastspeech2.lossr
   *espnet2.tts.fastspeech2.variance_predictorr   espnet2.tts.gst.style_encoderr   -espnet.nets.pytorch_backend.conformer.encoderr   r   9espnet.nets.pytorch_backend.fastspeech.duration_predictorr   7espnet.nets.pytorch_backend.fastspeech.length_regulatorr   &espnet.nets.pytorch_backend.nets_utilsr   r   -espnet.nets.pytorch_backend.tacotron2.decoderr   1espnet.nets.pytorch_backend.transformer.embeddingr   r   /espnet.nets.pytorch_backend.transformer.encoderr   r   r   r   r   r   <module>   s(   