o
    iA                    @   s  d Z ddlZddlmZmZ ddlZddlZddlmZ ddl	m
Z
mZmZ ddlmZ ddlmZmZmZ dd	lmZ dd
lmZ ddlmZ ddlmZmZ ddlmZ ddlmZm Z m!Z!m"Z"m#Z# ddl$m%Z%m&Z& ddl'm(Z(m)Z) ddl*m+Z+ ddl,m-Z-m.Z. e)/e0Z1dZ2dej3de4de4fddZ5	ddej3de4deej3 fddZ6		dde7e4e4f de8d e4deej9 d!e4d"ej:fd#d$Z;G d%d& d&eZ<G d'd( d(eZ=G d)d* d*eZ>G d+d, d,ej?Z@G d-d. d.ej?ZAG d/d0 d0ej?ZBG d1d2 d2ejj?ZCG d3d4 d4ej?ZDG d5d6 d6ej?ZEG d7d8 d8ej?ZFG d9d: d:ej?ZGG d;d< d<ej?ZHG d=d> d>ej?ZIG d?d@ d@ej?ZJG dAdB dBej?e%ZKG dCdD dDej?e%ZLG dEdF dFej?e%ZMG dGdH dHej?ZNG dIdJ dJej?ZOG dKdL dLeZPG dMdN dNeZQe(G dOdP dPe&ZRG dQdR dReRZSG dSdT dTeRZTG dUdV dVeRZUG dWdX dXeRZVG dYdZ dZeRZWG d[d\ d\eRZXG d]d^ d^eRZYG d_d` d`eRZZG dadb dbej?Z[G dcdd ddej?Z\e(dedfG dgdh dheRZ]e(didfG djdk dkeReZ^			l	m	n		o	oddpeRdej_dqeej_ deej9 dre8dse8dte8dueej? dve`dwe`d"eej_e7ej_ej_f f fdxdyZae(dzdfG d{d| d|eRZbe(d}dfG d~d deRZcG dd dej?Zde(ddfG dd de&Zeg dZfdS )zPyTorch SpeechT5 model.    N)OptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossL1Loss   )ACT2FN)CacheDynamicCacheEncoderDecoderCache)GenerationMixin)is_deepspeed_zero3_enabled)is_fsdp_managed_module)_prepare_4d_attention_mask!_prepare_4d_causal_attention_mask)GradientCheckpointingLayer)BaseModelOutput)BaseModelOutputWithPastAndCrossAttentionsSeq2SeqLMOutputSeq2SeqModelOutputSeq2SeqSpectrogramOutput)EmbeddingAccessMixinPreTrainedModel)auto_docstringlogging)deprecate_kwarg   )SpeechT5ConfigSpeechT5HifiGanConfig	input_idspad_token_iddecoder_start_token_idc                 C   sh   |  | j}| ddddf  |ddddf< ||dddf< |du r*td||dk| |S )z1
    Shift input ids one token to the right.
    Nr   r   z1self.model.config.pad_token_id has to be defined.i)	new_zerosshapeclone
ValueErrormasked_fill_)r    r!   r"   shifted_input_ids r*   b/home/ubuntu/.local/lib/python3.10/site-packages/transformers/models/speecht5/modeling_speecht5.pyshift_tokens_right6   s   (r,   input_valuesreduction_factorattention_maskc                 C   s   |dkr"| dd|d d|f } |dur"|dd|d d|f }|  | j}| ddddf  |ddddf< ||dkd ||fS )zw
    Shift input spectrograms one timestep to the right. Also applies the reduction factor to the sequence length.
    r   Nr#         Y        )r$   r%   r&   r(   )r-   r.   r/   shifted_input_valuesr*   r*   r+   shift_spectrograms_rightF   s   (r3   r%   	mask_probmask_length	min_masksreturnc                    s  | \}dk rt dkrt d d dtjd   fdd}|dur:| d	 n
fd
dt|D }tj	|ft
d}g }	|}
|
dkrZ|S |D ];}||}tjjt|d  |dd}t|dkr}d }n|d }t|tj|
| tjd| g}|	| q\t|	}	t|	dddddf ||
f}	|	||
 }	tddddf }t|||
