o
    Gi                     @   s  d dl Z d dlZd dlmZ ddlmZmZ ddlmZ ddlm	Z	 ddl
mZ ddlmZmZmZ dd	lmZ dd
lmZ ddlmZ eeZdd Zd ddZG dd dejZG dd dejZG dd dejZG dd dejZG dd deeZ G dd deeZ!dS )!    N)nn   )ConfigMixinregister_to_config)
ModelMixin)FeedForward)	Attention)TimestepEmbedding	Timestepsget_2d_sincos_pos_embed)Transformer2DModelOutput)AdaLayerNorm)loggingc                 C   s   dd }||d|  k s||d|  krt d t B ||| | }||| | }| d| d d| d  |   | |td  | 	| | j
||d | W  d    S 1 sbw   Y  d S )Nc                 S   s   dt | t d  d S )N      ?       @)matherfsqrt)x r   a/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/pipelines/unidiffuser/modeling_uvit.pynorm_cdf   s   z(_no_grad_trunc_normal_.<locals>.norm_cdf   zjmean is more than 2 std from [a, b] in nn.init.trunc_normal_. The distribution of values may be incorrect.   r   )minmax)loggerwarningtorchno_graduniform_erfinv_mul_r   r   add_clamp_)tensormeanstdabr   lur   r   r   _no_grad_trunc_normal_   s    

$r,           r          r   c                 C   s   t | ||||S )a  Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean},
    \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for
    generating the random values works best when :math:`a \leq \text{mean} \leq b`.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    Examples:
        >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w)
    )r,   )r%   r&   r'   r(   r)   r   r   r   trunc_normal_8   s   r/   c                       s<   e Zd ZdZ									d fdd		Zd
d Z  ZS )
PatchEmbedz2D Image to Patch Embedding      r      FTc
                    s   t    || ||  }
|| _|| _tj||||f||d| _|r,tj|ddd| _nd | _|	| _	| j	rOt
|t|
d dd}| jd| d	dd
 d S d S )N)kernel_sizestridebiasFgư>)elementwise_affineeps      ?pt)output_type	pos_embedr   )
persistent)super__init__flatten
layer_normr   Conv2dproj	LayerNormnormuse_pos_embedr   intregister_bufferfloat	unsqueeze)selfheightwidth
patch_sizein_channels	embed_dimrA   r@   r6   rF   num_patchesr<   	__class__r   r   r?   N   s   
zPatchEmbed.__init__c                 C   sF   |  |}| jr|ddd}| jr| |}| jr!|| j S |S )Nr   r   )rC   r@   	transposerA   rE   rF   r<   )rK   latentr   r   r   forwardm   s   


zPatchEmbed.forward)	r1   r1   r2   r   r3   FTTT)__name__
__module____qualname____doc__r?   rV   __classcell__r   r   rR   r   r0   K   s    r0   c                       s*   e Zd Zdef fddZdd Z  ZS )	SkipBlockdimc                    s,   t    td| || _t|| _d S )Nr   )r>   r?   r   Linearskip_linearrD   rE   )rK   r]   rR   r   r   r?   z   s   
zSkipBlock.__init__c                 C   s&   |  tj||gdd}| |}|S )Nr]   )r_   r   catrE   )rK   r   skipr   r   r   rV      s   
zSkipBlock.forward)rW   rX   rY   rG   r?   rV   r[   r   r   rR   r   r\   y   s    r\   c                       s   e Zd ZdZ												dded	ed
ededB dededB dedededededededef fddZ						dddZ  Z	S )UTransformerBlockaS  
    A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations.

    Parameters:
        dim (`int`): The number of channels in the input and output.
        num_attention_heads (`int`): The number of heads to use for multi-head attention.
        attention_head_dim (`int`): The number of channels in each head.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
        activation_fn (`str`, *optional*, defaults to `"geglu"`):
            Activation function to be used in feed-forward.
        num_embeds_ada_norm (:obj: `int`, *optional*):
            The number of diffusion steps used during training. See `Transformer2DModel`.
        attention_bias (:obj: `bool`, *optional*, defaults to `False`):
            Configure if the attentions should contain a bias parameter.
        only_cross_attention (`bool`, *optional*):
            Whether to use only cross-attention layers. In this case two cross attention layers are used.
        double_self_attention (`bool`, *optional*):
            Whether to use two self-attention layers. In this case no cross attention layers are used.
        upcast_attention (`bool`, *optional*):
            Whether to upcast the query and key to float32 when performing the attention calculation.
        norm_elementwise_affine (`bool`, *optional*):
            Whether to use learnable per-element affine parameters during layer normalization.
        norm_type (`str`, defaults to `"layer_norm"`):
            The layer norm implementation to use.
        pre_layer_norm (`bool`, *optional*):
            Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
            as opposed to after ("post-LayerNorm"). Note that `BasicTransformerBlock` uses pre-LayerNorm, e.g.
            `pre_layer_norm = True`.
        final_dropout (`bool`, *optional*):
            Whether to use a final Dropout layer after the feedforward network.
    r-   NgegluFTrA   r]   num_attention_headsattention_head_dimcross_attention_dimactivation_fnnum_embeds_ada_normattention_biasonly_cross_attentiondouble_self_attentionupcast_attentionnorm_elementwise_affine	norm_typepre_layer_normfinal_dropoutc              	         t    |	| _|d uo|dk| _|| _|dv r'|d u r'td| d| dt||||||	r1|nd |d| _|d us=|
