o
    id                     @   s   d Z ddlmZ ddlmZmZ ddlZddlmZ ddlm	Z	 ddl
mZ dd	lmZmZ d
dlmZ eeZeeddG dd de	ZG dd dejZG dd dejZG dd dejZG dd dejZeG dd deZdgZdS )zPyTorch UnivNetModel model.    )	dataclass)OptionalUnionN)nn   )ModelOutput)PreTrainedModel)auto_docstringlogging   )UnivNetConfigz
    Output class for the [`UnivNetModel`], which includes the generated audio waveforms and the original unpadded
    lengths of those waveforms (so that the padding can be removed by [`UnivNetModel.batch_decode`]).
    )custom_introc                   @   s6   e Zd ZU dZdZeej ed< dZ	eej ed< dS )UnivNetModelOutputa"  
    waveforms (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
        Batched 1D (mono-channel) output audio waveforms.
    waveform_lengths (`torch.FloatTensor` of shape `(batch_size,)`):
        The batched length in samples of each unpadded waveform in `waveforms`.
    N	waveformswaveform_lengths)
__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r    r   r   i/home/ubuntu/veenaModal/venv/lib/python3.10/site-packages/transformers/models/univnet/modeling_univnet.pyr      s   
 r   c                       sF   e Zd ZdZdef fddZdejfddZdd	 Z	d
d Z
  ZS )#UnivNetKernelPredictorResidualBlockz
    Implementation of the residual block for the kernel predictor network inside each location variable convolution
    block (LVCBlock).

    Parameters:
        config: (`UnivNetConfig`):
            Config for the `UnivNetModel` model.
    configc                    s   t    |j| _|j| _|j| _|j| _| jd d }t	
| j| _t	j| j| j| j|dd| _t	j| j| j| j|dd| _d S )Nr      Tpaddingbias)super__init__model_in_channelschannelskernel_predictor_conv_sizekernel_sizekernel_predictor_dropoutdropout_probleaky_relu_sloper   DropoutdropoutConv1dconv1conv2)selfr   r   	__class__r   r   r!   <   s   
 z,UnivNetKernelPredictorResidualBlock.__init__hidden_statesc                 C   sJ   |}|  |}| |}tj|| j}| |}tj|| j}|| S N)r*   r,   r   
functional
leaky_relur(   r-   )r.   r1   residualr   r   r   forwardL   s   


z+UnivNetKernelPredictorResidualBlock.forwardc                 C   s8   t jj}tt jjdrt jjj}|| j || j d S Nweight_norm)r   utilsr8   hasattrparametrizationsr,   r-   r.   r8   r   r   r   apply_weight_normV   s
   

z5UnivNetKernelPredictorResidualBlock.apply_weight_normc                 C   s    t j| j t j| j d S r2   )r   r9   remove_weight_normr,   r-   r.   r   r   r   r>   ^   s   z6UnivNetKernelPredictorResidualBlock.remove_weight_norm)r   r   r   r   r   r!   r   r   r6   r=   r>   __classcell__r   r   r/   r   r   2   s    	
r   c                       sT   e Zd ZdZ		ddededef fddZd	ejfd
dZ	dd Z
dd Z  ZS )UnivNetKernelPredictora  
    Implementation of the kernel predictor network which supplies the kernel and bias for the location variable
    convolutional layers (LVCs) in each UnivNet LVCBlock.

    Based on the KernelPredictor implementation in
    [maum-ai/univnet](https://github.com/maum-ai/univnet/blob/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/model/lvcnet.py#L7).

    Parameters:
        config: (`UnivNetConfig`):
            Config for the `UnivNetModel` model.
        conv_kernel_size (`int`, *optional*, defaults to 3):
            The kernel size for the location variable convolutional layer kernels (convolutional weight tensor).
        conv_layers (`int`, *optional*, defaults to 4):
            The number of location variable convolutional layers to output kernels and biases for.
    r      r   conv_kernel_sizeconv_layersc                    s   t     j| _d j | _|| _|| _| j| j | j | j | _| j| j | _ j	| _
 j| _ j| _ j| _ j| _| jd d }tj| j
| jdddd| _t fddt| jD | _tj| j| j| j|dd| _tj| j| j| j|dd| _d S )Nr   r      Tr   c                    s   g | ]}t  qS r   )r   ).0_r   r   r   
<listcomp>   s    z3UnivNetKernelPredictor.__init__.<locals>.<listcomp>)r    r!   model_hidden_channelsconv_in_channelsconv_out_channelsrC   rD   kernel_channelsbias_channelsnum_mel_binsresnet_in_channels kernel_predictor_hidden_channelsresnet_hidden_channelsr$   resnet_kernel_sizekernel_predictor_num_blocks
num_blocksr(   r   r+   
input_conv
ModuleListrange	resblockskernel_conv	bias_conv)r.   r   rC   rD   r   r/   rH   r   r!   t   s,   
 zUnivNetKernelPredictor.__init__spectrogramc                 C   s   |j \}}}| |}tj|| j}| jD ]}||}q| |}| |}|	|| j
