o
    Gi                     @   sR  d dl Z d dlZd dlmZ d dlm  mZ ddlmZm	Z	 ddl
mZ ddlmZ ddlmZ ddlmZ dd	lmZ d
dlmZmZmZ eeZG dd dejZG dd dejZG dd dejZG dd dejZ G dd dejZ!G dd dejZ"G dd dejZ#G dd dejZ$G dd dejZ%G dd deeeZ&dS )     N   )ConfigMixinregister_to_config)logging)apply_forward_hook   )get_activation)AutoencoderKLOutput)
ModelMixin   )AutoencoderMixinDecoderOutputDiagonalGaussianDistributionc                       s   e Zd Z							ddededeeedf B d	eeedf B d
eeedf B deeedf B dededef fddZdd Zde	j
de	j
f fddZ  ZS )EasyAnimateCausalConv3dr   r   Tzerosin_channelsout_channelskernel_size.stridepaddingdilationgroupsbiaspadding_modec
                    s~  t |tr|n|fd }t|dksJ d| dt |tr!|n|fd }t|dks4J d| dt |tr;|n|fd }t|dksNJ d| d|\}
}}|\| _}}|\}}}|
d | }|d u rt|d | d|  d }t|d | d|  d }nt |tr| }}ntsJ || _t|
d | d|  d | _	d | _
t j|||||d||f|||	d		 d S )
Nr   z#Kernel size must be a 3-tuple, got z	 instead.zStride must be a 3-tuple, got z Dilation must be a 3-tuple, got r   r   r   )	r   r   r   r   r   r   r   r   r   )
isinstancetuplelent_stridemathceilintNotImplementedErrortemporal_paddingtemporal_padding_originprev_featuressuper__init__)selfr   r   r   r   r   r   r   r   r   t_ksh_ksw_ksh_stridew_stride
t_dilation
h_dilation
w_dilationt_padh_padw_pad	__class__ g/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl_magvit.pyr&   #   s<   

 

 
z EasyAnimateCausalConv3d.__init__c                 C      | ` d | _ d S Nr$   r'   r5   r5   r6   _clear_conv_cache]      
z)EasyAnimateCausalConv3d._clear_conv_cachehidden_statesreturnc                    s  |j }| jd u rttj|dddd| jdfdd}|j|d}|   |d d d d | j d f  | _|d}g }d}|| j d |krnt	 
|d d d d ||| j d f }|| j7 }|| || j d |ksEt|dS | jdkrtj| jd d d d | jd  d f |gdd}n
tj| j|gdd}|j|d}|   |d d d d | j d f  | _|d}g }d}|| j d |krt	 
|d d d d ||| j d f }|| j7 }|| || j d |kst|dS )Nr   	replicate)padmode)dtyper   r   dim)rB   r$   Fr@   r"   tor;   clonesizer%   forwardr   appendtorchconcat)r'   r=   rB   
num_framesoutputsioutr3   r5   r6   rI   a   sH   
$
,


($
,

zEasyAnimateCausalConv3d.forward)r   r   r   r   r   Tr   )__name__
__module____qualname__r    r   boolstrr&   r;   rK   TensorrI   __classcell__r5   r5   r3   r6   r   "   s:    	
:"r   c                       sb   e Zd Z						ddeded	ed
ededededef fddZdej	dej	fddZ
  ZS )EasyAnimateResidualBlock3Dsilu    ư>T              ?r   r   non_linearitynorm_num_groupsnorm_epsspatial_group_normdropoutoutput_scale_factorc	           	         s   t    || _tj|||dd| _t|| _t||dd| _	tj|||dd| _
t|| _t||dd| _||krEtj||dd| _nt | _|| _d S )NT)
num_groupsnum_channelsepsaffiner   r   r   )r%   r&   rc   nn	GroupNormnorm1r   nonlinearityr   conv1norm2Dropoutrb   conv2Conv3dshortcutIdentityra   )	r'   r   r   r^   r_   r`   ra   rb   rc   r3   r5   r6   r&      s"   



