o
    Gi                  
   @   s  d dl Z d dlmZmZ d dlZd dlZd dlmZm	Z	m
Z
mZ ddlmZmZ ddlmZmZmZ ddlmZ ddlmZ dd	lmZ dd
lmZmZmZmZmZmZ ddl m!Z! ddl"m#Z# ddl$m%Z% e rud dl&m'  m(Z) dZ*ndZ*e+e,Z-dZ.				d+de/de/de0de0fddZ1	d,dej2dej3dB d e4fd!d"Z5				d-d#e/dB d$e4ej6B dB d%e7e/ dB d&e7e0 dB fd'd(Z8G d)d* d*e#eeZ9dS ).    N)AnyCallable)CLIPTextModelCLIPTokenizerT5EncoderModelT5TokenizerFast   )PipelineImageInputVaeImageProcessor)FluxLoraLoaderMixinFromSingleFileMixinTextualInversionLoaderMixin)AutoencoderKL)FluxTransformer2DModel)FlowMatchEulerDiscreteScheduler)USE_PEFT_BACKENDis_torch_xla_availableloggingreplace_example_docstringscale_lora_layersunscale_lora_layers)randn_tensor   )DiffusionPipeline   )FluxPipelineOutputTFa  
    Examples:
        ```py
        >>> import torch
        >>> from controlnet_aux import CannyDetector
        >>> from diffusers import FluxControlImg2ImgPipeline
        >>> from diffusers.utils import load_image

        >>> pipe = FluxControlImg2ImgPipeline.from_pretrained(
        ...     "black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16
        ... ).to("cuda")

        >>> prompt = "A robot made of exotic candies and chocolates of different kinds. Abstract background"
        >>> image = load_image(
        ...     "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/watercolor-painting.jpg"
        ... )
        >>> control_image = load_image(
        ...     "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png"
        ... )

        >>> processor = CannyDetector()
        >>> control_image = processor(
        ...     control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024
        ... )

        >>> image = pipe(
        ...     prompt=prompt,
        ...     image=image,
        ...     control_image=control_image,
        ...     strength=0.8,
        ...     height=1024,
        ...     width=1024,
        ...     num_inference_steps=50,
        ...     guidance_scale=30.0,
        ... ).images[0]
        >>> image.save("output.png")
        ```
            ?ffffff?base_seq_lenmax_seq_len
base_shift	max_shiftc                 C   s,   || ||  }|||  }| | | }|S N )image_seq_lenr    r!   r"   r#   mbmur%   r%   j/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/pipelines/flux/pipeline_flux_control_img2img.pycalculate_shift[   s   r+   sampleencoder_output	generatorsample_modec                 C   sR   t | dr|dkr| j|S t | dr|dkr| j S t | dr%| jS td)Nlatent_distr,   argmaxlatentsz3Could not access latents of provided encoder_output)hasattrr0   r,   moder2   AttributeError)r-   r.   r/   r%   r%   r*   retrieve_latentsi   s   

r6   num_inference_stepsdevice	timestepssigmasc                 K   s  |dur|durt d|dur>dtt| jj v }|s(t d| j d| jd||d| | j}t	|}||fS |durpdtt| jj v }|sZt d| j d| jd||d	| | j}t	|}||fS | j|fd
|i| | j}||fS )a  
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`list[int]`, *optional*):
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.
        sigmas (`list[float]`, *optional*):
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.

    Returns:
        `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    NzYOnly one of `timesteps` or `sigmas` can be passed. Please choose one to set custom valuesr9   zThe current scheduler class zx's `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler.)r9   r8   r:   zv's `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler.)r:   r8   r8   r%   )

