o
    Gi                     @   s  d dl Zd dlZd dlmZ d dlm  mZ ddlm	Z	m
Z
 ddlmZ ddlmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZ eeZ	d)dededejdej dedej!fddZ"G dd dej#Z$G dd dej#Z%G dd dej#Z&G dd dej#Z'G dd dej#Z(G dd  d ej#Z)G d!d" d"ej#Z*G d#d$ d$ej#Z+G d%d& d&ej#Z,G d'd( d(eee	Z-dS )*    N   )ConfigMixinregister_to_config)logging)apply_forward_hook   )get_activation)	Attention)AutoencoderKLOutput)
ModelMixin   )AutoencoderMixinDecoderOutputDiagonalGaussianDistribution
num_framesheight_widthdtypedevice
batch_sizereturnc           
      C   sv   t jd| d t j|d}||}t j||dd\}}t ||kdtd j|d}	|d ur9|	d	|dd}	|	S )	Nr   )r   r   xy)indexingr   inf)r   )
torcharangeint32repeat_interleavemeshgridwherefloatto	unsqueezeexpand)
r   r   r   r   r   indicesindices_blocksxymask r)   n/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.pyprepare_causal_attention_mask!   s   
 r+   c                       s   e Zd Z						ddededeeeeef B d	eeeeef B d
eeeeef B deeeeef B dededdf fddZdej	dej	fddZ
  ZS )HunyuanVideoCausalConv3dr   r   r   T	replicatein_channelsout_channelskernel_sizestridepaddingdilationbiaspad_moder   Nc	           	   	      s   t    t|tr|||fn|}|| _|d d |d d |d d |d d |d d df| _tj|||||||d| _d S )Nr   r   r   r4   )	super__init__
isinstanceintr5   time_causal_paddingnnConv3dconv)	selfr.   r/   r0   r1   r2   r3   r4   r5   	__class__r)   r*   r8   /   s   





	z!HunyuanVideoCausalConv3d.__init__hidden_statesc                 C   s   t j|| j| jd}| |S )N)mode)Fpadr;   r5   r>   r?   rB   r)   r)   r*   forwardJ   s   
z HunyuanVideoCausalConv3d.forward)r   r   r   r   Tr-   )__name__
__module____qualname__r:   tupleboolstrr8   r   TensorrG   __classcell__r)   r)   r@   r*   r,   .   s6    	
r,   c                       sj   e Zd Z					ddededB ded	ed
edeeeef ddf fddZdej	dej	fddZ
  ZS )HunyuanVideoUpsampleCausal3DNr   r   Tr   r   r   r.   r/   r0   r1   r4   upsample_factorr   c                    s0   t    |p|}|| _t|||||d| _d S Nr6   )r7   r8   rR   r,   r>   )r?   r.   r/   r0   r1   r4   rR   r@   r)   r*   r8   P   s   
	z%HunyuanVideoUpsampleCausal3D.__init__rB   c                 C   s   | d}|jd|d fdd\}}tj|d| jdd  ddd}|dkr@| }tj|| jdd}tj	||fdd}n|}| 
|}|S )Nr   r   dimnearest)scale_factorrC   )sizesplitrD   interpolatesqueezerR   r"   
contiguousr   catr>   )r?   rB   r   first_frameother_framesr)   r)   r*   rG   `   s   

z$HunyuanVideoUpsampleCausal3D.forward)Nr   r   TrQ   )rH   rI   rJ   r:   rL   rK   r    r8   r   rN   rG   rO   r)   r)   r@   r*   rP   O   s,    rP   c                       s\   e Zd Z					ddededB ded	ed
eddf fddZdejdejfddZ  Z	S )HunyuanVideoDownsampleCausal3DNr   r   Tr   channelsr/   r2   r0   r4   r   c                    s,   t    |p|}t||||||d| _d S rS   )r7   r8   r,   r>   )r?   ra   r/   r2   r0   r4   r1   r@   r)   r*   r8   z   s   
	z'HunyuanVideoDownsampleCausal3D.__init__rB   c                 C   s   |  |}|S N)r>   rF   r)   r)   r*   rG      s   
z&HunyuanVideoDownsampleCausal3D.forward)Nr   r   Tr   )
rH   rI   rJ   r:   rL   r8   r   rN   rG   rO   r)   r)   r@   r*   r`   y   s(    r`   c                       s`   e Zd Z					ddededB ded	ed
ededdf fddZdejdejfddZ	  Z
S )HunyuanVideoResnetBlockCausal3DN            ư>swishr.   r/   dropoutgroupsepsnon_linearityr   c                    s   t    |p|}t|| _tj|||dd| _t||ddd| _tj|||dd| _	t
|| _t||ddd| _d | _||krLt||ddd| _d S d S )NT)rj   affiner   r   r   )r7   r8   r   nonlinearityr<   	GroupNormnorm1r,   conv1norm2Dropoutrh   conv2conv_shortcut)r?   r.   r/   rh   ri   rj   rk   r@   r)   r*   r8      s   
	
