o
    GiPA                 
   @   s  d dl Z d dlmZmZ d dlZd dlZd dlZd dl	m
  mZ d dlmZmZmZmZmZ d dlmZ ddlmZmZ ddlmZmZ ddlmZmZmZmZ dd	l m!Z!m"Z"m#Z#m$Z$m%Z% dd
l&m'Z' ddl(m)Z) ddl*m+Z+m,Z,m-Z-m.Z.m/Z/m0Z0 ddl1m2Z2m3Z3m4Z4 ddl5m6Z6m7Z7 ddl8m9Z9 e rddl:m;Z; ddl*m<Z< e< rd dl=m>  m?Z@ dZAndZAe-BeCZDdZE				ddeFdB deGejHB dB deIeF dB deIeJ dB fddZKG dd de6e7eeeeZLdS )    N)AnyCallable)CLIPImageProcessorCLIPTextModelCLIPTextModelWithProjectionCLIPTokenizerCLIPVisionModelWithProjection) is_invisible_watermark_available   )MultiPipelineCallbacksPipelineCallback)PipelineImageInputVaeImageProcessor)FromSingleFileMixinIPAdapterMixin StableDiffusionXLLoraLoaderMixinTextualInversionLoaderMixin)AutoencoderKLControlNetModelImageProjectionMultiControlNetModelUNet2DConditionModel)adjust_lora_scale_text_encoder)KarrasDiffusionSchedulers)USE_PEFT_BACKEND	deprecateloggingreplace_example_docstringscale_lora_layersunscale_lora_layers)is_compiled_moduleis_torch_versionrandn_tensor   )DiffusionPipelineStableDiffusionMixin)StableDiffusionXLPipelineOutput)StableDiffusionXLWatermarker)is_torch_xla_availableTFa/  
    Examples:
        ```py
        >>> # !pip install opencv-python transformers accelerate
        >>> from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL
        >>> from diffusers.utils import load_image
        >>> import numpy as np
        >>> import torch

        >>> import cv2
        >>> from PIL import Image

        >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
        >>> negative_prompt = "low quality, bad quality, sketches"

        >>> # download an image
        >>> image = load_image(
        ...     "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
        ... )

        >>> # initialize the models and pipeline
        >>> controlnet_conditioning_scale = 0.5  # recommended for good generalization
        >>> controlnet = ControlNetModel.from_pretrained(
        ...     "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
        ... )
        >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
        >>> pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
        ...     "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
        ... )
        >>> pipe.enable_model_cpu_offload()

        >>> # get canny image
        >>> image = np.array(image)
        >>> image = cv2.Canny(image, 100, 200)
        >>> image = image[:, :, None]
        >>> image = np.concatenate([image, image, image], axis=2)
        >>> canny_image = Image.fromarray(image)

        >>> # generate image
        >>> image = pipe(
        ...     prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
        ... ).images[0]
        ```
num_inference_stepsdevice	timestepssigmasc                 K   s  |dur|durt d|dur>dtt| jj v }|s(t d| j d| jd||d| | j}t	|}||fS |durpdtt| jj v }|sZt d| j d| jd||d	| | j}t	|}||fS | j|fd
|i| | j}||fS )a  
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`list[int]`, *optional*):
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.
        sigmas (`list[float]`, *optional*):
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.

    Returns:
        `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    NzYOnly one of `timesteps` or `sigmas` can be passed. Please choose one to set custom valuesr+   zThe current scheduler class zx's `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler.)r+   r*   r,   zv's `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler.)r,   r*   r*    )