ValueErrorsetinspect	signatureset_timesteps
parameterskeys	__class__r9   len)	schedulerr7   r8   r9   r:   kwargsaccepts_timestepsaccept_sigmasr%   r%   r*   retrieve_timestepsw   s2   rH   c                -       s  e Zd ZdZdZg ZddgZdedede	de
d	ed
edef fddZ					dMdeee B dededejdB dejdB f
ddZ		dNdeee B dedejdB fddZ							dOdeee B deee B dB dejdB dedejdB dejdB dededB fddZdejd ejfd!d"Zd#d$ Z				dPd%d&Zed'd( Z ed)d* Z!ed+d, Z"	dQd-d.Z#	/	/dRd0d1Z$e%d2d3 Z&e%d4d5 Z'e%d6d7 Z(e%d8d9 Z)e* e+e,ddddddd:d;dd<dddddd=d>dddgdfdeee B deee B dB de-d?e-d@edB dAedB dBedCedDee dB dEededB d ejeej B dB dejdB dejdB dejdB dFedB dGe.dHe/ee0f dB dIe1eegdf dB dJee def*dKdLZ2  Z3S )SFluxControlImg2ImgPipelinea  
    The Flux pipeline for image inpainting.

    Reference: https://blackforestlabs.ai/announcing-black-forest-labs/

    Args:
        transformer ([`FluxTransformer2DModel`]):
            Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
        scheduler ([`FlowMatchEulerDiscreteScheduler`]):
            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        text_encoder ([`CLIPTextModel`]):
            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
        text_encoder_2 ([`T5EncoderModel`]):
            [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
            the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
        tokenizer (`CLIPTokenizer`):
            Tokenizer of class
            [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
        tokenizer_2 (`T5TokenizerFast`):
            Second Tokenizer of class
            [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
    z.text_encoder->text_encoder_2->transformer->vaer2   prompt_embedsrD   vaetext_encoder	tokenizertext_encoder_2tokenizer_2transformerc              	      s   t    | j|||||||d t| dd r"dt| jjjd  nd| _t	| jd d| _
t| dr<| jd ur<| jjnd| _d	| _d S )
N)rK   rL   rN   rM   rO   rP   rD   rK   r   r      )vae_scale_factorrM   M      )super__init__register_modulesgetattrrC   rK   configblock_out_channelsrR   r
   image_processorr3   rM   model_max_lengthtokenizer_max_lengthdefault_sample_size)selfrD   rK   rL   rM   rN   rO   rP   rB   r%   r*   rV      s   

(	
z#FluxControlImg2ImgPipeline.__init__Nr      promptnum_images_per_promptmax_sequence_lengthr8   dtypec              	   C   s0  |p| j }|p
| jj}t|tr|gn|}t|}t| tr%| || j}| j|d|ddddd}|j	}| j|dddj	}	|	j
d |j
d krit||	si| j|	d d | jd	 df }
td
| d|
  | j||ddd }| jj}|j||d}|j
\}}}|d	|d	}||| |d}|S )N
max_lengthTFpt)paddingrf   
truncationreturn_lengthreturn_overflowing_tokensreturn_tensorslongestrh   rl   r   zXThe following part of your input was truncated because `max_sequence_length` is set to  	 tokens: output_hidden_statesr   re   r8   )_execution_devicerL   re   
isinstancestrrC   r   maybe_convert_promptrO   	input_idsshapetorchequalbatch_decoder]   loggerwarningrN   torepeatview)r_   rb   rc   rd   r8   re   
batch_sizetext_inputstext_input_idsuntruncated_idsremoved_textrJ   _seq_lenr%   r%   r*   _get_t5_prompt_embeds   sB   

	 "z0FluxControlImg2ImgPipeline._get_t5_prompt_embedsc           
   	   C   s  |p| j }t|tr|gn|}t|}t| tr| || j}| j|d| jddddd}|j}| j|dddj}|j	d |j	d kret
||se| j|d d | jd	 df }td
| j d|  | j||dd}	|	j}	|	j| jj|d}	|	d	|}	|	|| d}	|	S )Nrf   TFrg   )rh   rf   ri   rk   rj   rl   rm   rn   ro   r   z\The following part of your input was truncated because CLIP can only handle sequences up to rp   rq   rs   )rt   ru   rv   rC   r   rw   rM   r]   rx   ry   rz   r{   r|   r}   r~   rL   r   pooler_outputre   r   r   )
r_   rb   rc   r8   r   r   r   r   r   rJ   r%   r%   r*   _get_clip_prompt_embeds"  s>   


 "z2FluxControlImg2ImgPipeline._get_clip_prompt_embedsprompt_2pooled_prompt_embeds
