o
    GiG                     @   s  d dl Z d dlmZmZ d dlZd dlZd dlZd dlm	Z	m
Z
 ddlmZ ddlmZmZ ddlmZ ddlmZmZmZ dd	lmZ d
dlmZ ddlmZ ddlmZ ddlmZm Z m!Z! e rod dl"m#  m$Z% dZ&ndZ&e'e(Z)dZ*dZ+edfde,e- de-de,ej.j.e,e,ej.j.  f dB fddZ/de,e,ej.j.  e,ej.j. B dede0de,e,ej.j.  fddZ1de0de0de2fd d!Z3				d0d"e0dB d#e-ej4B dB d$e,e0 dB d%e,e2 dB fd&d'Z5	(d1d)ej6d*ej7dB d+e-fd,d-Z8G d.d/ d/eeZ9dS )2    N)AnyCallable)AutoProcessor Mistral3ForConditionalGeneration   )Flux2LoraLoaderMixin)AutoencoderKLFlux2Flux2Transformer2DModel)FlowMatchEulerDiscreteScheduler)is_torch_xla_availableloggingreplace_example_docstring)randn_tensor   )DiffusionPipeline   )Flux2ImageProcessor)Flux2PipelineOutput)SYSTEM_MESSAGESYSTEM_MESSAGE_UPSAMPLING_I2ISYSTEM_MESSAGE_UPSAMPLING_T2ITFaU  
    Examples:
        ```py
        >>> import torch
        >>> from diffusers import Flux2Pipeline

        >>> pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16)
        >>> pipe.to("cuda")
        >>> prompt = "A cat holding a sign that says hello world"
        >>> # Depending on the variant being used, the pipeline call will slightly vary.
        >>> # Refer to the pipeline documentation for more details.
        >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0]
        >>> image.save("flux.png")
        ```
i  	 promptssystem_messageimagesc                    s   dd | D }|du st |dkr fdd|D S t |t | ks&J d fdd|D }tt||D ]%\}\}}|durM|dd	d |D d
 |dd|| dgd
 q6|S )a  
    Format a batch of text prompts into the conversation format expected by apply_chat_template. Optionally, add images
    to the input.

    Args:
        prompts: List of text prompts
        system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE)
        images (optional): List of images to add to the input.

    Returns:
        List of conversations, where each conversation is a list of message dicts
    c                 S   s   g | ]}| d dqS )z[IMG] )replace.0prompt r   \/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/pipelines/flux2/pipeline_flux2.py
<listcomp>U       z format_input.<locals>.<listcomp>Nr   c                    s0   g | ]}d d dgddd|dgdgqS )systemtexttyper$   rolecontentuserr   r   r   r   r    r!   X   s    
z-Number of images must match number of promptsc                    s    g | ]}d d dgdgqS )r#   r$   r%   r'   r   )r   _r+   r   r    r!   d   s    
r*   c                 S   s   g | ]}d |dqS )image)r&   r-   r   )r   	image_objr   r   r    r!   t       r'   r$   r%   )len	enumeratezipappend)r   r   r   cleaned_txtmessagesielr   r+   r    format_inputA   s,   


r8   image_processorupsampling_max_image_sizereturnc                    sR   | sg S t | d tjjrdd | D }  fdd| D }  fdd| D } | S )Nr   c                 S   s   g | ]}|gqS r   r   )r   imr   r   r    r!      s    z0_validate_and_process_images.<locals>.<listcomp>c                    s(   g | ]}t |d kr |gn|qS )r   )r0   concatenate_imagesr   img_i)r9   r   r    r!      s   ( c                    s    g | ]} fd d|D qS )c                    s   g | ]}  |qS r   )_resize_if_exceeds_arear>   r9   r:   r   r    r!      r"   z;_validate_and_process_images.<locals>.<listcomp>.<listcomp>r   r>   rA   r   r    r!      s    )
isinstancePILImage)r   r9   r:   r   rA   r    _validate_and_process_images   s   rE   image_seq_len	num_stepsc                 C   sp   d\}}d\}}| dkr||  | }t |S ||  | }||  | }|| d }	|d|	  }
|	| |
 }t |S )N)gT	?gŒ_?)g w:/&?gDw:?i  g     g@g      i@)float)rF   rG   a1b1a2b2mum_200m_10abr   r   r    compute_empirical_mu   s   rR   num_inference_stepsdevice	timestepssigmasc                 K   s  |dur|durt d|dur>dtt| jj v }|s(t d| j d| jd||d| | j}t	|}||fS |durpdtt| jj v }|sZt d| j d| jd||d	| | j}t	|}||fS | j|fd
