o
    Gi]                     @   s   d dl mZ d dlZd dlZd dlZd dlmZm	Z	m
Z
mZ ddlmZ ddlmZ ddlmZmZmZmZ ddlmZ d	d
lmZ e rQd dlm  mZ dZndZeeZ dZ!dZ"eG dd deZ#G dd deZ$dS )    )	dataclassN)CLIPImageProcessorCLIPTextModelWithProjectionCLIPTokenizerCLIPVisionModelWithProjection   )PriorTransformer)UnCLIPScheduler)
BaseOutputis_torch_xla_availableloggingreplace_example_docstring)randn_tensor   )DiffusionPipelineTFav  
    Examples:
        ```py
        >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline
        >>> import torch

        >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior")
        >>> pipe_prior.to("cuda")

        >>> prompt = "red cat, 4k photo"
        >>> out = pipe_prior(prompt)
        >>> image_emb = out.image_embeds
        >>> negative_image_emb = out.negative_image_embeds

        >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1")
        >>> pipe.to("cuda")

        >>> image = pipe(
        ...     prompt,
        ...     image_embeds=image_emb,
        ...     negative_image_embeds=negative_image_emb,
        ...     height=768,
        ...     width=768,
        ...     num_inference_steps=100,
        ... ).images

        >>> image[0].save("cat.png")
        ```
a  
    Examples:
        ```py
        >>> from diffusers import KandinskyPriorPipeline, KandinskyPipeline
        >>> from diffusers.utils import load_image
        >>> import PIL

        >>> import torch
        >>> from torchvision import transforms

        >>> pipe_prior = KandinskyPriorPipeline.from_pretrained(
        ...     "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16
        ... )
        >>> pipe_prior.to("cuda")

        >>> img1 = load_image(
        ...     "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
        ...     "/kandinsky/cat.png"
        ... )

        >>> img2 = load_image(
        ...     "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
        ...     "/kandinsky/starry_night.jpeg"
        ... )

        >>> images_texts = ["a cat", img1, img2]
        >>> weights = [0.3, 0.3, 0.4]
        >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights)

        >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
        >>> pipe.to("cuda")

        >>> image = pipe(
        ...     "",
        ...     image_embeds=image_emb,
        ...     negative_image_embeds=zero_image_emb,
        ...     height=768,
        ...     width=768,
        ...     num_inference_steps=150,
        ... ).images[0]

        >>> image.save("starry_cat.png")
        ```
c                   @   s2   e Zd ZU dZejejB ed< ejejB ed< dS )KandinskyPriorPipelineOutputa  
    Output class for KandinskyPriorPipeline.

    Args:
        image_embeds (`torch.Tensor`)
            clip image embeddings for text prompt
        negative_image_embeds (`list[PIL.Image.Image]` or `np.ndarray`)
            clip image embeddings for unconditional tokens
    image_embedsnegative_image_embedsN)	__name__
__module____qualname____doc__torchTensornpndarray__annotations__ r   r   j/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.pyr   x   s   
 
