o
    pi                     @   s  d dl mZmZmZ d dlZd dlZd dlmZ d dl	m  m
Z ddlmZmZ ddlmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZ ddlmZ ddlmZ ddlm Z m!Z! e"e#Z$G dd dej%Z&G dd dej'Z(G dd dej'Z)G dd dej'Z*G dd dej'Z+G dd dej'Z,G dd dej'Z-G dd dej'Z.G d d! d!ej'Z/G d"d# d#eeeZ0dS )$    )OptionalTupleUnionN   )ConfigMixinregister_to_config)FromOriginalModelMixin)logging)apply_forward_hook   )get_activation)CogVideoXDownsample3D)AutoencoderKLOutput)
ModelMixin)CogVideoXUpsample3D   )DecoderOutputDiagonalGaussianDistributionc                       s.   e Zd ZdZdejdejf fddZ  ZS )CogVideoXSafeConv3dzq
    A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
    inputreturnc                    s   t t |j d d }|dkrY| jd t|d d }t j||dd dkr@ d g fddtdt	 D   g } D ]}|
t | qDt j|dd}|S t |S )Nr   i   @r   r   dimc                    sF   g | ]}t j |d   dddd d  df  | fddqS )r   Nr   r   )torchcat).0iinput_chunkskernel_size t/home/ubuntu/SoloSpeech/.venv/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
<listcomp>5   s    8z/CogVideoXSafeConv3d.forward.<locals>.<listcomp>)r   prodtensorshapeitemr   intchunkrangelenappendsuperforwardr   )selfr   memory_countpart_numoutput_chunksinput_chunkoutput	__class__r   r!   r-   +   s   
zCogVideoXSafeConv3d.forward)__name__
__module____qualname____doc__r   Tensorr-   __classcell__r    r    r4   r!   r   &   s    "r   c                       s   e Zd ZdZ			ddededeeeeeef f deded	ef fd
dZde	j
de	j
fddZdd Zde	j
de	j
fddZ  ZS )CogVideoXCausalConv3da&  A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.

    Args:
        in_channels (int): Number of channels in the input tensor.
        out_channels (int): Number of output channels.
        kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolutional kernel.
        stride (int, optional): Stride of the convolution. Default is 1.
        dilation (int, optional): Dilation rate of the convolution. Default is 1.
        pad_mode (str, optional): Padding mode. Default is "constant".
    r   constantin_channelsout_channelsr   stridedilationpad_modec                    s   t    t|tr|fd }|\}}}	|| _||d  d|  }
|d }|	d }|| _|| _|
| _|||||
df| _d| _	|| _
|ddf}|ddf}t|||||d| _d | _d S )Nr   r   r   r   )r>   r?   r   r@   rA   )r,   __init__
isinstancer'   rB   
height_pad	width_padtime_padtime_causal_paddingtemporal_dimtime_kernel_sizer   conv
conv_cache)r.   r>   r?   r   r@   rA   rB   rJ   height_kernel_sizewidth_kernel_sizerG   rE   rF   r4   r    r!   rC   O   s0   
	