|i| | j}||fS )a  
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`list[int]`, *optional*):
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.
        sigmas (`list[float]`, *optional*):
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.

    Returns:
        `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    NzYOnly one of `timesteps` or `sigmas` can be passed. Please choose one to set custom valuesrU   zThe current scheduler class zx's `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler.)rU   rT   rV   zv's `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler.)rV   rT   rT   r   )

ValueErrorsetinspect	signatureset_timesteps
parameterskeys	__class__rU   r0   )	schedulerrS   rT   rU   rV   kwargsaccepts_timestepsaccept_sigmasr   r   r    retrieve_timesteps   s2   rc   sampleencoder_output	generatorsample_modec                 C   sR   t | dr|dkr| j|S t | dr|dkr| j S t | dr%| jS td)Nlatent_distrd   argmaxlatentsz3Could not access latents of provided encoder_output)hasattrrh   rd   moderj   AttributeError)re   rf   rg   r   r   r    retrieve_latents   s   

rn   c                )       s`  e Zd ZdZdZddgZdededede	d	e
f
 fd
dZedddedfdede	deee B dejdB dejdB dededee fddZe	d[dejdejdB fddZedejfddZe	d\deej defd d!Zed"d# Zed$d% Zed&d' Zedejd(ejd)eej fd*d+Z		,	d]deee B d-eejjeeejj  f d.edejd)ee f
d/d0Z 		1			d^deee B dejdB d2edejdB ded3e!e fd4d5Z"d6ejd7ej#fd8d9Z$	d[d7ej#dejdB fd:d;Z%d-eej d7ej#fd<d=Z&		d_d>d?Z'e(d@dA Z)e(dBdC Z*e(dDdE Z+e(dFdG Z,e(dHdI Z-e. e/e0dddddJddKd1ddddLdMdddgdddfd6eejjejjf dB deee B dNedB dOedB dPedQee dB dRedB d2ed7ej#eej# B dB dejdB dejdB dSedB dTe1dUe2ee3f dB dVe4eegdf dB dWee ded3e!e dXef&dYdZZ5  Z6S )`Flux2Pipelinea  
    The Flux2 pipeline for text-to-image generation.

    Reference: [https://bfl.ai/blog/flux-2](https://bfl.ai/blog/flux-2)

    Args:
        transformer ([`Flux2Transformer2DModel`]):
            Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
        scheduler ([`FlowMatchEulerDiscreteScheduler`]):
            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
        vae ([`AutoencoderKLFlux2`]):
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        text_encoder ([`Mistral3ForConditionalGeneration`]):
            [Mistral3ForConditionalGeneration](https://huggingface.co/docs/transformers/en/model_doc/mistral3#transformers.Mistral3ForConditionalGeneration)
        tokenizer (`AutoProcessor`):
            Tokenizer of class
            [PixtralProcessor](https://huggingface.co/docs/transformers/en/model_doc/pixtral#transformers.PixtralProcessor).
    ztext_encoder->transformer->vaerj   prompt_embedsr_   vaetext_encoder	tokenizertransformerc                    s   t    | j|||||d t| dd r dt| jjjd  nd| _t	| jd d| _
d| _d| _t| _t| _t| _t| _d S )	N)rq   rr   rs   r_   rt   rq   r   r      )vae_scale_factor      )super__init__register_modulesgetattrr0   rq   configblock_out_channelsrv   r   r9   tokenizer_max_lengthdefault_sample_sizer   r   r   system_message_upsampling_t2ir   system_message_upsampling_i2iUPSAMPLING_MAX_IMAGE_SIZEr:   )selfr_   rq   rr   rs   rt   r^   r   r    rz     s    
(
zFlux2Pipeline.__init__Nrw   
         r   dtyperT   max_sequence_lengthr   hidden_states_layersc              
      s   |d u r| j n|}|d u r| jn|}t|tr|gn|}t||d}|j|dddddd|d}	|	d |}