f||
 }|	| }	|	 d krd |	|	d k< t||	dd	 |S )an  
    Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
    ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
    CPU as part of the preprocessing during training.

    Args:
        shape: The shape for which to compute masks. This should be of a tuple of size 2 where
               the first element is the batch size and the second element is the length of the axis to span.
        mask_prob:  The percentage of the whole axis (between 0 and 1) which will be masked. The number of
                    independently generated mask spans of length `mask_length` is computed by
                    `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
                    actual percentage will be smaller.
        mask_length: size of the mask
        min_masks: minimum number of masked spans
        attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
                        each batch dimension.
    r   z&`mask_length` has to be bigger than 0.zO`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: z and `sequence_length`: `c                    sX   t |     }t|}| kr }| d  |k r*t| d  d}|S )z;Given input length, compute how many spans should be maskedr   r   )intmax)input_lengthnum_masked_spanepsilonr5   r4   r6   sequence_lengthr*   r+   compute_num_masked_span   s   
z6_compute_mask_indices.<locals>.compute_num_masked_spanNr#   c                    s   g | ]} qS r*   r*   .0_)r?   r*   r+   
<listcomp>   s    z)_compute_mask_indices.<locals>.<listcomp>dtyper   F)replace)r'   nprandomranditemdetachsumtolistrangezerosboolchoicearangelenconcatenateonesint32appendarraybroadcast_toreshaper:   put_along_axis)r%   r4   r5   r/   r6   
batch_sizer@   input_lengthsspec_aug_maskspec_aug_mask_idxsmax_num_masked_spanr;   r<   spec_aug_mask_idxdummy_mask_idxoffsetsr*   r=   r+   _compute_mask_indices\   s\   

re   c                       &   e Zd Zd fdd	Zdd Z  ZS )SpeechT5NoLayerNormConvLayerr   c                    sj   t    |dkr|j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
t|j | _d S )Nr   r   kernel_sizestridebias)super__init__conv_dimin_conv_dimout_conv_dimr   Conv1dconv_kernelconv_stride	conv_biasconvr	   feat_extract_activation
activationselfconfiglayer_id	__class__r*   r+   rm      s   
z%SpeechT5NoLayerNormConvLayer.__init__c                 C   s   |  |}| |}|S N)ru   rw   ry   hidden_statesr*   r*   r+   forward      

z$SpeechT5NoLayerNormConvLayer.forwardr   __name__
__module____qualname__rm   r   __classcell__r*   r*   r|   r+   rg      s    rg   c                       rf   )SpeechT5LayerNormConvLayerr   c                    s|   t    |dkr|j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
tj| jdd| _t|j | _d S )Nr   r   rh   T)elementwise_affine)rl   rm   rn   ro   rp   r   rq   rr   rs   rt   ru   	LayerNorm
layer_normr	   rv   rw   rx   r|   r*   r+   rm      s   
z#SpeechT5LayerNormConvLayer.__init__c                 C   s:   |  |}|dd}| |}|dd}| |}|S )Nr#   )ru   	transposer   rw   r   r*   r*   r+   r      s   


z"SpeechT5LayerNormConvLayer.forwardr   r   r*   r*   r|   r+   r      s    r   c                       rf   )SpeechT5GroupNormConvLayerr   c                    s   t    |dkr|j|d  nd| _|j| | _tj| j| j|j| |j| |j	d| _
t|j | _tj| j| jdd| _d S )Nr   r   rh   T)
num_groupsnum_channelsaffine)rl   rm   rn   ro   rp   r   rq   rr   rs   rt   ru   r	   rv   rw   	GroupNormr   rx   r|   r*   r+   rm     s   
z#SpeechT5GroupNormConvLayer.__init__c                 C   s"   |  |}| |}| |}|S r~   )ru   r   rw   r   r*   r*   r+   r     s   


z"SpeechT5GroupNormConvLayer.forwardr   r   r*   r*   r|   r+   r     s    r   c                	       s   e Zd ZdZddededee f fddZddededee fd	d
Zeddededee fddZ	e
 dde
jdefddZ	dde
jdedee fddZ  ZS )%SpeechT5SinusoidalPositionalEmbeddingzDThis module produces sinusoidal positional embeddings of any length.Nnum_positionsembedding_dimpadding_idxc                    s4   t    d| _|| _|| _| || j || d S N   )rl   rm   offsetr   r   make_weights)ry   r   r   r   r|   r*   r+   rm   "  s
   
z.SpeechT5SinusoidalPositionalEmbedding.__init__num_embeddingsc                 C   sB   |  |||}t| dr|j| jj| jjd}| jd|dd d S )NweightsrF   deviceF
persistent)get_embeddinghasattrtor   rF   r   register_buffer)ry   r   r   r   emb_weightsr*   r*   r+   r   )  s   
z2SpeechT5SinusoidalPositionalEmbedding.make_weightsc                 C   s   |d }t d|d  }ttj|tjd |  }tj| tjd d|d }tjt	|t
|gdd| d}|d dkrUtj|t| dgdd}|durad||ddf< |t S )	z
        Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
        description in Section 3.5 of "Attention Is All You Need".
        r   i'  r   rE   r   dimr#   N)mathlogtorchexprS   int64float	unsqueezecatsincosviewrP   r   get_default_dtype)r   r   r   half_dimembr*   r*   r+   r   1  s    $&z3SpeechT5SinusoidalPositionalEmbedding.get_embeddingr   r    past_key_values_lengthc                 C   s|   |  \}}| || j||j}| jd | }|| j dkr-| || j | j| j | j	d|
d
||d S )Nr   r   r#   )size"create_position_ids_from_input_idsr   r   r   r   r   r   r   index_selectr   rL   )ry   r    r   bszseq_lenposition_idsmax_posr*   r*   r+   r   C  s   "z-SpeechT5SinusoidalPositionalEmbedding.forwardc                 C   s6   | | }tj|dd|| | }| | S )a  
        Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
        symbols are ignored. This is modified from fairseq's `utils.make_positions`.

        Args:
            x: torch.Tensor x:
        Returns: torch.Tensor
        r   r   )ner9   r   cumsumtype_aslong)ry   r    r   r   maskincremental_indicesr*   r*   r+   r   R  s   zHSpeechT5SinusoidalPositionalEmbedding.create_position_ids_from_input_idsr~   r   )r   r   r   __doc__r9   r   rm   r   staticmethodr   r   no_gradTensorr   r   r   r*   r*   r|   r+   r     s      r   c                       $   e Zd Z fddZdd Z  ZS )SpeechT5PositionalConvEmbeddingc                    s$  t    tj|j|j|j|jd |jd| _tjj	}t
tjjdr'tjjj	}t r{dd l}|jj| jjdd || jddd| _W d    n1 sLw   Y  t
| jdrd| jjjj}| jjjj}n| jj}| jj}|j| | |j| | n	|| jddd| _t|j| _t|j | _d S )	Nr   )ri   paddinggroupsweight_normr   )modifier_rankweight)namer   parametrizations)rl   rm   r   rq   hidden_sizenum_conv_pos_embeddingsnum_conv_pos_embedding_groupsru   utilsr   r   r   r   	deepspeedzeroGatheredParametersr   	original0	original1weight_gweight_vregister_external_parameterSpeechT5SamePadLayerr   r	   rv   rw   )ry   rz   r   r   r   r   r|   r*   r+   rm   e  s4   

z(SpeechT5PositionalConvEmbedding.__init__c                 C   s:   | dd}| |}| |}| |}| dd}|S Nr   r   )r   ru   r   rw   r   r*   r*   r+   r     s   


z'SpeechT5PositionalConvEmbedding.forwardr   r*   r*   r|   r+   r   d  s    !r   c                       s*   e Zd ZdZd fdd	Zdd Z  ZS ) SpeechT5ScaledPositionalEncodingu[   
    Scaled positional encoding, see §3.2 in https://huggingface.co/papers/1809.08895
      c                    s   t ||}t d|d}t t jd|dt jd td|   }t 	| | |d d dd df< t 
| | |d d dd df< |d}t   | jd|dd tj|d	| _|| _tt d
| _d S )Nr   r   r   rE   g     @peFr   p      ?)r   rP   rS   r   r   r   r   r   r   r   r   rl   rm   r   r   Dropoutdropoutr   	Parametertensoralpha)ry   r   r   max_lenr   positiondiv_termr|   r*   r+   rm     s   .$$

z)SpeechT5ScaledPositionalEncoding.__init__c                 C   s4   || j | jd d d |df   }| |}|S )Nr   )r   r   r   r   )ry   r   r*   r*   r+   r     s   &
z(SpeechT5ScaledPositionalEncoding.forward)r   )r   r   r   r   rm   r   r   r*   r*   r|   r+   r     s    r   c                       rf   )"SpeechT5RelativePositionalEncoding  c                    s.   t    || _|| _tjd| || _d S r   )rl   rm   r   
max_lengthr   r   	Embeddingpe_k)ry   r   r   r|   r*   r+   rm     s   
z+SpeechT5RelativePositionalEncoding.__init__c                 C   s   |j d }td|j|jtjd}|d d d f |d d d f  }| j ||| j k < | jd ||| jk< || j }| |S )Nr   r   r   rF   )r%   r   rS   r   r   r   r   r   )ry   r   r   pos_seqr*   r*   r+   r     s   
 

z*SpeechT5RelativePositionalEncoding.forward)r   r   r*   r*   r|   r+   r     s    r   c                       r   )r   c                    s*   t    |d dkrd| _d S d| _d S )Nr   r   r   )rl   rm   num_pad_remove)ry   r   r|   r*   r+   rm     s   
 zSpeechT5SamePadLayer.__init__c                 C   s,   | j dkr|d d d d d | j  f }|S Nr   )r   r   r*   r*   r+   r     s   
zSpeechT5SamePadLayer.forwardr   r*   r*   r|   r+   r     s    r   c                       s0   e Zd ZdZ fddZdd Zdd Z  ZS )SpeechT5FeatureEncoderz.Construct the features from raw audio waveformc                    s   t     jdkr t ddg fddt jd D  }n jdkr2 fddt jD }n	td	 j d
t|| _	d| _
d| _d S )Ngroupr   r{   c                    s   g | ]
}t  |d  dqS )r   r   )rg   rB   irz   r*   r+   rD     s    z3SpeechT5FeatureEncoder.__init__.<locals>.<listcomp>r   layerc                       g | ]}t  |d qS )r   )r   r   r   r*   r+   rD     s    z`config.feat_extract_norm` is z), but has to be one of ['group', 'layer']FT)rl   rm   feat_extract_normr   rO   num_feat_extract_layersr'   r   
ModuleListconv_layersgradient_checkpointing_requires_grad)ry   rz   r  r|   r   r+   rm     s   





zSpeechT5FeatureEncoder.__init__c                 C   s   |   D ]}d|_qd| _d S )NF)
parametersrequires_gradr  )ry   paramr*   r*   r+   _freeze_parameters  s   
z)SpeechT5FeatureEncoder._freeze_parametersc                 C   s:   |d d d f }| j r| jrd|_| jD ]}||}q|S NT)r  trainingr  r  )ry   r-   r   
conv_layerr*   r*   r+   r     s   

zSpeechT5FeatureEncoder.forward)r   r   r   r   rm   r  r   r   r*   r*   r|   r+   r     s
    r   c                       r   )SpeechT5FeatureProjectionc                    sJ   t    tj|jd |jd| _t|jd |j| _	t
|j| _d S )Nr#   eps)rl   rm   r   r   rn   layer_norm_epsr   Linearr   
projectionr   feat_proj_dropoutr   ry   rz   r|   r*   r+   rm     s   
z"SpeechT5FeatureProjection.__init__c                 C   s&   |  |}| |}| |}||fS r~   )r   r  r   )ry   r   norm_hidden_statesr*   r*   r+   r     s   


z!SpeechT5FeatureProjection.forwardr   r*   r*   r|   r+   r    s    r  c                       s   e Zd Z fddZdd Z		ddejdeej deej	 fd	d
Z
dedejfddZdeejef fddZ		ddej	deej	 deej fddZ  ZS )SpeechT5SpeechEncoderPrenetc                    s|   t    || _t|| _t|| _|jdks|jdkr(t	
t|j | _t|| _t|j|j d |j|j| _d S )Nr1   r   )rl   rm   rz   r   feature_encoderr  feature_projectionmask_time_probmask_feature_probr   r   r   r   r   uniform_masked_spec_embedr   pos_conv_embedr   max_speech_positionsr!   pos_sinusoidal_embedr  r|   r*   r+   rm     s   




z$SpeechT5SpeechEncoderPrenet.__init__c                 C   s   | j   d S r~   )r  r  ry   r*   r*   r+   freeze_feature_encoder  s   z2SpeechT5SpeechEncoderPrenet.freeze_feature_encoderNr-   r/   mask_time_indicesc           	      C   s   |  |}|dd}|d ur| |jd |}| |\}}| j|||d}| |}|| }|d ur<|d }nt	j
|jd d t	j|jd}| |}|| }||fS )Nr   r   )r!  r/   r   )r  r   "_get_feature_vector_attention_maskr%   r  _mask_hidden_statesr  r   r   r   rP   r   r  )	ry   r-   r/   r!  extract_featuresr   positional_conv_embeddingpadding_mask positional_sinusoidal_embeddingsr*   r*   r+   r     s&   


z#SpeechT5SpeechEncoderPrenet.forwardfeature_vector_lengthc                 C   s   |j ddd d df }| |tj}|jd }tj||f|j|jd}d|tj	|jd |jd|d f< |
dg d
dg }|S )Nr#   r   r   r   r   r   )r    _get_feat_extract_output_lengthsr   r   r   r%   rP   rF   r   rS   fliprQ   )ry   r(  r/   non_padded_lengthsoutput_lengthsr]   r*   r*   r+   r"  9  s   
"z>SpeechT5SpeechEncoderPrenet._get_feature_vector_attention_maskr^   c                 C   s4   dd }t | jj| jjD ]
\}}||||}q|S )zH
        Computes the output length of the convolutional layers
        c                 S   s   t j| | |ddd S )Nfloor)rounding_moder   )r   div)r;   ri   rj   r*   r*   r+   _conv_out_lengthN  s   zVSpeechT5SpeechEncoderPrenet._get_feat_extract_output_lengths.<locals>._conv_out_length)ziprz   rr   rs   )ry   r^   r1  ri   rj   r*   r*   r+   r*  I  s   z<SpeechT5SpeechEncoderPrenet._get_feat_extract_output_lengthsr   c                 C   s  t | jdds	|S | \}}}|dur| j|j||< n-| jjdkrK| jrKt||f| jj| jj	|| jj
d}tj||jtjd}| j|j||< | jjdkr| jrt||f| jj| jj| jjd}tj||jtjd}|dddf d|d}d||< |S )	z
        Masks extracted features along time axis and/or along feature axis according to
        [SpecAugment](https://huggingface.co/papers/1904.08779).
        apply_spec_augmentTNr   )r4   r5   r/   r6   r   )r4   r5   r6   r#   )getattrrz   r   r  r   rF   r  r
  re   mask_time_lengthmask_time_min_masksr   r   r   rQ   r  mask_feature_lengthmask_feature_min_masksexpand)ry   r   r!  r/   r]   r?   r   mask_feature_indicesr*   r*   r+   r#  Y  s4   z/SpeechT5SpeechEncoderPrenet._mask_hidden_statesNN)r   r   r   rm   r   r   r   r   
LongTensorFloatTensorr   r9   r"  r   r*  r#  r   r*   r*   r|   r+   r    s.    
#r  c                       sB   e Zd Z fddZdd Z	d
dejdeej fdd	Z  Z	S )SpeechT5SpeechDecoderPrenetc                    sr   t     | _t fddt jD | _t j	 j
| _t j j
 j| _t j j
  j
| _d S )Nc                    s*   g | ]}t |d kr jn j jqS r   )r   r  num_mel_binsspeech_decoder_prenet_unitsr   r   r*   r+   rD     s    z8SpeechT5SpeechDecoderPrenet.__init__.<locals>.<listcomp>)rl   rm   rz   r   r  rO   speech_decoder_prenet_layerslayersr  r@  r   final_layerr   positional_dropoutr  encode_positionsspeaker_embedding_dimspeaker_embeds_layerr  r|   r   r+   rm     s   


z$SpeechT5SpeechDecoderPrenet.__init__c                 C   sJ   t j|d |d}|d|ddd}t |dk|dd d|  S )Nr   r   r   )r   	bernoullir   repeatr   where)ry   inputs_embedsr   r   	all_masksr*   r*   r+   _consistent_dropout  s   z/SpeechT5SpeechDecoderPrenet._consistent_dropoutNr-   speaker_embeddingsc                 C   s   |}| j D ]}tj||}| || jj}q| |}| |}|d urKtj	|}|
dd|dd}tj||gdd}tj| |}|S )Nr   r#   r   )rB  r   
functionalrelurM  rz   speech_decoder_prenet_dropoutrC  rE  	normalizer   r9  r   r   r   rG  )ry   r-   rN  rK  r   r*   r*   r+   r     s   


z#SpeechT5SpeechDecoderPrenet.forwardr~   )
r   r   r   rm   rM  r   r   r   r   r   r*   r*   r|   r+   r>    s    r>  c                       rf   )SpeechT5BatchNormConvLayerr   c                    s   t    |dkr|j}n|j}||jd kr|j}n|j}tj|||jd|jd d dd| _t	|| _
||jd k rCt | _nd | _t|j| _d S )Nr   r   r   F)ri   rj   r   rk   )rl   rm   r?  speech_decoder_postnet_unitsspeech_decoder_postnet_layersr   rq   speech_decoder_postnet_kernelru   BatchNorm1d
batch_normTanhrw   r   speech_decoder_postnet_dropoutr   )ry   rz   r{   ro   rp   r|   r*   r+   rm     s(   
z#SpeechT5BatchNormConvLayer.__init__c                 C   s6   |  |}| |}| jd ur| |}| |}|S r~   )ru   rX  rw   r   r   r*   r*   r+   r     s   




z"SpeechT5BatchNormConvLayer.forwardr   r   r*   r*   r|   r+   rS    s    rS  c                       s<   e Zd Z fddZdejfddZdejfddZ  ZS )SpeechT5SpeechDecoderPostnetc                    s^   t     | _t j j j | _t j j| _	t
 fddt jD | _d S )Nc                    s   g | ]}t  |qS r*   )rS  r   r   r*   r+   rD     s    z9SpeechT5SpeechDecoderPostnet.__init__.<locals>.<listcomp>)rl   rm   rz   r   r  r   r?  r.   feat_outprob_outr  rO   rU  rB  r  r|   r   r+   rm     s   

z%SpeechT5SpeechDecoderPostnet.__init__r   c                 C   sJ   |  ||dd| jj}| |}| ||dd}|||fS )Nr   r#   )r\  r   r   rz   r?  postnetr]  )ry   r   outputs_before_postnetoutputs_after_postnetlogitsr*   r*   r+   r     s   

z$SpeechT5SpeechDecoderPostnet.forwardc                 C   s0   | dd}| jD ]}||}q	|| dd S r   )r   rB  )ry   r   layer_outputr   r*   r*   r+   r^    s   

z$SpeechT5SpeechDecoderPostnet.postnet)	r   r   r   rm   r   r   r   r^  r   r*   r*   r|   r+   r[    s    r[  c                       s,   e Zd Z fddZdejfddZ  ZS )SpeechT5TextEncoderPrenetc                    s>   t    || _t|j|j|j| _t	|j
|j|j| _d S r~   )rl   rm   rz   r   r   
vocab_sizer   r!   embed_tokensr   rD  max_text_positionsrE  r  r|   r*   r+   rm     s   

z"SpeechT5TextEncoderPrenet.__init__r    c                 C   s   |  |}| |}|S r~   )re  rE  )ry   r    rK  r*   r*   r+   r     r   z!SpeechT5TextEncoderPrenet.forward)r   r   r   rm   r   r   r   r   r*   r*   r|   r+   rc    s    
rc  c                       sD   e Zd Z fddZ		d	dejdeej dee fddZ	  Z
S )
SpeechT5TextDecoderPrenetc                    sn   t    || _t|j| _|jrt	|j
nd| _t|j|j
|j| _t|j|j d |j
|j| _d S )Nr   r   )rl   rm   rz   r   r   rD  r   scale_embeddingr   sqrtr   embed_scaler   rd  r!   re  r   rf  embed_positionsr  r|   r*   r+   rm     s   

z"SpeechT5TextDecoderPrenet.__init__Nr    r/   past_key_valuesc                 C   s   |d ur|  }|d|d }ntdd}|d ur-t|ts)|d d jd n| }| ||}| || j	 }||7 }| 
|}||fS )Nr#   z'You have to specify `decoder_input_ids`r   r   )r   r   r'   
isinstancer
   r%   get_seq_lengthrk  re  rj  r   )ry   r    r/   rl  input_shaper   	positionsrK  r*   r*   r+   r     s   
z!SpeechT5TextDecoderPrenet.forwardr;  )r   r   r   rm   r   r   r   r<  r
   r   r   r*   r*   r|   r+   rg    s    rg  c                       s<   e Zd Z fddZdejfddZdd Zdd	 Z  Z	S )
SpeechT5TextDecoderPostnetc                    s*   t    || _tj|j|jdd| _d S )NFrk   )rl   rm   rz   r   r  r   rd  lm_headr  r|   r*   r+   rm   ;  s   
z#SpeechT5TextDecoderPostnet.__init__r   c                 C   s
   |  |S r~   rs  r   r*   r*   r+   r   @     
z"SpeechT5TextDecoderPostnet.forwardc                 C      | j S r~   rt  r  r*   r*   r+   get_output_embeddingsC  s   z0SpeechT5TextDecoderPostnet.get_output_embeddingsc                 C   s
   || _ d S r~   rt  ry   new_embeddingsr*   r*   r+   set_output_embeddingsH  ru  z0SpeechT5TextDecoderPostnet.set_output_embeddings)
r   r   r   rm   r   r   r   rw  rz  r   r*   r*   r|   r+   rq  :  s
    rq  c                       s   e Zd ZdZ				ddededee d	ee d
ee dee f fddZe	dddd							dde
jdee
j dee dee
j dee
j dee
j dedee
j dee
jee
j ee f fddZ  ZS )SpeechT5Attentionz
    Multi-headed attention from 'Attention Is All You Need' paper with relative position bias (see
    https://aclanthology.org/N18-2074.pdf)
    r1   FTN	embed_dim	num_headsr   
is_decoderrk   	layer_idxc                    s   t    || _|| _|| _|| | _| j| | jkr'td| j d| d| jd | _|| _|| _	t
j|||d| _t
j|||d| _t
j|||d| _t
j|||d| _d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      rr  )rl   rm   r|  r}  r   head_dimr'   scalingr~  r  r   r  k_projv_projq_projout_proj)ry   r|  r}  r   r~  rk   r  r|   r*   r+   rm   R  s$   
	

zSpeechT5Attention.__init__past_key_valuerl  4.58new_nameversionr   key_value_statesr/   layer_head_maskposition_biasoutput_attentionscache_positionr7   c	                 C   s  |du}	|  \}
}}| || j }d}|dur1t|tr/|j| j}|	r+|j}n|j	}n|}|	r5|n|}|	rN|durN|rN|j
| j j}|j
| j j}nJ| |}| |}||
d| j| jdd}||
d| j| jdd}|dur|	s||nd}|||| jd|i\}}|	rt|trd|j| j< |
| j d| jf}||
|| j| jdd}|j| }|j| }|j| }| d}t||dd}|  |
| j ||fkrtd|
| j ||f d	|   |dur#| |
| j d| jd
d}t||dd}|d
d|
| j | d
| d}||7 }|durX|  |
d||fkrCtd|
d||f d	|   ||
| j||| }||
| j ||}tjj|dd}|dur|  | jfkr|td| jf d	|   |dddd||
| j|| }||
| j ||}|r||
| j||}||
| j ||}nd}tjj|| j| jd}t||}|  |
| j || jfkrtd|
| j|| jf d	|   ||
| j|| j}|dd}||
|| j}|  |}||fS )z#Input shape: Batch x Time x ChannelNFr#   r   r   r  Tz$Attention weights should be of size z	, but is r   r   z!Attention mask should be of size r   z/Head mask for a single layer should be of size )r   r
  z `attn_output` should be of size )!r   r  r  rm  r   
is_updatedgetr  cross_attention_cacheself_attention_cacherB  keysvaluesr  r  r   r}  r  r   updater[   r   bmmr'   
contiguousmatmulr   rO  softmaxr   r
  r|  r  )ry   r   r  rl  r/   r  r  r  r  is_cross_attentionr   tgt_lenrC   query_statesr  curr_past_key_valuecurrent_states
key_statesvalue_states
proj_shapesrc_lenattn_weights	reshape_qrel_pos_biasattn_weights_reshaped
attn_probsattn_outputr*   r*   r+   r   o  s   







"

"
zSpeechT5Attention.forward)r1   FTN)NNNNNFN)r   r   r   r   r9   r   r   rQ   rm   r   r   r   r
   tupler   r   r*   r*   r|   r+   r{  L  s\    		
r{  c                       r   )SpeechT5FeedForwardc                    sl   t    t|j| _t|j|| _t	|j
tr!t|j
 | _n|j
| _t||j| _t|j| _d S r~   )rl   rm   r   r   activation_dropoutintermediate_dropoutr  r   intermediate_denserm  
hidden_actstrr	   intermediate_act_fnoutput_densehidden_dropoutoutput_dropout)ry   rz   intermediate_sizer|   r*   r+   rm     s   
zSpeechT5FeedForward.__init__c                 C   s6   |  |}| |}| |}| |}| |}|S r~   )r  r  r  r  r  r   r*   r*   r+   r      s   




zSpeechT5FeedForward.forwardr   r*   r*   r|   r+   r    s    r  c                       s^   e Zd Zdef fddZ				ddejdeej deej d	eej d
ef
ddZ	  Z
S )SpeechT5EncoderLayerrz   c                    sj   t    t|j|j|jdd| _t|j	| _
tj|j|jd| _t||j| _tj|j|jd| _d S )NF)r|  r}  r   r~  r  )rl   rm   r{  r   encoder_attention_headsattention_dropout	attentionr   r   r  r   r   r  r   r  encoder_ffn_dimfeed_forwardfinal_layer_normr  r|   r*   r+   rm     s   
zSpeechT5EncoderLayer.__init__NFr   r/   r  r  r  c           	      C   sh   |}| j |||||d\}}| |}|| }| |}|| | }| |}|f}|r2||f7 }|S )as  
        Args:
            hidden_states (`torch.FloatTensor`):
                input to the layer of shape `(batch, seq_len, hidden_size)`
            attention_mask (`torch.FloatTensor`):
                attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very
                large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(config.encoder_attention_heads,)`.
            position_bias (`torch.FloatTensor`):
                relative position embeddings of size `(seq_len, seq_len, hidden_size // encoder_attention_heads)`
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r   r/   r  r  r  )r  r   r   r  r  )	ry   r   r/   r  r  r  residualr  outputsr*   r*   r+   r     s"   




