o
    ۷i+l                     @   s   d dl mZ d dlZd dlmZmZ ddlmZ ddlm	Z	m
Z
 ddlmZ ddlmZmZmZmZ dd	lmZ d
dlmZmZ e rOd dlm  mZ dZndZeeZdZdddZ G dd deeZ!dS )    )CallableN)T5EncoderModelT5Tokenizer   )StableDiffusionLoraLoaderMixin)Kandinsky3UNetVQModel)DDPMScheduler)	deprecateis_torch_xla_availableloggingreplace_example_docstring)randn_tensor   )DiffusionPipelineImagePipelineOutputTFa  
    Examples:
        ```py
        >>> from diffusers import AutoPipelineForText2Image
        >>> import torch

        >>> pipe = AutoPipelineForText2Image.from_pretrained(
        ...     "kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16
        ... )
        >>> pipe.enable_model_cpu_offload()

        >>> prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."

        >>> generator = torch.Generator(device="cpu").manual_seed(0)
        >>> image = pipe(prompt, num_inference_steps=25, generator=generator).images[0]
        ```

   c                 C   sX   | |d  }| |d  dkr|d7 }||d  }||d  dkr$|d7 }|| || fS )Nr   r       )heightwidthscale_factor
new_height	new_widthr   r   h/home/ubuntu/vllm_env/lib/python3.10/site-packages/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.pydownscale_height_and_width1   s   r   c                $       s  e Zd ZdZg dZdedededede	f
 fdd	Z
d
d Ze 									d3dejdB dejdB dejdB dejdB fddZdd Z						d4ddZedd Zedd Zedd Ze eedd d!ddd"d"dddddd#dddd$gfd%eee B d&ed'ed(eee B dB d)edB d*edB d+edB d,ejeej B dB dejdB dejdB dejdB dejdB d-edB d.ed/eeegdf dB d0ee f d1d2Z  Z S )5Kandinsky3Pipelineztext_encoder->unet->movq)latentsprompt_embedsnegative_prompt_embedsnegative_attention_maskattention_mask	tokenizertext_encoderunet	schedulermovqc                    s"   t    | j|||||d d S )N)r"   r#   r$   r%   r&   )super__init__register_modules)selfr"   r#   r$   r%   r&   	__class__r   r   r(   E   s   


zKandinsky3Pipeline.__init__c                 C   s`   |r,t ||dk ||dk< |d d }|d d d |f }|d d d |f }||fS )Nr   r   )torch
zeros_likesummax)r*   
embeddingsr!   cut_contextmax_seq_lengthr   r   r   process_embedsS   s   z!Kandinsky3Pipeline.process_embedsTr   NFr   r   r!   r    c              
   C   s  |dur|durt |t |urtdt | dt | d|du r&| j}|dur2t|tr2d}n|dur@t|tr@t|}n|jd }d}|du r|| j|d|d	d
d}|j	
|}|j
|}	| j||	d}|d }| ||	|\}}	||	d }| jdur| jj}nd}|j
||d}|j\}}}|d|d}||| |d}|	|d}	|r4|du r4|du rdg| }n$t|tr|g}n|t|krtd| dt| d| d| d	|}|dur*| j|ddd	d	d
d}|j	
|}|j
|}
| j||
d}|d }|ddd|jd f }|
ddd|jd f }
||
d }n
t|}t|	}
|ra|jd }|j
||d}|j|jkr`|d|d}||| |d}|
|d}
nd}d}
|||	|
fS )aX  
        Encodes the prompt into text encoder hidden states.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                prompt to be encoded
            device: (`torch.device`, *optional*):
                torch device to place the resulting embeddings on
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                number of images that should be generated per prompt
            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
                whether to use classifier free guidance or not
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            attention_mask (`torch.Tensor`, *optional*):
                Pre-generated attention mask. Must provide if passing `prompt_embeds` directly.
            negative_attention_mask (`torch.Tensor`, *optional*):
                Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly.
        Nz?`negative_prompt` should be the same type to `prompt`, but got z != .r   r      
