o
    Gix                     @   s  d dl mZ d dlZd dlmZ ddlmZmZ ddlmZ ddl	m
Z
mZ dd	lmZ dd
lmZ ddlmZ ddlmZ ddlmZ eeZdededededejdejfddZdejdejdejfddZG dd dZG dd dejeZ G dd dejZ!G d d! d!ejZ"G d"d# d#ejZ#G d$d% d%ejZ$G d&d' d'ejZ%d(ejdedejfd)d*Z&d+ejded,ejdejfd-d.Z'G d/d0 d0eee
Z(dS )1    )AnyN)nn   )ConfigMixinregister_to_config)logging   )AttentionMixinAttentionModuleMixin)dispatch_attention_fn)get_timestep_embedding)Transformer2DModelOutput)
ModelMixin)RMSNorm
batch_sizeheightwidth
patch_sizedevicereturnc                 C   s   t j|| || d|d}t j|| |ddddf |d< t j|| |ddddf |d< ||| ||  dd| ddS )a  
    Generates 2D patch coordinate indices for a batch of images.

    Args:
        batch_size (`int`):
            Number of images in the batch.
        height (`int`):
            Height of the input images (in pixels).
        width (`int`):
            Width of the input images (in pixels).
        patch_size (`int`):
            Size of the square patches that the image is divided into.
        device (`torch.device`):
            The device on which to create the tensor.

    Returns:
        `torch.Tensor`:
            Tensor of shape `(batch_size, num_patches, 2)` containing the (row, col) coordinates of each patch in the
            image grid.
    r   )r   N.r   .   r   r   )torchzerosarangereshape	unsqueezerepeat)r   r   r   r   r   img_ids r    a/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/models/transformers/transformer_prx.pyget_image_ids!   s   ""(r"   xq	freqs_cisc                 C   sn   |   jg | jdd dddR  }|j| j|jd}|d |d  |d |d   }|j| j | S )a  
    Applies rotary positional embeddings (RoPE) to a query tensor.

    Args:
        xq (`torch.Tensor`):
            Input tensor of shape `(..., dim)` representing the queries.
        freqs_cis (`torch.Tensor`):
            Precomputed rotary frequency components of shape `(..., dim/2, 2)` containing cosine and sine pairs.

    Returns:
        `torch.Tensor`:
            Tensor of the same shape as `xq` with rotary embeddings applied.
    Nr   r   r   dtyper   r   )floatr   shapetor   r'   type_as)r#   r$   xq_xq_outr    r    r!   