zSpeechT5EncoderLayer.forward)NNNF)r   r   r   r   rm   r   r   r   rQ   r   r   r*   r*   r|   r+   r  
  s"    r  c                       s   e Zd Zddef fddZedddd									
	ddejdeej deej deej deej deej dee	 dee
 dee
 deej fddZ  ZS )SpeechT5DecoderLayerNrz   c                    s   t    t|j|j|jd|d| _t|j	| _
tj|j|jd| _t|j|j|jd|d| _tj|j|jd| _t||j| _tj|j|jd| _d S )NT)r|  r}  r   r~  r  r  )r   r~  r  )rl   rm   r{  r   decoder_attention_headsr  	self_attnr   r   r  r   r   r  self_attn_layer_normencoder_attnencoder_attn_layer_normr  decoder_ffn_dimr  r  )ry   rz   r  r|   r*   r+   rm   H  s(   
zSpeechT5DecoderLayer.__init__r  rl  r  r  FTr   r/   encoder_hidden_statesencoder_attention_maskr  cross_attn_layer_head_maskr  	use_cacher  c              	   C   s   |}| j ||||||
d\}}| |}|| }| |}d}|durA|}| j|||||||
d\}}| |}|| }| |}|| | }| |}|f}|rX|||f7 }|S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            encoder_hidden_states (`torch.FloatTensor`):
                cross attention input to the layer of shape `(batch, seq_len, hidden_size)`
            encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
                size `(decoder_attention_heads,)`.
            past_key_values (`Cache`): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r   rl  r/   r  r  r  N)r   r  r/   r  rl  r  r  )r  r   r  r  r  r  r  )ry   r   r/   r  r  r  r  rl  r  r  r  r  self_attn_weightscross_attn_weightsr  r*   r*   r+   r   `  sB    




	

zSpeechT5DecoderLayer.forwardr~   )	NNNNNNFTN)r   r   r   r   rm   r   r   r   r   r
   rQ   r   r   r*   r*   r|   r+   r  G  sB    	
r  c                   @   s2   e Zd ZU eed< dZdZdZdej	fddZ
dS )	SpeechT5PreTrainedModelrz   speecht5r-   Tmodulec              	   C   s  | j j}t|tr-tjj|jjddt	
d|jjd |jj   d tj|jjd nt|tr:|jjd nt|trat	
d|jj }tjj|jj| |d tjj|jj| |d n}t|tjr||jjjd|d |jdur{|jj  nbt|tjtjtjfr|jj  |jjd nIt|tjrtj|j |jdurt	
|j|j|jd   }tjj|j| |d nt|tjr|jjjd|d |j dur|jj|j    t!|d	rtj|j" dS dS )
zInitialize the weightsr   r   r   meanstdr   )abr1   Nr  )#rz   initializer_rangerm  r   r   initnormal_ru   r   r   ri  ri   in_channels	constant_rk   r   r   datafill_r  r  in_featuresr  r  zero_r   r   rW  rq   kaiming_normal_r   r   r   r   r  )ry   r  r  kr*   r*   r+   _init_weights  sF   
 





z%SpeechT5PreTrainedModel._init_weightsN)r   r   r   r   __annotations__base_model_prefixmain_input_namesupports_gradient_checkpointingr   Moduler  r*   r*   r*   r+   r    s   
 r  c                       z   e Zd ZdZdef fddZ					ddejdeej	 deej	 d	ee
 d
ee
 dee
 deeef fddZ  ZS )SpeechT5Encoderzu
    Transformer encoder consisting of *config.encoder_layers* layers. Each layer is a [`SpeechT5EncoderLayer`].
    rz   c                    s~   t    tj j jd| _t j| _	 j
| _t fddt jD | _t j j  j| _d| _|   d S )Nr  c                    s   g | ]}t  qS r*   )r  rA   r   r*   r+   rD         z,SpeechT5Encoder.__init__.<locals>.<listcomp>F)rl   rm   r   r   r   r  r   r   r  r   encoder_layerdrop	layerdropr  rO   encoder_layersrB  r   r  encoder_max_relative_positionrk  r  	post_initr  r|   r   r+   rm     s    zSpeechT5Encoder.__init__Nr   r/   	head_maskr  output_hidden_statesreturn_dictr7   c                 C   s  |dur|n| j j}|dur|n| j j}|dur|n| j j}|dur(t||j}| |}| |}| |}t	 p=t
| }|rBdnd}	|rHdnd}
|durk| d t| jkrktdt| j d| d  dt| jD ]@\}}|r{|	|f }	d}| jrtg }|| jk }|r|r|||||dur|| nd|d}|d }|rd	}|r|
|d
 f }
qp|r|	|f }	|stdd ||	|
fD S t||	|
dS )a  
        Args:
            hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`):
                Features extracted from the speech or text input by the encoder prenet.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing convolution and attention on padding token indices. Mask values selected in
                `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nr*   r   z&The head_mask should be specified for  layers, but it is for .F)r/   r  r  r  r;  r   c                 s       | ]	}|d ur|V  qd S r~   r*   rB   vr*   r*   r+   	<genexpr>O  s    z*SpeechT5Encoder.forward.<locals>.<genexpr>last_hidden_stater   
attentions)rz   r  r  use_return_dictr   rF   r   r   rk  r   r   r   rT   rB  r'   	enumerater
  r   rJ   r  r  r   )ry   r   r/   r  r  r  r  r  synced_gpusall_hidden_statesall_self_attentionsidxencoder_layerskip_the_layerdropout_probabilitylayer_outputsr*   r*   r+   r     sb   $







zSpeechT5Encoder.forwardNNNNNr   r   r   r   r   rm   r   r=  r   r   rQ   r   r  r   r   r   r*   r*   r|   r+   r    s.    
r  c                       r  )SpeechT5EncoderWithSpeechPrenetz
    Wrapper around SpeechT5Encoder that applies SpeechT5SpeechEncoderPrenet to convert the audio waveform data to
    hidden features.
    rz   c                    ,   t  | t|| _t|| _|   d S r~   )rl   rm   r  prenetr  wrapped_encoderr  r  r|   r*   r+   rm   ^     

z(SpeechT5EncoderWithSpeechPrenet.__init__Nr-   r/   r  r  r  r  r7   c           	      C   s*   |  ||\}}| j||||||d}|S N)r   r/   r  r  r  r  r  r  	ry   r-   r/   r  r  r  r  r   r  r*   r*   r+   r   f  s   		z'SpeechT5EncoderWithSpeechPrenet.forwardr  r  r*   r*   r|   r+   r  X  s.    
r  c                       s   e Zd ZdZdef fddZdd Zdd Z										dd
ej	de
ej de
ej de
e de
e de
e deeef fddZ  ZS )SpeechT5EncoderWithTextPrenetz|
    Wrapper around SpeechT5Encoder that applies SpeechT5TextEncoderPrenet to convert the input_ids to hidden features.
    rz   c                    r  r~   )rl   rm   rc  r  r  r  r  r  r|   r*   r+   rm     r	  z&SpeechT5EncoderWithTextPrenet.__init__c                 C   
   | j  S r~   r  get_input_embeddingsr  r*   r*   r+   r    ru  z2SpeechT5EncoderWithTextPrenet.get_input_embeddingsc                 C      | j | d S r~   r  set_input_embeddingsry   valuer*   r*   r+   r       z2SpeechT5EncoderWithTextPrenet.set_input_embeddingsNr-   r/   r  r  r  r  r7   c           	      C   s$   |  |}| j||||||d}|S r
  r  r  r*   r*   r+   r     s   
		z%SpeechT5EncoderWithTextPrenet.forwardr  )r   r   r   r   r   rm   r  r  r   r=  r   r   rQ   r   r  r   r   r   r*   r*   r|   r+   r  }  s2    
r  c                       r  )SpeechT5EncoderWithoutPrenet
    This wrapper class is a helper class to correctly load pretrained checkpoints when used in combination with
    [`SpeechT5Model`].
    rz   c                    "   t  | t|| _|   d S r~   )rl   rm   r  r  r  r  r|   r*   r+   rm        
z%SpeechT5EncoderWithoutPrenet.__init__Nr-   r/   r  r  r  r  r7   c                 C   s   | j ||||||dS r
  )r  )ry   r-   r/   r  r  r  r  r*   r*   r+   r     s   	z$SpeechT5EncoderWithoutPrenet.forwardr  r  r*   r*   r|   r+   r    s.    

r  c                          e Zd ZdZdef fddZ												ddeej deej	 deej d	eej	 d
eej
 deej
 dee dee dee dee dee deej
 deeef fddZ  ZS )SpeechT5Decoderzt
    Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`SpeechT5DecoderLayer`]
    rz   c                    sF   t     j| _t fddt jD | _d| _	| 
  d S )Nc                    r   ))r  )r  r   r   r*   r+   rD         z,SpeechT5Decoder.__init__.<locals>.<listcomp>F)rl   rm   decoder_layerdropr  r   r  rO   decoder_layersrB  r  r  r  r|   r   r+   rm     s
    zSpeechT5Decoder.__init__Nr   r/   r  r  r  cross_attn_head_maskrl  r  r  r  r  r  r7   c                 C   s  |	dur|	n| j j}	|
dur|
n| j j}
|dur|n| j j}|dur$|n| j j}| dd }| jr?| jr?|r?t	d d}|rR|du rRt
t| j dt| j d}|rct|trct	d t
|}|durk| nd}t||||}|dur|durt||j|d d}t pt| }|
rd	nd}|	rd	nd}|	r|durd	nd}t||gd
dgD ](\}}|dur| d t| jkrtd| dt| j d| d  dqt| jD ]Y\}}|
r||f }d}| jrtg }|| jk }|r|sq||||||dur|| nd|dur|| nd||	||d
}|d }|	r3||d f }|dur3||d f }q|
r<||f }|sMtdd |||||fD S t|||||dS )aA  
        Args:
            hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`):
                Features extracted from the speech or text input by the decoder prenet.
            attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
                Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
                of the decoder.
            encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
                Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
                selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
                cross-attention on hidden heads. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
                It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).

                Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
                cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.

                If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
                that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
                all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        Nr#   zZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...Fr   zPassing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.r   )r  r*   r  r   zThe `z` should be specified for r  r  )r  r  r  rl  r  r  r  r   r   c                 s   r  r~   r*   r  r*   r*   r+   r  y  s    z*SpeechT5Decoder.forward.<locals>.<genexpr>)r  rl  r   r  cross_attentions)rz   r  r  r  r  r   r  r
  loggerwarning_oncer   r   rm  r  from_legacy_cachern  r   r   rF   r   r   r2  rT   rB  r'   r  r   rJ   r  r   )ry   r   r/   r  r  r  r   rl  r  r  r  r  r  ro  r   r  r  r  all_cross_attentions	attn_mask	mask_namer  decoder_layerr   r  r  r*   r*   r+   r     s   H