zCogVideoXCausalConv3d.__init__inputsr   c                 C   s   | j }| j}|dkr|S |d|}| jd ur*tj| jd||j|gdd}ntj|d d g|d  |g dd}|d| }|S )Nr   r   r   )	rI   rJ   	transposerL   r   r   todevice
contiguous)r.   rO   r   r   r    r    r!   fake_context_parallel_forwardx   s   
&&z3CogVideoXCausalConv3d.fake_context_parallel_forwardc                 C   s   | ` d | _ d S N)rL   )r.   r    r    r!   "_clear_fake_context_parallel_cache   s   
z8CogVideoXCausalConv3d._clear_fake_context_parallel_cachec                 C   s~   |  |}|   |d d d d | j d d f     | _| j| j| j	| j	f}t
j||ddd}| |}|}|S )Nr   r=   r   )modevalue)rT   rV   rJ   rS   detachclonecpurL   rF   rE   FpadrK   )r.   rO   input_parallel
padding_2doutput_parallelr3   r    r    r!   r-      s   
4
zCogVideoXCausalConv3d.forward)r   r   r=   )r6   r7   r8   r9   r'   r   r   strrC   r   r:   rT   rV   r-   r;   r    r    r4   r!   r<   C   s*    )r<   c                       sN   e Zd ZdZ	ddededef fddZdejd	ejd
ejfddZ  Z	S )CogVideoXSpatialNorm3Da  
    Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
    to 3D-video like data.

    CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.

    Args:
        f_channels (`int`):
            The number of channels for input to group normalization layer, and output of the spatial norm layer.
        zq_channels (`int`):
            The number of channels for the quantized vector as described in the paper.
        
f_channelszq_channelsgroupsc                    sF   t    tj||ddd| _t||ddd| _t||ddd| _d S )Nư>T)num_channels
num_groupsepsaffiner   )r   r@   )r,   rC   nn	GroupNorm
norm_layerr<   conv_yconv_b)r.   rd   re   rf   r4   r    r!   rC      s   
zCogVideoXSpatialNorm3D.__init__fzqr   c                 C   s  |j d dkrm|j d d dkrm|d d d d d df |d d d d dd f }}|j dd  |j dd  }}|d d d d d df |d d d d dd f }}tj||d}tj||d}tj||gdd}ntj||j dd  d}| |}	|	| | | | }
|
S )Nr   r   )sizer   )r%   r\   interpolater   r   rn   ro   rp   )r.   rq   rr   f_firstf_restf_first_sizef_rest_sizez_firstz_restnorm_fnew_fr    r    r!   r-      s    66
zCogVideoXSpatialNorm3D.forward)rc   )
r6   r7   r8   r9   r'   rC   r   r:   r-   r;   r    r    r4   r!   rb      s    $rb   c                       s   e Zd ZdZ										dd
edee dededededededee def fddZ			dde
jdee
j dee
j de
jfddZ  ZS )CogVideoXResnetBlock3Da  
    A 3D ResNet block used in the CogVideoX model.

    Args:
        in_channels (int): Number of input channels.
        out_channels (Optional[int], optional):
            Number of output channels. If None, defaults to `in_channels`. Default is None.
        dropout (float, optional): Dropout rate. Default is 0.0.
        temb_channels (int, optional): Number of time embedding channels. Default is 512.
        groups (int, optional): Number of groups for group normalization. Default is 32.
        eps (float, optional): Epsilon value for normalization layers. Default is 1e-6.
        non_linearity (str, optional): Activation function to use. Default is "swish".
        conv_shortcut (bool, optional): If True, use a convolutional shortcut. Default is False.
        spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
        pad_mode (str, optional): Padding mode. Default is "first".
    N           rc   rg   swishFfirstr>   r?   dropouttemb_channelsrf   rj   non_linearityconv_shortcutspatial_norm_dimrB   c                    s  t    |p|}|| _|| _t|| _|| _|	d u r.tj|||d| _	tj|||d| _
nt||	|d| _	t||	|d| _
t||d|
d| _|dkrStj||d| _t|| _t||d|
d| _| j| jkr| jrvt||d|
d| _d S t||dddd| _d S d S )	N)rh   ri   rj   )rd   re   rf   r   )r>   r?   r   rB   r   )in_featuresout_featuresr   )r>   r?   r   r@   padding)r,   rC   r>   r?   r   nonlinearityuse_conv_shortcutrl   rm   norm1norm2rb   r<   conv1Linear	temb_projDropoutr   conv2r   r   )r.   r>   r?   r   r   rf   rj   r   r   r   rB   r4   r    r!   rC      sJ   


zCogVideoXResnetBlock3D.__init__rO   tembrr   r   c                 C   s   |}|d ur|  ||}n|  |}| |}| |}|d ur5|| | |d d d d d d d f  }|d ur@| ||}n| |}| |}| |}| |}| j| jkr_| 	|}|| }|S rU   )
r   r   r   r   r   r   r   r>   r?   r   )r.   rO   r   rr   hidden_statesr    r    r!   r-     s$   


*




zCogVideoXResnetBlock3D.forward)	Nr   r   rc   rg   r   FNr   NN)r6   r7   r8   r9   r'   r   floatra   boolrC   r   r:   r-   r;   r    r    r4   r!   r~      sV    	
>r~   c                       s   e Zd ZdZdZ										
d dedededededededededededef fddZ			d!de
jdee
j dee
j de
jfddZ  ZS )"CogVideoXDownBlock3Da1  
    A downsampling block used in the CogVideoX model.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        temb_channels (int): Number of time embedding channels.
        dropout (float, optional): Dropout rate. Default is 0.0.
        num_layers (int, optional): Number of layers in the block. Default is 1.
        resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
        resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
        resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
        add_downsample (bool, optional): If True, add a downsampling layer at the end of the block. Default is True.
        downsample_padding (int, optional): Padding for the downsampling layer. Default is 0.
        compress_time (bool, optional): If True, apply temporal compression. Default is False.
        pad_mode (str, optional): Padding mode. Default is "first".
    Tr   r   rg   r   rc   r   Fr   r>   r?   r   r   
num_layers
resnet_epsresnet_act_fnresnet_groupsadd_downsampledownsample_paddingcompress_timerB   c                    s   t    g }t|D ]}|dkr|n|}|t||||||||d qt|| _d | _|	r=tt	|||
|dg| _d| _
d S )Nr   )r>   r?   r   r   rf   rj   r   rB   r   r   F)r,   rC   r)   r+   r~   rl   
ModuleListresnetsdownsamplersr   gradient_checkpointing)r.   r>   r?   r   r   r   r   r   r   r   r   r   rB   r   r   
in_channelr4   r    r!   rC   H  s4   

zCogVideoXDownBlock3D.__init__Nr   r   rr   r   c                 C   sh   | j D ]}| jr| jrdd }tjj|||||}q||||}q| jd ur2| jD ]}||}q+|S )Nc                        fdd}|S )Nc                         |  S rU   r    rO   moduler    r!   create_forward     zSCogVideoXDownBlock3D.forward.<locals>.create_custom_forward.<locals>.create_forwardr    r   r   r    r   r!   create_custom_forward     z;CogVideoXDownBlock3D.forward.<locals>.create_custom_forward)r   trainingr   r   utils
checkpointr   )r.   r   r   rr   resnetr   downsamplerr    r    r!   r-   w  s   



zCogVideoXDownBlock3D.forward)	r   r   rg   r   rc   Tr   Fr   r   )r6   r7   r8   r9    _supports_gradient_checkpointingr'   r   ra   r   rC   r   r:   r   r-   r;   r    r    r4   r!   r   3  s`    	
2r   c                       s   e Zd ZdZdZ								dd
edededededededee def fddZ			dde
jdee
j dee
j de
jfddZ  ZS )CogVideoXMidBlock3Da/  
    A middle block used in the CogVideoX model.

    Args:
        in_channels (int): Number of input channels.
        temb_channels (int): Number of time embedding channels.
        dropout (float, optional): Dropout rate. Default is 0.0.
        num_layers (int, optional): Number of layers in the block. Default is 1.
        resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
        resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
        resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
        spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
        pad_mode (str, optional): Padding mode. Default is "first".
    Tr   r   rg   r   rc   Nr   r>   r   r   r   r   r   r   r   rB   c
                    sR   t    g }