z(HunyuanVideoResnetBlockCausal3D.__init__rB   c                 C   sr   |  }|}| |}| |}| |}| |}| |}| |}| |}| jd ur3| |}|| }|S rb   )r\   ro   rm   rp   rq   rh   rs   rt   )r?   rB   residualr)   r)   r*   rG      s   








z'HunyuanVideoResnetBlockCausal3D.forward)Nrd   re   rf   rg   )rH   rI   rJ   r:   r    rM   r8   r   rN   rG   rO   r)   r)   r@   r*   rc      s,    rc   c                       sh   e Zd Z							ddeded	ed
edededededdf fddZdej	dej	fddZ
  ZS )HunyuanVideoMidBlock3Drd   r   rf   rg   re   Tr.   rh   
num_layers
resnet_epsresnet_act_fnresnet_groupsadd_attentionattention_head_dimr   Nc	                    s   t    |d ur|nt|d d}|| _t||||||dg}	g }
t|D ]*}| jr>|
t||| |||ddddd	 n|
d  |	t||||||d q&t	|
| _
t	|	| _d| _d S )N   re   r.   r/   rj   ri   rh   rk   T)headsdim_headrj   norm_num_groupsresidual_connectionr4   upcast_softmax_from_deprecated_attn_blockF)r7   r8   minr{   rc   rangeappendr	   r<   
ModuleList
attentionsresnetsgradient_checkpointing)r?   r.   rh   rw   rx   ry   rz   r{   r|   r   r   _r@   r)   r*   r8      sT   



zHunyuanVideoMidBlock3D.__init__rB   c           
      C   sz  t  rd| jrd| | jd |}t| j| jdd  D ]F\}}|d ur[|j\}}}}}|ddddd	dd}t
||| |j|j|d}	|||	d}|d|||fddddd}| ||}q|S | jd |}t| j| jdd  D ]D\}}|d ur|j\}}}}}|ddddd	dd}t
||| |j|j|d}	|||	d}|d|||fddddd}||}qv|S )Nr   r   r   r   r}   )r   )attention_mask)r   is_grad_enabledr   _gradient_checkpointing_funcr   zipr   shapepermuteflattenr+   r   r   	unflatten)
r?   rB   attnresnetr   num_channelsr   heightwidthr   r)   r)   r*   rG      s2     
zHunyuanVideoMidBlock3D.forward)rd   r   rf   rg   re   Tr   rH   rI   rJ   r:   r    rM   rL   r8   r   rN   rG   rO   r)   r)   r@   r*   rv      s8    	
>rv   c                       sr   e Zd Z								dded	ed
ededededededededdf fddZdej	dej	fddZ
  ZS )HunyuanVideoDownBlock3Drd   r   rf   rg   re   Tr   r.   r/   rh   rw   rx   ry   rz   add_downsampledownsample_stridedownsample_paddingr   Nc                    s   t    g }t|D ]}|dkr|n|}|t||||||d qt|| _|r9tt|||