zSpeechT5Decoder.forwardNNNNNNNNNNNNr   r   r   r   r   rm   r   r   r=  r<  r   r
   rQ   r   r  r   r   r   r*   r*   r|   r+   r    sT    	

r  c                       s   e Zd ZdZdef fddZ													ddeej deej	 deej d	eej	 d
eej
 deej
 deej
 dee dee dee dee dee deej
 deeef fddZ  ZS )SpeechT5DecoderWithSpeechPrenetz
    Wrapper around SpeechT5Decoder that applies SpeechT5SpeechDecoderPrenet to convert log-mel filterbanks to hidden
    features.
    rz   c                    r  r~   )rl   rm   r>  r  r  wrapped_decoderr  r  r|   r*   r+   rm     r	  z(SpeechT5DecoderWithSpeechPrenet.__init__Nr-   r/   r  r  rN  r  r   rl  r  r  r  r  r  r7   c                 C   s2   |  ||}| j||||||||	|
|||d}|S N)r   r/   r  r  r  r   rl  r  r  r  r  r  r  r,  )ry   r-   r/   r  r  rN  r  r   rl  r  r  r  r  r  decoder_hidden_statesr  r*   r*   r+   r     s    z'SpeechT5DecoderWithSpeechPrenet.forward)NNNNNNNNNNNNNr*  r*   r*   r|   r+   r+    sZ    
	