z#EasyAnimateResidualBlock3D.__init__r=   r>   c                 C   s  |  |}| jr/|d}|ddddddd}| |}|d|dfddddd}n| |}| |}| |}| jrh|d}|ddddddd}| 	|}|d|dfddddd}n| 	|}| |}| 
|}| |}|| | j S Nr   r   r   r      )rr   ra   rH   permuteflattenrk   	unflattenrl   rm   rn   rb   rp   rc   )r'   r=   rr   
batch_sizer5   r5   r6   rI      s.   













z"EasyAnimateResidualBlock3D.forward)rY   rZ   r[   Tr\   r]   rQ   rR   rS   r    rU   floatrT   r&   rK   rV   rI   rW   r5   r5   r3   r6   rX      s2    	$rX   c                	       sF   e Zd Zddedededef fddZd	ejd
ejfddZ  Z	S )EasyAnimateDownsampler3Dr   r   r   r   r   r   r   r   c                    s"   t    t||||dd| _d S )Nr   )r   r   r   r   r   )r%   r&   r   conv)r'   r   r   r   r   r3   r5   r6   r&      s   

z!EasyAnimateDownsampler3D.__init__r=   r>   c                 C   s   t |d}| |}|S )N)r   r   r   r   )rE   r@   r   r'   r=   r5   r5   r6   rI      s   
z EasyAnimateDownsampler3D.forward)r   r~   )
rQ   rR   rS   r    r   r&   rK   rV   rI   rW   r5   r5   r3   r6   r}      s     r}   c                       sX   e Zd Z			ddededededef
 fd	d
Zdd ZdejdejfddZ	  Z
S )EasyAnimateUpsampler3Dr   FTr   r   r   temporal_upsamplera   c                    s8   t    |p|}|| _|| _t|||d| _d | _d S )N)r   r   r   )r%   r&   r   ra   r   r   r$   )r'   r   r   r   r   ra   r3   r5   r6   r&      s   

zEasyAnimateUpsampler3D.__init__c                 C   r7   r8   r9   r:   r5   r5   r6   r;      r<   z(EasyAnimateUpsampler3D._clear_conv_cacher=   r>   c                 C   sR   t j|ddd}| |}| jr'| jd u r|| _|S t j|d| js#dndd}|S )Nr   r   r   nearest)scale_factorrA   )r   r   r   	trilinear)rE   interpolater   r   r$   ra   r   r5   r5   r6   rI      s   

zEasyAnimateUpsampler3D.forward)r   FT)rQ   rR   rS   r    rT   r&   r;   rK   rV   rI   rW   r5   r5   r3   r6   r      s"    r   c                       st   e Zd Z									dded	ed
ededededededededef fddZdej	dej	fddZ
  ZS )EasyAnimateDownBlock3Dr   rY   rZ   r[   Tr\   r]   r   r   
num_layersact_fnr_   r`   ra   rb   rc   add_downsampleadd_temporal_downsamplec                    s   t    tg | _t|D ]}|dkr|n|}| jt||||||||	d q|
r?|r?t||ddd| _	d| _
d| _d S |
rT|sTt||ddd| _	d| _
d| _d S d | _	d| _
d| _d S )	Nr   r   r   r^   r_   r`   ra   rb   rc   r   r~   )r   r   r   r   r   )r%   r&   ri   
ModuleListconvsrangerJ   rX   r}   downsamplerspatial_downsample_factortemporal_downsample_factor)r'   r   r   r   r   r_   r`   ra   rb   rc   r   r   rO   r3   r5   r6   r&     s6   



zEasyAnimateDownBlock3D.__init__r=   r>   c                 C   ,   | j D ]}||}q| jd ur| |}|S r8   )r   r   r'   r=   r   r5   r5   r6   rI   =  
   