r   c                       s`  e Zd ZdZdgZdZdededede	de
def fd	d
Ze ee								d(deeejjB ejB  dee dededejeej B dB dejdB dedB dedefddZdd Zd)ddZ	d*dd Ze ee							!	"d+d#eee B deee B dB dededejeej B dB dejdB ded$edB d%efd&d'Z  Z S ),KandinskyPriorPipelinea  
    Pipeline for generating image prior for Kandinsky

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    Args:
        prior ([`PriorTransformer`]):
            The canonical unCLIP prior to approximate the image embedding from the text embedding.
        image_encoder ([`CLIPVisionModelWithProjection`]):
            Frozen image-encoder.
        text_encoder ([`CLIPTextModelWithProjection`]):
            Frozen text-encoder.
        tokenizer (`CLIPTokenizer`):
            Tokenizer of class
            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
        scheduler ([`UnCLIPScheduler`]):
            A scheduler to be used in combination with `prior` to generate image embedding.
    priorztext_encoder->priorimage_encodertext_encoder	tokenizer	schedulerimage_processorc                    s$   t    | j||||||d d S )N)r    r"   r#   r$   r!   r%   )super__init__register_modules)selfr    r!   r"   r#   r$   r%   	__class__r   r   r'      s   
	
zKandinskyPriorPipeline.__init__      N       @images_and_promptsweightsnum_images_per_promptnum_inference_steps	generatorlatentsnegative_prior_promptnegative_promptguidance_scalec              
   C   s8  |
p| j }
t|t|krtdt| dt| dg }t||D ]S\}}t|tr9| |||||||	dj}n6t|tjjt	j
frft|tjjr^| j|ddjd dj| jj|
d}| |d	 }n	td
t| |||  q#t	|jddd}| |||||||	d}|dkr|jn|j}t||dS )a	  
        Function invoked when using the prior pipeline for interpolation.

        Args:
            images_and_prompts (`list[str | PIL.Image.Image | torch.Tensor]`):
                list of prompts and images to guide the image generation.
            weights: (`list[float]`):
                list of weights for each condition in `images_and_prompts`
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            num_inference_steps (`int`, *optional*, defaults to 25):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
            latents (`torch.Tensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will be generated by sampling using the supplied random `generator`.
            negative_prior_prompt (`str`, *optional*):
                The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if
                `guidance_scale` is less than `1`).
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if
                `guidance_scale` is less than `1`).
            guidance_scale (`float`, *optional*, defaults to 4.0):
                Guidance scale as defined in [Classifier-Free Diffusion
                Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
                of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
                `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
                the text `prompt`, usually at the expense of lower image quality.

        Examples:

        Returns:
            [`KandinskyPriorPipelineOutput`] or `tuple`
        z`images_and_prompts` contains z items and `weights` contains z, items - they should be lists of same length)r3   r2   r4   r5   r7   r8   pt)return_tensorsr   )dtypedevicer   zq`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor`  but is T)dimkeepdimr.   r   r   )r<   len
ValueErrorzip
isinstancestrr   PILImager   r   r%   pixel_values	unsqueezetor!   r;   typeappendcatsumr   r   )r)   r0   r1   r2   r3   r4   r5   r6   r7   r8   r<   image_embeddingscondweight	image_embout_zerozero_image_embr   r   r   interpolate   sZ   
6


	z"KandinskyPriorPipeline.interpolatec                 C   sR   |d u rt ||||d}n|j|krtd|j d| ||}||j }|S )N)r4   r<   r;   zUnexpected latents shape, got z, expected )r   shaperA   rI   init_noise_sigma)r)   rU   r;   r<   r4   r5   r$   r   r   r   prepare_latents   s   


z&KandinskyPriorPipeline.prepare_latentsc                 C   sR   |p| j }tdd| jjj| jjjj|| jjd}| |d }||d}|S )Nr,   r   )r<   r;   r   )	r<   r   zerosr!   config
image_sizerI   r;   repeat)r)   
batch_sizer<   zero_imgrS   r   r   r   get_zero_embed+  s   
z%KandinskyPriorPipeline.get_zero_embedc              
   C   sz  t |tr	t|nd}| j|d| jjddd}|j}|j |}	| j|dddj}
|
j	d |j	d krft
||
sf| j|
d d | jjd df }td	| jj d
|  |d d d | jjf }| ||}|j}|j}|j|dd}|j|dd}|	j|dd}	|r8|d u rdg| }n;t|t|urtdt| dt| dt |tr|g}n|t|krtd| dt| d| d| d	|}| j|d| jjddd}|j |}| |j|}|j}|j}|j	d }|d|}||| |}|j	d }|d|d}||| |d}|j|dd}t
||g}t
||g}t
||	g}	|||	fS )Nr,   
max_lengthTr9   )paddingr_   
truncationr:   longest)r`   r:   z\The following part of your input was truncated because CLIP can only handle sequences up to z	 tokens: r   )r=   r.   z?`negative_prompt` should be the same type to `prompt`, but got z != .z`negative_prompt`: z has batch size z, but `prompt`: zT. Please make sure that passed `negative_prompt` matches the batch size of `prompt`.)rC   listr@   r#   model_max_length	input_idsattention_maskboolrI   rU   r   equalbatch_decodeloggerwarningr"   text_embedslast_hidden_staterepeat_interleaverJ   	TypeErrorrD   rA   r[   viewrL   )r)   promptr<   r2   do_classifier_free_guidancer7   r\   text_inputstext_input_ids	text_maskuntruncated_idsremoved_texttext_encoder_outputprompt_embedstext_encoder_hidden_statesuncond_tokensuncond_inputuncond_text_mask*negative_prompt_embeds_text_encoder_outputnegative_prompt_embeds!uncond_text_encoder_hidden_statesseq_lenr   r   r   _encode_prompt4  s    $