rNt||
sC|nd |||||d| _nd | _| jr[t	||| _
ntj||d| _
|d usi|
rz| jrqt	||ntj||d| _nd | _tj||d| _t||||d	| _d S 
Nada_norm)ru   ada_norm_zeroz`norm_type` is set to zw, but `num_embeds_ada_norm` is not defined. Please make sure to define `num_embeds_ada_norm` if setting `norm_type` to .)	query_dimheadsdim_headdropoutr6   rh   rn   )rx   rh   ry   rz   r{   r6   rn   )r7   )r{   ri   rr   r>   r?   rl   use_ada_layer_normrq   
ValueErrorr   attn1attn2r   norm1r   rD   norm2norm3r   ffrK   r]   rf   rg   r{   rh   ri   rj   rk   rl   rm   rn   ro   rp   rq   rr   rR   r   r   r?      R   




zUTransformerBlock.__init__c                 C   s8  | j r| jr| ||}n| |}n|}|d ur|ni }| j|f| jr&|nd |d|}	| j s@| jr;| |	|}	n| |	}	|	| }| jd ur~| j r[| jrU| ||n| |}n|}| j|f||d|}	| j sz| jru| |	|n| |	}	|	| }| j r| |}n|}| |}
| j s| |
}
|
| }|S N)encoder_hidden_statesattention_mask	rq   r}   r   r   rl   r   r   r   r   )rK   hidden_statesr   r   encoder_attention_masktimestepcross_attention_kwargsclass_labelsnorm_hidden_statesattn_output	ff_outputr   r   r   rV      sV   



zUTransformerBlock.forward)r-   Nre   NFFFFTrA   TFNNNNNN
rW   rX   rY   rZ   rG   strboolr?   rV   r[   r   r   rR   r   rd      sd    &	
Prd   c                       s   e Zd ZdZ												dded	ed
ededB dededB dedededededededef fddZ						dddZ  Z	S )UniDiffuserBlocka@	  
    A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations and puts the
    LayerNorms on the residual backbone of the block. This matches the transformer block in the [original UniDiffuser
    implementation](https://github.com/thu-ml/unidiffuser/blob/main/libs/uvit_multi_post_ln_v1.py#L104).

    Parameters:
        dim (`int`): The number of channels in the input and output.
        num_attention_heads (`int`): The number of heads to use for multi-head attention.
        attention_head_dim (`int`): The number of channels in each head.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
        activation_fn (`str`, *optional*, defaults to `"geglu"`):
            Activation function to be used in feed-forward.
        num_embeds_ada_norm (:obj: `int`, *optional*):
            The number of diffusion steps used during training. See `Transformer2DModel`.
        attention_bias (:obj: `bool`, *optional*, defaults to `False`):
            Configure if the attentions should contain a bias parameter.
        only_cross_attention (`bool`, *optional*):
            Whether to use only cross-attention layers. In this case two cross attention layers are used.
        double_self_attention (`bool`, *optional*):
            Whether to use two self-attention layers. In this case no cross attention layers are used.
        upcast_attention (`bool`, *optional*):
            Whether to upcast the query and key to float() when performing the attention calculation.
        norm_elementwise_affine (`bool`, *optional*):
            Whether to use learnable per-element affine parameters during layer normalization.
        norm_type (`str`, defaults to `"layer_norm"`):
            The layer norm implementation to use.
        pre_layer_norm (`bool`, *optional*):
            Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
            as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm
            (`pre_layer_norm = False`).
        final_dropout (`bool`, *optional*):
            Whether to use a final Dropout layer after the feedforward network.
    r-   Nre   FTrA   r]   rf   rg   rh   ri   rj   rk   rl   rm   rn   ro   rp   rq   rr   c              	      rs   rt   r|   r   rR   r   r   r?   q  r   zUniDiffuserBlock.__init__c           
      C   s&  | j r| jr| ||}n| |}|d ur|ni }| j|f| jr#|nd |d|}|| }| j sA| jr<| ||}n| |}| jd urx| j rW| jrR| ||n| |}| j|f||d|}|| }| j sx| jrs| ||n| |}| j r| |}| |}	|	| }| j s| |}|S r   r   )
