o
    GiC                     @   s:  d dl Z d dlZd dlmZ d dlmZ d dlZd dlZd dlm	Z	 d dl
m	  mZ ddlmZmZ ddlmZmZ ddlmZmZmZ ddlmZ d	d
lmZmZ d	dlmZmZ d	dlm Z  d	dl!m"Z" d	dl#m$Z$ d	dl%m&Z&m'Z' d	dl(m)Z) d	dl*m+Z+ d	dl,m-Z-m.Z. e/e0Z1				d7dej2de3de4de5de5de3dej2fddZ6		 d8d!ej2d"ej2e7ej2 B d#e4d$e3de7ej2ej2f f
d%d&Z8d'ej2d(ej2dB de7e3ej2dB ej2dB f fd)d*Z9G d+d, d,e	j:Z;G d-d. d.e	j:Z<G d/d0 d0e	j:Z=G d1d2 d2Z>eG d3d4 d4e	j:Z?G d5d6 d6e+eeee$eZ@dS )9    Nprod)Any   )ConfigMixinregister_to_config)FromOriginalModelMixinPeftAdapterMixin)apply_lora_scale	deprecatelogging)maybe_allow_in_graph   )ContextParallelInputContextParallelOutput)AttentionMixinFeedForward)dispatch_attention_fn)	Attention)
CacheMixin)TimestepEmbedding	Timesteps)Transformer2DModelOutput)
ModelMixin)AdaLayerNormContinuousRMSNormF   '  	timestepsembedding_dimflip_sin_to_cosdownscale_freq_shiftscale
max_periodreturnc           	      C   s   t | jdksJ d|d }t| tjd|tj| jd }|||  }t|	| j
}| dddf  |dddf  }|| }tjt|t|gdd}|rotj|dd|df |ddd|f gdd}|d dkr}tjj|d	}|S )
a&  
    This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.

    Args
        timesteps (torch.Tensor):
            a 1-D Tensor of N indices, one per batch element. These may be fractional.
        embedding_dim (int):
            the dimension of the output.
        flip_sin_to_cos (bool):
            Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
        downscale_freq_shift (float):
            Controls the delta between frequencies between dimensions
        scale (float):
            Scaling factor applied to the embeddings.
        max_period (int):
            Controls the maximum frequency of the embeddings
    Returns
        torch.Tensor: an [N x dim] Tensor of positional embeddings.
    r   zTimesteps should be a 1d-arrayr   r   )startenddtypedeviceNdim)r   r   r   r   )lenshapemathlogtorcharangefloat32r(   exptor'   floatcatsincosnn
functionalpad)	r   r   r    r!   r"   r#   half_dimexponentemb r?   g/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/models/transformers/transformer_qwenimage.pyget_timestep_embedding+   s   $2rA   Tr)   x	freqs_cisuse_realuse_real_unbind_dimc                 C   sX  |r|\}}|d }|d }| | j| | j}}|dkrC| jg | jdd ddR  d\}}tj| |gddd}n-|dkrh| jg | jdd ddR  d\}}tj| |gdd}nt	d| d	| 
 | |
 |   | j}	|	S t| 
 jg | jdd ddR  }|d
}t|| d}
|
| S )a3  
    Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
    to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
    reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
    tensors contain rotary embeddings and are returned as real tensors.

    Args:
        x (`torch.Tensor`):
            Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
        freqs_cis (`tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)

    Returns:
        tuple[torch.Tensor, torch.Tensor]: tuple of modified query tensor and key tensor with rotary embeddings.
    )NNr)   Nr   r*   r   z`use_real_unbind_dim=z` but should be -1 or -2.r   )r4   r(   reshaper-   unbindr0   stackflattenr6   
ValueErrorr5   r'   view_as_complex	unsqueezeview_as_realtype_as)rB   rC   rD   rE   r8   r7   x_realx_imag	x_rotatedoutx_outr?   r?   r@   apply_rotary_emb_qwena   s$   ,, ,

rU   encoder_hidden_statesencoder_hidden_states_maskc                 C   s   | j dd \}}|du r|ddfS |j dd ||fkr,td|j  d| d| d|jtjkr8|tj}tj|| jtjd}t	|||
d}|jd	d
}t	||jd	d
jd	 tj|| jd}|||fS )z}
    Compute text sequence length without assuming contiguous masks. Returns length for RoPE and a normalized bool mask.
    Nr   z#`encoder_hidden_states_mask` shape z( must match (batch_size, text_seq_len)=(z, z).r(   r'   r?   r   r*   )r(   )r-   rK   r'   r0   boolr4   r1   r(   longwhere	new_zerosanymaxvalues	as_tensor)rV   rW   
batch_sizetext_seq_lenposition_idsactive_positions
has_activeper_sample_lenr?   r?   r@   compute_text_seq_len_from_mask   s,   


rg   c                       s(   e Zd Zd fdd	ZdddZ  ZS )	QwenTimestepProjEmbeddingsFc                    sJ   t    tddddd| _td|d| _|| _|r#td|| _	d S d S )N   Tr     )num_channelsr    r!   r"   )in_channelstime_embed_dimr   )
super__init__r   	time_projr   timestep_embedderuse_additional_t_condr9   	Embeddingaddition_t_embedding)selfr   rr   	__class__r?   r@   ro      s   
z#QwenTimestepProjEmbeddings.__init__Nc                 C   s\   |  |}| |j|jd}|}| jr,|d u rtd| |}|j|jd}|| }|S )N)r'   zAWhen additional_t_cond is True, addition_t_cond must be provided.)rp   rq   r4   r'   rr   rK   rt   )ru   timestephidden_statesaddition_t_condtimesteps_projtimesteps_embconditioningaddition_t_embr?   r?   r@   forward   s   

z"QwenTimestepProjEmbeddings.forwardFN)__name__
__module____qualname__ro   r   __classcell__r?   r?   rv   r@   rh      s    	rh   c                       s   e Zd Zddedee f fddZdddZ						dd
eeeeeeeeef  f dee d	B dej	deej
B d	B deej
ej
f f
ddZejdd		ddededededej	dej
fddZ  ZS )QwenEmbedRopeFthetaaxes_dimc                       t    || _|| _td}tddd d }tj| || jd | j| || jd | j| || jd | jgdd| _	tj| || jd | j| || jd | j| || jd | jgdd| _
|| _d S Ni   r   r)   r   r   r*   rn   ro   r   r   r0   r1   flipr6   rope_params	pos_freqs	neg_freqs
scale_roperu   r   r   r   	pos_index	neg_indexrv   r?   r@   ro      s(   



zQwenEmbedRope.__init__r   c                 C   V   |d dksJ t |dt |t d|dt j| }t t ||}|S zn
        Args:
            index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
        r   r   g      ?	r0   outerpowr1   r4   r2   divpolar	ones_likeru   indexr+   r   freqsr?   r?   r@   r         0zQwenEmbedRope.rope_paramsN	video_fhwtxt_seq_lensr(   max_txt_seq_lenr$   c                    sV  |durt ddddd |du rt|trt|n|}|du r#tdt|trIt|dkrI|d	  t fd
d|D sItd| d  d t|trR|d	 }t|tsZ|g}g }d	}t	|D ]+\}}|\}	}
}| 
|	|
|||}|| | jrt|
d |d |}qbt|
||}qbt|}| j|||| df }tj|d	d}||fS )a  
        Args:
            video_fhw (`tuple[int, int, int]` or `list[tuple[int, int, int]]`):
                A list of 3 integers [frame, height, width] representing the shape of the video.
            txt_seq_lens (`list[int]`, *optional*, **Deprecated**):
                Deprecated parameter. Use `max_txt_seq_len` instead. If provided, the maximum value will be used.
            device: (`torch.device`, *optional*):
                The device on which to perform the RoPE computation.
            max_txt_seq_len (`int` or `torch.Tensor`, *optional*):
                The maximum text sequence length for RoPE computation. This should match the encoder hidden states
                sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility).
        Nr   0.39.0zPassing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. Please use `max_txt_seq_len` instead. The new parameter accepts a single int or tensor value representing the maximum text sequence length.Fstandard_warnzIEither `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided.r   r   c                 3       | ]}| kV  qd S r   r?   ).0fhw	first_fhwr?   r@   	<genexpr>      z(QwenEmbedRope.forward.<locals>.<genexpr>zBatch inference with variable-sized images is not currently supported in QwenEmbedRope. All images in the batch should have the same dimensions (frame, height, width). Detected sizes: z%. Using the first image's dimensions Y for RoPE computation, which may lead to incorrect results for other images in the batch.r   .r*   )r   
isinstancelistr^   rK   r,   allloggerwarning	enumerate_compute_video_freqsappendr   intr   r4   r0   r6   )ru   r   r   r(   r   	vid_freqsmax_vid_indexidxr   frameheightwidth
video_freqmax_txt_seq_len_int	txt_freqsr?   r   r@   r      sL   



zQwenEmbedRope.forward   maxsizer   r   r   r   r   c                 C     || | }|d ur| j |n| j }|d ur| j|n| j}|jdd | jD dd}	|jdd | jD dd}
|	d |||  |ddd|||d}| jrtj	|
d ||d   d  |	d d |d  gdd}|d|dd|||d}tj	|
d ||d   d  |	d d |d  gdd}|dd|d|||d}n(|	d d | d|dd|||d}|	d d | dd|d|||d}tj	|||gdd
|d}|  S )	Nc                 S      g | ]}|d  qS r   r?   r   rB   r?   r?   r@   
<listcomp>>      z6QwenEmbedRope._compute_video_freqs.<locals>.<listcomp>r   r*   c                 S   r   r   r?   r   r?   r?   r@   r   ?  r   r   r)   r   r   r4   r   splitr   viewexpandr   r0   r6   rG   clone
contiguousru   r   r   r   r   r(   seq_lensr   r   	freqs_pos	freqs_negfreqs_framefreqs_heightfreqs_widthr   r?   r?   r@   r   6  s   ,88((z"QwenEmbedRope._compute_video_freqsr   r   NNNr   N)r   r   r   r   r   ro   r   tupler0   r(   Tensorr   	functools	lru_cacher   r   r?   r?   rv   r@   r      s@    



Ir   c                       s   e Zd Zddedee f fddZdddZ		dd
eeeeeeeeef  f deej	B dej
deej	ej	f fddZejd	dddej
fddZejd	dddej
fddZ  ZS )QwenEmbedLayer3DRopeFr   r   c                    r   r   r   r   rv   r?   r@   ro   P  s(   


	zQwenEmbedLayer3DRope.__init__r   c                 C   r   r   r   r   r?   r?   r@   r   i  r   z QwenEmbedLayer3DRope.rope_paramsNr   r   r(   r$   c                    s@  t |tr&t|dkr&|d  t fdd|D s&td| d  d t |tr/|d }t |ts7|g}g }d}t|d }t|D ]8\}}|\}	}
}||kr\| |	|
|||}n| |	|
||}|	| | j
rwt|
d |d |}qEt|
||}qEt||}t|}| j|||| d	f }tj|dd
}||fS )a  
        Args:
            video_fhw (`tuple[int, int, int]` or `list[tuple[int, int, int]]`):
                A list of 3 integers [frame, height, width] representing the shape of the video, or a list of layer
                structures.
            max_txt_seq_len (`int` or `torch.Tensor`):
                The maximum text sequence length for RoPE computation. This should match the encoder hidden states
                sequence length. Can be either an int or a scalar tensor (for torch.compile compatibility).
            device: (`torch.device`, *optional*):
                The device on which to perform the RoPE computation.
        r   r   c                 3   r   r   r?   )r   entryfirst_entryr?   r@   r     r   z/QwenEmbedLayer3DRope.forward.<locals>.<genexpr>zBatch inference with variable-sized images is not currently supported in QwenEmbedLayer3DRope. All images in the batch should have the same layer structure. Detected sizes: z*. Using the first image's layer structure r   r   .r*   )r   r   r,   r   r   r   r   r   _compute_condition_freqsr   r   r^   r   r   r4   r0   r6   )ru   r   r   r(   r   r   	layer_numr   r   r   r   r   r   r   r   r?   r   r@   r   s  s>   