|	d |}| |
|ddd	 tj fd
d|D dd}|j||d}|j	\}}}}|
dddd|||| }|S )N)r   r   FTpt
max_lengthadd_generation_prompttokenizereturn_dictreturn_tensorspadding
truncationr   	input_idsattention_mask)r   r   output_hidden_states	use_cachec                    s   g | ]} j | qS r   )hidden_statesr   koutputr   r    r!   [  r/   zDFlux2Pipeline._get_mistral_3_small_prompt_embeds.<locals>.<listcomp>r   dim)r   rT   r   r   r   )r   rT   rB   strr8   apply_chat_templatetotorchstackshapepermutereshape)rr   rs   r   r   rT   r   r   r   messages_batchinputsr   r   out
batch_sizenum_channelsseq_len
hidden_dimrp   r   r   r    "_get_mistral_3_small_prompt_embeds/  s6   z0Flux2Pipeline._get_mistral_3_small_prompt_embedsxt_coordc                 C   sz   | j \}}}g }t|D ]+}|d u rtdn|| }td}td}	t|}
t|||	|
}|| qt|S )Nr   )r   ranger   arangecartesian_prodr3   r   )r   r   BLr,   out_idsr6   thwlcoordsr   r   r    _prepare_text_idsc  s   



zFlux2Pipeline._prepare_text_idsc           
      C   s^   | j \}}}}td}t|}t|}td}t||||}	|	d|dd}	|	S )a  
        Generates 4D position coordinates (T, H, W, L) for latent tensors.

        Args:
            latents (torch.Tensor):
                Latent tensor of shape (B, C, H, W)

        Returns:
            torch.Tensor:
                Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
                H=[0..H-1], W=[0..W-1], L=0
        r   r   )r   r   r   r   	unsqueezeexpand)
rj   r   r,   heightwidthr   r   r   r   
latent_idsr   r   r    _prepare_latent_idsv  s   



z!Flux2Pipeline._prepare_latent_idsr   image_latentsscalec           
   	      s   t | tstdt|  d fddtdt| D }dd |D }g }t| |D ]%\}}|d}|j	\}}}t
|t|t|td}	||	 q,tj|dd}|d}|S )	a  
        Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.

        This function creates a unique coordinate for every pixel/patch across all input latent with different
        dimensions.

        Args:
            image_latents (list[torch.Tensor]):
                A list of image latent feature tensors, typically of shape (C, H, W).
            scale (int, optional):
                A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
                latent is: 'scale + scale * i'. Defaults to 10.

        Returns:
            torch.Tensor:
                The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
                input latents.

        Coordinate Components (Dimension 4):
            - T (Time): The unique index indicating which latent image the coordinate belongs to.
            - H (Height): The row index within that latent image.
            - W (Width): The column index within that latent image.
            - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
        z+Expected `image_latents` to be a list, got .c                    s   g | ]}  |  qS r   r   r   r   r   r   r    r!     r"   z4Flux2Pipeline._prepare_image_ids.<locals>.<listcomp>r   c                 S   s   g | ]}| d qS )r   )viewr   r   r   r    r!     r/   r   r   )rB   listrW   r&   r   r   r0   r2   squeezer   r   r3   catr   )
r   r   t_coordsimage_latent_idsr   r   r,   r   r   x_idsr   r   r    _prepare_image_ids  s   