rK   r   r   r   r   r   r   r   r   r   r   r   r   rV     sR   





zUniDiffuserBlock.forward)r-   Nre   NFFFFTrA   FTr   r   r   r   rR   r   r   M  sd    (	
Pr   c                .       s   e Zd ZdZe														
										d+dedededB dedB dededededB dededB dedB dedB dededB dedededed ed!ed"ed#ef, fd$d%Z								d,d&ed'ed(efd)d*Z
  ZS )-UTransformer2DModelay  
    Transformer model based on the [U-ViT](https://github.com/baofff/U-ViT) architecture for image-like data. Compared
    to [`Transformer2DModel`], this model has skip connections between transformer blocks in a "U"-shaped fashion,
    similar to a U-Net. Supports only continuous (actual embeddings) inputs, which are embedded via a [`PatchEmbed`]
    layer and then reshaped to (b, t, d).

    Parameters:
        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
        in_channels (`int`, *optional*):
            Pass if the input is continuous. The number of channels in the input.
        out_channels (`int`, *optional*):
            The number of output channels; if `None`, defaults to `in_channels`.
        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
        norm_num_groups (`int`, *optional*, defaults to `32`):
            The number of groups to use when performing Group Normalization.
        cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
        attention_bias (`bool`, *optional*):
            Configure if the TransformerBlocks' attention should contain a bias parameter.
        sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
            Note that this is fixed at training time as it is used for learning a number of position embeddings. See
            `ImagePositionalEmbeddings`.
        num_vector_embeds (`int`, *optional*):
            Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
            Includes the class for the masked latent pixel.
        patch_size (`int`, *optional*, defaults to 2):
            The patch size to use in the patch embedding.
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
        num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
            The number of diffusion steps used during training. Note that this is fixed at training time as it is used
            to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
            up to but not more than steps than `num_embeds_ada_norm`.
        use_linear_projection (int, *optional*): TODO: Not used
        only_cross_attention (`bool`, *optional*):
            Whether to use only cross-attention layers. In this case two cross attention layers are used in each
            transformer block.
        upcast_attention (`bool`, *optional*):
            Whether to upcast the query and key to float() when performing the attention calculation.
        norm_type (`str`, *optional*, defaults to `"layer_norm"`):
            The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`.
        block_type (`str`, *optional*, defaults to `"unidiffuser"`):
            The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual
            backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard
            behavior in `diffusers`.)
        pre_layer_norm (`bool`, *optional*):
            Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
            as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm
            (`pre_layer_norm = False`).
        norm_elementwise_affine (`bool`, *optional*):
            Whether to use learnable per-element affine parameters during layer normalization.
        use_patch_pos_embed (`bool`, *optional*):
            Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`).
        final_dropout (`bool`, *optional*):
            Whether to use a final Dropout layer after the feedforward network.
    r2   X   Nr   r-       Fr   re   rA   unidiffuserTrf   rg   rO   out_channels