ValueErrorsetinspect	signatureset_timesteps
parameterskeys	__class__r+   len)	schedulerr)   r*   r+   r,   kwargsaccepts_timestepsaccept_sigmasr-   r-   l/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.pyretrieve_timestepsy   s2   r<   c                P       s   e Zd ZdZdZg dZg dZ				dqdeded	e	d
e
de
dedeee B ee B eB dedededB dedef fddZ												drdededB dejdB dedededB dedB dejdB dejdB dejdB d ejdB d!edB d"edB fd#d$Zdsd%d&Zd'd( Zd)d* Z									+	,	+	dtd-d.Zd/d0 Z 	1	1dud2d3Z!dsd4d5Z"	dsd6d7Z#d8d9 Z$d:ej%fd;ejd<ed=ej&d>ejfd?d@Z'e(dAdB Z)e(dCdD Z*e(dEdF Z+e(dGdH Z,e(dIdJ Z-e(dKdL Z.e(dMdN Z/e0 e1e2ddddddOddddPdddd,dddddddddQddd+d1d,d+ddRdddRddddSgf&deee B deee B dB dTe3dUedB dVedB dWedXee dYee dZedB d[edeee B dB deee B dB dedB d\ed]ej4eej4 B dB dSejdB dejdB dejdB dejdB d ejdB d^e3dB d_eej dB d`edB daedbe5ee6f dB dceee B ddedeeee B dfeee B dgeeef dheeef dieeef djeeef dB dkeeef dleeef dB d"edB dme7eegdf e8B e9B dB dnee fLdodpZ:  Z;S )v#StableDiffusionXLControlNetPipelinea  
    Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).

    The pipeline also inherits the following loading methods:
        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters

    Args:
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
        text_encoder ([`~transformers.CLIPTextModel`]):
            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
        text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]):
            Second frozen text-encoder
            ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
        tokenizer ([`~transformers.CLIPTokenizer`]):
            A `CLIPTokenizer` to tokenize text.
        tokenizer_2 ([`~transformers.CLIPTokenizer`]):
            A `CLIPTokenizer` to tokenize text.
        unet ([`UNet2DConditionModel`]):
            A `UNet2DConditionModel` to denoise the encoded image latents.
        controlnet ([`ControlNetModel`] or `list[ControlNetModel]`):
            Provides additional conditioning to the `unet` during the denoising process. If you set multiple
            ControlNets as a list, the outputs from each ControlNet are added together to create one combined
            additional conditioning.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
            Whether the negative prompt embeddings should always be set to 0. Also see the config of
            `stabilityai/stable-diffusion-xl-base-1-0`.
        add_watermarker (`bool`, *optional*):
            Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to
            watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no
            watermarker is used.
    z6text_encoder->text_encoder_2->image_encoder->unet->vae)	tokenizertokenizer_2text_encodertext_encoder_2feature_extractorimage_encoder)latentsprompt_embedsnegative_prompt_embedsadd_text_embedsadd_time_idsnegative_pooled_prompt_embedsnegative_add_time_idsimageTNvaer@   rA   r>   r?   unet
controlnetr7   force_zeros_for_empty_promptadd_watermarkerrB   rC   c                    s   t    t|ttfrt|}| j||||||||||d
 t| dd r0dt| j	j
jd  nd| _t| jdd| _t| jddd	| _|
d urJ|
nt }
|
rTt | _nd | _| j|	d
 d S )N)
rL   r@   rA   r>   r?   rM   rN   r7   rB   rC   rL   r#         T)vae_scale_factordo_convert_rgbF)rS   rT   do_normalize)rO   )super__init__
isinstancelisttupler   register_modulesgetattrr6   rL   configblock_out_channelsrS   r   image_processorcontrol_image_processorr	   r'   	watermarkregister_to_config)selfrL   r@   rA   r>   r?   rM   rN   r7   rO   rP   rB   rC   r5   r-   r;   rW      s2   
(
z,StableDiffusionXLControlNetPipeline.__init__rQ   promptprompt_2r*   num_images_per_promptdo_classifier_free_guidancenegative_promptnegative_prompt_2rE   rF   pooled_prompt_embedsrI   
lora_scale	clip_skipc           !   
   C   s:  |p| j }|dur9t| tr9|| _| jdur%tst| j| nt| j| | jdur9ts3t| j| nt| j| t|t	rA|gn|}|durLt
|}n|jd }| jdur\| j| jgn| jg}| jdurk| j| jgn| jg}|du r|pw|}t|t	r|gn|}g }||g}t|||D ]\}}}t| tr| ||}||d|jddd}|j}||dddj}|jd	 |jd	 krt||s||dd|jd
 d	f }td|j d|  |||dd}|