zEasyAnimateDownBlock3D.forward)	r   rY   rZ   r[   Tr\   r]   TTr{   r5   r5   r3   r6   r     sD    	
-r   c                       st   e Zd Z									dd	ed
edededededededededef fddZdej	dej	fddZ
  ZS )EasyAnimateUpBlock3dr   rY   rZ   r[   Fr\   r]   Tr   r   r   r   r_   r`   ra   rb   rc   add_upsampleadd_temporal_upsamplec                    sx   t    tg | _t|D ]}|dkr|n|}| jt||||||||	d q|
r7t||||d| _	d S d | _	d S )Nr   r   )r   ra   )
r%   r&   ri   r   r   r   rJ   rX   r   	upsampler)r'   r   r   r   r   r_   r`   ra   rb   rc   r   r   rO   r3   r5   r6   r&   F  s0   

zEasyAnimateUpBlock3d.__init__r=   r>   c                 C   r   r8   )r   r   r   r5   r5   r6   rI   p  r   zEasyAnimateUpBlock3d.forward)	r   rY   rZ   r[   Fr\   r]   TTr{   r5   r5   r3   r6   r   E  sD    	
*r   c                       sd   e Zd Z							dded	ed
edededededef fddZdej	dej	fddZ
  ZS )EasyAnimateMidBlock3dr   rY   rZ   r[   Tr\   r]   r   r   r   r_   r`   ra   rb   rc   c	           
         s|   t    |d ur|nt|d d}tt||||||||dg| _t|d D ]}	| jt||||||||d q)d S )Nru   rZ   r   r   )	r%   r&   minri   r   rX   r   r   rJ   )
r'   r   r   r   r_   r`   ra   rb   rc   _r3   r5   r6   r&   y  s:   
zEasyAnimateMidBlock3d.__init__r=   r>   c                 C   s.   | j d |}| j dd  D ]}||}q|S )Nr   r   )r   )r'   r=   resnetr5   r5   r6   rI     s   
zEasyAnimateMidBlock3d.forward)r   rY   rZ   r[   Tr\   r]   r{   r5   r5   r3   r6   r   x  s4    	,r   c                       s   e Zd ZdZdZdddg dddd	dd
f	dededeedf deedf dededededef fddZ	de
jde
jfddZ  ZS )EasyAnimateEncoderzp
    Causal encoder for 3D video-like data used in [EasyAnimate](https://huggingface.co/papers/2405.18991).
    Tr      SpatialDownBlock3DSpatialTemporalDownBlock3Dr   r            r   r   rZ   rY   Fr   r   down_block_types.block_out_channelslayers_per_blockr_   r   double_zra   c
                    s8  t    t||d dd| _tg | _|d }
t|D ]E\}}|
}|| }
|t|d k}|dkrBt	||
|||d|	| dd	}n|d	krUt	||
|||d|	| d
d	}nt
d| | j| qt|d |||	|dddd| _|	| _tj|d |dd| _t|| _|rd| n|}t|d |dd| _d| _d S )Nr   r   rh   r   r   r[   F)	r   r   r   r   r_   r`   ra   r   r   r   TUnknown up block type: rv   )r   r   r   ra   r_   r`   rb   rc   re   rd   rf   r   )r%   r&   r   conv_inri   r   down_blocks	enumerater   r   
ValueErrorrJ   r   	mid_blockra   rj   conv_norm_outr   conv_actconv_outgradient_checkpointing)r'   r   r   r   r   r   r_   r   r   ra   output_channelsrO   down_block_typeinput_channelsis_final_block
down_blockconv_out_channelsr3   r5   r6   r&     sl   


zEasyAnimateEncoder.__init__r=   r>   c                 C   s   |  |}| jD ]}t r| jr| ||}q||}q| |}| jrL|d}|	ddddd
dd}| |}|d|df	ddddd}n| |}| |}| |}|S rt   )r   r   rK   is_grad_enabledr   _gradient_checkpointing_funcr   ra   rH   rw   rx   r   ry   r   r   )r'   r=   r   rz   r5   r5   r6   rI     s   





 


zEasyAnimateEncoder.forwardrQ   rR   rS   __doc__ _supports_gradient_checkpointingr    r   rU   rT   r&   rK   rV   rI   rW   r5   r5   r3   r6   r     s@    