num_layersr{   norm_num_groupsrh   rk   sample_sizenum_vector_embedsrN   ri   rj   use_linear_projectionrl   rn   rp   
block_typerq   ro   ff_final_dropoutc                    sJ  t    || _
| _| _
 |d ur|d usJ d|
d us&J d|
| _|
| _|| _t|
|
|||d| _	|dkrAt
ntt 	
fddt|d D | _
 	d| _t 	
fd	dt|d D | _|d u r|n|| _t| _d S )
Nz0Patch input requires in_channels and patch_size.z?UTransformer2DModel over patched input must provide sample_sizerL   rM   rN   rO   rP   rF   r   c                    s0   g | ]}
 	d qS )r{   rh   ri   rj   rk   rl   rn   rp   rq   ro   rr   r   .0dri   rk   rg   	block_clsrh   r{   r   	inner_dimro   rp   rf   rj   rl   rq   rn   r   r   
<listcomp>  s&    z0UTransformer2DModel.__init__.<locals>.<listcomp>r   r   c                    s@   g | ]}t t
 	d dqS )r   )rc   block)r   
ModuleDictr\   r   r   r   r   r     s2    )r>   r?   r   rf   rg   rL   rM   rN   r0   r<   r   rd   r   
ModuleListrangetransformer_in_blockstransformer_mid_blocktransformer_out_blocksr   rD   norm_out)rK   rf   rg   rO   r   r   r{   r   rh   rk   r   r   rN   ri   rj   r   rl   rn   rp   r   rq   ro   use_patch_pos_embedr   rR   r   r   r?   L  sb   
&
&
zUTransformer2DModel.__init__return_dicthidden_states_is_embedding
unpatchifyc	                 C   s&  |s|rt d| d| d| d|s| |}g }	| jD ]}
|
|||||d}|	| q| |}| jD ]}|d ||	 }|d |||||d}q7| |}|rt|j	d d	  }}|j
d
||| j| j| jfd}td|}|j
d
| j|| j || j fd}n|}|s|fS t|dS )a  
        Args:
            hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
                When continuous, `torch.Tensor` of shape `(batch size, channel, height, width)`): Input hidden_states
            encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
                self-attention.
            timestep ( `torch.long`, *optional*):
                Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
            class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
                Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
                conditioning.
            cross_attention_kwargs (*optional*):
                Keyword arguments to supply to the cross attention layers, if used.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
                tuple.
            hidden_states_is_embedding (`bool`, *optional*, defaults to `False`):
                Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will
                ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the
                transformer blocks.
            unpatchify (`bool`, *optional*, defaults to `True`):
                Whether to unpatchify the transformer output.

        Returns:
            [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
            [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
            returning a tuple, the first element is the sample tensor.
        z!Cannot both define `unpatchify`: z and `return_dict`: z since when `unpatchify` is zy the returned output is of shape (batch_size, seq_len, hidden_dim) rather than (batch_size, num_channels, height, width).)r   r   r   r   rc   r   r   r9   r`   shapenhwpqc->nchpwq)sample)r~   r<   r   appendr   r   popr   rG   r   reshaperN   r   r   einsumr   )rK   r   r   r   r   r   r   r   r   skipsin_block	out_blockrL   rM   outputr   r   r   rV     sT   *






zUTransformer2DModel.forward)r2   r   NNr   r-   r   NFNNr   re   NFFFrA   r   FTFF)NNNNTFT)rW   rX   rY   rZ   r   rG   rI   r   r   r?   rV   r[   r   r   rR   r   r     s    9	
 	r   c                6       sF  e Zd ZdZe											
																		d6dededededededB dedB dededededB dededB dedB dedB ded edB d!ed"ed#ed$ed%ed&ed'ed(ed)ef4 fd*d+Z	e
jjd,d- Z			d7d.e
jd/e
jd0e
jd1e
jeB eB d2e
jeB eB d3e
jeB eB dB fd4d5Z  ZS )8UniDiffuserModela  
    Transformer model for a image-text [UniDiffuser](https://huggingface.co/papers/2303.06555) model. This is a
    modification of [`UTransformer2DModel`] with input and output heads for the VAE-embedded latent image, the
    CLIP-embedded image, and the CLIP-embedded prompt (see paper for more details).

    Parameters:
        text_dim (`int`): The hidden dimension of the CLIP text model used to embed images.
        clip_img_dim (`int`): The hidden dimension of the CLIP vision model used to embed prompts.
        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
        in_channels (`int`, *optional*):
            Pass if the input is continuous. The number of channels in the input.
        out_channels (`int`, *optional*):
            The number of output channels; if `None`, defaults to `in_channels`.
        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
        norm_num_groups (`int`, *optional*, defaults to `32`):
            The number of groups to use when performing Group Normalization.
        cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
        attention_bias (`bool`, *optional*):
            Configure if the TransformerBlocks' attention should contain a bias parameter.
        sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
            Note that this is fixed at training time as it is used for learning a number of position embeddings. See
            `ImagePositionalEmbeddings`.
        num_vector_embeds (`int`, *optional*):
            Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
            Includes the class for the masked latent pixel.
        patch_size (`int`, *optional*, defaults to 2):
            The patch size to use in the patch embedding.
        activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
        num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
            The number of diffusion steps used during training. Note that this is fixed at training time as it is used
            to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
            up to but not more than steps than `num_embeds_ada_norm`.
        use_linear_projection (int, *optional*): TODO: Not used
        only_cross_attention (`bool`, *optional*):
            Whether to use only cross-attention layers. In this case two cross attention layers are used in each
            transformer block.
        upcast_attention (`bool`, *optional*):
            Whether to upcast the query and key to float32 when performing the attention calculation.
        norm_type (`str`, *optional*, defaults to `"layer_norm"`):
            The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`.
        block_type (`str`, *optional*, defaults to `"unidiffuser"`):
            The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual
            backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard
            behavior in `diffusers`.)
        pre_layer_norm (`bool`, *optional*):
            Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"),
            as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm
            (`pre_layer_norm = False`).
        norm_elementwise_affine (`bool`, *optional*):
            Whether to use learnable per-element affine parameters during layer normalization.
        use_patch_pos_embed (`bool`, *optional*):
            Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`).
        ff_final_dropout (`bool`, *optional*):
            Whether to use a final Dropout layer after the feedforward network.
        use_data_type_embedding (`bool`, *optional*):
            Whether to use a data type embedding. This is only relevant for UniDiffuser-v1 style models; UniDiffuser-v1
            is continue-trained from UniDiffuser-v0 on non-publically-available data and accepts a `data_type`
            argument, which can either be `1` to use the weights trained on non-publically-available data or `0`
            otherwise. This argument is subsequently embedded by the data type embedding, if used.
    r3      M   r2   r   Nr   r-   r   Fre   rA   r   Ttext_dimclip_img_dimnum_text_tokensrf   rg   rO   r   r   r{   r   rh   rk   r   r   rN   ri   rj   r   rl   rn   rp   r   rq   ro   r   use_data_type_embeddingc                    s`  t    || | _|d usJ d|| _|| _|d u r|n|| _|| _| j| | j|  | _t||||| j|d| _	t