| j| j| j| }	|	|| j
| j| }
|	|
fS )a  
        Maps a conditioning log-mel spectrogram to a tensor of convolutional kernels and biases, for use in location
        variable convolutional layers. Note that the input spectrogram should have shape (batch_size, input_channels,
        seq_length).

        Args:
            spectrogram (`torch.FloatTensor` of shape `(batch_size, input_channels, seq_length)`):
                Tensor containing the log-mel spectrograms.

        Returns:
            tuple[`torch.FloatTensor, `torch.FloatTensor`]: tuple of tensors where the first element is the tensor of
            location variable convolution kernels of shape `(batch_size, self.conv_layers, self.conv_in_channels,
            self.conv_out_channels, self.conv_kernel_size, seq_length)` and the second element is the tensor of
            location variable convolution biases of shape `(batch_size, self.conv_layers. self.conv_out_channels,
            seq_length)`.
        )shaperV   r   r3   r4   r(   rY   rZ   r[   viewrD   rK   rL   rC   
contiguous)r.   r\   
batch_sizerG   
seq_lengthr1   resblockkernel_hidden_statesbias_hidden_stateskernelsbiasesr   r   r   r6      s4   




zUnivNetKernelPredictor.forwardc                 C   sV   t jj}tt jjdrt jjj}|| j | jD ]}|  q|| j || j	 d S r7   )
r   r9   r8   r:   r;   rV   rY   r=   rZ   r[   r.   r8   layerr   r   r   r=      s   




z(UnivNetKernelPredictor.apply_weight_normc                 C   sB   t j| j | jD ]}|  q
t j| j t j| j d S r2   )r   r9   r>   rV   rY   rZ   r[   r.   rh   r   r   r   r>      s
   

z)UnivNetKernelPredictor.remove_weight_norm)r   rB   r   r   r   r   r   intr!   r   r   r6   r=   r>   r@   r   r   r/   r   rA   c   s    &.rA   c                       sr   e Zd ZdZdededef fddZddd	Z	
	ddej	dej	dej	dedef
ddZ
dd Zdd Z  ZS )UnivNetLvcResidualBlocka  
    Implementation of the location variable convolution (LVC) residual block for the UnivNet residual network.

    Parameters:
        config: (`UnivNetConfig`):
            Config for the `UnivNetModel` model.
        kernel_size (`int`):
            The kernel size for the dilated 1D convolutional layer.
        dilation (`int`):
            The dilation for the dilated 1D convolutional layer.
    r   r%   dilationc                    s\   t    |j| _|| _|| _|j| _| j| jd  d }tj| j| j| j|| jd| _	d S )Nr   r   )r   rm   )
r    r!   rJ   hidden_channelsr%   rm   r(   r   r+   conv)r.   r   r%   rm   r   r/   r   r   r!      s   
z UnivNetLvcResidualBlock.__init__   c                 C   s   |}t j|| j}| |}t j|| j}| j||||d}t|d d d | jd d f t	|d d | jd d d f  }|| }|S N)hop_size)
r   r3   r4   r(   ro   location_variable_convolutionr   sigmoidrn   tanh)r.   r1   kernelr   rr   r5   r   r   r   r6      s   
$zUnivNetLvcResidualBlock.forwardr   r1   rv   r   rr   c                 C   sB  |j \}}}|j \}}}	}
}||| kr!td||  d| d|t|
d d  }tj|||fdd}|d|d|  |}||k rPtj|d|fdd}|d||}|d	d	d	d	d	d	d	d	d	|f }|dd
}|d
|
d}t	d||}|j
tjd}|ddj
tjd}|| }| ||	d}|S )u  
        Performs location-variable convolution operation on the input sequence (hidden_states) using the local
        convolution kernel. This was introduced in [LVCNet: Efficient Condition-Dependent Modeling Network for Waveform
        Generation](https://huggingface.co/papers/2102.10815) by Zhen Zheng, Jianzong Wang, Ning Cheng, and Jing Xiao.

        Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.

        Args:
            hidden_states (`torch.FloatTensor` of shape `(batch_size, in_channels, in_length)`):
                The input sequence of shape (batch, in_channels, in_length).
            kernel (`torch.FloatTensor` of shape `(batch_size, in_channels, out_channels, kernel_size, kernel_length)`):
                The local convolution kernel of shape (batch, in_channels, out_channels, kernel_size, kernel_length).
            bias (`torch.FloatTensor` of shape `(batch_size, out_channels, kernel_length)`):
                The bias for the local convolution of shape (batch, out_channels, kernel_length).
            dilation (`int`, *optional*, defaults to 1):
                The dilation of convolution.
            hop_size (`int`, *optional*, defaults to 256):
                The hop_size of the conditioning sequence.
        Returns:
            `torch.FloatTensor`: the output sequence after performing local convolution with shape (batch_size,
            out_channels, in_length).
        z#Dim 2 of `hidden_states` should be z
) but got zX. Please check `hidden_states` or `kernel` and `hop_size` to make sure they are correct.r   r   constantr   r   NrB   zbildsk,biokl->bolsd)memory_format)r]   
ValueErrorrk   r   r3   padunfold	transposer   einsumtochannels_last_3d	unsqueezer_   r^   )r.   r1   rv   r   rm   rr   batchrG   	in_lengthout_channelsr%   kernel_lengthr   output_hidden_statesr   r   r   rs     s*   &z5UnivNetLvcResidualBlock.location_variable_convolutionc                 C   s.   t jj}tt jjdrt jjj}|| j d S r7   )r   r9   r8   r:   r;   ro   r<   r   r   r   r=   O  s   
z)UnivNetLvcResidualBlock.apply_weight_normc                 C   s   t j| j d S r2   )r   r9   r>   ro   r?   r   r   r   r>   V  s   z*UnivNetLvcResidualBlock.remove_weight_normrp   )r   rp   )r   r   r   r   r   rk   r!   r6   r   r   rs   r=   r>   r@   r   r   r/   r   rl      s2    

Arl   c                       sX   e Zd ZdZ	ddededef fddZdejd	ejfd
dZ	dd Z
dd Z  ZS )UnivNetLvcBlocka#  
    Implementation of the location variable convolution (LVC) residual block of the UnivNet residual block. Includes a
    `UnivNetKernelPredictor` inside to predict the kernels and biases of the LVC layers.

    Based on LVCBlock in
    [maum-ai/univnet](https://github.com/maum-ai/univnet/blob/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/model/lvcnet.py#L98)

    Parameters:
        config (`UnivNetConfig`):
            Config for the `UnivNetModel` model.
        layer_id (`int`):
            An integer corresponding to the index of the current LVC resnet block layer. This should be between 0 and
            `len(config.resblock_stride_sizes) - 1)` inclusive.
        lvc_hop_size (`int`, *optional*, defaults to 256):
            The hop size for the location variable convolutional layers.
    rp   r   layer_idlvc_hop_sizec                    s   t     j_ j| _ j| _ j| _	|_
 j_tj	_tjjjdj jjd jd  jd d_t jj_t fddtjD _d S )Nr   )strider   output_paddingc                    s    g | ]}t  jj| qS r   )rl   r%   	dilationsrF   ir   r.   r   r   rI     s     z,UnivNetLvcBlock.__init__.<locals>.<listcomp>)r    r!   rJ   rn   resblock_kernel_sizesr%   resblock_stride_sizesr   resblock_dilation_sizesr   cond_hop_lengthr(   lenrU   r   ConvTranspose1d	convt_prerA   kernel_predictorrW   rX   rY   )r.   r   r   r   r/   r   r   r!   l  s(   
	
zUnivNetLvcBlock.__init__r1   r\   c           	   	   C   s   t j|| j}| |}| |\}}t| jD ]/\}}|d d |d d d d d d d d f }|d d |d d d d f }||||| jd}q|S rq   )	r   r3   r4   r(   r   r   	enumeraterY   r   )	r.   r1   r\   re   rf   r   rb   rv   r   r   r   r   r6     s   
(zUnivNetLvcBlock.forwardc                 C   sL   t jj}tt jjdrt jjj}|| j | j  | jD ]}|  qd S r7   )	r   r9   r8   r:   r;   r   r   r=   rY   rg   r   r   r   r=     s   




z!UnivNetLvcBlock.apply_weight_normc                 C   s0   t j| j | j  | jD ]}|  qd S r2   )r   r9   r>   r   r   rY   ri   r   r   r   r>     s
   


z"UnivNetLvcBlock.remove_weight_normr   rj   r   r   r/   r   r   Z  s    
r   c                       s   e Zd ZU eed< dZdef fddZe				ddej	de
ej	 de
ej	 de
ej d	e
e d
eeej	 ef fddZdd Zdd Zdd Z  ZS )UnivNetModelr   input_featuresc                    s   t    t j| _ j| _tj j j	ddddd| _
t j}d}g  jD ]}|| }| q*t fddt|D | _tj j	ddddd| _|   d S )	N   r   r   reflect)r%   r   r   padding_modec                    s   g | ]}t  || d qS ))r   r   )r   r   r   hop_lengthsr   r   rI     s    z)UnivNetModel.__init__.<locals>.<listcomp>)r   r   )r    r!   r   r   num_kernelsr(   r   r+   r"   rJ   conv_prer   appendrW   rX   rY   	conv_post	post_init)r.   r   
num_layers
hop_lengthr   r/   r   r   r!     s0   


zUnivNetModel.__init__Nnoise_sequencepadding_mask	generatorreturn_dictreturnc                 C   s  |dur|n| j j}| dk}|s|d}|j\}}}	|dur/| dk}
|
s.|d}n||| j jf}tj|||j|j	d}|jd }|dkrV|dkrV|
|dd}n|dkre|dkre|
|dd}||krttd| d| d|dur| dkr|d}|jd }||krtd	| d| d|d
d}|d
d}| |}| jD ]}|||}qtj|| j}| |}t|}|d}d}|durtj|dd}|s||f}|S t||dS )a  
        noise_sequence (`torch.FloatTensor`, *optional*):
            Tensor containing a noise sequence of standard Gaussian noise. Can be batched and of shape `(batch_size,
            sequence_length, config.model_in_channels)`, or un-batched and of shape (sequence_length,
            config.model_in_channels)`. If not supplied, will be randomly generated.
        padding_mask (`torch.BoolTensor`, *optional*):
            Mask indicating which parts of each sequence are padded. Mask values are selected in `[0, 1]`:

            - 1 for tokens that are **not masked**
            - 0 for tokens that are **masked**

            The mask can be batched and of shape `(batch_size, sequence_length)` or un-batched and of shape
            `(sequence_length,)`.
        generator (`torch.Generator`, *optional*):
            A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
            deterministic.
            return_dict:
            Whether to return a [`~utils.ModelOutput`] subclass instead of a plain tuple.

        Example:

         ```python
         >>> from transformers import UnivNetFeatureExtractor, UnivNetModel
         >>> from datasets import load_dataset, Audio

         >>> model = UnivNetModel.from_pretrained("dg845/univnet-dev")
         >>> feature_extractor = UnivNetFeatureExtractor.from_pretrained("dg845/univnet-dev")

         >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
         >>> # Resample the audio to the feature extractor's sampling rate.
         >>> ds = ds.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
         >>> inputs = feature_extractor(
         ...     ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt"
         ... )
         >>> audio = model(**inputs).waveforms
         >>> list(audio.shape)
         [1, 140288]
         ```
        Nr   r   )r   dtypedevicer   z&The batch size of `noise_sequence` is z+ and the batch size of `input_features` is z', but the two are expected to be equal.z$The batch size of `padding_mask` is r   )dim)r   r   )r   use_return_dictr   r   r]   r"   r   randnr   r   repeatrz   r}   r   rY   r   r3   r4   r(   r   ru   squeezesumr   )r.   r   r   r   r   r   spectrogram_batchedspectrogram_batch_sizespectrogram_lengthrG   noise_sequence_batchednoise_sequence_shapenoise_sequence_batch_sizepadding_mask_batch_sizer1   rb   waveformr   outputsr   r   r   r6     sl   0