r+  c                       s   e Zd ZdZdef fddZdd Zdd Z																								dd
ee	j
 dee	j dee	j
 dee	j dee	j dee	j dee dee dee dee dee dee	j deeef fddZ  ZS )SpeechT5DecoderWithTextPrenetz{
    Wrapper around SpeechT5Decoder that applies SpeechT5TextDecoderPrenet to convert input tokens to hidden features.
    rz   c                    r  r~   )rl   rm   rg  r  r  r,  r  r  r|   r*   r+   rm     r	  z&SpeechT5DecoderWithTextPrenet.__init__c                 C   r  r~   r  r  r*   r*   r+   r    ru  z2SpeechT5DecoderWithTextPrenet.get_input_embeddingsc                 C   r  r~   r  r  r*   r*   r+   r    r  z2SpeechT5DecoderWithTextPrenet.set_input_embeddingsNr-   r/   r  r  r  r   rl  r  r  r  r  r  r7   c                 C   s8   |  |||\}}| j|||||||||	|
||d}|S r-  r.  )ry   r-   r/   r  r  r  r   rl  r  r  r  r  r  r/  r  r*   r*   r+   r     s    z%SpeechT5DecoderWithTextPrenet.forwardr)  )r   r   r   r   r   rm   r  r  r   r   r=  r<  r   r
   rQ   r   r  r   r   r   r*   r*   r|   r+   r0    sX    	

r0  c                       r  )SpeechT5DecoderWithoutPrenetr  rz   c                    r  r~   )rl   rm   r  r,  r  r  r|   r*   r+   rm     r  z%SpeechT5DecoderWithoutPrenet.__init__Nr-   r/   r  r  r  r   rl  r  r  r  r  r  r7   c                 C   s&   | j |||||||||	|
||d}|S r-  )r,  )ry   r-   r/   r  r  r  r   rl  r  r  r  r  r  r  r*   r*   r+   r     s   z$SpeechT5DecoderWithoutPrenet.forwardr)  r*  r*   r*   r|   r+   r1    sT    		

r1  c                       s\   e Zd ZdZdef fddZdejdejdejdej	fd	d
Z
dd Zedd Z  ZS )$SpeechT5GuidedMultiheadAttentionLossz
    Guided attention loss from the paper [Efficiently Trainable Text-to-Speech System Based on Deep Convolutional
    Networks with Guided Attention](https://huggingface.co/papers/1710.08969), adapted for multi-head attention.
    rz   c                    s   t    |j| _|j| _d S r~   )rl   rm   guided_attention_loss_sigmasigmaguided_attention_loss_scalescaler  r|   r*   r+   rm   #  s   
z-SpeechT5GuidedMultiheadAttentionLoss.__init__r  input_masksoutput_masksr7   c                 C   sX   |  |||j}|d|d@ }||jd}|| }t||}| j| S )aY  
        Compute the attention loss.

        Args:
            attentions (`torch.FloatTensor` of shape `(batch_size, layers * heads, output_sequence_length, input_sequence_length)`):
                Batch of multi-head attention weights
            input_masks (`torch.BoolTensor` of shape `(batch_size, input_sequence_length)`):
                Input attention mask as booleans.
            output_masks (`torch.BoolTensor` of shape `(batch_size, output_sequence_length)`):
                Target attention mask as booleans.

        Returns:
            `torch.Tensor` with the loss value
        r#   r   r   )_make_guided_attention_masksr   r   r   r   r  masked_selectr6  )ry   r  r7  r8  guided_attn_masksmaskslosseslossr*   r*   r+   r   (  s   
z,SpeechT5GuidedMultiheadAttentionLoss.forwardc           
      C   s   | d}| d}tjt||jd |jd f|d}tt||D ]\}\}}	| ||	| j|||d |	d |f< q#|	dS )Nr#   r   r)  )
rM   r   rP   rT   r%   r  r2  _make_guided_attention_maskr4  r   )
ry   r7  r8  r   r^   r-  r;  r  ilenolenr*   r*   r+   r9  A  s   

$&
zASpeechT5GuidedMultiheadAttentionLoss._make_guided_attention_masksc                 C   sd   t jt j| |dt j||ddd\}}| | }| |  }dt || d  d|d    S )Nr)  xy)indexingr   r   )r   meshgridrS   r   r   )r;   output_lengthr4  r   grid_ygrid_xr*   r*   r+   r?  L  s   
$z@SpeechT5GuidedMultiheadAttentionLoss._make_guided_attention_mask)r   r   r   r   r   rm   r   r=  
BoolTensorr   r   r9  r   r?  r   r*   r*   r|   r+   r2    s    
r2  c                       sb   e Zd ZdZdef fddZ	ddejdejdejd	ejd
ejde	ej dej
fddZ  ZS )SpeechT5SpectrogramLossz;
    Loss computation used by SpeechT5ForTextToSpeech.
    rz   c                    sT   t    |j| _|j| _|j| _t | _tt	dd| _
| jr(t|| _d S d S )Ng      @)
pos_weight)rl   rm   use_guided_attention_lossguided_attention_loss_num_headsr.   r   l1_criterionr   r   r   bce_criterionr2  attn_criterionr  r|   r*   r+   rm   ]  s   
z SpeechT5SpectrogramLoss.__init__Nr/   r_  r`  ra  labelsr!  r7   c                    s<  |dk}| |}| |}| |} || || }|d d d d df }	tj|	 d t|	dd|	jgdd}
|
d d dd f  |	}
| |	} ||
}|| } j	rtj fdd|D dd}|dk}|d d d d df } j
dkr|d d  j
d d  j
f } |||}||7 }|S )Nr0   r   r   r   r   c                    s"   g | ]}|d d d  j f qS r~   )rL  )rB   xr  r*   r+   rD     s   " z3SpeechT5SpectrogramLoss.forward.<locals>.<listcomp>)r:  rM  r   r   rV   r   r   r   rN  rK  r.   rO  )ry   r/   r_  r`  ra  rP  r!  r&  l1_lossr<  stop_labelsbce_lossr>  attnr7  r8  	attn_lossr*   r  r+   r   i  s(   	


.

zSpeechT5SpectrogramLoss.forwardr~   )r   r   r   r   r   rm   r   r<  r=  r   r   r   r   r*   r*   r|   r+   rI  X  s&    rI  zv
    The bare SpeechT5 Encoder-Decoder Model outputting raw hidden-states without any specific pre- or post-nets.
    custom_introc                $       s0  e Zd Z		d!dedeej deej f fddZdd Zd	d
 Z	dd Z
