o
    پiv                     @   sN  d dl Zd dlZd dlmZ d dlm  mZ d dlm	Z	 d dl
mZ d dlmZ 	d#dededejdejd	edB d
ejfddZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZG dd  d ejZG d!d" d"eZeZdS )$    N)HunyuanVAEConfig)
get_act_fn)ParallelTiledVAE
num_framesheight_widthdtypedevice
batch_sizereturnc           
      C   sv   t jd| d t j|d}||}t j||dd\}}t ||kdtd j|d}	|d ur9|	d	|dd}	|	S )	N   )r   r   xy)indexingr   inf)r   )
torcharangeint32repeat_interleavemeshgridwherefloatto	unsqueezeexpand)
r   r   r   r   r	   indicesindices_blocksxymask r   h/home/ubuntu/.local/lib/python3.10/site-packages/sglang/multimodal_gen/runtime/models/vaes/hunyuanvae.pyprepare_causal_attention_mask   s   
 r!   c                       sF   e Zd Z		d	 fddZ	d
dejdejdB dejfddZ  ZS )HunyuanVAEAttentionr
   Nc                    s   t    || _|| _|| _|| _|| _|| _|| }tj	|||d| _
tj	|||d| _tj	|||d| _ttj	|||d| _tj|||dd| _d S )NbiasTepsaffine)super__init__in_channelsheadsdim_headr&   norm_num_groupsr$   nnLinearto_qto_kto_v
Sequentialto_out	GroupNorm
group_norm)selfr*   r+   r,   r&   r-   r$   	inner_dim	__class__r   r    r)   1   s   
zHunyuanVAEAttention.__init__hidden_statesattention_maskc                 C   s   |}|j \}}}| |dddd}| |}| |}| |}	| j}
||d| j|
dd}||d| j|
dd}|	|d| j|
dd}	t	j
|||	|ddd}|dd|d| j|
 }||j}| |}|| }|S )Nr      r           F)	attn_mask	dropout_p	is_causal)shaper6   	transposer0   r1   r2   r,   viewr+   Fscaled_dot_product_attentionreshaper   r   r4   )r7   r;   r<   residualr	   sequence_length_querykeyvaluehead_dimr   r   r    forwardI   s(   



zHunyuanVAEAttention.forward)r
   NN)__name__
__module____qualname__r)   r   TensorrO   __classcell__r   r   r9   r    r"   /   s    r"   c                       s   e Zd Z						ddededeeeeef B d	eeeeef B d
eeeeef B deeeeef B dededdf fddZdej	dej	fddZ
  ZS )HunyuanVideoCausalConv3d   r   r   T	replicater*   out_channelskernel_sizestridepaddingdilationr$   pad_moder
   Nc	           	   	      s   t    t|tr|||fn|}|| _|d d |d d |d d |d d |d d df| _tj|||||||d| _d S )Nr   r=   r   r#   )	r(   r)   
isinstanceintr^   time_causal_paddingr.   Conv3dconv)	r7   r*   rY   rZ   r[   r\   r]   r$   r^   r9   r   r    r)   t   s    





	z!HunyuanVideoCausalConv3d.__init__r;   c                 C   s   t j|| j| jd}| |S )N)mode)rE   padra   r^   rc   r7   r;   r   r   r    rO      s   

z HunyuanVideoCausalConv3d.forward)rW   r   r   r   TrX   )rQ   rR   rS   r`   tupleboolstrr)   r   rT   rO   rU   r   r   r9   r    rV   r   s6    	
!rV   c                       sh   e Zd Z					ddededB ded	ed
edeedf ddf fddZdejdejfddZ	  Z
S )HunyuanVideoUpsampleCausal3DNrW   r   Tr=   r=   r=   r*   rY   rZ   r[   r$   upsample_factor.r
   c                    s0   t    |p|}|| _t|||||d| _d S Nr#   )r(   r)   rl   rV   rc   )r7   r*   rY   rZ   r[   r$   rl   r9   r   r    r)      s   
	
z%HunyuanVideoUpsampleCausal3D.__init__r;   c                 C   s   | d}|jd|d fdd\}}tj|d| jdd  ddd}|dkr@| }tj|| jdd}tj	||fdd}n|}| 
|}|S )Nr=   r   )dimnearest)scale_factorrd   )sizesplitrE   interpolatesqueezerl   r   
contiguousr   catrc   )r7   r;   r   first_frameother_framesr   r   r    rO      s$   

z$HunyuanVideoUpsampleCausal3D.forward)NrW   r   Trk   )rQ   rR   rS   r`   rh   rg   r)   r   rT   rO   rU   r   r   r9   r    rj      s,    
rj   c                       s\   e Zd Z					ddededB ded	ed
eddf fddZdejdejfddZ  Z	S )HunyuanVideoDownsampleCausal3DNr   rW   Tr=   channelsrY   r\   rZ   r$   r
   c                    s,   t    |p|}t||||||d| _d S rm   )r(   r)   rV   rc   )r7   rz   rY   r\   rZ   r$   r[   r9   r   r    r)      s
   
	z'HunyuanVideoDownsampleCausal3D.__init__r;   c                 C   s   |  |}|S rP   )rc   rf   r   r   r    rO      s   