Tr   c                       s   e Zd ZdZdZdddg dddd	d
fdededeedf deedf dedededef fddZ	de
jde
jfddZ  ZS )EasyAnimateDecoderzp
    Causal decoder for 3D video-like data used in [EasyAnimate](https://huggingface.co/papers/2405.18991).
    Tr   r   SpatialUpBlock3DSpatialTemporalUpBlock3Dr   r   r   r   rZ   rY   Fr   r   up_block_types.r   r   r_   r   ra   c	                    s:  t    t||d dd| _t|d |||dddd| _tg | _t	t
|}	|	d }
t|D ]I\}}|
}|	| }
|t|d k}|dkrXt||
|d ||d|| d	d
	}n|dkrmt||
|d ||d|| dd
	}ntd| | j| q1|| _tj|d |dd| _t|| _t|d |dd| _d	| _d S )Nrv   r   rh   r[   r   r   )r   r   r   r_   r`   rb   rc   r   F)	r   r   r   r   r_   r`   ra   r   r   r   Tr   r   )r%   r&   r   r   r   r   ri   r   	up_blockslistreversedr   r   r   r   rJ   ra   rj   r   r   r   r   r   )r'   r   r   r   r   r   r_   r   ra   reversed_block_out_channelsr   rO   up_block_typer   r   up_blockr3   r5   r6   r&   '  sj   


zEasyAnimateDecoder.__init__r=   r>   c                 C   s   |  |}t r| jr| | j|}n| |}| jD ]}t r,| jr,| ||}q||}q| jr[|d}|	ddddd
dd}| |}|d|df	ddddd}n| |}| |}| |}|S rt   )r   rK   r   r   r   r   r   ra   rH   rw   rx   r   ry   r   r   )r'   r=   r   rz   r5   r5   r6   rI   {  s&   









zEasyAnimateDecoder.forwardr   r5   r5   r3   r6   r      s:    


Tr   c                       s(  e Zd ZdZdZedddg dg dg ddd	d
ddfdedededeedf deedf deedf dededede	de
f fddZdd Z						dBdedB dedB dedB d e	dB d!e	dB d"e	dB d#dfd$d%Ze	dCd&ejd'e
d#eee B fd(d)Ze	dCd&ejd'e
d#eee B fd*d+ZdCd,ejd'e
d#eejB fd-d.ZedCd,ejd'e
d#eejB fd/d0Zd1ejd2ejd3ed#ejfd4d5Zd1ejd2ejd3ed#ejfd6d7ZdCd&ejd'e
d#efd8d9ZdCd,ejd'e
d#eejB fd:d;Z	<		dDd=ejd>e
d'e
d?ejdB d#eejB f
d@dAZ  ZS )EAutoencoderKLMagvitaq  
    A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This
    model is used in [EasyAnimate](https://huggingface.co/papers/2405.18991).

    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
    for all models (such as downloading or saving).
    Tr      r   r   r   r   rY   rZ   g?r   latent_channelsr   r   .r   r   r   r   r_   scaling_factorra   c                    s   t    t||||||	|d|d	| _t||||||	||d| _tjd| d| dd| _tj||dd| _	dt
|d  | _dt
|d  | _d| _d| _d| _d| _d| _d| _d	| _d	| _d| _d
| _d
| _d| _d S )NT)	r   r   r   r   r   r_   r   r   ra   )r   r   r   r   r   r_   r   ra   r   r   rh   Fru   r   i  r   )r%   r&   r   encoderr   decoderri   rq   
quant_convpost_quant_convr   spatial_compression_ratiotemporal_compression_ratiouse_slicing
use_tilinguse_framewise_encodinguse_framewise_decodingnum_sample_frames_batch_sizenum_latent_frames_batch_sizetile_sample_min_heighttile_sample_min_widthtile_sample_min_num_framestile_sample_stride_heighttile_sample_stride_widthtile_sample_stride_num_frames)r'   r   r   r   r   r   r   r   r   r_   r   ra   r3   r5   r6   r&     sL   