apply_rope=   s   * r.   c                   @   s`   e Zd ZdZdZdZdd Z			ddddejdejdB d	ejdB d
ejdB dejfddZ	dS )PRXAttnProcessor2_0z
    Processor for implementing PRX-style attention with multi-source tokens and RoPE. Supports multiple attention
    backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn.
    Nc                 C   s   t tjjdstdd S )Nscaled_dot_product_attentionzHPRXAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.)hasattrr   r   
functionalImportError)selfr    r    r!   __init__[   s   zPRXAttnProcessor2_0.__init__attnPRXAttentionhidden_statesencoder_hidden_statesattention_maskimage_rotary_embr   c           #      K   sp  |du rt d||}|j\}}	}
|||	d|j|j}|ddddd}|d |d |d }}}||}||}|	|}|j\}}}
|||d|j|j}|ddddd}|d |d }}|
|}|durzt||}t||}tj||fdd}tj||fdd}d}|dur|j\}}
}}
|jd }| dkrt d	|j |jd
 |krt d|jd
  d| |j}tj||ftj|d}|j|tjd}tj||gd
d}|ddddddf d
|j|d
}|dd}|dd}|dd}t||||| j| jd}|j\}} }!}"||| |!|" }|jd |}t|jdkr6|jd |}|S )a  
        Apply PRX attention using PRXAttention module.

        Args:
            attn: PRXAttention module containing projection layers
            hidden_states: Image tokens [B, L_img, D]
            encoder_hidden_states: Text tokens [B, L_txt, D]
            attention_mask: Boolean mask for text tokens [B, L_txt]
            image_rotary_emb: Rotary positional embeddings [B, 1, L_img, head_dim//2, 2, 2]
        NzLPRXAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.r   r   r   r      dimz"Unsupported attention_mask shape: r%   zattention_mask last dim z must equal text length r'   r   r&   )	attn_maskbackendparallel_config)
ValueErrorimg_qkv_projr)   r   headshead_dimpermutenorm_qnorm_ktxt_kv_projnorm_added_kr.   r   catr>   r   onesboolr*   expand	transposer   _attention_backend_parallel_configto_outlen)#r4   r6   r8   r9   r:   r;   kwargsimg_qkvBL_img_img_qimg_kimg_vtxt_kvL_txttxt_ktxt_vkvattn_mask_tensorbsl_imgl_txtr   ones_img
joint_maskquerykeyvalueattn_outputr   seq_len	num_headsrF   r    r    r!   __call___   sd   







&
zPRXAttnProcessor2_0.__call__NNN)
__name__
__module____qualname____doc__rQ   rR   r5   r   Tensorro   r    r    r    r!   r/   R   s*    r/   c                       s   e Zd ZdZeZegZ						ddeded	ed
edede	f fddZ
			ddejdejdB dejdB dejdB dejf
ddZ  ZS )r7   z
    PRX-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
    PRX's architecture.
       @   Fư>N	query_dimrE   dim_headbiasout_biasepsc                    s   t    || _|| _|| | _|| _tj||d |d| _t	| j|dd| _
t	| j|dd| _tj||d |d| _t	| j|dd| _tg | _| jtj| j||d | jtd |d u rh|  }| | d S )Nr   r{   T)r}   elementwise_affiner           )superr5   rE   rF   	inner_dimry   r   LinearrD   r   rH   rI   rJ   rK   
ModuleListrS   appendDropout_default_processor_clsset_processor)r4   ry   rE   rz   r{   r|   r}   	processor	__class__r    r!   r5      s    


zPRXAttention.__init__r8   r9   r:   r;   r   c                 K   s   | j | |f|||d|S )N)r9   r:   r;   )r   )r4   r8   r9   r:   r;   rU   r    r    r!   forward   s   zPRXAttention.forward)rv   rw   FFrx   Nrp   )rq   rr   rs   rt   r/   r   _available_processorsintrN   r(   r5   r   ru   r   __classcell__r    r    r   r!   r7      sJ    $r7   c                       sf   e Zd ZdZdededee f fddZdejdededejfd	d
Z	dejdejfddZ
  ZS )
PRXEmbedNDa  
    N-dimensional rotary positional embedding.

    This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding
    dimension. The embeddings are combined and returned as a single tensor

    Args:
        dim (int):
        Base embedding dimension (must be even).
        theta (int):
        Scaling factor that controls the frequency spectrum of the rotary embeddings.
        axes_dim (list[int]):
        list of embedding dimensions for each axis (each must be even).
    r>   thetaaxes_dimc                    s    t    || _|| _|| _d S N)r   r5   r>   r   r   )r4   r>   r   r   r   r    r!   r5     s   

zPRXEmbedND.__init__posr   c           
      C   s   |d dksJ |j jdk}|j jdk}|s|rtjntj}tjd|d||j d| }d||  }|d|d }	tjt|	t	|	 t	|	t|	gdd}	|	j
g |	jd d ddR  }	|	 S )	Nr   r   mpsnpur?   g      ?r%   r=   )r   typer   float32float64r   r   stackcossinr   r)   r(   )
r4   r   r>   r   is_mpsis_npur'   scaleomegaoutr    r    r!   rope  s   0"zPRXEmbedND.ropeidsc                    s6    j d }tj fddt|D dd}|dS )Nr%   c                    s4   g | ]}  d d d d |f j| jqS r   )r   r   r   .0ir   r4   r    r!   
<listcomp>&  s   4 z&PRXEmbedND.forward.<locals>.<listcomp>r=   r   )r)   r   rL   ranger   )r4   r   n_axesembr    r   r!   r   #  s   

zPRXEmbedND.forward)rq   rr   rs   rt   r   listr5   r   ru   r   r   r   r    r    r   r!   r      s
    r   c                       s@   e Zd ZdZdedef fddZdejdejfdd	Z  Z	S )
MLPEmbedderan  
    A simple 2-layer MLP used for embedding inputs.

    Args:
        in_dim (`int`):
            Dimensionality of the input features.
        hidden_dim (`int`):
            Dimensionality of the hidden and output embedding space.

    Returns:
        `torch.Tensor`:
            Tensor of shape `(..., hidden_dim)` containing the embedded representations.
    in_dim
hidden_dimc                    s<   t    tj||dd| _t | _tj||dd| _d S )NTr~   )r   r5   r   r   in_layerSiLUsilu	out_layer)r4   r   r   r   r    r!   r5   ;  s   

zMLPEmbedder.__init__xr   c                 C   s   |  | | |S r   )r   r   r   )r4   r   r    r    r!   r   A  s   zMLPEmbedder.forward
rq   rr   rs   rt   r   r5   r   ru   r   r   r    r    r   r!   r   ,  s    r   c                	       sb   e Zd ZdZdef fddZdejdeeejejejf eejejejf f fddZ	  Z
S )	
Modulationa  
    Modulation network that generates scale, shift, and gating parameters.

    Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into
    two tuples `(shift, scale, gate)`.

    Args:
        dim (`int`):
            Dimensionality of the input vector. The output will have `6 * dim` features internally.

    Returns:
        ((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`), (`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)):
            Two tuples `(shift, scale, gate)`.
    r>   c                    sH   t    tj|d| dd| _tj| jjd tj| jjd d S )N   Tr~   r   )	r   r5   r   r   lininit	constant_weightr{   )r4   r>   r   r    r!   r5   U  s   
zModulation.__init__vecr   c                 C   sN   |  tj|d d d d d f jddd}t|d d t|dd  fS )Nr   r%   r=   r   )r   r   r2   r   chunktuple)r4   r   r   r    r    r!   r   [  s   . zModulation.forward)rq   rr   rs   rt   r   r5   r   ru   r   r   r   r    r    r   r!   r   E  s    *r   c                       s~   e Zd ZdZ		ddededededB f fdd	Z	dd
ejdejdejdejdejdB de	e
ef dejfddZ  ZS )PRXBlocku5  
    Multimodal transformer block with text–image cross-attention, modulation, and MLP.

    Args:
        hidden_size (`int`):
            Dimension of the hidden representations.
        num_heads (`int`):
            Number of attention heads.
        mlp_ratio (`float`, *optional*, defaults to 4.0):
            Expansion ratio for the hidden dimension inside the MLP.
        qk_scale (`float`, *optional*):
            Scale factor for queries and keys. If not provided, defaults to ``head_dim**-0.5``.

    Attributes:
        img_pre_norm (`nn.LayerNorm`):
            Pre-normalization applied to image tokens before attention.
        attention (`PRXAttention`):
            Multi-head attention module with built-in QKV projections and normalizations for cross-attention between
            image and text tokens.
        post_attention_layernorm (`nn.LayerNorm`):
            Normalization applied after attention.
        gate_proj / up_proj / down_proj (`nn.Linear`):
            Feedforward layers forming the gated MLP.
        mlp_act (`nn.GELU`):
            Nonlinear activation used in the MLP.
        modulation (`Modulation`):
            Produces scale/shift/gating parameters for modulated layers.

        Methods:
            The forward method performs cross-attention and the MLP with modulation.
          @Nhidden_sizern   	mlp_ratioqk_scalec              	      s   t    || _|| _|| | _|p| jd | _t|| | _|| _t	j
|ddd| _t||| jdddt d| _t	j
|ddd| _t	j|| jdd| _t	j|| jdd| _t	j| j|dd| _t	jdd| _t|| _d S )	Ng      Frx   r   r}   )ry   rE   rz   r{   r|   r}   r   r~   tanh)approximate)r   r5   r   rn   rF   r   r   mlp_hidden_dimr   r   	LayerNormimg_pre_normr7   r/   	attentionpost_attention_layernormr   	gate_projup_proj	down_projGELUmlp_actr   
modulation)r4   r   rn   r   r   r   r    r!   r5     s.   

zPRXBlock.__init__r8   r9   tembr;   r:   rU   r   c              	   K   s   |  |\}}|\}	}
}|\}}}d|
 | | |	 }| j||||d}|||  }d| | | | }||| | | || |   }|S )a  
        Runs modulation-gated cross-attention and MLP, with residual connections.

        Args:
            hidden_states (`torch.Tensor`):
                Image tokens of shape `(B, L_img, hidden_size)`.
            encoder_hidden_states (`torch.Tensor`):
                Text tokens of shape `(B, L_txt, hidden_size)`.
            temb (`torch.Tensor`):
                Conditioning vector used by `Modulation` to produce scale/shift/gates, shape `(B, hidden_size)` (or
                broadcastable).
            image_rotary_emb (`torch.Tensor`):
                Rotary positional embeddings applied inside attention.
            attention_mask (`torch.Tensor`, *optional*):
                Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding.
            **kwargs:
                Additional keyword arguments for API compatibility.

        Returns:
            `torch.Tensor`:
                Updated image tokens of shape `(B, L_img, hidden_size)`.
        r   )r8   r9   r:   r;   )r   r   r   r   r   r   r   r   )r4   r8   r9   r   r;   r:   rU   mod_attnmod_mlp
attn_shift
attn_scale	attn_gate	mlp_shift	mlp_scalemlp_gatehidden_states_modattn_outr   r    r    r!   r     s    

(zPRXBlock.forward)r   Nr   )rq   rr   rs   rt   r   r(   r5   r   ru   dictstrr   r   r   r    r    r   r!   r   b  s:    $.
r   c                       sJ   e Zd ZdZdededef fddZdejdejd	ejfd
dZ  Z	S )
FinalLayera  
    Final projection layer with adaptive LayerNorm modulation.

    This layer applies a normalized and modulated transformation to input tokens and projects them into patch-level
    outputs.

    Args:
        hidden_size (`int`):
            Dimensionality of the input tokens.
        patch_size (`int`):
            Size of the square image patches.
        out_channels (`int`):
            Number of output channels per pixel (e.g. RGB = 3).

    Forward Inputs:
        x (`torch.Tensor`):
            Input tokens of shape `(B, L, hidden_size)`, where `L` is the number of patches.
        vec (`torch.Tensor`):
            Conditioning vector of shape `(B, hidden_size)` used to generate shift and scale parameters for adaptive
            LayerNorm.

    Returns:
        `torch.Tensor`:
            Projected patch outputs of shape `(B, L, patch_size * patch_size * out_channels)`.
    r   r   out_channelsc                    s\   t    tj|ddd| _tj||| | dd| _tt tj|d| dd| _	d S )NFrx   r   Tr~   r   )
r   r5   r   r   
norm_finalr   linear
Sequentialr   adaLN_modulation)r4   r   r   r   r   r    r!   r5     s   
&zFinalLayer.__init__r   r   r   c                 C   s`   |  |jddd\}}d|d d d d d f  | | |d d d d d f  }| |}|S )Nr   r   r=   )r   r   r   r   )r4   r   r   shiftr   r    r    r!   r      s   :
zFinalLayer.forwardr   r    r    r   r!   r     s    $r   imgc                 C   sT   | j \}}}}|}| |||| ||| |} td| } | |d|| | } | S )a  
    Flattens an image tensor into a sequence of non-overlapping patches.

    Args:
        img (`torch.Tensor`):
            Input image tensor of shape `(B, C, H, W)`.
        patch_size (`int`):
            Size of each square patch. Must evenly divide both `H` and `W`.

    Returns:
        `torch.Tensor`:
            Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W
            // patch_size)` is the number of patches.
    znchpwq->nhwcpqr%   )r)   r   r   einsum)r   r   bchwpr    r    r!   img2seq  s   r   seqr)   c           
      C   s   t |tr|dd \}}nt |tjr"t|d t|d }}n