t|D ]}|
t|||||||||	d	 qt|
| _d| _d S )N)	r>   r?   r   r   rf   rj   r   r   rB   F)	r,   rC   r)   r+   r~   rl   r   r   r   )r.   r>   r   r   r   r   r   r   r   rB   r   _r4   r    r!   rC     s$   

zCogVideoXMidBlock3D.__init__r   r   rr   r   c                 C   sJ   | j D ]}| jr| jrdd }tjj|||||}q||||}q|S )Nc                    r   )Nc                     r   rU   r    r   r   r    r!   r     r   zRCogVideoXMidBlock3D.forward.<locals>.create_custom_forward.<locals>.create_forwardr    r   r    r   r!   r     r   z:CogVideoXMidBlock3D.forward.<locals>.create_custom_forward)r   r   r   r   r   r   )r.   r   r   rr   r   r   r    r    r!   r-     s   
zCogVideoXMidBlock3D.forward)r   r   rg   r   rc   Nr   r   )r6   r7   r8   r9   r   r'   r   ra   r   rC   r   r:   r-   r;   r    r    r4   r!   r     sP    	
$r   c                       s   e Zd ZdZ											
d!dededededededededededededef fddZ		d"de	j
dee	j
 dee	j
 de	j
fdd Z  ZS )#CogVideoXUpBlock3Da  
    An upsampling block used in the CogVideoX model.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        temb_channels (int): Number of time embedding channels.
        dropout (float, optional): Dropout rate. Default is 0.0.
        num_layers (int, optional): Number of layers in the block. Default is 1.
        resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
        resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
        resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
        spatial_norm_dim (int, optional): Dimension of the spatial normalization. Default is 16.
        add_upsample (bool, optional): If True, add an upsampling layer at the end of the block. Default is True.
        upsample_padding (int, optional): Padding for the upsampling layer. Default is 1.
        compress_time (bool, optional): If True, apply temporal compression. Default is False.
        pad_mode (str, optional): Padding mode. Default is "first".
    r   r   rg   r   rc      TFr   r>   r?   r   r   r   r   r   r   r   add_upsampleupsample_paddingr   rB   c                    s   t    g }t|D ]}|dkr|n|}|t||||||||	|d	 qt|| _d | _|