du r|d jdkr|d }
|du r|jd }n|j|d   }|| qtj|d	d}|du o| jj}|r6|	du r6|r6t|}	t|
}n|r|	du r|pBd}|pG|}t|t	rS||g n|}t|t	r`||g n|}|durt |t |urt!dt | dt | d|t
|krt"d| dt
| d| d| d	||g}g }t|||D ]E\}}}t| tr| ||}|jd
 }||d|ddd}||j|dd}	|du r|	d jdkr|	d }|	jd }	||	 qtj|d	d}	| jdur|j| jj#|d}n	|j| j$j#|d}|j\}}} |%d
|d
}|&|| |d	}|rR|	jd
 }| jdur9|	j| jj#|d}	n	|	j| j$j#|d}	|	%d
|d
}	|	&|| |d	}	|
%d
|&|| d	}
|rm|%d
|&|| d	}| jdurt| trtrt'| j| | jdurt| trtrt'| j| ||	|
|fS )a\  
        Encodes the prompt into text encoder hidden states.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                prompt to be encoded
            prompt_2 (`str` or `list[str]`, *optional*):
                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
                used in both text-encoders
            device: (`torch.device`):
                torch device
            num_images_per_prompt (`int`):
                number of images that should be generated per prompt
            do_classifier_free_guidance (`bool`):
                whether to use classifier free guidance or not
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            negative_prompt_2 (`str` or `list[str]`, *optional*):
                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            pooled_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
                If not provided, pooled text embeddings will be generated from `prompt` input argument.
            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
                input argument.
            lora_scale (`float`, *optional*):
                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
            clip_skip (`int`, *optional*):
                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
                the output of the pre-final layer will be used for computing the prompt embeddings.
        Nr   
max_lengthTpt)paddingrn   
truncationreturn_tensorslongest)rp   rr   rQ   z\The following part of your input was truncated because CLIP can only handle sequences up to z	 tokens: output_hidden_statesr#   dim z?`negative_prompt` should be the same type to `prompt`, but got z != .z`negative_prompt`: z has batch size z, but `prompt`: zT. Please make sure that passed `negative_prompt` matches the batch size of `prompt`.)dtyper*   )(_execution_devicerX   r   _lora_scaler@   r   r   r   rA   strr6   shaper>   r?   zipr   maybe_convert_promptmodel_max_length	input_idstorchequalbatch_decodeloggerwarningtondimhidden_statesappendconcatr]   rO   
zeros_liketype	TypeErrorr.   r|   rM   repeatviewr   )!rc   re   rf   r*   rg   rh   ri   rj   rE   rF   rk   rI   rl   rm   
batch_size
tokenizerstext_encodersprompt_embeds_listpromptsr>   r@   text_inputstext_input_idsuntruncated_idsremoved_textzero_out_negative_promptuncond_tokensnegative_prompt_embeds_listrn   uncond_inputbs_embedseq_len_r-   r-   r;   encode_prompt+  s   
:





 







z1StableDiffusionXLControlNetPipeline.encode_promptc           
      C   s   t | j j}t|tjs| j|ddj}|j	||d}|rH| j|ddj
d }|j|dd}| jt|ddj
d }|j|dd}||fS | |j}|j|dd}t|}	||	fS )	Nro   )rr   r*   r|   Tru   rw   r   rx   )nextrC   r3   r|   rX   r   TensorrB   pixel_valuesr   r   repeat_interleaver   image_embeds)
rc   rK   r*   rg   rv   r|   image_enc_hidden_statesuncond_image_enc_hidden_statesr   uncond_image_embedsr-   r-   r;   encode_image  s(   

z0StableDiffusionXLControlNetPipeline.encode_imagec                 C   sl  g }|rg }|d u ret |ts|g}t|t| jjjkr/tdt| dt| jjj dt|| jjjD ],\}}	t |	t }
| 	||d|
\}}|
|d d d f  |rc|
|d d d f  q7n|D ]}|rw|d\}}|
| |
| qgg }t|D ]0\}}tj|g| dd}|rtj|| g| dd}tj||gdd}|j|d}|
| q|S )	NzK`ip_adapter_image` must have same length as the number of IP Adapters. Got  images and z IP Adapters.rQ   r#   r   rx   )r*   )rX   rY   r6   rM   encoder_hid_projimage_projection_layersr.   r   r   r   r   chunk	enumerater   catr   )rc   ip_adapter_imageip_adapter_image_embedsr*   rg   rh   r   negative_image_embedssingle_ip_adapter_imageimage_proj_layeroutput_hidden_statesingle_image_embedssingle_negative_image_embedsir-   r-   r;   prepare_ip_adapter_image_embeds3  sH   


zCStableDiffusionXLControlNetPipeline.prepare_ip_adapter_image_embedsc                 C   sX   dt t| jjj v }i }|r||d< dt t| jjj v }|r*||d< |S )Neta	generator)r/   r0   r1   r7   stepr3   r4   )rc   r   r   accepts_etaextra_step_kwargsaccepts_generatorr-   r-   r;   prepare_extra_step_kwargsa  s   z=StableDiffusionXLControlNetPipeline.prepare_extra_step_kwargs      ?        c              
      sn  |d urt |tr|dkrtd| dt| d|d ur;t fdd|D s;td j d fd	d
|D  |d urN|d urNtd| d| d|d ura|d uratd| d| d|d u rm|d u rmtd|d urt |tst |tstdt| |d urt |tst |tstdt| |d ur|d urtd| d| d|d ur|d urtd| d| d|d ur|d ur|j|jkrtd|j d|j d|d ur|	d u rtd|d ur|d u rtdt  j	t
rt |trtdt j	j dt| d ttdo!t  j	tjjj}t  j	ts4|r<t  j	jtr< ||| nVt  j	t
sN|rt  j	jt
rt |tsXtdtdd |D rftdt|t j	jkrtd t| d!t j	j d"|D ]
} ||| qnJ t  j	ts|rt  j	jtrt |tstd$n?t  j	t
s|rt  j	jt
rt |trtd%d |D rtdnt |trt|t j	jkrtd&nJ t |ttfs|g}t |ttfs|g}t|t|krtd't| d(t| d)t  j	t
rIt|t j	jkrItd*| d+t| d,t j	j d-t j	j d	t||D ]/\}}||krbtd.| d/| d|d0k rotd.| d1|d2kr|td3| d4qN|
d ur|d urtd5|d urt |tstd6t| |d jd7vrtd8|d j d9d S d S ):Nr   z5`callback_steps` has to be a positive integer but is z	 of type r{   c                 3   s    | ]}| j v V  qd S N_callback_tensor_inputs.0krc   r-   r;   	<genexpr>  s    

zCStableDiffusionXLControlNetPipeline.check_inputs.<locals>.<genexpr>z2`callback_on_step_end_tensor_inputs` has to be in z, but found c                    s   g | ]	}| j vr|qS r-   r   r   r   r-   r;   
<listcomp>  s    zDStableDiffusionXLControlNetPipeline.check_inputs.<locals>.<listcomp>zCannot forward both `prompt`: z and `prompt_embeds`: z2. Please make sure to only forward one of the two.z Cannot forward both `prompt_2`: zeProvide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.z2`prompt` has to be of type `str` or `list` but is z4`prompt_2` has to be of type `str` or `list` but is z'Cannot forward both `negative_prompt`: z and `negative_prompt_embeds`: z)Cannot forward both `negative_prompt_2`: zu`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but got: `prompt_embeds` z != `negative_prompt_embeds` zIf `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.zIf `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`.z	You have z! ControlNets and you have passed z= prompts. The conditionings will be fixed across the prompts.scaled_dot_product_attentionz5For multiple controlnets: `image` must be type `list`c                 s       | ]}t |tV  qd S r   rX   rY   r   r   r-   r-   r;   r         zEA single batch of multiple conditionings are supported at the moment.zbFor multiple controlnets: `image` must have the same length as the number of controlnets, but got r   z ControlNets.FzLFor single controlnet: `controlnet_conditioning_scale` must be type `float`.c                 s   r   r   r   r   r-   r-   r;   r     r   zFor multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have the same length as the number of controlnetsz`control_guidance_start` has z* elements, but `control_guidance_end` has zI elements. Make sure to provide the same number of elements to each list.z`control_guidance_start`: z has z elements but there are z- controlnets available. Make sure to provide zcontrol guidance start: z4 cannot be larger or equal to control guidance end: r   z can't be smaller than 0.r   zcontrol guidance end: z can't be larger than 1.0.zProvide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined.z:`ip_adapter_image_embeds` has to be of type `list` but is )r
      zF`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is D)rX   intr.   r   allr   r   rY   r   rN   r   r   r   r6   netshasattrFr   _dynamo