"
z Flux2Pipeline._prepare_image_idsc                 C   s^   | j \}}}}| |||d d|d d} | dddddd} | ||d |d |d } | S )Nr   r   r   r         )r   r   r   r   rj   r   num_channels_latentsr   r   r   r   r    _patchify_latents  s
   zFlux2Pipeline._patchify_latentsc                 C   sZ   | j \}}}}| ||d dd||} | dddddd} | ||d |d |d } | S )Nr   r   r   r   r   r   r   r   r   r   r   r   r    _unpatchify_latents  s
   z!Flux2Pipeline._unpatchify_latentsc                 C   s.   | j \}}}}| |||| ddd} | S )zw
        pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
        r   r   r   r   )rj   r   r   r   r   r   r   r    _pack_latents  s   zFlux2Pipeline._pack_latentsr   r;   c                 C   s   g }t | |D ]b\}}|j\}}|dddf tj}|dddf tj}t|d }	t|d }
||
 | }tj|	|
 |f|j|jd}|	d|
dd|| ||	|
|ddd}|| qtj|ddS )zA
        using position ids to scatter tokens into place
        Nr   r   rT   r   r   r   r   )r2   r   r   r   int64maxzerosrT   r   scatter_r   r   r   r   r3   r   )r   r   x_listdataposr,   chh_idsw_idsr   r   flat_idsr   r   r   r    _unpack_latents_with_ids  s   
z&Flux2Pipeline._unpack_latents_with_ids333333?r   temperaturec              
   C   s.  t |tr|gn|}|d u r| jjn|}|d u s$t|dks$|d d u r't}nt}|r3t|| j| j	}t
|||d}| jj|dddddddd}|d ||d< |d	 ||d	< d
|v rj|d
 || jj|d
< | jjdi |dd|dd}|d jd }	|d d |	d f }
| jjj|
ddd}|S )Nr   )r   r   r   Tr   r   i   r   r   r   pixel_valuesrw   )max_new_tokens	do_sampler   r   r   )skip_special_tokensclean_up_tokenization_spacesr   )rB   r   rr   rT   r0   r   r   rE   r9   r:   r8   rs   r   r   r   generater   batch_decode)r   r   r   r   rT   r   r   r   generated_idsinput_lengthgenerated_tokensupsampled_promptr   r   r    upsample_prompt  sH    


zFlux2Pipeline.upsample_promptr   num_images_per_prompttext_encoder_out_layersc              	   C   s   |p| j }|d u rd}t|tr|gn|}|d u r(| j| j| j|||| j|d}|j\}}}	|d|d}|	|| |d}| 
|}
|
|}
||
fS )Nr   )rr   rs   r   rT   r   r   r   r   r   )_execution_devicerB   r   r   rr   rs   r   r   repeatr   r   r   )r   r   rT   r   rp   r   r   r   r   r,   text_idsr   r   r    encode_prompt;  s(   
	


zFlux2Pipeline.encode_promptr-   rf   c                 C   s   |j dkrtd|j  dt| j||dd}| |}| jjjdddd	|j
|j}t| jjjdddd| jjj }|| | }|S )Nr   zExpected image dims 4, got r   ri   )rf   rg   r   r   )ndimrW   rn   rq   encoder   bnrunning_meanr   r   rT   r   r   sqrtrunning_varr}   batch_norm_eps)r   r-   rf   r   latents_bn_meanlatents_bn_stdr   r   r    _encode_vae_image^  s   

"&zFlux2Pipeline._encode_vae_imagec	                 C   s   dt || jd   }dt || jd   }||d |d |d f}	t|tr:t||kr:tdt| d| d|d u rGt|	|||d}n|j||d}| |}
|
|}
| 	|}||
fS )Nr   r   z/You have passed a list of generators of length z+, but requested an effective batch size of z@. Make sure the batch size matches the length of the generators.)rf   rT   r   r   )
intrv   rB   r   r0   rW   r   r   r   r   )r   r   num_latents_channelsr   r   r   rT   rf   rj   r   r   r   r   r    prepare_latentsk  s    