|	dg| _	nd | _	d| _
d S )Nr   r~   )r/   r2   r1   F)r7   r8   r   r   rc   r<   r   r   r`   downsamplersr   )r?   r.   r/   rh   rw   rx   ry   rz   r   r   r   r   ir@   r)   r*   r8     s6   

z HunyuanVideoDownBlock3D.__init__rB   c                 C   ^   t  r| jr| jD ]}| ||}q
n
| jD ]}||}q| jd ur-| jD ]}||}q&|S rb   )r   r   r   r   r   r   )r?   rB   r   downsamplerr)   r)   r*   rG   L  s   





zHunyuanVideoDownBlock3D.forward)rd   r   rf   rg   re   Tr   r   r   r)   r)   r@   r*   r     sB    	
/r   c                       sv   e Zd Z							dded	ed
ededededededeeeef ddf fddZde	j
de	j
fddZ  ZS )HunyuanVideoUpBlock3Drd   r   rf   rg   re   TrQ   r.   r/   rh   rw   rx   ry   rz   add_upsampleupsample_scale_factorr   Nc
                    s   t    g }
t|D ]}|dkr|n|}|
t||||||d qt|
| _|r8tt|||	dg| _	nd | _	d| _
d S )Nr   r~   )r/   rR   F)r7   r8   r   r   rc   r<   r   r   rP   
upsamplersr   )r?   r.   r/   rh   rw   rx   ry   rz   r   r   r   r   input_channelsr@   r)   r*   r8   \  s4   


zHunyuanVideoUpBlock3D.__init__rB   c                 C   r   rb   )r   r   r   r   r   r   )r?   rB   r   	upsamplerr)   r)   r*   rG     s   





zHunyuanVideoUpBlock3D.forward)rd   r   rf   rg   re   TrQ   )rH   rI   rJ   r:   r    rM   rL   rK   r8   r   rN   rG   rO   r)   r)   r@   r*   r   [  s<    	
.r   c                       s   e Zd ZdZ												
ddededeedf deedf dededededededdf fddZde	j
de	j
fddZ  ZS )HunyuanVideoEncoder3Dzx
    Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
    r   r   r   r   r            r   r   re   siluTr}      r.   r/   down_block_types.block_out_channelslayers_per_blockr   act_fndouble_ztemporal_compression_ratiospatial_compression_ratior   Nc                    s  t    t||d ddd| _d | _tg | _|d }t|D ]\}}|dkr0t	d| |}|| }|t
|d k}tt|}tt|
}|
dkrft||k }t|t
|d | koc| }n|
dkrwt||k }t||k }nt	d	|
 |rd
nd}|rdnd}t|| }t|||t|p|d|||dd	}| j| q!t|d d||d ||	d| _tj|d |dd| _t | _|rd| n|}t|d |dd| _d| _d S )Nr   r   r   r0   r1   r   zUnsupported down_block_type: r}   r   $Unsupported time_compression_ratio: r   r   r   r   r   r   rf   )	rw   r.   r/   r   rx   ry   rz   r   r   r   r.   rx   ry   r|   rz   r{   r   
num_groupsrj   r   r0   F)r7   r8   r,   conv_in	mid_blockr<   r   down_blocks	enumerate
ValueErrorlenr:   nplog2rL   rK   r   r   rv   rn   conv_norm_outSiLUconv_actconv_outr   )r?   r.   r/   r   r   r   r   r   r   mid_block_add_attentionr   r   output_channelr   down_block_typeinput_channelis_final_blocknum_spatial_downsample_layersnum_time_downsample_layersadd_spatial_downsampleadd_time_downsampledownsample_stride_HWdownsample_stride_Tr   
down_blockconv_out_channelsr@   r)   r*   r8     sd   

	

zHunyuanVideoEncoder3D.__init__rB   c                 C   s   |  |}t r | jr | jD ]}| ||}q| | j|}n| jD ]}||}q#| |}| |}| |}| 	|}|S rb   )
r   r   r   r   r   r   r   r   r   r   )r?   rB   r   r)   r)   r*   rG     s   







zHunyuanVideoEncoder3D.forward)r   r   r   r   r   re   r   TTr}   r   )rH   rI   rJ   __doc__r:   rK   rM   rL   r8   r   rN   rG   rO   r)   r)   r@   r*   r     sJ    


Rr   c                       s   e Zd ZdZ											
ddededeedf deedf dededededef fddZdej	dej	fddZ
  ZS )HunyuanVideoDecoder3Dzx
    Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
    r   r   r   r   r   r   r   re   r   Tr}   r   r.   r/   up_block_types.r   r   r   r   time_compression_ratior   c                    s  t    || _t||d ddd| _tg | _t|d d||d ||d| _	t
t|}|d }t|D ]y\}}|dkrEtd	| |}|| }|t|d k}tt|
}tt|	}|	d
kr{t||k }t|t|d | kox| }ntd|	 |rdnd}|rdnd}t|| }t| jd ||t|p||d||d}| j| |}q6tj|d |dd| _t | _t|d |dd| _d| _d S )Nr   r   r   r   rf   r   r   r   zUnsupported up_block_type: r}   r   r   r   r   r   )rw   r.   r/   r   r   rx   ry   rz   r   r   F)r7   r8   r   r,   r   r<   r   	up_blocksrv   r   listreversedr   r   r   r:   r   r   rL   rK   r   r   rn   r   r   r   r   r   )r?   r.   r/   r   r   r   r   r   r   r   r   reversed_block_out_channelsr   r   up_block_typeprev_output_channelr   num_spatial_upsample_layersnum_time_upsample_layersadd_spatial_upsampleadd_time_upsampleupsample_scale_factor_HWupsample_scale_factor_Tr   up_blockr@   r)   r*   r8     s^   




zHunyuanVideoDecoder3D.__init__rB   r   c                 C   s   |  |}t r | jr | | j|}| jD ]}| ||}qn| |}| jD ]}||}q(| |}| |}| 	|}|S rb   )
r   r   r   r   r   r   r   r   r   r   )r?   rB   r   r)   r)   r*   rG   [  s   







zHunyuanVideoDecoder3D.forward)
r   r   r   r   r   re   r   Tr}   r   )rH   rI   rJ   r   r:   rK   rM   r8   r   rN   rG   rO   r)   r)   r@   r*   r     s@    


Pr   c                       s\  e Zd ZdZdZe										
				dJdedededeedf deedf dee dededede	dedede
ddf fddZ						dKd edB d!edB d"edB d#e	dB d$e	dB d%e	dB ddfd&d'Zd(ejdejfd)d*Ze	dLd(ejd+e
deee B fd,d-ZdLd.ejd+e
deejB fd/d0ZedLd.ejd+e
deejB fd1d2Zd3ejd4ejd5edejfd6d7Zd3ejd4ejd5edejfd8d9Zd3ejd4ejd5edejfd:d;Zd(ejdefd<d=ZdLd.ejd+e
deejB fd>d?Zd(ejdefd@dAZdLd.ejd+e
deejB fdBdCZ	D		dMdEejdFe
d+e
dGejdB deejB f
dHdIZ  Z S )NAutoencoderKLHunyuanVideoaj  
    A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
    Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).

    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
    for all models (such as downloading or saving).
    Tr      r   r   r   r   r   re   >I?r   r}   r.   r/   latent_channelsr   .r   r   r   r   r   scaling_factorr   r   r   r   Nc                    s   t    || _t||||||	|d|||d| _t||||||	||||d
| _tjd| d| dd| _	tj||dd| _
|| _|| _d| _d| _d| _d| _d| _d| _d	| _d
| _d
| _d| _d S )NT)r.   r/   r   r   r   r   r   r   r   r   r   )
r.   r/   r   r   r   r   r   r   r   r   r   r   r   Fr   r         )r7   r8   r   r   encoderr   decoderr<   r=   
quant_convpost_quant_convr   r   use_slicing
use_tilinguse_framewise_encodinguse_framewise_decodingtile_sample_min_heighttile_sample_min_widthtile_sample_min_num_framestile_sample_stride_heighttile_sample_stride_widthtile_sample_stride_num_frames)r?   r.   r/   r   r   r   r   r   r   r   r   r   r   r   r@   r)   r*   r8   |  sR   

z"AutoencoderKLHunyuanVideo.__init__r   r   r   r   r   r  c                 C   sR   d| _ |p| j| _|p| j| _|p| j| _|p| j| _|p| j| _|p%| j| _dS )a  
        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
        processing larger images.

        Args:
            tile_sample_min_height (`int`, *optional*):
                The minimum height required for a sample to be separated into tiles across the height dimension.
            tile_sample_min_width (`int`, *optional*):
                The minimum width required for a sample to be separated into tiles across the width dimension.
            tile_sample_min_num_frames (`int`, *optional*):
                The minimum number of frames required for a sample to be separated into tiles across the frame
                dimension.
            tile_sample_stride_height (`int`, *optional*):
                The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
                no tiling artifacts produced across the height dimension.
            tile_sample_stride_width (`int`, *optional*):
                The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
                artifacts produced across the width dimension.
            tile_sample_stride_num_frames (`int`, *optional*):
                The stride between two consecutive frame tiles. This is to ensure that there are no tiling artifacts
                produced across the frame dimension.
        TN)r   r   r   r   r   r   r  )r?   r   r   r   r   r   r  r)   r)   r*   enable_tiling  s    z'AutoencoderKLHunyuanVideo.enable_tilingr&   c                 C   sf   |j \}}}}}| jr|| jkr| |S | jr'|| jks"|| jkr'| |S | |}| 	|}|S rb   )
r   r   r   _temporal_tiled_encoder   r   r   tiled_encoder   r   )r?   r&   r   r   r   r   r   encr)   r)   r*   _encode  s   



z!AutoencoderKLHunyuanVideo._encodereturn_dictc                    s^    j r|jd dkr fdd|dD }t|}n |}t|}|s*|fS t|dS )a  
        Encode a batch of images into latents.

        Args:
            x (`torch.Tensor`): Input batch of images.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

        Returns:
                The latent representations of the encoded videos. If `return_dict` is True, a
                [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
        r   r   c                    s   g | ]}  |qS r)   )r  ).0x_slicer?   r)   r*   
<listcomp>  s    z4AutoencoderKLHunyuanVideo.encode.<locals>.<listcomp>)latent_dist)r   r   rY   r   r]   r  r   r
   )r?   r&   r  encoded_slicesh	posteriorr)   r
  r*   encode	  s   

z AutoencoderKLHunyuanVideo.encodezc                 C   s   |j \}}}}}| j| j }| j| j }	| j| j }
| jr(||
kr(| j||dS | jr:||	ks3||kr:| j	||dS | 
|}| |}|sI|fS t|dS )Nr  sample)r   r   r   r   r   r   r   _temporal_tiled_decoder   tiled_decoder   r   r   )r?   r  r  r   r   r   r   r   tile_latent_min_heighttile_latent_min_widthtile_latent_min_num_framesdecr)   r)   r*   _decode%  s   


z!AutoencoderKLHunyuanVideo._decodec                    sX    j r|jd dkr fdd|dD }t|}n |j}|s'|fS t|dS )a  
        Decode a batch of images.

        Args:
            z (`torch.Tensor`): Input batch of latent vectors.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

        Returns:
            [`~models.vae.DecoderOutput`] or `tuple`:
                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
                returned.
        r   r   c                    s   g | ]}  |jqS r)   )r  r  )r  z_slicer
  r)   r*   r  I  s    z4AutoencoderKLHunyuanVideo.decode.<locals>.<listcomp>r  )r   r   rY   r   r]   r  r  r   )r?   r  r  decoded_slicesdecodedr)   r
  r*   decode9  s   
z AutoencoderKLHunyuanVideo.decodeabblend_extentc              	   C   s   t |jd |jd |}t|D ]@}|d d d d d d | | d d f d||   |d d d d d d |d d f ||   |d d d d d d |d d f< q|S )Nr   r   r   r   )r?   r   r!  r"  r'   r)   r)   r*   blend_vS     R&z!AutoencoderKLHunyuanVideo.blend_vc                 C   s   t |jd |jd |}t|D ]@}|d d d d d d d d | | f d||   |d d d d d d d d |f ||   |d d d d d d d d |f< q|S )Nr   r   r$  r?   r   r!  r"  r&   r)   r)   r*   blend_h[  r&  z!AutoencoderKLHunyuanVideo.blend_hc              	   C   s   t |jd |jd |}t|D ]@}|d d d d | | d d d d f d||   |d d d d |d d d d f ||   |d d d d |d d d d f< q|S )Nr   r$  r'  r)   r)   r*   blend_tc  r&  z!AutoencoderKLHunyuanVideo.blend_tc                 C   s  |j \}}}}}|| j }|| j }| j| j }	| j| j }
| j| j }| j| j }|	| }|
| }g }td|| jD ];}g }td|| jD ]*}|dddddd||| j ||| j f }| |}| |}|	| qF|	| q;g }t
|D ]O\}}g }t
|D ]:\}}|dkr| ||d  | ||}|dkr| ||d  ||}|	|ddddddd|d|f  q|	tj|dd q}tj|ddddddddd|d|f }|S )zEncode a batch of images using a tiled encoder.

        Args:
            x (`torch.Tensor`): Input batch of videos.

        Returns:
            `torch.Tensor`:
                The latent representation of the encoded videos.
        r   Nr   r}   rT   r   )r   r   r   r   r   r   r   r   r   r   r   r%  r(  r   r]   )r?   r&   r   r   r   r   r   latent_heightlatent_widthr  r  tile_latent_stride_heighttile_latent_stride_widthblend_heightblend_widthrowsr   rowjtileresult_rows
result_rowr  r)   r)   r*   r  k  s<   


2

.0z&AutoencoderKLHunyuanVideo.tiled_encodec                 C   s  |j \}}}}}|| j }|| j }	| j| j }
| j| j }| j| j }| j| j }| j| j }| j| j }g }td||D ]8}g }td||D ](}|dddddd|||
 ||| f }| |}| |}|	| qH|	| q>g }t
|D ]Q\}}g }t
|D ]<\}}|dkr| ||d  | ||}|dkr| ||d  ||}|	|ddddddd| jd| jf  q|	tj|dd q}tj|ddddddddd|d|	f }|s|fS t|dS )a  
        Decode a batch of images using a tiled decoder.

        Args:
            z (`torch.Tensor`): Input batch of latent vectors.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

        Returns:
            [`~models.vae.DecoderOutput`] or `tuple`:
                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
                returned.
        r   Nr   r   rT   r   r  )r   r   r   r   r   r   r   r   r   r   r   r%  r(  r   r]   r   )r?   r  r  r   r   r   r   r   sample_heightsample_widthr  r  r-  r.  r/  r0  r1  r   r2  r3  r4  r  r5  r6  r  r)   r)   r*   r    s@   

.

20
z&AutoencoderKLHunyuanVideo.tiled_decodec              
   C   s  |j \}}}}}|d | j d }| j| j }| j| j }	||	 }
g }td|| jD ]S}|d d d d ||| j d d d d d f }| jrW|| jksQ|| jkrW| |}n
| 	|}| 
|}|dkrx|d d d d dd d d d d f }|| q*g }t|D ]B\}}|dkr| ||d  ||
}||d d d d d |	d d d d f  q||d d d d d |	d d d d d f  qtj|ddd d d d d |f }|S )Nr   r   r   rT   )r   r   r   r  r   r   r   r   r  r   r   r   r   r*  r   r]   )r?   r&   r   r   r   r   r   latent_num_framesr  tile_latent_stride_num_framesblend_num_framesr2  r   r4  r6  r  r)   r)   r*   r    s.   0

&.2$z0AutoencoderKLHunyuanVideo._temporal_tiled_encodec              
   C   s  |j \}}}}}|d | j d }| j| j }	| j| j }
| j| j }| j| j }| j| j }g }td||D ]Y}|d d d d ||| d d d d d f }| jrj|j d |
ksa|j d |	krj| j	|ddj
}n
| |}| |}|dkr|d d d d dd d d d d f }|| q7g }t|D ]D\}}|dkr| ||d  ||}||d d d d d | jd d d d f  q||d d d d d | jd d d d d f  qtj|ddd d d d d |f }|s|fS t|d	S )
Nr   r   r   r#  Tr  r   rT   r  )r   r   r   r   r   r   r  r   r   r  r  r   r   r   r   r*  r   r]   r   )r?   r  r  r   r   r   r   r   num_sample_framesr  r  r  r:  r;  r2  r   r4  r  r6  r  r)   r)   r*   r    s6   ."

&04$
z0AutoencoderKLHunyuanVideo._temporal_tiled_decodeFr  sample_posterior	generatorc           	      C   s<   |}|  |j}|r|j|d}n| }| j||d}|S )aa  
        Args:
            sample (`torch.Tensor`): Input sample.
            sample_posterior (`bool`, *optional*, defaults to `False`):
                Whether to sample from the posterior.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
        )r>  r  )r  r  r  rC   r  )	r?   r  r=  r  r>  r&   r  r  r  r)   r)   r*   rG     s   z!AutoencoderKLHunyuanVideo.forward)r   r   r   r   r   r   r   r   re   r   r   r}   T)NNNNNN)T)FTN)!rH   rI   rJ   r    _supports_gradient_checkpointingr   r:   rK   rM   r    rL   r8   r  r   rN   r  r   r
   r   r  r   r  r  r%  r(  r*  r  r  r  r  	GeneratorrG   rO   r)   r)   r@   r*   r   q  s    

Y
(
 "    2: 'r   rb   ).numpyr   r   torch.nnr<   torch.nn.functional
functionalrD   configuration_utilsr   r   utilsr   utils.accelerate_utilsr   activationsr   attention_processorr	   modeling_outputsr
   modeling_utilsr   vaer   r   r   
get_loggerrH   loggerr:   r   r   rN   r+   Moduler,   rP   r`   rc   rv   r   r   r   r   r   r)   r)   r)   r*   <module>   sJ   

!*.a??lk