eval_frameOptimizedModuler   	_orig_modcheck_imager   anyfloatrZ   r   r   )rc   re   rf   rK   callback_stepsri   rj   rE   rF   rk   r   r   rI   controlnet_conditioning_scalecontrol_guidance_startcontrol_guidance_end"callback_on_step_end_tensor_inputsis_compiledimage_startendr-   r   r;   check_inputsr  s6  



0



z0StableDiffusionXLControlNetPipeline.check_inputsc                 C   s$  t |tjj}t |tj}t |tj}t |to t |d tjj}t |to-t |d tj}t |to:t |d tj}	|sP|sP|sP|sP|sP|	sPtdt	| |rUd}
nt
|}
|d uret |tred}n|d urst |trst
|}n	|d ur||jd }|
dkr|
|krtd|
 d| d S d S )Nr   zimage must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is rQ   zdIf image batch size is not 1, image batch size must be same as prompt batch size. image batch size: z, prompt batch size: )rX   PILImager   r   npndarrayrY   r   r   r6   r   r   r.   )rc   rK   re   rE   image_is_pilimage_is_tensorimage_is_npimage_is_pil_listimage_is_tensor_listimage_is_np_listimage_batch_sizeprompt_batch_sizer-   r-   r;   r   0  sF   

z/StableDiffusionXLControlNetPipeline.check_imageFc
                 C   sp   | j j|||djtjd}|jd }