zFlux2Pipeline.prepare_latentsc                 C   s   g }|D ]}|j ||d}| j||d}|| q| |}	g }
|D ]}| |}|d}|
| q#tj|
dd}|d}|	|dd}|		|dd}	|	 |}	||	fS )Nr   )r-   rf   r   r   r   )
r   r  r3   r   r   r   r   r   r   r   )r   r   r   rf   rT   r   r   r-   imagge_latentr   packed_latentslatentpackedr   r   r    prepare_image_latents  s"   




z#Flux2Pipeline.prepare_image_latentsc              	      s  |d ur| j d  dks|d ur,| j d  dkr,td j d  d| d| d |d urMt fdd|D sMtd	 j d
 fdd|D  |d ur`|d ur`td| d| d|d u rl|d u rltd|d urt|tst|tstdt	| d S d S d S )Nr   r   z-`height` and `width` have to be divisible by z	 but are z and z(. Dimensions will be resized accordinglyc                 3   s    | ]}| j v V  qd S N_callback_tensor_inputsr   r   r   r    	<genexpr>  s    

z-Flux2Pipeline.check_inputs.<locals>.<genexpr>z2`callback_on_step_end_tensor_inputs` has to be in z, but found c                    s   g | ]	}| j vr|qS r   r  r   r  r   r    r!     s    z.Flux2Pipeline.check_inputs.<locals>.<listcomp>zCannot forward both `prompt`: z and `prompt_embeds`: z2. Please make sure to only forward one of the two.zeProvide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.z2`prompt` has to be of type `str` or `list` but is )
rv   loggerwarningallrW   r  rB   r   r   r&   )r   r   r   r   rp   "callback_on_step_end_tensor_inputsr   r  r    check_inputs  s0   	zFlux2Pipeline.check_inputsc                 C      | j S r  )_guidance_scaler  r   r   r    guidance_scale     zFlux2Pipeline.guidance_scalec                 C   r  r  )_attention_kwargsr  r   r   r    attention_kwargs  r  zFlux2Pipeline.attention_kwargsc                 C   r  r  )_num_timestepsr  r   r   r    num_timesteps  r  zFlux2Pipeline.num_timestepsc                 C   r  r  )_current_timestepr  r   r   r    current_timestep  r  zFlux2Pipeline.current_timestepc                 C   r  r  )
_interruptr  r   r   r    	interrupt  r  zFlux2Pipeline.interrupt2   g      @pilTr   r   rS   rV   r  output_typer   r  callback_on_step_endr  caption_upsample_temperaturec           2      C   sJ  | j |||||d || _|| _d| _d| _|dur"t|tr"d}n|dur0t|tr0t|}n|j	d }| j
}|rC| j||||d}| j||||||d\}}|dur\t|ts\|g}d}|dur|D ]}| j| qdg }|D ]A}|j\}}|| dkr| j|d}|j\}}| jd	 }|| | }|| | }| jj|||d
d}|| |p|}|p|}qq|p| j| j }|p| j| j }| jjjd }| j|| ||||j||	|
d\}
}d}d}|dur| j||| |	|| jjd\}}|du rtdd| |n|}t| jjdr| jjj rd}|
j	d } t!| |d}!t"| j||||!d\}"}t#t|"|| jj$  d}#t|"| _%t&j'dg||t&j(d}$|$)|
j	d }$| j*d | j+|d}%t,|"D ]\}&}'| j-rkqa|'| _|')|
j	d .|
j}(|
.| jj})|}*|durt&j/|
|gdd.| jj})t&j/||gdd}*| j|)|(d |$|||*| j0ddd }+|+ddd|
df }+|
j},| jj1|+|'|
ddd }
|
j|,krt&j2j34 r|
.|,}
|duri }-|D ]
}.t5 |. |-|.< q|| |&|'|-}/|/6d|
}
|/6d|}|&t|"d ks!|&d |#kr%|&d | jj$ dkr%|%7  t8r,t9:  qaW d   n	1 s9w   Y  d| _|dkrI|
}nM| ;|
|}
| jj<j=>dddd.|
j?|
j}0t&@| jj<jA>dddd| jjjB .|
j?|
j}1|
|1 |0 }
| C|
}
| jjD|
ddd }| jjE||d}| F  |s|fS tG|dS )a  
        Function invoked when calling the pipeline for generation.

        Args:
            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`):
                `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
                numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
                or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
                list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
                latents as `image`, but if passing latents directly it is not encoded again.
            prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
                instead.
            guidance_scale (`float`, *optional*, defaults to 1.0):
                Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
                a model to generate images more aligned with `prompt` at the expense of lower image quality.

                Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
                the [paper](https://huggingface.co/papers/2210.03142) to learn more.
            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
                The height in pixels of the generated image. This is set to 1024 by default for the best results.
            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
                The width in pixels of the generated image. This is set to 1024 by default for the best results.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            sigmas (`list[float]`, *optional*):
                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
                will be used.
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
            latents (`torch.Tensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will be generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
            attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            callback_on_step_end (`Callable`, *optional*):
                A function that calls at the end of each denoising steps during the inference. The function is called
                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
                `callback_on_step_end_tensor_inputs`.
            callback_on_step_end_tensor_inputs (`List`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
                `._callback_tensor_inputs` attribute of your pipeline class.
            max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
            text_encoder_out_layers (`tuple[int]`):
                Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
            caption_upsample_temperature (`float`):
                When specified, we will try to perform caption upsampling for potentially improved outputs. We
                recommend setting it to 0.15 if caption upsampling is to be performed.

        Examples:

        Returns:
            [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if
            `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
            generated images.
        )r   r   r   rp   r  NFr   r   )r   r   rT   )r   rp   rT   r   r   r   i   r   crop)r   r   resize_moder   )r   r  r   r   r   rT   rf   rj   )r   r   rf   rT   r   g      ?use_flow_sigmas)rF   rG   )rV   rM   r   )totalr   i  )r   timestepguidanceencoder_hidden_statestxt_idsimg_idsjoint_attention_kwargsr   )r   rj   rp   r  r   )r&  )r   )Hr  r  r  r   r"  rB   r   r   r0   r   r   r   r   r9   check_image_inputsize_resize_to_target_arearv   