dd Ze															d"deej deej deej deej deej deej deej deeeej   dee dee deej dee dee dee deej deeej ef f dd Z  ZS )#SpeechT5ModelNrz   encoderdecoderc                    sJ   t  | || _|du rt|n|| _|du rt|n|| _|   dS )z
        encoder (`PreTrainedModel`, *optional*):
            The encoder model to use.
        decoder (`PreTrainedModel`, *optional*):
            The decoder model to use.
        N)rl   rm   rz   r  rZ  r1  r[  r  )ry   rz   rZ  r[  r|   r*   r+   rm     s
   zSpeechT5Model.__init__c                 C   s0   t | jtr| j S t | jtr| j S tr~   )rm  rZ  r  r  r[  r0  NotImplementedErrorr  r*   r*   r+   r    s
   

z"SpeechT5Model.get_input_embeddingsc                 C   s8   t | jtr| j| t | jtr| j| d S d S r~   )rm  rZ  r  r  r[  r0  r  r*   r*   r+   r    s
   z"SpeechT5Model.set_input_embeddingsc                 C   rv  r~   )rZ  r  r*   r*   r+   get_encoder  s   zSpeechT5Model.get_encoderc                 C   s    t | jtr| jj  dS dS z
        Calling this function will disable the gradient computation for the feature encoder so that its parameter will
        not be updated during training.
        N)rm  rZ  r  r  r   r  r*   r*   r+   r     s   z$SpeechT5Model.freeze_feature_encoderr-   r/   decoder_input_valuesdecoder_attention_maskr  decoder_head_maskr   encoder_outputsrl  r  rN  r  r  r  r  r7   c                 C   sh  |dur|n| j j}|dur|n| j j}|
dur|
n| j j}
|dur$|n| j j}|du r8| j||||||d}n$|r\t|ts\t|d t|dkrM|d ndt|dkrX|d ndd}|durtt| jt	rt| jj
|d jd |}n|}t| jtrd|i}ni }| jd
|||d ||||	|
||||d|}|s|| S t|j|j|j|j|j|j|j|jd	S )a  
        input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`):
            Depending on which encoder is being used, the `input_values` are either: float values of the input raw
            speech waveform, or indices of input sequence tokens in the vocabulary, or hidden states.
        decoder_input_values (`torch.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Depending on which decoder is being used, the `decoder_input_values` are either: float values of log-mel
            filterbank features extracted from the raw speech waveform, or indices of decoder input sequence tokens in
            the vocabulary, or hidden states.
        decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will
            also be used by default.

            If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
            information on the default strategy.
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
            Tensor containing the speaker embeddings.
        N)r-   r/   r  r  r  r  r   r   r   r  rN  )r-   r/   r  r  r  r   rl  r  r  r  r  r  )r  rl  r/  decoder_attentionsr!  encoder_last_hidden_stater  encoder_attentionsr*   )rz   r  r  r  r  rZ  rm  r   rT   r  r  r"  r%   r[  r+  r   r  rl  r   r  r!  )ry   r-   r/   r_  r`  r  ra  r   rb  rl  r  rN  r  r  r  r  r  decoder_argsdecoder_outputsr*   r*   r+   r     sp   *	
zSpeechT5Model.forwardr;  NNNNNNNNNNNNNNN)r   r   r   r   r   r   r  rm   r  r  r]  r   r   r   r   r<  r=  r  r
   rQ   r   r   r   r   r*   r*   r|   r+   rY    s~    		
rY  zB
    SpeechT5 Model with a speech encoder and a text decoder.
    c                $       s  e Zd ZdgZdef fddZdd Zdd Zd	d
 Zdd Z	dd Z
e															d"deej deej deej deej deej deej deej deeeej   dee dee dee dee dee deej deej deeef f d d!Z  ZS )#SpeechT5ForSpeechToTextz#text_decoder_postnet.lm_head.weightrz   c                    \   t  | |jd u rtd| j dt|}t|}t|||| _t	|| _
|   d S )NYou are trying to instantiate a    with a configuration that does not define the vocabulary size of the language model head. Please instantiate the model as follows: `SpeechT5ForSpeechToText.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of your model's configuration.)rl   rm   rd  r'   r}   r  r0  rY  r  rq  text_decoder_postnetr  )ry   rz   speech_encodertext_decoderr|   r*   r+   rm   >     

z SpeechT5ForSpeechToText.__init__c                 C   r  r~   r  r]  r  r*   r*   r+   r]  R  ru  z#SpeechT5ForSpeechToText.get_encoderc                 C   r  r~   r  get_decoderr  r*   r*   r+   rr  U  ru  z#SpeechT5ForSpeechToText.get_decoderc                 C      |   j  dS r^  r]  r  r   r  r*   r*   r+   r   X     z.SpeechT5ForSpeechToText.freeze_feature_encoderc                 C   r  r~   )rl  rw  r  r*   r*   r+   rw  _  ru  z-SpeechT5ForSpeechToText.get_output_embeddingsc                 C   r  r~   )rl  rz  rx  r*   r*   r+   rz  b  r  z-SpeechT5ForSpeechToText.set_output_embeddingsNr-   r/   decoder_input_idsr`  r  ra  r   rb  rl  r  r  r  r  rP  r  r7   c                 C   s   |dur|n| j j}|dur|du rt|| j j| j j}| j|||||||||	|
||d|d}| |d }d}|durMt }||d| j j	|d}|sc|f|dd  }|dura|f| S |S t
|||j|j|j|j|j|j|jd	S )a(  
        input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
            into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
            (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
            To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding
            and conversion into a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details.
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Indices of decoder input sequence tokens in the vocabulary.

            Indices can be obtained using [`SpeechT5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are decoder input IDs?](../glossary#decoder-input-ids)

            SpeechT5 uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
            `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
            `past_key_values`).
        decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will
            also be used by default.

            If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
            information on the default strategy.
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
            or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
            only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

            Label indices can be obtained using [`SpeechT5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

        Example:

        ```python
        >>> from transformers import SpeechT5Processor, SpeechT5ForSpeechToText
        >>> from datasets import load_dataset

        >>> dataset = load_dataset(
        ...     "hf-internal-testing/librispeech_asr_demo", "clean", split="validation"
        ... )  # doctest: +IGNORE_RESULT
        >>> dataset = dataset.sort("id")
        >>> sampling_rate = dataset.features["audio"].sampling_rate

        >>> processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_asr")
        >>> model = SpeechT5ForSpeechToText.from_pretrained("microsoft/speecht5_asr")

        >>> # audio file is decoded on the fly
        >>> inputs = processor(audio=dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
        >>> predicted_ids = model.generate(**inputs, max_length=100)

        >>> # transcribe speech
        >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
        >>> transcription[0]
        'mister quilter is the apostle of the middle classes and we are glad to welcome his gospel'
        ```

        ```python
        >>> inputs["labels"] = processor(text_target=dataset[0]["text"], return_tensors="pt").input_ids

        >>> # compute loss
        >>> loss = model(**inputs).loss
        >>> round(loss.item(), 2)
        19.68
        ```
        NT)r-   r/   r_  r`  r  ra  r   rb  rl  r  r  r  r  r  r   r#   r   )	r>  ra  rl  r/  rc  r!  rd  r  re  )rz   r  r,   r!   r"   r  rl  r   r   rd  r   rl  r/  rc  r!  rd  r  re  )ry   r-   r/   rv  r`  r  ra  r   rb  rl  r  r  r  r  rP  r  r  ra  r>  loss_fctoutputr*   r*   r+   r   e  sR   [zSpeechT5ForSpeechToText.forwardrh  )r   r   r   _tied_weights_keysr   rm   r]  rr  r   rw  rz  r   r   r   r=  r<  r   r  r
   rQ   r   r   r   r   r*   r*   r|   r+   ri  6  sr    	

ri        ?r1         4@FmodelrN  	thresholdminlenratiomaxlenratiovocoderoutput_cross_attentionsreturn_output_lengthsc
           "   
      s  |d u rt d|d u rd|| jjk  }
n|}
|d}| jj||
dd}|j}t| jjt	r?| jjj
|d jd |
}
t|d| | jj }t|d| | jj }||d| jj}g }g }d }d}i  	 |d7 }| jj
||}| jjj|d d dd f d ||
|d|dd}|r|tj|jdd |jd}|j}| j|}||| jj| jj}|| |d d dd d f |d| jj}tj||fdd}t| j|}||k rql||k rtj|dd|k}t|d  }nt t!|} fd	d
|D }t!|dkr3t"|}|#dd$dd}| j%|}|D ]	}||  |< q)t! |kr;nqm fdd
t t! D }|	s|dkrU|d n	tj&j'j(j)|dd}|d uri||}n|}|rtj|dd}|dkr|j|t|d| g| dd  R  }||f}|S g t |D ]} ||  d q|d u rtj&j'j(j)|dd}|f}ng tj&j'j(j)|dd}||fdd
D }!|!f}|rtj|dd}|j|t|d| g| dd  R  }g ||R }|S )Na  `speaker_embeddings` must be specified. For example, you can use a speaker embeddings by following
                    the code snippet provided in this link:
                    https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors
                    r   r   T)r-   r/   r  r#   )r   r/   r  r  rl  r  r  r  r   c                    s   g | ]}| vr|qS r*   r*   r   result_spectrogramr*   r+   rD   R	  r  z$_generate_speech.<locals>.<listcomp>r   c                    s   g | ]} | qS r*   r*   r   r  r*   r+   rD   [	  r  )batch_firstc                    s&   g | ]}t d t  | qS r   )r9   r   r:   r   )spectrogram_lengths	waveformsr*   r+   rD   u	  s   & )*r'   rz   r!   r9   r   r  rZ  r  rm  r  r  r"  r%   r.   r$   r?  r[  r,  rX   r   r   r!  squeezerl  speech_decoder_postnetr\  r   sigmoidr]  rM   rJ  rN   rO   rT   stackr   flattenr^  r   r   rnnpad_sequence)"r|  r-   rN  r/   r}  r~  r  r  r  r  r  r   encoder_outrd  maxlenminlenoutput_sequencespectrogramr!  rl  r  r/  decoder_outlast_decoder_outputspectrumnew_spectrogramprobmeet_thresholdsmeet_indexesspectrograms