zUnivNetModel.forwardc                 C   sN   t |tjtjtjfr#|jjjd| jj	d |j
dur%|j
j  dS dS dS )zInitialize the weights.g        )meanstdN)
isinstancer   Linearr+   r   weightdatanormal_r   initializer_ranger   zero_)r.   moduler   r   r   _init_weightsQ  s   
zUnivNetModel._init_weightsc                 C   sL   t jj}tt jjdrt jjj}|| j | jD ]}|  q|| j d S r7   )	r   r9   r8   r:   r;   r   rY   r=   r   rg   r   r   r   r=   X  s   



zUnivNetModel.apply_weight_normc                 C   s4   t j| j | jD ]}|  q
t j| j d S r2   )r   r9   r>   r   rY   r   ri   r   r   r   r>   b  s   

zUnivNetModel.remove_weight_norm)NNNN)r   r   r   r   r   main_input_namer!   r	   r   r   r   	Generatorboolr   tupler   r6   r   r=   r>   r@   r   r   r/   r   r     s2   
 'z
r   )r   dataclassesr   typingr   r   r   r   modeling_outputsr   modeling_utilsr   r9   r	   r
   configuration_univnetr   
get_loggerr   loggerr   Moduler   rA   rl   r   r   __all__r   r   r   r   <module>   s.   
1xP 
?