preprocessr3   r   rt   r}   in_channelsr  r   r  rq   nplinspacerk   r_   r+  rR   rc   r   orderr  r   fullfloat32r   set_begin_indexprogress_barr1   r#  r   r   r  stepbackendsmpsis_availablelocalspopupdateXLA_AVAILABLExm	mark_stepr   r   r   r   rT   r   r  r  r   decodepostprocessmaybe_free_model_hooksr   )2r   r-   r   r   r   rS   rV   r  r   rf   rj   rp   r&  r   r  r'  r  r   r   r(  r   rT   r   condition_imagesimgimage_widthimage_heightmultiple_ofr   r   r   r   rF   rM   rU   num_warmup_stepsr.  r>  r6   r   r-  latent_model_inputlatent_image_ids
noise_predlatents_dtypecallback_kwargsr   callback_outputsr  r  r   r   r    __call__  s&  d










 	



	

66
"&

zFlux2Pipeline.__call__r  )r   )Nr   N)Nr   Nrw   r   )NN)7__name__
__module____qualname____doc__model_cpu_offload_seqr  r
   r   r   r   r	   rz   staticmethodr   r   r   r   r   rT   r  r   Tensorr   r   r   r   r   r   r   rC   rD   rH   r   tupler   	Generatorr  r  r  r  propertyr  r  r  r!  r#  no_gradr   EXAMPLE_DOC_STRINGbooldictr   r   rX  __classcell__r   r   r   r    ro      sj   
31



"

A

#	
!
'
%






	
ro   )NNNN)Nrd   ):rY   typingr   r   numpyr8  rC   r   transformersr   r   loadersr   modelsr   r	   
schedulersr
   utilsr   r   r   utils.torch_utilsr   pipeline_utilsr   r9   r   pipeline_outputr   system_messagesr   r   r   torch_xla.core.xla_modelcore	xla_modelrG  rF  
get_loggerrY  r  rd  r   r   r   rD   r8   r  rE   rH   rR   rT   rc   r_  ra  rn   ro   r   r   r   r    <module>   s~   

C



=