meet_indexr  r   waveform_lengthsr*   )r  r  r  r+   _generate_speech  s   


$
5&




r  zB
    SpeechT5 Model with a text encoder and a speech decoder.
    c                (       s  e Zd ZdZdef fddZedefddZdd	 Z	d
d Z
e																	d-deej deej deej deej deej deej deej deeeej   dee dee dee dee dee deej deej deej deej deeef f$ddZe 				 	!		"	"d.dejdeej deej d#ed$ed%ed&eej d'ed(edeejeejejf f fd)d*Ze 				 	!		"	"d.dejdeej deej d#ed$ed%ed&eej d'ed(edeejeejejf f fd+d,Z  ZS )/SpeechT5ForTextToSpeechr    rz   c                    rj  )Nrk  a    with a configuration that does not define the vocabulary size of the language model head. Please instantiate the model as follows: `SpeechT5ForTextToSpeech.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of your model's configuration.)rl   rm   rd  r'   r}   r  r+  rY  r  r[  r  r  )ry   rz   text_encoderspeech_decoderr|   r*   r+   rm   	  ro  z SpeechT5ForTextToSpeech.__init__r7   c                 C   s   dS r	  r*   )clsr*   r*   r+   can_generate	  s   z$SpeechT5ForTextToSpeech.can_generatec                 C   r  r~   rp  r  r*   r*   r+   r]  	  ru  z#SpeechT5ForTextToSpeech.get_encoderc                 C   r  r~   rq  r  r*   r*   r+   rr  	  ru  z#SpeechT5ForTextToSpeech.get_decoderNr/   r_  r`  r  ra  r   rb  rl  r  r  r  r  rN  rP  rS  r  c                 C   s   |dur|n| j j}|dur"|du rt|| j j|\}}| j jr"d}| j|||||||||	|
|||d|d}| |d \}}}d}|durUt| j }|||||||j}|sk|f|dd  }|duri|f| S |S t	|||j
|j|j|j|j|j|jd	S )a  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and
            [`~PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        decoder_input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`):
            Float values of input mel spectrogram.

            SpeechT5 uses an all-zero spectrum as the starting token for `decoder_input_values` generation. If
            `past_key_values` is used, optionally only the last `decoder_input_values` have to be input (see
            `past_key_values`).
        decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will
            also be used by default.

            If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
            information on the default strategy.
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
            Tensor containing the speaker embeddings.
        labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*):
            Float values of target mel spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss
            computation. Spectrograms can be obtained using [`SpeechT5Processor`]. See [`SpeechT5Processor.__call__`]
            for details.
        stop_labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Binary tensor indicating the position of the stop token in the sequence.

        Example:

        ```python
        >>> from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, set_seed
        >>> import torch

        >>> processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
        >>> model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
        >>> vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")

        >>> inputs = processor(text="Hello, my dog is cute", return_tensors="pt")
        >>> speaker_embeddings = torch.zeros((1, 512))  # or load xvectors from a file

        >>> set_seed(555)  # make deterministic

        >>> # generate speech
        >>> speech = model.generate(inputs["input_ids"], speaker_embeddings=speaker_embeddings, vocoder=vocoder)
        >>> speech.shape
        torch.Size([15872])
        ```
        NTr-   r/   r_  r`  r  ra  r   rb  rl  r  rN  r  r  r  r  r   r   	r>  r  rl  r/  rc  r!  rd  r  re  )rz   r  r3   r.   rK  r  r  rI  r!  r   rl  r/  rc  rd  r  re  )ry   r    r/   r_  r`  r  ra  r   rb  rl  r  r  r  r  rN  rP  rS  r  r  r_  r`  ra  r>  	criterionrx  r*   r*   r+   r   	  sf   M

	zSpeechT5ForTextToSpeech.forwardrz  r1   r{  Fr}  r~  r  r  r  r  c
                 K   s^   |dur"| d}| d|kr"| ddkr||d}ntdt| |||||||||	
S )aE  
        Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a
        speech waveform using a vocoder.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary.

                Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and
                [`~PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Attention mask from the tokenizer, required for batched inference to signal to the model where to
                ignore padded tokens from the input_ids.
            speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
                Tensor containing the speaker embeddings.
            threshold (`float`, *optional*, defaults to 0.5):
                The generated sequence ends when the predicted stop token probability exceeds this value.
            minlenratio (`float`, *optional*, defaults to 0.0):
                Used to calculate the minimum required length for the output sequence.
            maxlenratio (`float`, *optional*, defaults to 20.0):
                Used to calculate the maximum allowed length for the output sequence.
            vocoder (`nn.Module`, *optional*):
                The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel
                spectrogram.
            output_cross_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of the decoder's cross-attention layers.
            return_output_lengths (`bool`, *optional*, defaults to `False`):
                Whether or not to return the concrete spectrogram/waveform lengths.

        Returns:
            `tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
            - when `return_output_lengths` is False
                - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
                `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
                - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
                `(num_frames,)` -- The predicted speech waveform.
                - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
                `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads,
                output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
            - when `return_output_lengths` is True
                - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
                `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that
                are padded to the maximum length.
                - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `list[Int]` -- A list of
                all the concrete lengths for each spectrogram.
                - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
                `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length.
                - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `list[Int]` -- A list of all
                the concrete lengths for each waveform.
                - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
                `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads,
                output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
        Nr   r   zUThe first dimension of speaker_embeddings must be either 1 or the same as batch_size.r   rI  r'   r  )ry   r    r/   rN  r}  r~  r  r  r  r  kwargsr]   r*   r*   r+   generate0
  s(   E
z SpeechT5ForTextToSpeech.generatec
                 C   s^   |dur"| d}
| d|
kr"| ddkr||
d}ntdt| |||||||||	
S )a  
        Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a
        speech waveform using a vocoder.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary.

                Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and
                [`~PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
                Tensor containing the speaker embeddings.
            attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing convolution and attention on padding token indices. Mask values selected in
                `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            threshold (`float`, *optional*, defaults to 0.5):
                The generated sequence ends when the predicted stop token probability exceeds this value.
            minlenratio (`float`, *optional*, defaults to 0.0):
                Used to calculate the minimum required length for the output sequence.
            maxlenratio (`float`, *optional*, defaults to 20.0):
                Used to calculate the maximum allowed length for the output sequence.
            vocoder (`nn.Module`, *optional*, defaults to `None`):
                The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel
                spectrogram.
            output_cross_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of the decoder's cross-attention layers.
            return_output_lengths (`bool`, *optional*, defaults to `False`):
                Whether or not to return the concrete spectrogram/waveform lengths.

        Returns:
            `tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
            - when `return_output_lengths` is False
                - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
                `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
                - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
                `(num_frames,)` -- The predicted speech waveform.
                - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
                `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads,
                output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
            - when `return_output_lengths` is True
                - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
                `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that
                are padded to the maximum length.
                - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `list[Int]` -- A list of
                all the concrete lengths for each spectrogram.
                - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
                `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length.
                - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `list[Int]` -- A list of all
                the concrete lengths for each waveform.
                - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
                `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads,
                output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
        Nr   r   zUThe first dimension of speaker_embeddings must be either 1 or the same as batch size.r  )ry   r    rN  r/   r}  r~  r  r  r  r  r]   r*   r*   r+   generate_speech
  s(   I
z'SpeechT5ForTextToSpeech.generate_speechNNNNNNNNNNNNNNNNNNNrz  r1   r{  NFF)r   r   r   r  r   rm   classmethodrQ   r  r]  rr  r   r   r   r<  r=  r   r  r
   r   r   r   r   r   r   r  r  r  r   r*   r*   r|   r+   r  	  s    	

 	
[	
r  zD
    SpeechT5 Model with a speech encoder and a speech decoder.
    c                (       s  e Zd Zdef fddZdd Zdd Zdd	 Ze	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
d+de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	e
j de	eee
j   de	e de	e de	e de	e de	e de	e
j de	e
j de	e
j de	e
j deeef f$ddZe
 	
	
		 	!	
	"	"d,de
jde	e
j de	e
j d#ed$ed%ed&e	ej d'ed(ede
jfd)d*Z  ZS )-SpeechT5ForSpeechToSpeechrz   c                    s@   t  | t|}t|}t|||| _t|| _|   d S r~   )	rl   rm   r  r+  rY  r  r[  r  r  )ry   rz   rm  r  r|   r*   r+   rm   
  s   