z%KandinskyPriorPipeline._encode_promptr9   Trs   output_typereturn_dictc
                 C   s\  t |tr	|g}nt |tstdt| t |tr |g}nt |ts2|dur2tdt| |dur>|| }d| }| j}
t|}|| }|dk}| ||
|||\}}}| jj	||
d | jj
}| jjj}| ||f|j|
||| j}t| |D ]P\}}|rt|gd n|}| j|||||dj}|r|d\}}||||   }|d |jd	 krd}n||d  }| jj|||||d
j}trt  q| j|}|}|du r| j|jd	 |jd}|   n|d\}}t | dr| j!dur| j"#  |dvrtd| |dkr!|$ % }|$ % }|	s(||fS t&||dS )a	  
        Function invoked when calling the pipeline for generation.

        Args:
            prompt (`str` or `list[str]`):
                The prompt or prompts to guide the image generation.
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
                if `guidance_scale` is less than `1`).
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            num_inference_steps (`int`, *optional*, defaults to 25):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
            latents (`torch.Tensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will be generated by sampling using the supplied random `generator`.
            guidance_scale (`float`, *optional*, defaults to 4.0):
                Guidance scale as defined in [Classifier-Free Diffusion
                Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
                of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
                `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
                the text `prompt`, usually at the expense of lower image quality.
            output_type (`str`, *optional*, defaults to `"pt"`):
                The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"`
                (`torch.Tensor`).
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.

        Examples:

        Returns:
            [`KandinskyPriorPipelineOutput`] or `tuple`
        z2`prompt` has to be of type `str` or `list` but is Nz;`negative_prompt` has to be of type `str` or `list` but is r   g      ?)r<   )timestepproj_embeddingencoder_hidden_statesrh   r,   r   )r   sampler4   prev_timestepfinal_offload_hook)r9   r   zBOnly the output types `pt` and `np` are supported not output_type=r   r?   )'rC   rD   re   rA   rJ   _execution_devicer@   r   r$   set_timesteps	timestepsr    rY   embedding_dimrW   r;   	enumerateprogress_barr   rL   predicted_image_embeddingchunkrU   stepprev_sampleXLA_AVAILABLExm	mark_steppost_process_latentsr^   r<   maybe_free_model_hookshasattrr   
prior_hookoffloadcpunumpyr   )r)   rs   r7   r2   r3   r4   r5   r8   r   r   r<   r\   rt   r{   r|   rw   prior_timesteps_tensorr   itlatent_model_inputr    predicted_image_embedding_uncondpredicted_image_embedding_textr   rN   zero_embedsr   r   r   __call__  s   
5




	



zKandinskyPriorPipeline.__call__)r,   r-   NNNr.   r/   N)r,   N)N)Nr,   r-   NNr/   r9   T)!r   r   r   r   _exclude_from_cpu_offloadmodel_cpu_offload_seqr   r   r   r   r	   r   r'   r   no_gradr   EXAMPLE_INTERPOLATE_DOC_STRINGre   rD   rE   rF   r   floatint	GeneratorrT   rW   r^   r   EXAMPLE_DOC_STRINGri   r   __classcell__r   r   r*   r   r      s    	
j

a
	
r   )%dataclassesr   r   r   	PIL.ImagerE   r   transformersr   r   r   r   modelsr   
schedulersr	   utilsr
   r   r   r   utils.torch_utilsr   pipeline_utilsr   torch_xla.core.xla_modelcore	xla_modelr   r   
get_loggerr   rl   r   r   r   r   r   r   r   r   <module>   s(   
.