max_lengthTpt)paddingr8   
truncationreturn_tensors)r!   r   )dtypedevicer-    z`negative_prompt`: z has batch size z, but `prompt`: zT. Please make sure that passed `negative_prompt` matches the batch size of `prompt`.)r:   r8   r;   return_attention_maskr<   )type	TypeError_execution_device
isinstancestrlistlenshaper"   	input_idstor!   r#   r5   	unsqueezer=   repeatview
ValueErrorr.   r/   )r*   promptdo_classifier_free_guidancenum_images_per_promptr>   negative_promptr   r   _cut_contextr!   r    
batch_sizer8   text_inputstext_input_idsr=   bs_embedseq_len_uncond_tokensuncond_inputr   r   r   encode_prompt[   s   *








z Kandinsky3Pipeline.encode_promptc                 C   sR   |d u rt ||||d}n|j|krtd|j d| ||}||j }|S )N)	generatorr>   r=   zUnexpected latents shape, got z, expected )r   rH   rN   rJ   init_noise_sigma)r*   rH   r=   r>   r]   r   r%   r   r   r   prepare_latents   s   


z"Kandinsky3Pipeline.prepare_latentsc	           	         s
  |d urt |tr|dkrtd| dt| d|d ur;t fdd|D s;td j d fd	d
|D  |d urN|d urNtd| d| d|d u rZ|d u rZtd|d urqt |tsqt |tsqtdt| |d ur|d urtd| d| d|d ur|d ur|j|jkrtd|j d|j d|d ur|d u rtd|d ur|d ur|jd d |jkrtd|jd d  d|j d|d ur|d u rtd|d ur|d ur|jd d |jkrtd|jd d  d|j dd S d S d S )Nr   z5`callback_steps` has to be a positive integer but is z	 of type r6   c                 3       | ]}| j v V  qd S N_callback_tensor_inputs.0kr*   r   r   	<genexpr>      

z2Kandinsky3Pipeline.check_inputs.<locals>.<genexpr>2`callback_on_step_end_tensor_inputs` has to be in , but found c                       g | ]	}| j vr|qS r   rb   rd   rg   r   r   
<listcomp>      z3Kandinsky3Pipeline.check_inputs.<locals>.<listcomp>zCannot forward both `prompt`: z and `prompt_embeds`: z2. Please make sure to only forward one of the two.zeProvide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.z2`prompt` has to be of type `str` or `list` but is z'Cannot forward both `negative_prompt`: z and `negative_prompt_embeds`: zu`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but got: `prompt_embeds` z != `negative_prompt_embeds` zLPlease provide `negative_attention_mask` along with `negative_prompt_embeds`r   z`negative_prompt_embeds` and `negative_attention_mask` must have the same batch_size and token length when passed directly, but got: `negative_prompt_embeds` z != `negative_attention_mask` z:Please provide `attention_mask` along with `prompt_embeds`z`prompt_embeds` and `attention_mask` must have the same batch_size and token length when passed directly, but got: `prompt_embeds` z != `attention_mask` )	rD   intrN   rA   allrc   rE   rF   rH   )	r*   rO   callback_stepsrR   r   r   "callback_on_step_end_tensor_inputsr!   r    r   rg   r   check_inputs   sz   zKandinsky3Pipeline.check_inputsc                 C      | j S ra   _guidance_scalerg   r   r   r   guidance_scaleB     z!Kandinsky3Pipeline.guidance_scalec                 C   s
   | j dkS )Nr   ru   rg   r   r   r   rP   F  s   
z.Kandinsky3Pipeline.do_classifier_free_guidancec                 C   rt   ra   )_num_timestepsrg   r   r   r   num_timestepsJ  rx   z Kandinsky3Pipeline.num_timesteps   g      @i   pilr   rO   num_inference_stepsrw   rR   rQ   r   r   r]   output_typereturn_dictcallback_on_step_endrr   c           &         s$  | dd}| dd}|durtddd |dur tddd |durAt fdd|D sAtd	 j d
 fdd|D  d} j} ||||	|