r>tt	||||dg| _d| _
d S )Nr   )	r>   r?   r   r   rf   rj   r   r   rB   r   F)r,   rC   r)   r+   r~   rl   r   r   
upsamplersr   r   )r.   r>   r?   r   r   r   r   r   r   r   r   r   r   rB   r   r   r   r4   r    r!   rC     s6   

zCogVideoXUpBlock3D.__init__Nr   r   rr   r   c                 C   sh   | j D ]}| jr| jrdd }tjj|||||}q||||}q| jdur2| jD ]}||}q+|S )z1Forward method of the `CogVideoXUpBlock3D` class.c                    r   )Nc                     r   rU   r    r   r   r    r!   r   .  r   zQCogVideoXUpBlock3D.forward.<locals>.create_custom_forward.<locals>.create_forwardr    r   r    r   r!   r   -  r   z9CogVideoXUpBlock3D.forward.<locals>.create_custom_forwardN)r   r   r   r   r   r   r   )r.   r   r   rr   r   r   	upsamplerr    r    r!   r-   #  s   



zCogVideoXUpBlock3D.forward)
r   r   rg   r   rc   r   Tr   Fr   r   )r6   r7   r8   r9   r'   r   ra   r   rC   r   r:   r   r-   r;   r    r    r4   r!   r     sd    	
4r   c                       s   e Zd ZdZdZ										
		d!dededeedf deedf dededededededef fddZ	d"de
jdee
j de
jfdd Z  ZS )#CogVideoXEncoder3Da  
    The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.

    Args:
        in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            The number of output channels.
        down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
            The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
            options.
        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
            The number of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2):
            The number of layers per block.
        norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups for normalization.
        act_fn (`str`, *optional*, defaults to `"silu"`):
            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
        double_z (`bool`, *optional*, defaults to `True`):
            Whether to double the number of output channels for the last block.
    Tr   r   r   r   r   r         r   r   silurg   rc   r   r      r>   r?   down_block_types.block_out_channelslayers_per_blockact_fnnorm_epsnorm_num_groupsr   rB   temporal_compression_ratioc                    s  t    tt|}t||d d|