z&HunyuanVideoDownsampleCausal3D.forward)Nr   rW   Tr=   )
rQ   rR   rS   r`   rh   r)   r   rT   rO   rU   r   r   r9   r    ry      s(    ry   c                       s`   e Zd Z					ddededB ded	ed
ededdf fddZdejdejfddZ	  Z
S )HunyuanVideoResnetBlockCausal3DNr>       ư>silur*   rY   dropoutgroupsr&   non_linearityr
   c                    s   t    |p|}t|| _tj|||dd| _t||ddd| _tj|||dd| _	t
|| _t||ddd| _d | _||krLt||ddd| _d S d S )NTr%   rW   r   r   )r(   r)   r   nonlinearityr.   r5   norm1rV   conv1norm2Dropoutr   conv2conv_shortcut)r7   r*   rY   r   r   r&   r   r9   r   r    r)      s   
	


z(HunyuanVideoResnetBlockCausal3D.__init__r;   c                 C   sr   |  }|}| |}| |}| |}| |}| |}| |}| |}| jd ur3| |}|| }|S rP   )ru   r   r   r   r   r   r   r   )r7   r;   rH   r   r   r    rO     s   








z'HunyuanVideoResnetBlockCausal3D.forward)Nr>   r|   r}   r~   )rQ   rR   rS   r`   r   ri   r)   r   rT   rO   rU   r   r   r9   r    r{      s,    r{   c                       sh   e Zd Z							ddeded	ed
edededededdf fddZdej	dej	fddZ
  ZS )HunyuanVideoMidBlock3Dr>   r   r}   r~   r|   Tr*   r   
num_layers
resnet_epsresnet_act_fnresnet_groupsadd_attentionattention_head_dimr
   Nc	                    s   t    |d ur|nt|d d}|| _t||||||dg}	g }
t|D ]'}| jr;|
t||| |||dd n|
d  |	t||||||d q&t	|
| _
t	|	| _d| _d S )N   r|   r*   rY   r&   r   r   r   T)r+   r,   r&   r-   r$   F)r(   r)   minr   r{   rangeappendr"   r.   
ModuleList
attentionsresnetsgradient_checkpointing)r7   r*   r   r   r   r   r   r   r   r   r   rJ   r9   r   r    r)     sP   



zHunyuanVideoMidBlock3D.__init__r;   c           
      C   s  t  rf| jrf| | jd |}t| j| jdd  ddD ]F\}}|d ur]|j\}}}}}|ddddd	dd}t
||| |j|j|d}	|||	d	}|d|||fddddd}| ||}q|S | jd |}t| j| jdd  ddD ]D\}}|d ur|j\}}}}}|ddddd	dd}t
||| |j|j|d}	|||	d	}|d|||fddddd}||}qz|S )
Nr   r   T)strictr=   rW   r   )r	   )r<   )r   is_grad_enabledr   _gradient_checkpointing_funcr   zipr   rB   permuteflattenr!   r   r   	unflatten)
r7   r;   attnresnetr	   num_channelsr   heightwidthr<   r   r   r    rO   T  sZ   
"
"