tdt| d| j\}}}|}|||  }	| ||| || |	||} t	d| } | ||	||} | S )av  
    Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`).

    Args:
        seq (`torch.Tensor`):
            Patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W //
            patch_size)`.
        patch_size (`int`):
            Size of each square patch.
        shape (`tuple` or `torch.Tensor`):
            The original image spatial shape `(H, W)`. If a tensor is provided, the first two values are interpreted as
            height and width.

    Returns:
        `torch.Tensor`:
            Reconstructed image tensor of shape `(B, C, H, W)`.
    Nr   r   zshape type z not supportedznhwcpq->nchpwq)

isinstancer   r   ru   r   NotImplementedErrorr   r)   r   r   )
r   r   r)   r   r   r   ldr   r   r    r    r!   seq2img%  s   
r   c                       s   e Zd ZdZdZdZe									
			d'dededededededede	dededef fddZ
dejdejdejfddZ	
	
	d(dejdejd ejd!ejd
B d"eeef d
B d#edeejd$f eB fd%d&Z  ZS ))PRXTransformer2DModela  
    Transformer-based 2D model for text to image generation.

    Args:
        in_channels (`int`, *optional*, defaults to 16):
            Number of input channels in the latent image.
        patch_size (`int`, *optional*, defaults to 2):
            Size of the square patches used to flatten the input image.
        context_in_dim (`int`, *optional*, defaults to 2304):
            Dimensionality of the text conditioning input.
        hidden_size (`int`, *optional*, defaults to 1792):
            Dimension of the hidden representation.
        mlp_ratio (`float`, *optional*, defaults to 3.5):
            Expansion ratio for the hidden dimension inside MLP blocks.
        num_heads (`int`, *optional*, defaults to 28):
            Number of attention heads.
        depth (`int`, *optional*, defaults to 16):
            Number of transformer blocks.
        axes_dim (`list[int]`, *optional*):
            list of dimensions for each positional embedding axis. Defaults to `[32, 32]`.
        theta (`int`, *optional*, defaults to 10000):
            Frequency scaling factor for rotary embeddings.
        time_factor (`float`, *optional*, defaults to 1000.0):
            Scaling factor applied in timestep embeddings.
        time_max_period (`int`, *optional*, defaults to 10000):
            Maximum frequency period for timestep embeddings.

    Attributes:
        pe_embedder (`EmbedND`):
            Multi-axis rotary embedding generator for positional encodings.
        img_in (`nn.Linear`):
            Projection layer for image patch tokens.
        time_in (`MLPEmbedder`):
            Embedding layer for timestep embeddings.
        txt_in (`nn.Linear`):
            Projection layer for text conditioning.
        blocks (`nn.ModuleList`):
            Stack of transformer blocks (`PRXBlock`).
        final_layer (`LastLayer`):
            Projection layer mapping hidden tokens back to patch outputs.

    Methods:
        attn_processors:
            Returns a dictionary of all attention processors in the model.
        set_attn_processor(processor):
            Replaces attention processors across all attention layers.
        process_inputs(image_latent, txt):
            Converts inputs into patch tokens, encodes text, and produces positional encodings.
        compute_timestep_embedding(timestep, dtype):
            Creates a timestep embedding of dimension 256, scaled and projected.
        forward_transformers(image_latent, cross_attn_conditioning, timestep, time_embedding, attention_mask,
        **block_kwargs):
            Runs the sequence of transformer blocks over image and text tokens.
        forward(image_latent, timestep, cross_attn_conditioning, micro_conditioning, cross_attn_mask=None,
        attention_kwargs=None, return_dict=True):
            Full forward pass from latent input to reconstructed output image.

    Returns:
        `Transformer2DModelOutput` if `return_dict=True` (default), otherwise a tuple containing:
            - `sample` (`torch.Tensor`): Reconstructed image of shape `(B, C, H, W)`.
    zconfig.jsonT   r    	           @   N'       @@in_channelsr   context_in_dimr   r   rn   depthr   r   time_factortime_max_periodc                    s$  t    |d u rddg}|_|_jjd  _|
_|_|| dkr2td| d| || }t||krFtd| d| |_	|_
t||	|d_tjjjd  j	d	d
_tdj	d_t|j	_t fddt|D _tj	dj_d_d S )N    r   r   zHidden size z  must be divisible by num_heads zGot z but expected positional dim )r>   r   r   Tr~      )r   r   c                    s   g | ]}t jj d qS ))r   )r   r   rn   r   r   r4   r    r!   r     s    z2PRXTransformer2DModel.__init__.<locals>.<listcomp>r   F)r   r5   r   r   r   r  r  rC   sumr   rn   r   pe_embedderr   r   img_inr   time_intxt_inr   r   blocksr   final_layergradient_checkpointing)r4   r   r   r   r   r   rn   r   r   r   r  r  pe_dimr   r  r!   r5     s4   
 