z"SpeechT5ForSpeechToSpeech.__init__c                 C   r  r~   rp  r  r*   r*   r+   r]  
  ru  z%SpeechT5ForSpeechToSpeech.get_encoderc                 C   r  r~   rq  r  r*   r*   r+   rr    ru  z%SpeechT5ForSpeechToSpeech.get_decoderc                 C   rs  r^  rt  r  r*   r*   r+   r     ru  z0SpeechT5ForSpeechToSpeech.freeze_feature_encoderNr-   r/   r_  r`  r  ra  r   rb  rl  r  r  r  r  rN  rP  rS  r  r7   c                 C   s   |dur|n| j j}|dur|du rt|| j j|\}}| j|||||||||	|
|||d|d}| |d \}}}d}|sR|f|dd  }|durP|f| S |S t|||j|j|j	|j
|j|j|jd	S )a  
        input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
            Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
            into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
            (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
            To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding and conversion into
            a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details.
        decoder_input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`):
            Float values of input mel spectrogram.

            SpeechT5 uses an all-zero spectrum as the starting token for `decoder_input_values` generation. If
            `past_key_values` is used, optionally only the last `decoder_input_values` have to be input (see
            `past_key_values`).
        decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will
            also be used by default.

            If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
            information on the default strategy.
        cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
            Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
            Tensor containing the speaker embeddings.
        labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*):
            Float values of target mel spectrogram. Spectrograms can be obtained using [`SpeechT5Processor`]. See
            [`SpeechT5Processor.__call__`] for details.
        stop_labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Binary tensor indicating the position of the stop token in the sequence.

        Example:

        ```python
        >>> from transformers import SpeechT5Processor, SpeechT5ForSpeechToSpeech, SpeechT5HifiGan, set_seed
        >>> from datasets import load_dataset
        >>> import torch

        >>> dataset = load_dataset(
        ...     "hf-internal-testing/librispeech_asr_demo", "clean", split="validation"
        ... )  # doctest: +IGNORE_RESULT
        >>> dataset = dataset.sort("id")
        >>> sampling_rate = dataset.features["audio"].sampling_rate

        >>> processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_vc")
        >>> model = SpeechT5ForSpeechToSpeech.from_pretrained("microsoft/speecht5_vc")
        >>> vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")

        >>> # audio file is decoded on the fly
        >>> inputs = processor(audio=dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")

        >>> speaker_embeddings = torch.zeros((1, 512))  # or load xvectors from a file

        >>> set_seed(555)  # make deterministic

        >>> # generate speech
        >>> speech = model.generate_speech(inputs["input_values"], speaker_embeddings, vocoder=vocoder)
        >>> speech.shape
        torch.Size([77824])
        ```
        NTr  r   r   r  )rz   r  r3   r.   r  r  r   rl  r/  rc  r!  rd  r  re  )ry   r-   r/   r_  r`  r  ra  r   rb  rl  r  r  r  r  rN  rP  rS  r  r  rC   r  ra  r>  rx  r*   r*   r+   r     sN   T
z!SpeechT5ForSpeechToSpeech.forwardrz  r1   r{  Fr}  r~  r  r  r  r  c
           
      C   s2   |du rt jd|jd}t| |||||||||	
S )a'  
        Converts a raw speech waveform into a sequence of mel spectrograms, which are subsequently turned back into a
        speech waveform using a vocoder.

        Args:
            input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
                Float values of input raw speech waveform.

                Values can be obtained by loading a *.flac* or *.wav* audio file into an array of type `list[float]`,
                a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`)
                or the soundfile library (`pip install soundfile`).
                To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding and
                conversion into a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details.
            speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
                Tensor containing the speaker embeddings.
            attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing convolution and attention on padding token indices. Mask values selected in
                `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            threshold (`float`, *optional*, defaults to 0.5):
                The generated sequence ends when the predicted stop token probability exceeds this value.
            minlenratio (`float`, *optional*, defaults to 0.0):
                Used to calculate the minimum required length for the output sequence.
            maxlenratio (`float`, *optional*, defaults to 20.0):
                Used to calculate the maximum allowed length for the output sequence.
            vocoder (`nn.Module`, *optional*, defaults to `None`):
                The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel
                spectrogram.
            output_cross_attentions (`bool`, *optional*, defaults to `False`):
                Whether or not to return the attentions tensors of the decoder's cross-attention layers.
            return_output_lengths (`bool`, *optional*, defaults to `False`):
                Whether or not to return the concrete spectrogram/waveform lengths.

        Returns:
            `tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
            - when `return_output_lengths` is False
                - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
                `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
                - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
                `(num_frames,)` -- The predicted speech waveform.
                - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
                `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads,
                output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
            - when `return_output_lengths` is True
                - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
                `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that
                are padded to the maximum length.
                - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `list[Int]` -- A list of
                all the concrete lengths for each spectrogram.
                - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
                `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length.
                - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `list[Int]` -- A list of all
                the concrete lengths for each waveform.
                - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
                `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads,
                output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
        N)r   i   r)  )r   rP   r   r  )
ry   r-   rN  r/   r}  r~  r  r  r  r  r*   r*   r+   r    s   Jz)SpeechT5ForSpeechToSpeech.generate_speechr  r  )r   r   r   r   rm   r]  rr  r   r   r   r   r=  r<  r   r  r
   rQ   r   r   r   r   r   r   r  r  r   r*   r*   r|   r+   r  
  s    	

 	
r  c                       s@   e Zd Zd fdd	ZdddZd	d
 Zdd Zdd Z  ZS )HifiGanResidualBlockr   r   r      皙?c                    sb   t    |_t fddttD _t fddttD _d S )Nc                    s2   g | ]}t j  d | | dqS r   )rj   dilationr   r   rq   get_paddingr   channelsr  ri   ry   r*   r+   rD     s    	z1HifiGanResidualBlock.__init__.<locals>.<listcomp>c                    s*   g | ]}t j  d d d dqS r  r  rA   )r  ri   ry   r*   r+   rD     s    	
)	rl   rm   leaky_relu_sloper   r  rO   rT   convs1convs2)ry   r  ri   r  r  r|   r  r+   rm     s   

	
	
zHifiGanResidualBlock.__init__r   c                 C   s   || | d S r   r*   )ry   ri   r  r*   r*   r+   r  	  r  z HifiGanResidualBlock.get_paddingc                 C   sL   t jj}tt jjdrt jjj}| jD ]}|| q| jD ]}|| qd S Nr   )r   r   r   r   r   r  r  ry   r   r   r*   r*   r+   apply_weight_norm  s   




z&HifiGanResidualBlock.apply_weight_normc                 C   s4   | j D ]}tj| q| jD ]}tj| qd S r~   )r  r   r   remove_weight_normr  ry   r   r*   r*   r+   r    s
   

z'HifiGanResidualBlock.remove_weight_normc                 C   sX   t | j| jD ]"\}}|}tj|| j}||}tj|| j}||}|| }q|S r~   )r2  r  r  r   rO  
leaky_relur  )ry   r   conv1conv2r  r*   r*   r+   r     s   
zHifiGanResidualBlock.forward)r   r  r  r  )	r   r   r   rm   r  r  r  r   r   r*   r*   r|   r+   r    s    

r  z
    HiFi-GAN vocoder.
    c                       sp   e Zd ZU eed< dZdef fddZdejfddZ	dd	 Z
d
d ZedddejdejfddZ  ZS )SpeechT5HifiGanrz   r  c              
      sN  t  | t|j| _t|j| _tj|j	|j
dddd| _t | _tt|j|jD ]$\}\}}| jtj|j
d|  |j
d|d   |||| d d q-t | _tt| jD ]#}|j
d|d   }t|j|jD ]\}}| jt||||j qpq^tj|ddddd| _| dt|j	 | dt|j	 |   d S )N   r   r   )ri   rj   r   r   r  r6  )rl   rm   rT   resblock_kernel_sizesnum_kernelsupsample_ratesnum_upsamplesr   rq   model_in_dimupsample_initial_channelconv_prer  	upsamplerr  r2  upsample_kernel_sizesrX   ConvTranspose1d	resblocksrO   resblock_dilation_sizesr  r  	conv_postr   r   rP   rV   r  )ry   rz   r   upsample_rateri   r  r  r|   r*   r+   rm   0  s>   



zSpeechT5HifiGan.__init__r  c                 C   sJ   t |tjtjfr!|jjjd| jjd |j	dur#|j	j
  dS dS dS )zInitialize the weights.r1   r  N)rm  r   rq   r  r   r  r  rz   r  rk   r  )ry   r  r*   r*   r+   r  V  s   
zSpeechT5HifiGan._init_weightsc                 C   s`   t jj}tt jjdrt jjj}|| j | jD ]}|| q| jD ]}|  q"|| j	 d S r  )
r   r   r   r   r   r  r  r  r  r  r  r*   r*   r+   r  ]  s   





z!SpeechT5HifiGan.apply_weight_normc                 C   sL   t j| j | jD ]}t j| q
| jD ]}|  qt j| j d S r~   )r   r   r  r  r  r  r  r  r*   r*   r+   r  i  s   


z"SpeechT5HifiGan.remove_weight_norma  
        Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch
        of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech
        waveform.
        rW  r7   c                 C   s  | j jr|| j | j }| dk}|s|d}|dd}| |}t| j	D ]8}t
j|| j j}| j| |}| j|| j  |}td| jD ]}|| j|| j |  |7 }qK|| j }q)t
j|}| |}t|}|s|dddd}|S |d}|S )a  
        spectrogram (`torch.FloatTensor`):
            Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length,
            config.model_in_dim)`, or un-batched and of shape `(sequence_length, config.model_in_dim)`.

        Returns:
            `torch.FloatTensor`: Tensor containing the speech waveform. If the input spectrogram is batched, will be of
            shape `(batch_size, num_frames,)`. If un-batched, will be of shape `(num_frames,)`.
        r   r   r   r   r#   )rz   normalize_beforer  r6  r   r   r   r  rO   r  r   rO  r  r  r  r  r  r  r   tanhr  r   )ry   r  
is_batchedr   r   	res_statejwaveformr*   r*   r+   r   q  s,   




zSpeechT5HifiGan.forward)r   r   r   r   r  r  rm   r   r  r  r  r  r   r   r=  r   r   r*   r*   r|   r+   r  '  s   
 & r  )ri  r  r  rY  r  r  )r   Nr   r  )gr   r   typingr   r   numpyrH   r   r   torch.nnr   r   r   activationsr	   cache_utilsr
   r   r   
generationr   integrations.deepspeedr   integrations.fsdpr   modeling_attn_mask_utilsr   r   modeling_layersr   modeling_outputsr   r   r   r   r   modeling_utilsr   r   r   r   r   utils.deprecationr   configuration_speecht5r   r   
get_loggerr   r"  _HIDDEN_STATES_START_POSITIONr   r9   r,   r3   r  r   r<  ndarrayre   rg   r   r   r  r   r   r   r   r   r   r  r  r>  rS  r[  rc  rg  rq  r{  r  r  r  r  r  r  r  r  r  r+  r0  r1  r2  rI  rY  ri  r=  rQ   r  r  r  r  r  __all__r*   r*   r*   r+   <module>   s  



xE-) 4(, '=f+%*  B26-;=  :	

   j x>w