|| j| _t
|| j| _t| jddd| _|rbt| jd| j | jdnt
 | _t| jddd| _|r~t| jd| j | jdnt
 | _|| _d| d	 | j | _t
td	| j| j| _t
j|	d
| _t| jdd || _| jrt
d| j| _ t
td	d	| j| _!t"d$i d|d|d|d|d|d|	d|
d|d|d|d|d|d|d|d|d|d|d|d|d |d!|d"|d#|| _#|d | }t
| j|| _$t
| j|| _%t
| j|| _&d S )%Nz<UniDiffuserModel over patched input must provide sample_sizer   Tr   )flip_sin_to_cosdownscale_freq_shift   )out_dimr   r   )pg{Gz?)r'   rf   rg   rO   r   r   r{   r   rh   rk   r   r   rN   ri   rj   r   rl   rn   rp   r   rq   ro   r   r   r   )'r>   r?   r   r   rO   r   rN   rQ   r0   
vae_img_inr   r^   clip_img_intext_inr
   timestep_img_projr	   Identitytimestep_img_embedtimestep_text_projtimestep_text_embedr   
num_tokens	Parameterr   zerosr<   Dropoutpos_embed_dropr/   r   	Embeddingdata_type_token_embeddingdata_type_pos_embed_tokenr   transformervae_img_outclip_img_outtext_out)rK   r   r   r   rf   rg   rO   r   r   r{   r   rh   rk   r   r   rN   ri   rj   r   rl   rn   rp   r   rq   use_timestep_embeddingro   r   r   r   	patch_dimrR   r   r   r?     s   
 

	
zUniDiffuserModel.__init__c                 C   s   dhS )Nr<   r   )rK   r   r   r   no_weight_decay	  s   z UniDiffuserModel.no_weight_decaylatent_image_embedsimage_embedsprompt_embedstimestep_imgtimestep_text	data_typec	              
   C   s  |j d }	| |}