|
dkr|}n|}|j|dd}|j||d}|r6|	s6t|gd }|S )N)heightwidthr|   r   rQ   rx   r   r#   )r`   
preprocessr   r   float32r   r   r   )rc   rK   r   r   r   rg   r*   r|   rh   
guess_moder   	repeat_byr-   r-   r;   prepare_imageV  s   
z1StableDiffusionXLControlNetPipeline.prepare_imagec	           
      C   s   ||t || j t || j f}	t|tr(t||kr(tdt| d| d|d u r5t|	|||d}n||}|| jj	 }|S )Nz/You have passed a list of generators of length z+, but requested an effective batch size of z@. Make sure the batch size matches the length of the generators.)r   r*   r|   )
r   rS   rX   rY   r6   r.   r"   r   r7   init_noise_sigma)
rc   r   num_channels_latentsr   r   r|   r*   r   rD   r   r-   r-   r;   prepare_latentsu  s    
z3StableDiffusionXLControlNetPipeline.prepare_latentsc           	      C   sd   t || | }| jjjt| | }| jjjj}||kr(td| d| dt	j
|g|d}|S )Nz7Model expects an added time embedding vector of length z, but a vector of z was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`.r   )rY   rM   r]   addition_time_embed_dimr6   add_embeddinglinear_1in_featuresr.   r   tensor)	rc   original_sizecrops_coords_top_lefttarget_sizer|   text_encoder_projection_dimrH   passed_add_embed_dimexpected_add_embed_dimr-   r-   r;   _get_add_time_ids  s   z5StableDiffusionXLControlNetPipeline._get_add_time_idsc                 C   s    t ddd | jjtjd d S )N
upcast_vae1.0.0z`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.r   )r   rL   r   r   r   r   r-   r-   r;   r    s   z.StableDiffusionXLControlNetPipeline.upcast_vaei   wembedding_dimr|   returnc                 C   s   t |jdks	J |d }|d }ttd|d  }ttj||d|  }||dddf |dddf  }tjt	|t
|gdd}|d dkrZtjj|d}|j|jd	 |fksfJ |S )
a  
        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298

        Args:
            w (`torch.Tensor`):
                Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
            embedding_dim (`int`, *optional*, defaults to 512):
                Dimension of the embeddings to generate.
            dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
                Data type of the generated embeddings.

        Returns:
            `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
        rQ   g     @@r#   g     @r   Nrx   )r   rQ   r   )r6   r   r   logr	  exparanger   r   sincosnn
functionalpad)rc   r  r  r|   half_dimembr-   r-   r;   get_guidance_scale_embedding  s   &z@StableDiffusionXLControlNetPipeline.get_guidance_scale_embeddingc                 C      | j S r   )_guidance_scaler   r-   r-   r;   guidance_scale     z2StableDiffusionXLControlNetPipeline.guidance_scalec                 C   r!  r   )
_clip_skipr   r-   r-   r;   rm     r$  z-StableDiffusionXLControlNetPipeline.clip_skipc                 C   s   | j dko| jjjd u S )NrQ   )r"  rM   r]   time_cond_proj_dimr   r-   r-   r;   rh     s   z?StableDiffusionXLControlNetPipeline.do_classifier_free_guidancec                 C   r!  r   )_cross_attention_kwargsr   r-   r-   r;   cross_attention_kwargs  r$  z:StableDiffusionXLControlNetPipeline.cross_attention_kwargsc                 C   r!  r   )_denoising_endr   r-   r-   r;   denoising_end  r$  z1StableDiffusionXLControlNetPipeline.denoising_endc                 C   r!  r   )_num_timestepsr   r-   r-   r;   num_timesteps  r$  z1StableDiffusionXLControlNetPipeline.num_timestepsc                 C   r!  r   )
_interruptr   r-   r-   r;   	interrupt  r$  z-StableDiffusionXLControlNetPipeline.interrupt2   g      @pil)r   r   rD   rK   r   r   r)   r+   r,   r*  r#  r   r   r   r   output_typereturn_dictr(  r   r   r   r   r
  r  r  negative_original_sizenegative_crops_coords_top_leftnegative_target_sizecallback_on_step_endr   c'           Y         s
  |' dd}(|' dd})|(durtddd |)dur tddd t|%ttfr*|%j}&t| jr3| jjn| j}*t|t	sHt|t	rHt
||g }n3t|t	sZt|t	rZt
||g }n!t|t	s{t|t	s{t|*trnt
|*jnd}+|+|g |+|g }}| ||||)||||||||||||& |
| _|$| _|| _|	| _d| _|durt|trd},n|durt|t	rt
|},n|jd	 },| j}-t|*trt|tr|gt
|*j }t|*tr|*jjn|*jd	 jj}.|p|.}| jdur| jd
dnd}/| j|||-|| j|||||||/| jd\}}}}|dus|dur"| |||-|,| | j}0t|*trD| j ||||,| ||-|*j!| j|d	}|jdd \}}n5t|*trwg }1|D ]}2| j |2|||,| ||-|*j!| j|d	}2|1"|2 qN|1}|d	 jdd \}}nJ t#rd}3n|-}3t$| j%||3|\}t
| _&| j'jj(}4| )|,| |4|||j!|-||}d}5| j'jj*durt+,| j-d .|,| }6| j/|6| j'jj*dj0|-|j!d}5| 1||}7g }8t2t
D ]fddt3||D }9|8"t|*tr|9d	 n|9 qt|t	r|p|d	 jdd }n
|p|jdd }| p||f} |}:| j4du r.t5|jd };n| j4jj6};| j7||| |j!|;d}<|!durT|#durT| j7|!|"|#|j!|;d}=n|<}=| jrut+j8||gd	d}t+j8||:gd	d}:t+j8|=|<gd	d}<|0|-}|:0|-}:|<0|-.|,| d}<t
|| j%j9  }>| j:durt| j:tr| j:d	kr| j:dk rt5t;| j%jj<| j:| j%jj<   t
t	t= fdd}d| t| j'}?t| j}@t>dd}A| j?|d}Bt@D ]\}C| jArqt+jBC r|?r|@r|Art+jDE  | jrt+8|gd n|}D| j%F|D|C}D|:|<d}E|rL| jrL|}F| j%F|F|C}F|Gdd }G|:Gdd |<Gdd d}Hn|D}F|}G|E}Ht|8 t	rgdd t3||8 D }In|}Jt|Jt	rs|Jd	 }J|J|8  }I| j|F|C|G||I||Hdd\}K}L|r| jrdd |KD }Kt+8t+H|L|Lg}L|dus|dur|0|Ed < | j'|D|C||5| j|K|L|Edd!	d	 }M| jr|MGd\}N}O|N|
|O|N   }M| j%jI|M|C|fi |7d"did	 }|%dur/i }P|&D ]
}QtJ |Q |P|Q< q|%| |C|P}R|R d#|}|R d$|}|R d%|}|R d&|:}:|R d'|}|R d(|<}<|R d)|=}=|R d*|}t
d ksJd |>krid | j%j9 d	kri|BK  |(duri|) d	kritL| j%d+d }S|(|S|C| t#rptMN  qW d   n	1 s}w   Y  |d,ks| jOj!t+jPko| jOjjQ}T|Tr| R  |0tStT| jOjUV j!}tW| jOjd-o| jOjjXdu}UtW| jOjd.o| jOjjYdu}V|Ur|Vrt+,| jOjjXZdd/dd0|j[|j!}Wt+,| jOjjYZdd/dd0|j[|j!}X||X | jOjj\ |W }n|| jOjj\ }| jOj]|dd0d	 }|Tr| jOj0t+jPd1 n|}|d,ks9| j^dur1| j^_|}| j`ja||d2}| b  |sC|fS tc|d3S )4uZ5  
        The call function to the pipeline for generation.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
            prompt_2 (`str` or `list[str]`, *optional*):
                The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
                used in both text-encoders.
            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `list[torch.Tensor]`, `list[PIL.Image.Image]`, `list[np.ndarray]`,:
                    `list[list[torch.Tensor]]`, `list[list[np.ndarray]]` or `list[list[PIL.Image.Image]]`):
                The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
                specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
                as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
                width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
                images must be passed as a list such that each element of the list can be correctly batched for input
                to a single ControlNet.
            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
                The height in pixels of the generated image. Anything below 512 pixels won't work well for
                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
                and checkpoints that are not specifically fine-tuned on low resolutions.
            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
                The width in pixels of the generated image. Anything below 512 pixels won't work well for
                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
                and checkpoints that are not specifically fine-tuned on low resolutions.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            timesteps (`list[int]`, *optional*):
                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
                passed will be used. Must be in descending order.
            sigmas (`list[float]`, *optional*):
                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
                will be used.
            denoising_end (`float`, *optional*):
                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
                completed before it is intentionally prematurely terminated. As a result, the returned sample will
                still retain a substantial amount of noise as determined by the discrete timesteps selected by the
                scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
                "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
            guidance_scale (`float`, *optional*, defaults to 5.0):
                A higher guidance scale value encourages the model to generate images closely linked to the text
                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts to guide what to not include in image generation. If not defined, you need to
                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
            negative_prompt_2 (`str` or `list[str]`, *optional*):
                The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
                and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
                applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
            generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
                generation deterministic.
            latents (`torch.Tensor`, *optional*):
                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor is generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
                provided, text embeddings are generated from the `prompt` input argument.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
            pooled_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
                not provided, pooled text embeddings are generated from `prompt` input argument.
            negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
                weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
                argument.
            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
            ip_adapter_image_embeds (`list[torch.Tensor]`, *optional*):
                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
                contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
                provided, embeddings are computed from the `ip_adapter_image` input argument.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
                plain tuple.
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            controlnet_conditioning_scale (`float` or `list[float]`, *optional*, defaults to 1.0):
                The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
                to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
                the corresponding scale as a list.
            guess_mode (`bool`, *optional*, defaults to `False`):
                The ControlNet encoder tries to recognize the content of the input image even if you remove all
                prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
            control_guidance_start (`float` or `list[float]`, *optional*, defaults to 0.0):
                The percentage of total steps at which the ControlNet starts applying.
            control_guidance_end (`float` or `list[float]`, *optional*, defaults to 1.0):
                The percentage of total steps at which the ControlNet stops applying.
            original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)):
                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
                explained in section 2.2 of
                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
            crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)):
                `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
            target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)):
                For most cases, `target_size` should be set to the desired height and width of the generated image. If
                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
            negative_original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)):
                To negatively condition the generation process based on a specific image resolution. Part of SDXL's
                micro-conditioning as explained in section 2.2 of
                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
            negative_crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)):
                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
                micro-conditioning as explained in section 2.2 of
                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
            negative_target_size (`tuple[int]`, *optional*, defaults to (1024, 1024)):
                To negatively condition the generation process based on a target image resolution. It should be as same
                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
            clip_skip (`int`, *optional*):
                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
                the output of the pre-final layer will be used for computing the prompt embeddings.
            callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
                A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
                each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
                DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
                list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
            callback_on_step_end_tensor_inputs (`list`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
                `._callback_tensor_inputs` attribute of your pipeline class.

        Examples:

        Returns:
            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
                otherwise a `tuple` is returned containing the output images.
        callbackNr   r  zjPassing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`zpPassing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`rQ   Fr   scale)rE   rF   rk   rI   rl   rm   )	rK   r   r   r   rg   r*   r|   rh   r   rw   cpu)r  r   c                    s<   g | ]\}}d t  t |k p d t |k qS )r   rQ   )r   r6   )r   se)r   r+   r-   r;   r   e  s    *z@StableDiffusionXLControlNetPipeline.__call__.<locals>.<listcomp>rt   )r|   r  rx   c                    s   |  kS r   r-   )ts)discrete_timestep_cutoffr-   r;   <lambda>  s    z>StableDiffusionXLControlNetPipeline.__call__.<locals>.<lambda>z>=z2.1)totalr#   )text_embedstime_idsc                 S   s   g | ]\}}|| qS r-   r-   )r   cr:  r-   r-   r;   r     s    )encoder_hidden_statescontrolnet_condconditioning_scaler   added_cond_kwargsr2  c                 S   s    g | ]}t t ||gqS r-   )r   r   r   )r   dr-   r-   r;   r     s     r   )rC  timestep_condr(  down_block_additional_residualsmid_block_additional_residualrF  r2  r2  rD   rE   rF   rG   rI   rH   rJ   rK   orderlatentlatents_meanlatents_stdr   )r2  r   )r1  )images)dpopr   rX   r   r   tensor_inputsr    rN   r   rY   r6   r   r   r   r"  r%  r'  r)  r-  r   r   r}   r   r   r]   global_pool_conditionsr(  getr   rh   rm   r   r  r|   r   XLA_AVAILABLEr<   r7   r+  rM   in_channelsr  r&  r   r	  r#  r   r   r   r   ranger   rA   r   projection_dimr  r   rK  r*  roundnum_train_timestepsfilterr!   progress_barr   r.  cudais_available	_inductorcudagraph_mark_step_beginscale_model_inputr   r   r   localsupdater\   xm	mark_steprL   float16force_upcastr  r   iterpost_quant_convr3   r   rM  rN  r   r*   scaling_factordecodera   apply_watermarkr_   postprocessmaybe_free_model_hooksr&   )Yrc   re   rf   rK   r   r   r)   r+   r,   r*  r#  ri   rj   rg   r   r   rD   rE   rF   rk   rI   r   r   r1  r2  r(  r   r   r   r   r
  r  r  r3  r4  r5  rm   r6  r   r8   r7  r   rN   multr   r*   rR  text_encoder_lora_scaler   rO  r   timestep_devicer  rH  guidance_scale_tensorr   controlnet_keepkeepsrG   r  rH   rJ   num_warmup_stepsis_unet_compiledis_controlnet_compiledis_torch_higher_equal_2_1r[  tlatent_model_inputrF  control_model_inputcontrolnet_prompt_embedscontrolnet_added_cond_kwargs
cond_scalecontrolnet_cond_scaledown_block_res_samplesmid_block_res_sample
noise_prednoise_pred_uncondnoise_pred_textcallback_kwargsr   callback_outputsstep_idxneeds_upcastinghas_latents_meanhas_latents_stdrM  rN  r-   )r=  r   r+   r;   __call__  s   E


	


"










$
6
o&&

z,StableDiffusionXLControlNetPipeline.__call__)TNNN)NNrQ   TNNNNNNNNr   )NNNNNNNNr   r   r   N)FF)<__name__
__module____qualname____doc__model_cpu_offload_seq_optional_componentsr   r   r   r   r   r   r   rY   rZ   r   r   boolr   r   rW   r   r   r*   r   r   r   r   r   r   r   r   r   r  r  r  r  r   r|   r   propertyr#  rm   rh   r(  r*  r,  r.  no_gradr   EXAMPLE_DOC_STRINGr   	Generatordictr   r   r   r   r  __classcell__r-   r-   rd   r;   r=      s   ,	
2	

 
p.
 ?/












	





 
!"
#$%&'r=   )NNNN)Mr0   typingr   r   numpyr   	PIL.Imager   r   torch.nn.functionalr  r  r   transformersr   r   r   r   r   diffusers.utils.import_utilsr	   	callbacksr   r   r_   r   r   loadersr   r   r   r   modelsr   r   r   r   r   models.lorar   
schedulersr   utilsr   r   r   r   r   r   utils.torch_utilsr    r!   r"   pipeline_utilsr$   r%   #stable_diffusion_xl.pipeline_outputr&   stable_diffusion_xl.watermarkr'   r(   torch_xla.core.xla_modelcore	xla_modelrc  rT  
get_loggerr  r   r  r   r   r*   rY   r   r<   r=   r-   r-   r-   r;   <module>   s`    
1



;