zHunyuanVideoMidBlock3D.forward)r>   r   r}   r~   r|   Tr   )rQ   rR   rS   r`   r   ri   rh   r)   r   rT   rO   rU   r   r   r9   r    r     s8    	
=r   c                       s~   e Zd Z								dded	ed
ededededededeedf eB deddf fddZde	j
de	j
fddZ  ZS )HunyuanVideoDownBlock3Dr>   r   r}   r~   r|   Tr=   r*   rY   r   r   r   r   r   add_downsampledownsample_stride.downsample_paddingr
   Nc                    s   t    g }t|D ]}|dkr|n|}|t||||||d qt|| _|r9tt|||
|	dg| _	nd | _	d| _
d S )Nr   r   )rY   r\   r[   F)r(   r)   r   r   r{   r.   r   r   ry   downsamplersr   )r7   r*   rY   r   r   r   r   r   r   r   r   r   ir9   r   r    r)     s6   

z HunyuanVideoDownBlock3D.__init__r;   c                 C   ^   t  r| jr| jD ]}| ||}q
n
| jD ]}||}q| jd ur-| jD ]}||}q&|S rP   )r   r   r   r   r   r   )r7   r;   r   downsamplerr   r   r    rO     s   





zHunyuanVideoDownBlock3D.forward)r>   r   r}   r~   r|   Tr=   r   rQ   rR   rS   r`   r   ri   rh   rg   r)   r   rT   rO   rU   r   r   r9   r    r     sB    	
/r   c                       st   e Zd Z							dded	ed
ededededededeedf ddf fddZde	j
de	j
fddZ  ZS )HunyuanVideoUpBlock3Dr>   r   r}   r~   r|   Trk   r*   rY   r   r   r   r   r   add_upsampleupsample_scale_factor.r
   Nc
                    s   t    g }
t|D ]}|dkr|n|}|
t||||||d qt|
| _|r8tt|||	dg| _	nd | _	d| _
d S )Nr   r   )rY   rl   F)r(   r)   r   r   r{   r.   r   r   rj   
upsamplersr   )r7   r*   rY   r   r   r   r   r   r   r   r   r   input_channelsr9   r   r    r)     s4   


zHunyuanVideoUpBlock3D.__init__r;   c                 C   r   rP   )r   r   r   r   r   r   )r7   r;   r   	upsamplerr   r   r    rO     s   





zHunyuanVideoUpBlock3D.forward)r>   r   r}   r~   r|   Trk   r   r   r   r9   r    r     s<    	

.r   c                       s   e Zd ZdZ												
ddededeedf deedf dededededededdf fddZde	j
de	j
fddZ  ZS )HunyuanVideoEncoder3Dzx
    Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
    rW   r   r   r   r            r   r=   r|   r~   Tr      r*   rY   down_block_types.block_out_channelslayers_per_blockr-   act_fndouble_ztemporal_compression_ratiospatial_compression_ratior
   Nc                    s  t    t||d ddd| _d | _tg | _|d }t|D ]\}}|dkr0t	d| |}|| }|t
|d k}tt|}tt|
}|
dkrft||k }t|t
|d | koc| }n|
dkrwt||k }t||k }nt	d	|
 |rd
nd}|rdnd}t|| }t|||t|p|d|||dd	}| j| q!t|d d||d ||	d| _tj|d |dd| _t | _|rd| n|}t|d |dd| _d| _d S )Nr   rW   r   rZ   r[   r   zUnsupported down_block_type: r   r   $Unsupported time_compression_ratio: r=   r=   r   r   r=   r   r}   )	r   r*   rY   r   r   r   r   r   r   r   r*   r   r   r   r   r   r   
num_groupsr&   r=   rZ   F)r(   r)   rV   conv_in	mid_blockr.   r   down_blocks	enumerate
ValueErrorlenr`   nplog2rh   rg   r   r   r   r5   conv_norm_outSiLUconv_actconv_outr   )r7   r*   rY   r   r   r   r-   r   r   mid_block_add_attentionr   r   output_channelr   down_block_typeinput_channelis_final_blocknum_spatial_downsample_layersnum_time_downsample_layersadd_spatial_downsampleadd_time_downsampledownsample_stride_HWdownsample_stride_Tr   
down_blockconv_out_channelsr9   r   r    r)     sv   

	



zHunyuanVideoEncoder3D.__init__r;   c                 C   s   |  |}t r | jr | jD ]}| ||}q| | j|}n| jD ]}||}q#| jd us1J | |}| |}| |}| 	|}|S rP   )
r   r   r   r   r   r   r   r   r   r   )r7   r;   r   r   r   r    rO   h  s"   







zHunyuanVideoEncoder3D.forward)rW   rW   r   r   r=   r|   r~   TTr   r   )rQ   rR   rS   __doc__r`   rg   ri   rh   r)   r   rT   rO   rU   r   r   r9   r    r     sJ    