zAutoencoderKLMagvit.__init__c                 C   s:   |   D ]\}}t|tr|  t|tr|  qd S r8   )named_modulesr   r   r;   r   )r'   namemoduler5   r5   r6   r;     s   

z%AutoencoderKLMagvit._clear_conv_cacheNr   r   r   r   r   r   r>   c                 C   s^   d| _ d| _d| _|p| j| _|p| j| _|p| j| _|p| j| _|p%| j| _|p+| j| _dS )aX  
        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
        processing larger images.

        Args:
            tile_sample_min_height (`int`, *optional*):
                The minimum height required for a sample to be separated into tiles across the height dimension.
            tile_sample_min_width (`int`, *optional*):
                The minimum width required for a sample to be separated into tiles across the width dimension.
            tile_sample_stride_height (`int`, *optional*):
                The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
                no tiling artifacts produced across the height dimension.
            tile_sample_stride_width (`int`, *optional*):
                The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
                artifacts produced across the width dimension.
        TN)	r   r   r   r   r   r   r   r   r   )r'   r   r   r   r   r   r   r5   r5   r6   enable_tiling  s   z!AutoencoderKLMagvit.enable_tilingxreturn_dictc              
   C   s   | j r|jd | jks|jd | jkr| j||dS | |ddddddddddf }|g}td|jd | jD ] }| |dddd||| j ddddf }|| q=t	j
|dd}| |}|   |S )a  
        Encode a batch of images into latents.

        Args:
            x (`torch.Tensor`): Input batch of images.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

        Returns:
                The latent representations of the encoded images. If `return_dict` is True, a
                [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
        rv   r   Nr   r   rC   )r   shaper   r   tiled_encoder   r   r   rJ   rK   catr   r;   )r'   r   r   first_frameshrO   next_framesmomentsr5   r5   r6   _encode'  s   &,2
zAutoencoderKLMagvit._encodec                    s^    j r|jd dkr fdd|dD }t|}n |}t|}|s*|fS t|dS )a  
        Encode a batch of images into latents.

        Args:
            x (`torch.Tensor`): Input batch of images.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

        Returns:
                The latent representations of the encoded videos. If `return_dict` is True, a
                [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
        r   r   c                    s   g | ]}  |qS r5   )r   ).0x_slicer:   r5   r6   
<listcomp>V  s    z.AutoencoderKLMagvit.encode.<locals>.<listcomp>)latent_dist)r   r   splitrK   r   r   r   r	   )r'   r   r   encoded_slicesr   	posteriorr5   r:   r6   encodeE  s   

zAutoencoderKLMagvit.encodezc              
   C   s  |j \}}}}}| j| j }| j| j }	| jr,|j d |ks%|j d |	kr,| j||dS | |}| |d d d d d dd d d d f }