d| _tg | _	|d }t
|D ]4\}}|}|| }|t|d k}||k }|dkrOt||d|	||||| |d
}ntd| j	| q%t|d d|	d	||||
d
| _tj||d dd| _t | _t|d d	| d|
d| _d| _d S )Nr   r   r   rB   r   r   )
r>   r?   r   r   r   r   r   r   r   r   zEInvalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`r   )r>   r   r   r   r   r   r   rB   rg   )rj   F)r,   rC   r'   nplog2r<   conv_inrl   r   down_blocks	enumerater*   r   
ValueErrorr+   r   	mid_blockrm   norm_outSiLUconv_actconv_outr   )r.   r>   r?   r   r   r   r   r   r   r   rB   r   temporal_compress_leveloutput_channelr   down_block_typeinput_channelis_final_blockr   
down_blockr4   r    r!   rC   Z  sR   


zCogVideoXEncoder3D.__init__Nsampler   r   c                 C   s   |  |}| jr/| jr/dd }| jD ]}tjj||||d}qtjj|| j||d}n| jD ]}|||d}q2| ||d}| |}| 	|}| 
|}|S )z5The forward method of the `CogVideoXEncoder3D` class.c                    r   )Nc                     r   rU   r    r   r   r    r!   custom_forward  r   zQCogVideoXEncoder3D.forward.<locals>.create_custom_forward.<locals>.custom_forwardr    r   r   r    r   r!   r     r   z9CogVideoXEncoder3D.forward.<locals>.create_custom_forwardN)r   r   r   r   r   r   r   r   r   r   r   )r.   r   r   r   r   r   r    r    r!   r-     s"   





zCogVideoXEncoder3D.forward)r   r   r   r   r   r   rg   rc   r   r   r   rU   r6   r7   r8   r9   r   r'   r   ra   r   rC   r   r:   r   r-   r;   r    r    r4   r!   r   @  sL    


*Ir   c                       s   e Zd ZdZdZ										
		d!dededeedf deedf dededededededef fddZ	d"de
jdee
j de
jfdd Z  ZS )#CogVideoXDecoder3Da  
    The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
    sample.

    Args:
        in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            The number of output channels.
        up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
            The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
            The number of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2):
            The number of layers per block.
        norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups for normalization.
        act_fn (`str`, *optional*, defaults to `"silu"`):
            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
        norm_type (`str`, *optional*, defaults to `"group"`):
            The normalization type to use. Can be either `"group"` or `"spatial"`.
    Tr   r   r   r   r   r   r   r   rg   rc   r   r   r   r>   r?   up_block_types.r   r   r   r   r   r   rB   r   c                    s(  t    tt|}t||d d|
d| _t|d dd|||||
d| _t	g | _
|d }tt|}t|D ]:\}}|}|| }|t|d k}||k }|dkrjt||d|	|d ||||| ||
d}|}ntd	| j
| q:t|d
 ||d| _t | _t|d
 |d|
d| _d| _d S )Nr   r   r   r   )r>   r   r   r   r   r   r   rB   r   r   )r>   r?   r   r   r   r   r   r   r   r   r   rB   zAInvalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`r   )rf   F)r,   rC   listreversedr<   r   r   r   rl   r   	up_blocksr'   r   r   r   r*   r   r   r+   rb   r   r   r   r   r   )r.   r>   r?   r   r   r   r   r   r   r   rB   r   reversed_block_out_channelsr   r   r   up_block_typeprev_output_channelr   r   up_blockr4   r    r!   rC     s^   


zCogVideoXDecoder3D.__init__Nr   r   r   c                 C   s   |  |}| jr/| jr/dd }tjj|| j|||}| jD ]}tjj|||||}qn| |||}| jD ]}||||}q9| ||}| 	|}| 
|}|S )z5The forward method of the `CogVideoXDecoder3D` class.c                    r   )Nc                     r   rU   r    r   r   r    r!   r   9  r   zQCogVideoXDecoder3D.forward.<locals>.create_custom_forward.<locals>.custom_forwardr    r   r    r   r!   r   8  r   z9CogVideoXDecoder3D.forward.<locals>.create_custom_forward)r   r   r   r   r   r   r   r   r   r   r   )r.   r   r   r   r   r   r    r    r!   r-   2  s$   




zCogVideoXDecoder3D.forward)r   r   r   r   r   r   rg   rc   r   r   r   rU   r   r    r    r4   r!   r     sL    