zQwenEmbedLayer3DRope.forwardr   r   c                 C   r   )	Nc                 S   r   r   r?   r   r?   r?   r@   r     r   z=QwenEmbedLayer3DRope._compute_video_freqs.<locals>.<listcomp>r   r*   c                 S   r   r   r?   r   r?   r?   r@   r     r   r   r)   r   r   r   r?   r?   r@   r     s   ,88((z)QwenEmbedLayer3DRope._compute_video_freqsc                 C   s  || | }|d ur| j |n| j }|d ur| j|n| j}|jdd | jD dd}|jdd | jD dd}	|	d dd  |ddd|||d}
| jrtj	|	d ||d   d  |d d |d  gdd}|d|dd|||d}tj	|	d ||d   d  |d d |d  gdd}|dd|d|||d}n(|d d | d|dd|||d}|d d | dd|d|||d}tj	|
||gdd
|d}|  S )	Nc                 S   r   r   r?   r   r?   r?   r@   r     r   zAQwenEmbedLayer3DRope._compute_condition_freqs.<locals>.<listcomp>r   r*   c                 S   r   r   r?   r   r?   r?   r@   r     r   r   r)   r   r   )ru   r   r   r   r(   r   r   r   r   r   r   r   r   r   r?   r?   r@   r     s   (88((z-QwenEmbedLayer3DRope._compute_condition_freqsr   r   r   r   )r   r   r   r   r   ro   r   r   r0   r   r(   r   r   r   r   r   r   r?   r?   rv   r@   r   O  s"    


<
r   c                   @   sd   e Zd ZdZdZdZdd Z				ddedej	dej	dej	d	ej	dB d
ej
dB dej	fddZdS ) QwenDoubleStreamAttnProcessor2_0z
    Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
    implements joint attention computation where text and image streams are processed together.
    Nc                 C   s   t tds	tdd S )Nscaled_dot_product_attentionz`QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.)hasattrFImportError)ru   r?   r?   r@   ro     s
   
z)QwenDoubleStreamAttnProcessor2_0.__init__attnry   rV   rW   attention_maskimage_rotary_embr$   c              
   C   s>  |d u rt d|jd }||}||}	||}
||}||}||}|d|j	df}|	d|j	df}	|
d|j	df}
|d|j	df}|d|j	df}|d|j	df}|j
d urk|
|}|jd uru||	}	|jd ur||}|jd ur||}|d ur|\}}t||dd}t|	|dd}	t||dd}t||dd}tj||gdd}tj||	gdd}tj||
gdd}t||||dd| j| jd}|d	d
}||j}|d d d |d d f }|d d |d d d f }|jd | }t|jdkr|jd |}|| }||fS )NzMQwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)r   r)   F)rD   r*   g        )	attn_mask	dropout_p	is_causalbackendparallel_configr   r   r   )rK   r-   to_qto_kto_v
add_q_proj
add_k_proj
add_v_proj	unflattenheadsnorm_qnorm_knorm_added_qnorm_added_krU   r0   r6   r   _attention_backend_parallel_configrJ   r4   r'   to_outr   r,   
to_add_out)ru   r   ry   rV   rW   r   r   seq_txt	img_queryimg_key	img_value	txt_querytxt_key	txt_value	img_freqsr   joint_query	joint_keyjoint_valuejoint_hidden_statestxt_attn_outputimg_attn_outputr?   r?   r@   __call__  sf   	














z)QwenDoubleStreamAttnProcessor2_0.__call__)NNNN)r   r   r   __doc__r   r   ro   r   r0   FloatTensorr   r  r?   r?   r?   r@   r     s0    
r   c                       s   e Zd Z			ddededededed	ef fd
dZdddZ			dde	j
de	j
de	j
de	j
dee	j
e	j
f dB deeef dB dee dB dee	j
e	j
f fddZ  ZS )QwenImageTransformerBlockrms_normư>Fr+   num_attention_headsattention_head_dimqk_normepszero_cond_tc                    s   t    || _|| _|| _tt tj|d| dd| _	tj
|d|d| _t|d ||||ddt ||d| _tj
|d|d| _t||dd| _tt tj|d| dd| _tj
|d|d| _tj
|d|d| _t||dd| _|| _d S )	N   TbiasFelementwise_affiner  )	query_dimcross_attention_dimadded_kv_proj_dimdim_headr   out_dimcontext_pre_onlyr  	processorr  r  zgelu-approximate)r+   dim_outactivation_fn)rn   ro   r+   r  r  r9   
SequentialSiLULinearimg_mod	LayerNorm	img_norm1r   r   r   	img_norm2r   img_mlptxt_mod	txt_norm1	txt_norm2txt_mlpr  )ru   r+   r  r  r  r  r  rv   r?   r@   ro   C  s@   
	
z"QwenImageTransformerBlock.__init__Nc                 C   s(  |j ddd\}}}|dur{|dd }|d| ||d }}	|d| ||d }
}|d| ||d }}|d}|d}|	d}|
d}|d}|d}|d}t|dk||}t|dk||}t|dk||}n|d}|d}|d}|d|  | |fS )z Apply modulation to input tensorr   r)   r*   Nr   r   r   )chunksizerM   r0   r[   )ru   rB   
mod_paramsr   shiftr"   gateactual_batchshift_0shift_1scale_0scale_1gate_0gate_1index_expandedshift_0_expshift_1_expscale_0_expscale_1_exp
gate_0_exp
gate_1_expshift_resultscale_resultgate_resultr?   r?   r@   	_modulatet  s(   









z#QwenImageTransformerBlock._modulatery   rV   rW   tembr   joint_attention_kwargsmodulate_indexr$   c                 C   sd  |  |}| jrtj|dddd }| |}	|jddd\}
}|	jddd\}}| |}| ||
|\}}| |}| ||\}}|pGi }| jd||||d|}|\}}|||  }|||  }| 	|}| |||\}}| 
|}|||  }| |}| ||\}}| |}|||  }|jtjkr|dd}|jtjkr|dd}||fS )	Nr   r   r*   r)   )ry   rV   rW   r   i  i  r?   )r*  r  r0   r3  r/  r,  rI  r0  r   r-  r.  r1  r2  r'   float16clip)ru   ry   rV   rW   rJ  r   rK  rL  img_mod_paramstxt_mod_paramsimg_mod1img_mod2txt_mod1txt_mod2
img_normedimg_modulated	img_gate1
txt_normedtxt_modulated	txt_gate1attn_outputr  r  img_normed2img_modulated2	img_gate2img_mlp_outputtxt_normed2txt_modulated2	txt_gate2txt_mlp_outputr?   r?   r@   r     sF   



	



z!QwenImageTransformerBlock.forward)r  r  Fr   r   )r   r   r   r   strr5   rY   ro   rI  r0   r   r   dictr   r   r   r   r?   r?   rv   r@   r  A  sL    
1*
	r  c                       st  e Zd ZdZdZdgZddgZdgZedddd	edddd	d
dedddd	iedddd	edddd	de	ddddZ
e												d3dedededB dedededed ed!eeeef d"ed#ed$ef fd%d&Zed'										d4d(ejd)ejd*ejd+ejd,eeeeef  dB d-ee dB d.ejd'eeef dB d/ed0ejeB fd1d2Z  ZS )5QwenImageTransformer2DModela  
    The Transformer model introduced in Qwen.

    Args:
        patch_size (`int`, defaults to `2`):
            Patch size to turn the input data into small patches.
        in_channels (`int`, defaults to `64`):
            The number of channels in the input.
        out_channels (`int`, *optional*, defaults to `None`):
            The number of channels in the output. If not specified, it defaults to `in_channels`.
        num_layers (`int`, defaults to `60`):
            The number of layers of dual stream DiT blocks to use.
        attention_head_dim (`int`, defaults to `128`):
            The number of dimensions to use for each attention head.
        num_attention_heads (`int`, defaults to `24`):
            The number of attention heads to use.
        joint_attention_dim (`int`, defaults to `3584`):
            The number of dimensions to use for the joint attention (embedding/channel dimension of
            `encoder_hidden_states`).
        guidance_embeds (`bool`, defaults to `False`):
            Whether to use guidance embeddings for guidance-distilled variant of the model.
        axes_dims_rope (`tuple[int]`, defaults to `(16, 56, 56)`):
            The dimensions to use for the rotary positional embeddings.
    Tr  	pos_embednormr   r   F)	split_dimexpected_dimssplit_output)ry   rV   rL  r   r   )r   r   )
gather_dimrj  )ztransformer_blocks.0ztransformer_blocks.*rg  proj_out@      <   r         ro  8   rt  
patch_sizerl   out_channelsN
num_layersr  r  joint_attention_dimguidance_embedsaxes_dims_roper  rr   use_layer3d_ropec                    s   t    |p|_  _|stdt|	dd_n
tdt|	dd_tj|d_	t
|dd_t|j_t|j_t fddt|D _tjjd	dd
_tjj|| j dd_d	__d S )Nr   T)r   r   r   )r   rr   r  )r  c                    s   g | ]}t j d qS ))r+   r  r  r  )r  	inner_dim)r   _r  r  ru   r  r?   r@   r   3  s    z8QwenImageTransformer2DModel.__init__.<locals>.<listcomp>Fr  r  )rn   ro   rv  r|  r   r   rg  r   rh   time_text_embedr   txt_normr9   r)  img_intxt_in
ModuleListrangetransformer_blocksr   norm_outrm  gradient_checkpointingr  )ru   ru  rl   rv  rw  r  r  rx  ry  rz  r  rr   r{  rv   r~  r@   ro     s*   



z$QwenImageTransformer2DModel.__init__attention_kwargsry   rV   rW   rx   
img_shapesr   guidancereturn_dictr$   c                 C   s  |durt ddddd | |}||j}| jr5tj||d gdd}tjd	d
 |D |jtj	d}nd}| 
|}| |}t||\}}}|durU||jd }|du r`| |||
n| ||||
}| j|||jd}|dury| ni }|dur|jdd \}}tj||ftj|jd}tj||gdd}||d< t| jD ]C\}}t r| jr| |||d||||\}}n|||d||||d\}}|	durt| jt|	 }t	t|}||	||   }q| jr|jdddd }| ||}| |}|s|fS t|dS )a@	  
        The [`QwenTransformer2DModel`] forward method.

        Args:
            hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
                Input `hidden_states`.
            encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
                Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
            encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
                Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
                Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
                (not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
            timestep ( `torch.LongTensor`):
                Used to indicate denoising step.
            img_shapes (`list[tuple[int, int, int]]`, *optional*):
                Image shapes for RoPE computation.
            txt_seq_lens (`list[int]`, *optional*, **Deprecated**):
                Deprecated parameter. Use `encoder_hidden_states_mask` instead. If provided, the maximum value will be
                used to compute RoPE sequence length.
            guidance (`torch.Tensor`, *optional*):
                Guidance tensor for conditional generation.
            attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            controlnet_block_samples (*optional*):
                ControlNet block samples to add to the transformer blocks.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
                tuple.

        Returns:
            If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
            `tuple` where the first element is the sample tensor.
        Nr   r   zPassing `txt_seq_lens` is deprecated and will be removed in version 0.39.0. Please use `encoder_hidden_states_mask` instead. The mask-based approach is more flexible and supports variable-length sequences.Fr   r   r*   c              	   S   s>   g | ]}d gt |d   dgtdd |dd D   qS )r   r   c                 S   s   g | ]}t |qS r?   r   )r   sr?   r?   r@   r     r   zBQwenImageTransformer2DModel.forward.<locals>.<listcomp>.<listcomp>N)r   sum)r   sampler?   r?   r@   r     s   > z7QwenImageTransformer2DModel.forward.<locals>.<listcomp>rX   rj   )r   r(   r   )r'   r(   r   r   )ry   rV   rW   rJ  r   rK  rL  )r  )r   r  r4   r'   r  r0   r6   tensorr(   r   r  r  rg   r  rg  copyr-   onesrY   r   r  is_grad_enabledr  _gradient_checkpointing_funcr,   npceilr3  r  rm  r   )ru   ry   rV   rW   rx   r  r   r  r  controlnet_block_samplesadditional_t_condr  rL  rb   r}  rJ  r   block_attention_kwargsra   image_seq_len
image_maskjoint_attention_maskindex_blockblockinterval_controloutputr?   r?   r@   r   D  s   2
	






z#QwenImageTransformer2DModel.forward)r   rn  ro  rp  r   rq  rr  Frs  FFF)
NNNNNNNNNT)r   r   r   r   _supports_gradient_checkpointing_no_split_modules _skip_layerwise_casting_patterns_repeated_blocksr   r   _cp_planr   r   rY   r   ro   r
   r0   r   
LongTensorr   re  rd  r   r   r   r   r?   r?   rv   r@   rf    s    
	
3
	rf  )Fr   r   r   )Tr)   )Ar   r.   r   typingr   numpyr  r0   torch.nnr9   torch.nn.functionalr:   r   configuration_utilsr   r   loadersr   r	   utilsr
   r   r   utils.torch_utilsr   _modeling_parallelr   r   	attentionr   r   attention_dispatchr   attention_processorr   cache_utilsr   
embeddingsr   r   modeling_outputsr   modeling_utilsr   normalizationr   r   
get_loggerr   r   r   r   rY   r5   rA   r   rU   rg   Modulerh   r   r   r   r  rf  r?   r?   r?   r@   <module>   s   

9
0
  e 