|
g}td|j d | j	D ] }| |d d d d ||| j	 d d d d f }|
| qTtj|dd}|s|fS t|dS )Nrv   r   r   r   r   rC   sample)r   r   r   r   r   tiled_decoder   r   r   r   rJ   rK   r   r   )r'   r   r   rz   re   rM   heightwidthtile_latent_min_heighttile_latent_min_widthr   decrO   r   r5   r5   r6   _decodea  s   "
,2
zAutoencoderKLMagvit._decodec                    s`    j r|jd dkr fdd|dD }t|}n |j}   |s+|fS t|dS )a  
        Decode a batch of images.

        Args:
            z (`torch.Tensor`): Input batch of latent vectors.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

        Returns:
            [`~models.vae.DecoderOutput`] or `tuple`:
                If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
                returned.
        r   r   c                    s   g | ]}  |jqS r5   )r  r   )r   z_slicer:   r5   r6   r     s    z.AutoencoderKLMagvit.decode.<locals>.<listcomp>r   )	r   r   r   rK   r   r  r   r;   r   )r'   r   r   decoded_slicesdecodedr5   r:   r6   decode{  s   
zAutoencoderKLMagvit.decodeabblend_extentc              	   C   s   t |jd |jd |}t|D ]@}|d d d d d d | | d d f d||   |d d d d d d |d d f ||   |d d d d d d |d d f< q|S )Nr   r   r   r   r   )r'   r	  r
  r  yr5   r5   r6   blend_v     R&zAutoencoderKLMagvit.blend_vc                 C   s   t |jd |jd |}t|D ]@}|d d d d d d d d | | f d||   |d d d d d d d d |f ||   |d d d d d d d d |f< q|S )Nru   r   r  )r'   r	  r
  r  r   r5   r5   r6   blend_h  r  zAutoencoderKLMagvit.blend_hc                 C   s\  |j \}}}}}|| j }|| j }	| j| j }
| j| j }| j| j }| j| j }|
| }|| }g }td|| jD ]}g }td|| jD ]q}|d d d d d d ||| j ||| j f }| |d d d d ddd d d d f }|g}td|| jD ] }| |d d d d ||| j d d d d f }|	| qt
j|dd}| |}|   |	| qF|	| q;g }t|D ]O\}}g }t|D ]:\}}|dkr| ||d  | ||}|dkr| ||d  ||}|	|d d d d d d d |d |	f  q|	t
j|dd qt
j|ddd d d d d d d |d |	f }|S )Nr   r   r   rC   ru   r   )r   r   r   r   r   r   r   r   r   rJ   rK   r   r   r;   r   r  r  )r'   r   r   rz   re   rM   r   r   latent_heightlatent_widthr  r  tile_latent_stride_heighttile_latent_stride_widthblend_heightblend_widthrowsrO   rowjtiler   tile_hkr   result_rows
result_rowr   r5   r5   r6   r     sV   

,2
.0z AutoencoderKLMagvit.tiled_encodec                 C   sr  |j \}}}}}|| j }|| j }	| j| j }
| j| j }| j| j }| j| j }| j| j }| j| j }g }td||D ]}g }td||D ]o}|d d d d d d |||
 ||| f }| |}| |d d d d d dd d d d f }|g}td|| j	D ] }| |d d d d ||| j	 d d d d f }|
| qtj|dd}|   |
| qH|
| q>g }t|D ]Q\}}g }t|D ]<\}}|dkr| ||d  | ||}|dkr| ||d  ||}|
|d d d d d d d | jd | jf  q|
tj|dd qtj|ddd d d d d d d |d |	f }|s4|fS t|dS )Nr   r   r   rC   ru   r   r   )r   r   r   r   r   r   r   r   r   r   rJ   rK   r   r;   r   r  r  r   )r'   r   r   rz   re   rM   r   r   sample_heightsample_widthr  r  r  r  r  r  r  rO   r  r  r  r   tile_decr  r   r  r  r  r  r5   r5   r6   r     sZ   




,220
z AutoencoderKLMagvit.tiled_decodeFr   sample_posterior	generatorc           	      C   sJ   |}|  |j}|r|j|d}n| }| |j}|s |fS t|dS )aa  
        Args:
            sample (`torch.Tensor`): Input sample.
            sample_posterior (`bool`, *optional*, defaults to `False`):
                Whether to sample from the posterior.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
        )r#  r   )r   r   r   rA   r  r   )	r'   r   r"  r   r#  r   r   r   r  r5   r5   r6   rI     s   
zAutoencoderKLMagvit.forward)NNNNNN)T)FTN)rQ   rR   rS   r   r   r   r    r   rU   r|   rT   r&   r;   r   r   rK   rV   r	   r   r   r   r   r  r  r  r  r   r   	GeneratorrI   rW   r5   r5   r3   r6   r     s    


V

$

 "   4@r   )'r   rK   torch.nnri   torch.nn.functional
functionalrE   configuration_utilsr   r   utilsr   utils.accelerate_utilsr   activationsr   modeling_outputsr	   modeling_utilsr
   vaer   r   r   
get_loggerrQ   loggerrq   r   ModulerX   r}   r   r   r   r   r   r   r   r5   r5   r5   r6   <module>   s,   
qF(634ty