lora_scalec	                 C   s8  |p| j }|dur+t| tr+|| _| jdurtrt| j| | jdur+tr+t| j| t|tr3|gn|}|du rX|p<|}t|trE|gn|}| j	|||d}| j
||||d}| jdurjt| trjtrjt| j| | jdur|t| tr|tr|t| j| | jdur| jjn| jj}	t|jd dj||	d}
|||
fS )a  

        Args:
            prompt (`str` or `list[str]`, *optional*):
                prompt to be encoded
            prompt_2 (`str` or `list[str]`, *optional*):
                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
                used in all text-encoders
            device: (`torch.device`):
                torch device
            num_images_per_prompt (`int`):
                number of images that should be generated per prompt
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
                If not provided, pooled text embeddings will be generated from `prompt` input argument.
            lora_scale (`float`, *optional*):
                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
        N)rb   r8   rc   )rb   rc   rd   r8   r   r   r8   re   )rt   ru   r   _lora_scalerL   r   r   rN   rv   r   r   r   re   rP   rz   zerosry   r   )r_   rb   r   r8   rc   rJ   r   rd   r   re   text_idsr%   r%   r*   encode_promptO  s>   
 


z(FluxControlImg2ImgPipeline.encode_promptimager.   c                    sj   t  tr fddtjd D }tj|dd}n
tj d}|jj	j
 jj	j }|S )Nc              	      s0   g | ]}t j||d    | dqS )r   r.   )r6   rK   encode).0ir.   r   r_   r%   r*   
<listcomp>  s    "z@FluxControlImg2ImgPipeline._encode_vae_image.<locals>.<listcomp>r   dimr   )ru   listrangery   rz   catr6   rK   r   rY   shift_factorscaling_factor)r_   r   r.   image_latentsr%   r   r*   _encode_vae_image  s   
z,FluxControlImg2ImgPipeline._encode_vae_imagec                 C   sd   t || |}tt|| d}| jj|| jj d  }t| jdr,| j|| jj  ||| fS )Nr   set_begin_index)minintmaxrD   r9   orderr3   r   )r_   r7   strengthr8   init_timestept_startr9   r%   r%   r*   get_timesteps  s   z(FluxControlImg2ImgPipeline.get_timestepsc
           
   	      s  |dk s|dkrt d| | jd  dks!| jd  dkr3td jd  d| d| d |d urTt fd	d
|D sTt d j d fdd|D  |d urg|d urgt d| d| d|d urz|d urzt d| d| d|d u r|d u rt d|d urt|tst|tst dt	| |d urt|tst|tst dt	| |d ur|d u rt d|	d ur|	dkrt d|	 d S d S )Nr   r   z2The value of strength should in [0.0, 1.0] but is r   z-`height` and `width` have to be divisible by z	 but are z and z(. Dimensions will be resized accordinglyc                 3   s    | ]}| j v V  qd S r$   _callback_tensor_inputsr   kr_   r%   r*   	<genexpr>  s    