[r   c                       s   e Zd ZdZ											
ddededeedf deedf dededededef fddZdej	dej	fddZ
  ZS )HunyuanVideoDecoder3Dzx
    Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
    rW   r   r   r   r   r   r=   r|   r~   Tr   r   r*   rY   up_block_types.r   r   r-   r   time_compression_ratior   c                    s  t    || _t||d ddd| _tg | _t|d d||d ||d| _	t
t|}|d }t|D ]y\}}|dkrEtd	| |}|| }|t|d k}tt|
}tt|	}|	d
kr{t||k }t|t|d | kox| }ntd|	 |rdnd}|rdnd}t|| }t| jd ||t|p||d||d}| j| |}q6tj|d |dd| _t | _t|d |dd| _d| _d S )Nr   rW   r   r   r}   r   r   r   zUnsupported up_block_type: r   r   r   r   r   r   )r   r*   rY   r   r   r   r   r   r   r   F)r(   r)   r   rV   r   r.   r   	up_blocksr   r   listreversedr   r   r   r`   r   r   rh   rg   r   r   r5   r   r   r   r   r   )r7   r*   rY   r   r   r   r-   r   r   r   r   reversed_block_out_channelsr   r   up_block_typeprev_output_channelr   num_spatial_upsample_layersnum_time_upsample_layersadd_spatial_upsampleadd_time_upsampleupsample_scale_factor_HWupsample_scale_factor_Tr   up_blockr9   r   r    r)     st   






zHunyuanVideoDecoder3D.__init__r;   r
   c                 C   s   |  |}t r | jr | | j|}| jD ]}| ||}qn| |}| jD ]}||}q(| |}| |}| 	|}|S rP   )
r   r   r   r   r   r   r   r   r   r   )r7   r;   r   r   r   r    rO     s"   







zHunyuanVideoDecoder3D.forward)
rW   rW   r   r   r=   r|   r~   Tr   r   )rQ   rR   rS   r   r`   rg   ri   r)   r   rT   rO   rU   r   r   r9   r    r     s@    


[r   c                	   @   s|   e Zd ZdZdZdeddfddZdejdejfd	d
Z	dejdejfddZ
		ddejdedejdB dejfddZdS )AutoencoderKLHunyuanVideoaj  
    A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
    Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).

    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
    for all models (such as downloading or saving).
    Tconfigr
   Nc                 C   s   t j|  t| | |j| _|jr<t|j|j|j	|j|j
|j|jd|j|j|jd| _t jd|j d|j dd| _|jret|j|j|j|j|j
|j|j|j|j|jd
| _t j|j|jdd| _d S d S )NT)r*   rY   r   r   r   r-   r   r   r   r   r   r=   r   r   )
r*   rY   r   r   r   r-   r   r   r   r   )r.   Moduler)   r   r   load_encoderr   r*   latent_channelsr   r   r-   r   r   r   r   encoderrb   
quant_convload_decoderr   rY   r   decoderpost_quant_conv)r7   r   r   r   r    r)     sJ   
z"AutoencoderKLHunyuanVideo.__init__r   c                 C      |  |}| |}|S rP   )r   r   )r7   r   encr   r   r    _encode4     

z!AutoencoderKLHunyuanVideo._encodezc                 C   r  rP   )r   r   )r7   r  decr   r   r    _decode9  r  z!AutoencoderKLHunyuanVideo._decodeFsamplesample_posterior	generatorc                 C   s8   |}|  |j}|r|j|d}n| }| |}|S )z
        Args:
            sample (`torch.Tensor`): Input sample.
            sample_posterior (`bool`, *optional*, defaults to `False`):
                Whether to sample from the posterior.
        )r
  )encodelatent_distr  rd   decode)r7   r  r	  r
  r   	posteriorr  r  r   r   r    rO   >  s   
z!AutoencoderKLHunyuanVideo.forward)FN)rQ   rR   rS   r    _supports_gradient_checkpointingr   r)   r   rT   r  r  rh   	GeneratorrO   r   r   r   r    r     s*    
.r   rP   ) numpyr   r   torch.nnr.   torch.nn.functional
functionalrE   )sglang.multimodal_gen.configs.models.vaesr   /sglang.multimodal_gen.runtime.layers.activationr   0sglang.multimodal_gen.runtime.models.vaes.commonr   r`   r   r   rT   r!   r   r"   rV   rj   ry   r{   r   r   r   r   r   r   
EntryClassr   r   r   r    <module>   sB   
C*11s@@yzY