| |}| |}|d|
d}}t|s/tj|gtj|
j	d}|tj
|	|j|j	d }| |}|j| jd}| |}|jdd}t|sbtj|gtj|
j	d}|tj
|	|j|j	d }| |}|j| jd}| |}|jdd}| jr|dusJ dt|stj|gtj|
j	d}|tj
|	|j|j	d }| |jdd}tj||||||
gdd}ntj|||||
gdd}| jrtj| jddddddf | j| jddddddf gdd}n| j}|| }| |}| j||dd|d	d
d	dd }| jr)|jddd|d|fdd\}}}}}}n|jdd|d|fdd\}}}}}| |}t|j d d  }}|jd||| j| j| jfd}td|}|jd| j|| j || j fd}|  |}| !|}|||fS )am  
        Args:
            latent_image_embeds (`torch.Tensor` of shape `(batch size, latent channels, height, width)`):
                Latent image representation from the VAE encoder.
            image_embeds (`torch.Tensor` of shape `(batch size, 1, clip_img_dim)`):
                CLIP-embedded image representation (unsqueezed in the first dimension).
            prompt_embeds (`torch.Tensor` of shape `(batch size, seq_len, text_dim)`):
                CLIP-embedded text representation.
            timestep_img (`torch.long` or `float` or `int`):
                Current denoising step for the image.
            timestep_text (`torch.long` or `float` or `int`):
                Current denoising step for the text.
            data_type: (`torch.int` or `float` or `int`, *optional*, defaults to `1`):
                Only used in UniDiffuser-v1-style models. Can be either `1`, to use weights trained on nonpublic data,
                or `0` otherwise.
            encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
                self-attention.
            cross_attention_kwargs (*optional*):
                Keyword arguments to supply to the cross attention layers, if used.


        Returns:
            `tuple`: Returns relevant parts of the model's noise prediction: the first element of the tuple is tbe VAE
            image embedding, the second element is the CLIP image embedding, and the third element is the CLIP text
            embedding.
        r   r   )dtypedevice)r   ra   NzBdata_type must be supplied if the model uses a data type embeddingr   FT)r   r   r   r   r   r   r   r9   r`   r   r   )"r   r   r   r   sizer   	is_tensorr%   longr   onesr   r   tor   rJ   r   r   r   rG   r   rb   r<   r   r   r   splitr   r   rN   r   r   r   r   )rK   r   r   r   r   r   r   r   r   
batch_sizevae_hidden_statesclip_hidden_statestext_hidden_statesr   num_img_tokenstimestep_img_tokentimestep_text_tokendata_type_tokenr   r<   t_img_token_outt_text_token_outdata_type_token_outr   img_clip_outimg_vae_outrL   rM   r   r   r   rV     s   
&









<
	



zUniDiffuserModel.forward)r3   r   r   r2   r   NNr   r-   r   NFNNNre   NFFFrA   r   FFTFTF)r   NN)rW   rX   rY   rZ   r   rG   rI   r   r   r?   r   jitignorer   TensorrV   r[   r   r   rR   r   r   @  s    ?	
 	

r   )r-   r   r.   r   )"r   r   r   configuration_utilsr   r   modelsr   models.attentionr   models.attention_processorr   models.embeddingsr	   r
   r   models.modeling_outputsr   models.normalizationr   utilsr   
get_loggerrW   r   r,   r/   Moduler0   r\   rd   r   r   r   r   r   r   r   <module>   s0    

%. B F  0