zPRXTransformer2DModel.__init__timestepr'   r   c              
   C   s$   |  t|d| j| jddd|S )Nr  Tr   )	timestepsembedding_dim
max_periodr   flip_sin_to_cosdownscale_freq_shift)r	  r   r  r  r*   )r4   r  r'   r    r    r!   _compute_timestep_embedding  s   z1PRXTransformer2DModel._compute_timestep_embeddingr8   r9   r:   attention_kwargsreturn_dict.c              	   C   s   |  |}t|| j}| |}|j\}	}
}}t|	||| j|jd}| |}| j||j	d}| j
D ]}t rG| jrG| |j|||||}q2||||||d}q2| ||}t|| j|j}|sd|fS t|dS )a  
        Forward pass of the PRXTransformer2DModel.

        The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of
        transformer blocks modulated by the timestep. The output is reconstructed into the latent image space.

        Args:
            hidden_states (`torch.Tensor`):
                Input latent image tensor of shape `(B, C, H, W)`.
            timestep (`torch.Tensor`):
                Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning.
            encoder_hidden_states (`torch.Tensor`):
                Text conditioning tensor of shape `(B, L_txt, context_in_dim)`.
            attention_mask (`torch.Tensor`, *optional*):
                Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence.
            attention_kwargs (`dict`, *optional*):
                Additional arguments passed to attention layers.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether to return a `Transformer2DModelOutput` or a tuple.

        Returns:
            `Transformer2DModelOutput` if `return_dict=True`, otherwise a tuple:

                - `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`.
        )r   r   )r'   )r8   r9   r   r;   r:   )sample)r
  r   r   r  r)   r"   r   r  r  r'   r  r   is_grad_enabledr  _gradient_checkpointing_funcro   r  r   r   )r4   r8   r  r9   r:   r  r  txtr   rd   rY   r   r   r   per   blockoutputr    r    r!   r     s:   
#


		
zPRXTransformer2DModel.forward)r   r   r   r   r   r   r   Nr   r   r   )NNT)rq   rr   rs   rt   config_name _supports_gradient_checkpointingr   r   r(   r   r5   r   ru   r'   r  r   r   r   rN   r   r   r   r   r    r    r   r!   r   N  st    >	
9r   ))typingr   r   r   configuration_utilsr   r   utilsr   r   r	   r
   attention_dispatchr   
embeddingsr   modeling_outputsr   modeling_utilsr   normalizationr   
get_loggerrq   loggerr   r   ru   r"   r.   r/   Moduler7   r   r   r   r   r   r   r   r   r    r    r    r!   <module>   s0   
&n=/}( )