||| | _|durat|t	rad}n|durot|t
rot|}n|	jd } j| j||||	|
|||d
\}	}
}} jrt|
|	g}	t||g } jj||d  jj}t||d\}} || d||f|	j||| j}t drЈ jdurЈ j  t|| jj  }t| _ j|d!}t|D ]\}} jrt|gd n|} j|||	|ddd } jr| d\}} |d |  ||  } jj!||||dj"}|durai }!|D ]
}"t# |" |!|"< q1| |||!}#|# d|}|# d|	}	|# d|
}
|# d|}|# d|}|t|d ks||d |kr|d  jj dkr|$  |dur|| dkr|t% jd d }$||$|| t&rt'(  q|d!vrtd"| |d#ks j)j*|dd$d% }%|d&v r|%d' d' }%|%+dd}%|%, -ddd(d. / }%|d)kr 0|%}%n|}% 1  |s|%fW  d   S t2|%d*W  d   S 1 sw   Y  dS )+u6  
        Function invoked when calling the pipeline for generation.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
                instead.
            num_inference_steps (`int`, *optional*, defaults to 25):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            timesteps (`list[int]`, *optional*):
                Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
                timesteps are used. Must be in descending order.
            guidance_scale (`float`, *optional*, defaults to 3.0):
                Guidance scale as defined in [Classifier-Free Diffusion
                Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
                of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
                `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
                the text `prompt`, usually at the expense of lower image quality.
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            height (`int`, *optional*, defaults to self.unet.config.sample_size):
                The height in pixels of the generated image.
            width (`int`, *optional*, defaults to self.unet.config.sample_size):
                The width in pixels of the generated image.
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
                applies to [`schedulers.DDIMScheduler`], will be ignored for others.
            generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            attention_mask (`torch.Tensor`, *optional*):
                Pre-generated attention mask. Must provide if passing `prompt_embeds` directly.
            negative_attention_mask (`torch.Tensor`, *optional*):
                Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
            callback (`Callable`, *optional*):
                A function that will be called every `callback_steps` steps during inference. The function will be
                called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
            callback_steps (`int`, *optional*, defaults to 1):
                The frequency at which the `callback` function will be called. If not specified, the callback will be
                called at every step.
            clean_caption (`bool`, *optional*, defaults to `True`):
                Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
                be installed. If the dependencies are not installed, the embeddings will be created from the raw
                prompt.
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).

        Examples:

        Returns:
            [`~pipelines.ImagePipelineOutput`] or `tuple`

        callbackNrq   z1.0.0zhPassing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`znPassing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`c                 3   r`   ra   rb   rd   rg   r   r   rh     ri   z.Kandinsky3Pipeline.__call__.<locals>.<genexpr>rj   rk   c                    rl   r   rb   rd   rg   r   r   rm     rn   z/Kandinsky3Pipeline.__call__.<locals>.<listcomp>Tr   r   )rQ   r>   rR   r   r   rS   r!   r    )r>   r      text_encoder_offload_hook)totalr   F)encoder_hidden_statesencoder_attention_maskr   g      ?)r]   r   r   r   r!   r    order)r9   npr|   latentzSOnly the output types `pt`, `pil`, `np` and `latent` are supported not output_type=r   )force_not_quantizesample)r   r|   g      ?r   r|   )images)3popr
   rp   rN   rc   rC   rs   rv   rD   rE   rF   rG   rH   r\   rP   r.   catboolr%   set_timesteps	timestepsr   r_   r=   hasattrr   offloadr   ry   progress_bar	enumerater$   chunkstepprev_samplelocalsupdategetattrXLA_AVAILABLExm	mark_stepr&   decodeclampcpupermutefloatnumpynumpy_to_pilmaybe_free_model_hooksr   )&r*   rO   r}   rw   rR   rQ   r   r   r]   r   r   r!   r    r~   r   r   r   rr   kwargsr   rq   r3   r>   rT   r   num_warmup_stepsr   itlatent_model_input
noise_prednoise_pred_uncondnoise_pred_textcallback_kwargsrf   callback_outputsstep_idximager   rg   r   __call__N  s  `

	


6




H&zKandinsky3Pipeline.__call__)	Tr   NNNNFNN)NNNNNN)!__name__
__module____qualname__model_cpu_offload_seqrc   r   r   r   r	   r   r(   r5   r.   no_gradTensorr\   r_   rs   propertyrw   rP   rz   r   EXAMPLE_DOC_STRINGrE   rF   ro   r   	Generatorr   r   r   __classcell__r   r   r+   r   r   ;   s    
 
F



	
r   )r   )"typingr   r.   transformersr   r   loadersr   modelsr   r   
schedulersr	   utilsr
   r   r   r   utils.torch_utilsr   pipeline_utilsr   r   torch_xla.core.xla_modelcore	xla_modelr   r   
get_loggerr   loggerr   r   r   r   r   r   r   <module>   s"    