*Pr   c                (       sd  e Zd ZdZdZdgZe										
										d7dededee	 dee	 dee dedede	de
dede
dede
dee
 deee
  d eee
  d!e
d"ed#ef& fd$d%Zd8d&d'Zd(d) Ze	d9d*ejd+ed,eeee f fd-d.Zed9d/ejd+ed,eeejf fd0d1Z			d:d2ejd3ed+ed4eej d,eejejf f
d5d6Z  ZS );AutoencoderKLCogVideoXa  
    A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
    [CogVideoX](https://github.com/THUDM/CogVideo).

    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
    for all models (such as downloading or saving).

    Parameters:
        in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
        out_channels (int,  *optional*, defaults to 3): Number of channels in the output.
        down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
            Tuple of downsample block types.
        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
            Tuple of upsample block types.
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
            Tuple of block output channels.
        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
        sample_size (`int`, *optional*, defaults to `32`): Sample input size.
        scaling_factor (`float`, *optional*, defaults to 0.18215):
            The component-wise standard deviation of the trained latent space computed using the first batch of the
            training set. This is used to scale the latent space to have unit variance when training the diffusion
            model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
            diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
            / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
            Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
        force_upcast (`bool`, *optional*, default to `True`):
            If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
            can be fine-tuned / trained to a lower range without loosing too much precision in which case
            `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
    Tr~   r   r   r   r   r   r   rg   rc   r   r   Yi1p?NFr>   r?   r   r   r   latent_channelsr   r   r   r   r   sample_sizescaling_factorshift_factorlatents_meanlatents_stdforce_upcastuse_quant_convuse_post_quant_convc                    s   t    t|||||||	|
|d	| _t|||||||	|
|d	| _|r-td| d| dnd | _|r8t||dnd | _d| _	d| _
| jj| _t| jjttfrU| jjd n| jj}t|dt| jjd   | _d| _d S )N)	r>   r?   r   r   r   r   r   r   r   )	r>   r?   r   r   r   r   r   r   r   r   r   Fr   g      ?)r,   rC   r   encoderr   decoderr   
quant_convpost_quant_convuse_slicing
use_tilingconfigr   tile_sample_min_sizerD   r   tupler'   r*   r   tile_latent_min_sizetile_overlap_factor)r.   r>   r?   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r4   r    r!   rC   z  sD   
!

zAutoencoderKLCogVideoX.__init__c                 C   s   t |ttfr||_d S d S rU   )rD   r   r   r   )r.   r   rX   r    r    r!   _set_gradient_checkpointing  s   
z2AutoencoderKLCogVideoX._set_gradient_checkpointingc                 C   s8   |   D ]\}}t|trtd|  |  qd S )Nz0Clearing fake Context Parallel cache for layer: )named_modulesrD   r<   loggerdebugrV   )r.   namer   r    r    r!   !clear_fake_context_parallel_cache  s   
z8AutoencoderKLCogVideoX.clear_fake_context_parallel_cachexreturn_dictr   c                 C   s:   |  |}| jdur| |}t|}|s|fS t|dS )a  
        Encode a batch of images into latents.

        Args:
            x (`torch.Tensor`): Input batch of images.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

        Returns:
                The latent representations of the encoded images. If `return_dict` is True, a
                [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
        N)latent_dist)r   r   r   r   )r.   r
  r  h	posteriorr    r    r!   encode  s   



zAutoencoderKLCogVideoX.encodezc                 C   s2   | j dur
|  |}| |}|s|fS t|dS )a  
        Decode a batch of images.

        Args:
            z (`torch.Tensor`): Input batch of latent vectors.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

        Returns:
            [`~models.vae.DecoderOutput`] or `tuple`:
                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
                returned.

        N)r   )r   r   r   )r.   r  r  decr    r    r!   decode  s   



zAutoencoderKLCogVideoX.decoder   sample_posterior	generatorc           	      C   sB   |}|  |j}|r|j|d}n| }| |}|s|fS |S )N)r  )r  r  r   rW   r  )	r.   r   r  r  r  r
  r  r  r  r    r    r!   r-     s   
zAutoencoderKLCogVideoX.forward)r   r   r   r   r   r   r   r   rg   rc   r   r   r   NNNTFF)F)T)FTN)r6   r7   r8   r9   r   _no_split_modulesr   r'   r   ra   r   r   r   rC   r  r	  r
   r   r:   r   r   r   r  FloatTensorr   r  	Generatorr-   r;   r    r    r4   r!   r   W  s    



G&r   )1typingr   r   r   numpyr   r   torch.nnrl   torch.nn.functional
functionalr\   configuration_utilsr   r   loaders.single_file_modelr   r   r	   utils.accelerate_utilsr
   activationsr   downsamplingr   modeling_outputsr   modeling_utilsr   
upsamplingr   vaer   r   
get_loggerr6   r  Conv3dr   Moduler<   rb   r~   r   r   r   r   r   r   r    r    r    r!   <module>   s8   
W)p`Kb 	 