z:FluxControlImg2ImgPipeline.check_inputs.<locals>.<genexpr>z2`callback_on_step_end_tensor_inputs` has to be in z, but found c                    s   g | ]	}| j vr|qS r%   r   r   r   r%   r*   r     s    z;FluxControlImg2ImgPipeline.check_inputs.<locals>.<listcomp>zCannot forward both `prompt`: z and `prompt_embeds`: z2. Please make sure to only forward one of the two.z Cannot forward both `prompt_2`: zeProvide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.z2`prompt` has to be of type `str` or `list` but is z4`prompt_2` has to be of type `str` or `list` but is zIf `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.ra   z8`max_sequence_length` cannot be greater than 512 but is )
r;   rR   r}   r~   allr   ru   rv   r   type)
r_   rb   r   r   heightwidthrJ   r   "callback_on_step_end_tensor_inputsrd   r%   r   r*   check_inputs  sF   $z'FluxControlImg2ImgPipeline.check_inputsc           	      C   s|   t ||d}|d t |d d d f  |d< |d t |d d d f  |d< |j\}}}||| |}|j||dS )Nr   ).r   ).r   r   )rz   r   arangery   reshaper   )	r   r   r   r8   re   latent_image_idslatent_image_id_heightlatent_image_id_widthlatent_image_id_channelsr%   r%   r*   _prepare_latent_image_ids  s   ""z4FluxControlImg2ImgPipeline._prepare_latent_image_idsc                 C   sR   |  |||d d|d d} | dddddd} | ||d |d  |d } | S )Nr   r      r   r      )r   permuter   )r2   r   num_channels_latentsr   r   r%   r%   r*   _pack_latents  s   z(FluxControlImg2ImgPipeline._pack_latentsc                 C   s   | j \}}}dt||d   }dt||d   }| ||d |d |d dd} | dddddd} | ||d ||} | S )Nr   r   r   r   r   r   )ry   r   r   r   r   )r2   r   r   rR   r   num_patcheschannelsr%   r%   r*   _unpack_latents  s    z*FluxControlImg2ImgPipeline._unpack_latentsc                 C   s  t |	trt|	|krtdt|	 d| ddt|| jd   }dt|| jd   }||||f}| ||d |d ||}|
d urN|
j||d|fS |j||d}| j||	d}||j	d kr~||j	d  dkr~||j	d  }t
j|g| dd}n&||j	d kr||j	d  dkrtd	|j	d  d
| dt
j|gdd}t||	||d}| j|||}
| |
||||}
|
|fS )Nz/You have passed a list of generators of length z+, but requested an effective batch size of z@. Make sure the batch size matches the length of the generators.r   r   )r   r.   r   r   z'Cannot duplicate `image` of batch size z to z text prompts.)r.   r8   re   )ru   r   rC   r;   r   rR   r   r   r   ry   rz   r   r   rD   scale_noiser   )r_   r   timestepr   r   r   r   re   r8   r.   r2   ry   r   r   additional_image_per_promptnoiser%   r%   r*   prepare_latents  s4     z*FluxControlImg2ImgPipeline.prepare_latentsFc
                 C   st   t |tjrn	| jj|||d}|jd }
|
dkr|}n|}|j|dd}|j||d}|r8|	s8t|gd }|S )Nr   r   r   r   r   r   r   )	ru   rz   Tensorr[   
preprocessry   repeat_interleaver   r   )r_   r   r   r   r   rc   r8   re   do_classifier_free_guidance
guess_modeimage_batch_size	repeat_byr%   r%   r*   prepare_imageH  s   
z(FluxControlImg2ImgPipeline.prepare_imagec                 C      | j S r$   )_guidance_scaler   r%   r%   r*   guidance_scalej     z)FluxControlImg2ImgPipeline.guidance_scalec                 C   r   r$   )_joint_attention_kwargsr   r%   r%   r*   joint_attention_kwargsn  r   z1FluxControlImg2ImgPipeline.joint_attention_kwargsc                 C   r   r$   )_num_timestepsr   r%   r%   r*   num_timestepsr  r   z(FluxControlImg2ImgPipeline.num_timestepsc                 C   r   r$   )
_interruptr   r%   r%   r*   	interruptv  r   z$FluxControlImg2ImgPipeline.interruptg333333?   g      @pilTcontrol_imager   r   r   r7   r:   r   output_typereturn_dictr   callback_on_step_endr   c           0      C   s  |p| j | j }|p| j | j }| j|||||||||d	 |
| _|| _d| _| jj|||d}|jt	j
d}|durCt|trCd}n|durQt|trQt|}n|jd }| j}| jdure| jddnd}| j||||||||d	\}}}|	du rtd
d| |n|	}	t|| j d t|| j d  }t|| jjdd| jjdd| jjdd| jjdd}trd}n|}t| j|||	|d\}}| |||\}}|dk rtd| d| d|dd || }| jjj d } | j!||||| ||| j"j#d}|j$dkr6| j"%|j&j'|d}|| j"jj( | j"jj) }|jdd \}!}"| *||| | |!|"}| +|||| | |||j#|||
\}}#t,t||| jj-  d}$t|| _.| jjj/rut	j0dg|
|t	j
d}%|%1|jd }%nd}%| j2|d}&t3|D ]\}'}(| j4rqt	j5||gdd})|(1|jd |j#}*| j|)|*d  |%||||#| jdd!	d }+|j#},| jj6|+|(|dd"d }|j#|,krt	j7j89 r||,}|duri }-|D ]
}.t: |. |-|.< q|| |'|(|-}/|/;d#|}|/;d$|}|'t|d ks|'d |$kr|'d | jj- dkr|&<  tr"t=>  qW d   n	1 s/w   Y  |d%kr<|}n'| ?|||| j}|| j"jj) | j"jj( }| j"j@|dd"d }| jjA||d&}| B  |sm|fS tC|d'S )(ag  
        Function invoked when calling the pipeline for generation.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
                instead.
            prompt_2 (`str` or `list[str]`, *optional*):
                The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
                will be used instead
            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, or `list[np.ndarray]`):
                `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
                numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
                or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
                list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
                latents as `image`, but if passing latents directly it is not encoded again.
            control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,:
                    `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`):
                The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
                specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
                as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
                width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
                images must be passed as a list such that each element of the list can be correctly batched for input
                to a single ControlNet.
            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
                The height in pixels of the generated image. This is set to 1024 by default for the best results.
            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
                The width in pixels of the generated image. This is set to 1024 by default for the best results.
            strength (`float`, *optional*, defaults to 1.0):
                Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
                starting point and more noise is added the higher the `strength`. The number of denoising steps depends
                on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
                process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
                essentially ignores `image`.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            sigmas (`list[float]`, *optional*):
                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
                will be used.
            guidance_scale (`float`, *optional*, defaults to 7.0):
                Guidance scale as defined in [Classifier-Free Diffusion
                Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
                of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
                `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
                the text `prompt`, usually at the expense of lower image quality.
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
            latents (`torch.FloatTensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will be generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
                If not provided, pooled text embeddings will be generated from `prompt` input argument.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
            joint_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            callback_on_step_end (`Callable`, *optional*):
                A function that calls at the end of each denoising steps during the inference. The function is called
                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
                `callback_on_step_end_tensor_inputs`.
            callback_on_step_end_tensor_inputs (`list`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
                `._callback_tensor_inputs` attribute of your pipeline class.
            max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.

        Examples:

        Returns:
            [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
            is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
            images.
        )rJ   r   r   rd   Fr   )re   Nr   r   scale)rb   r   rJ   r   r8   rc   rd   r   g      ?r   base_image_seq_lenr   max_image_seq_lenr   r"   r   r#   r   cpu)r:   r)   z?After adjusting the num_inference_steps by strength parameter: z!, the number of pipelinesteps is z4 which is < 1 and not appropriate for this pipeline.rQ   )r   r   r   r   rc   r8   re   r   r   r   )totalr   i  )	hidden_statesr   guidancepooled_projectionsencoder_hidden_statestxt_idsimg_idsr   r   )r   r2   rJ   latent)r   )images)Dr^   rR   r   r   r   r   r[   r   r   rz   float32ru   rv   r   rC   ry   rt   r   getr   nplinspacer   r+   rD   rY   XLA_AVAILABLErH   r   r;   r   rP   in_channelsr   rK   re   ndimr   r0   r,   r   r   r   r   r   r   r   guidance_embedsfullexpandprogress_bar	enumerater   r   stepbackendsmpsis_availablelocalspopupdatexm	mark_stepr   decodepostprocessmaybe_free_model_hooksr   )0r_   rb   r   r   r   r   r   r   r7   r:   r   rc   r.   r2   rJ   r   r   r   r   r   r   rd   
init_imager   r8   r   r   r&   r)   timestep_devicer9   latent_timestepr   height_control_imagewidth_control_imager   num_warmup_stepsr   r   r   tlatent_model_inputr   
noise_predlatents_dtypecallback_kwargsr   callback_outputsr%   r%   r*   __call__z  s4  t

$





6
/
z#FluxControlImg2ImgPipeline.__call__)Nr   ra   NN)r   N)NNr   NNra   NNNNNr$   )FF)4__name__
__module____qualname____doc__model_cpu_offload_seq_optional_componentsr   r   r   r   r   r   r   r   rV   rv   r   r   rz   r8   re   r   r   FloatTensorfloatr   r   	Generatorr   r   r   staticmethodr   r   r   r   r   propertyr   r   r   r   no_gradr   EXAMPLE_DOC_STRINGr	   booldictr   r   r  __classcell__r%   r%   r`   r*   rI      sB   !

5

0
	
P
6



9
"




	

rI   )r   r   r   r   )Nr,   r  ):r=   typingr   r   numpyr   rz   transformersr   r   r   r   r[   r	   r
   loadersr   r   r   models.autoencodersr   models.transformersr   
schedulersr   utilsr   r   r   r   r   r   utils.torch_utilsr   pipeline_utilsr   pipeline_outputr   torch_xla.core.xla_modelcore	xla_modelr  r   
get_loggerr  r}   r%  r   r   r+   r   r!  rv   r6   r8   r   rH   rI   r%   r%   r%   r*   <module>   sn    
+




;