o
    Gig                     @   s.  d dl Z d dlmZmZ d dlZd dlmZmZmZ d dl	m
Z
 ddlmZ ddlmZmZ ddlmZmZmZ dd	lmZ dd
lmZ ddlmZ ddlmZmZmZmZmZm Z m!Z! ddl"m#Z# ddl$m%Z%m&Z&m'Z' ddl(m)Z) e rd dl*m+  m,Z- dZ.ndZ.e/e0Z1dZ2G dd de%e'eeZ3dS )    N)AnyCallable)CLIPTextModelCLIPTextModelWithProjectionCLIPTokenizer)CLIPTextModelOutput   )VaeImageProcessor)StableDiffusionLoraLoaderMixinTextualInversionLoaderMixin)AutoencoderKLPriorTransformerUNet2DConditionModel)get_timestep_embedding)adjust_lora_scale_text_encoder)KarrasDiffusionSchedulers)USE_PEFT_BACKEND	deprecateis_torch_xla_availableloggingreplace_example_docstringscale_lora_layersunscale_lora_layers)randn_tensor   )DiffusionPipelineImagePipelineOutputStableDiffusionMixin   )StableUnCLIPImageNormalizerTFa  
    Examples:
        ```py
        >>> import torch
        >>> from diffusers import StableUnCLIPPipeline

        >>> pipe = StableUnCLIPPipeline.from_pretrained(
        ...     "fusing/stable-unclip-2-1-l", torch_dtype=torch.float16
        ... )  # TODO update model path
        >>> pipe = pipe.to("cuda")

        >>> prompt = "a photo of an astronaut riding a horse on mars"
        >>> images = pipe(prompt).images
        >>> images[0].save("astronaut_horse.png")
        ```
c                0       s  e Zd ZU dZddgZdZeed< eed< e	ed< e
ed< eed< e
ed< eed	< eed
< eed< e
ed< eed< dedede	de
dede
d	ed
edede
def fddZ		dIdeeB dB dejdB fddZ				dJdejdB dejdB dedB fddZ					dKdejdB dejdB dedB dedB fddZdd Zdd  Zd!d" Z			dLd#d$Zd%d& Z		dId'ejd(ed)ejdB d*ejdB fd+d,Ze  e!e"				-	.		/	0					1	2		/		3	4	5		dMd6e#e$e# B dB d7edB d8edB d9ed:ed;e#e$e# B dB d<edB d=ed*ejdB d>ejdB dejdB dejdB d?e#dB d@e%dAe&eeejgdf dB dBedCe'e#e(f dB d(edDedEedFejdB dedB f,dGdHZ)  Z*S )NStableUnCLIPPipelineac  
    Pipeline for text-to-image generation using stable unCLIP.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).

    The pipeline also inherits the following loading methods:
        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights

    Args:
        prior_tokenizer ([`CLIPTokenizer`]):
            A [`CLIPTokenizer`].
        prior_text_encoder ([`CLIPTextModelWithProjection`]):
            Frozen [`CLIPTextModelWithProjection`] text-encoder.
        prior ([`PriorTransformer`]):
            The canonical unCLIP prior to approximate the image embedding from the text embedding.
        prior_scheduler ([`KarrasDiffusionSchedulers`]):
            Scheduler used in the prior denoising process.
        image_normalizer ([`StableUnCLIPImageNormalizer`]):
            Used to normalize the predicted image embeddings before the noise is applied and un-normalize the image
            embeddings after the noise has been applied.
        image_noising_scheduler ([`KarrasDiffusionSchedulers`]):
            Noise schedule for adding noise to the predicted image embeddings. The amount of noise to add is determined
            by the `noise_level`.
        tokenizer ([`CLIPTokenizer`]):
            A [`CLIPTokenizer`].
        text_encoder ([`CLIPTextModel`]):
            Frozen [`CLIPTextModel`] text-encoder.
        unet ([`UNet2DConditionModel`]):
            A [`UNet2DConditionModel`] to denoise the encoded image latents.
        scheduler ([`KarrasDiffusionSchedulers`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image latents.
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
    priorimage_normalizerz+text_encoder->prior_text_encoder->unet->vaeprior_tokenizerprior_text_encoderprior_schedulerimage_noising_scheduler	tokenizertext_encoderunet	schedulervaec                    sd   t    | j|||||||||	|
|d t| dd r&dt| jjjd  nd| _t	| jd| _
d S )N)r#   r$   r!   r%   r"   r&   r'   r(   r)   r*   r+   r+   r   r      )vae_scale_factor)super__init__register_modulesgetattrlenr+   configblock_out_channelsr-   r	   image_processor)selfr#   r$   r!   r%   r"   r&   r'   r(   r)   r*   r+   	__class__ o/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.pyr/      s    
(zStableUnCLIPPipeline.__init__Ntext_model_outputtext_attention_maskc                 C   s(  |d u ryt |trt|nd}| j|d| jjddd}|j}	|j |}
| j|dddj}|j	d |	j	d krjt
|	|sj| j|d d | jjd df }td	| jj d
|  |	d d d | jjf }	| |	|}|j}|j}n|d j	d }|d |d }}|}
|j|dd}|j|dd}|
j|dd}
|rdg| }| j|d| jjddd}|j |}| |j|}|j}|j}|j	d }|d|}||| |}|j	d }|d|d}||| |d}|j|dd}t
||g}t
||g}t
||
g}
|||
fS )Nr   
max_lengthTptpaddingr=   
truncationreturn_tensorslongestr@   rB   \The following part of your input was truncated because CLIP can only handle sequences up to 	 tokens: r   )dim )
isinstancelistr2   r#   model_max_length	input_idsattention_maskbooltoshapetorchequalbatch_decodeloggerwarningr$   text_embedslast_hidden_staterepeat_interleaverepeatviewcat)r6   promptdevicenum_images_per_promptdo_classifier_free_guidancer;   r<   
batch_sizetext_inputstext_input_ids	text_maskuntruncated_idsremoved_textprior_text_encoder_outputprompt_embedstext_enc_hid_statesuncond_tokensuncond_inputuncond_text_mask0negative_prompt_embeds_prior_text_encoder_outputnegative_prompt_embedsuncond_text_enc_hid_statesseq_lenr9   r9   r:   _encode_prior_prompt   s~   	





z)StableUnCLIPPipeline._encode_prior_promptrh   rn   
lora_scalec	                 K   sP   d}
t dd|
dd | jd	||||||||d|	}t|d |d g}|S )
Nz`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple.z_encode_prompt()1.0.0Fstandard_warn)r]   r^   r_   r`   negative_promptrh   rn   rr   r   r   r9   )r   encode_promptrR   r\   )r6   r]   r^   r_   r`   rv   rh   rn   rr   kwargsdeprecation_messageprompt_embeds_tupler9   r9   r:   _encode_prompt  s    	z#StableUnCLIPPipeline._encode_prompt	clip_skipc
              
   C   s  |durt | tr|| _tst| j| nt| j| |dur't |tr'd}
n|dur5t |tr5t	|}
n|j
d }
|du rt | trJ| || j}| j|d| jjddd}|j}| j|ddd	j}|j
d
 |j
d
 krt||s| j|dd| jjd d
f }td| jj d|  t| jjdr| jjjr|j|}nd}|	du r| j|||d}|d }n| j|||dd}|d
 |	d   }| jj|}| jdur| jj}n| jdur| jj}n|j}|j||d}|j
\}}}|d|d}||| |d
}|r|du r|du rdg|
 }nC|dur8t |t |ur8t!dt | dt | dt |trB|g}n|
t	|kr\t"d| dt	| d| d|
 d	|}t | trk| || j}|j
d }| j|d|ddd}t| jjdr| jjjr|j|}nd}| j|j||d}|d }|r|j
d }|j||d}|d|d}||
| |d
}| jdurt | trtrt#| j| ||fS )a  
        Encodes the prompt into text encoder hidden states.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                prompt to be encoded
            device: (`torch.device`):
                torch device
            num_images_per_prompt (`int`):
                number of images that should be generated per prompt
            do_classifier_free_guidance (`bool`):
                whether to use classifier free guidance or not
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            lora_scale (`float`, *optional*):
                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
            clip_skip (`int`, *optional*):
                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
                the output of the pre-final layer will be used for computing the prompt embeddings.
        Nr   r   r=   Tr>   r?   rC   rD   rE   rF   rG   use_attention_mask)rN   )rN   output_hidden_states)dtyper^   rI   ?`negative_prompt` should be the same type to `prompt`, but got  != .z`negative_prompt`: z has batch size z, but `prompt`: zT. Please make sure that passed `negative_prompt` matches the batch size of `prompt`.)$rJ   r
   _lora_scaler   r   r(   r   strrK   r2   rQ   r   maybe_convert_promptr'   rL   rM   rR   rS   rT   rU   rV   hasattrr3   r}   rN   rP   
text_modelfinal_layer_normr   r)   rZ   r[   type	TypeError
ValueErrorr   )r6   r]   r^   r_   r`   rv   rh   rn   rr   r|   ra   rb   rc   re   rf   rN   prompt_embeds_dtypebs_embedrp   _rj   r=   rk   r9   r9   r:   rw   )  s   +











z"StableUnCLIPPipeline.encode_promptc                 C   sn   d}t dd|dd d| jjj | }| jj|ddd }|d	 d
 dd}| dd	dd 	 }|S )Nz{The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) insteaddecode_latentsrs   Frt   r   return_dictr   r   g      ?r   )
r   r+   r3   scaling_factordecodeclampcpupermutefloatnumpy)r6   latentsry   imager9   r9   r:   r     s   z#StableUnCLIPPipeline.decode_latentsc                 C   X   dt t| jjj v }i }|r||d< dt t| jjj v }|r*||d< |S Neta	generator)setinspect	signaturer%   step
parameterskeysr6   r   r   accepts_etaextra_step_kwargsaccepts_generatorr9   r9   r:   prepare_prior_extra_step_kwargs     z4StableUnCLIPPipeline.prepare_prior_extra_step_kwargsc                 C   r   r   )r   r   r   r*   r   r   r   r   r9   r9   r:   prepare_extra_step_kwargs  r   z.StableUnCLIPPipeline.prepare_extra_step_kwargsc	           	      C   s  |d dks|d dkrt d| d| d|d u s(|d ur5t|tr(|dkr5t d| dt| d|d urA|d urAt d|d u rM|d u rMt d	|d urdt|tsdt|tsdt d
t| |d urp|d urpt d|d ur|d urt|t|urtdt| dt| d|d ur|d ur|j|jkrt d|j d|j d|dk s|| jj	j
krt d| jj	j
d  dd S )Nr,   r   z7`height` and `width` have to be divisible by 8 but are z and r   z5`callback_steps` has to be a positive integer but is z	 of type z[Provide either `prompt` or `prompt_embeds`. Please make sure to define only one of the two.zeProvide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.z2`prompt` has to be of type `str` or `list` but is zProvide either `negative_prompt` or `negative_prompt_embeds`. Cannot leave both `negative_prompt` and `negative_prompt_embeds` undefined.r   r   zu`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but got: `prompt_embeds` z != `negative_prompt_embeds` z$`noise_level` must be between 0 and r   z, inclusive.)r   rJ   intr   r   rK   r   rQ   r&   r3   num_train_timesteps)	r6   r]   heightwidthcallback_stepsnoise_levelrv   rh   rn   r9   r9   r:   check_inputs  sX   z!StableUnCLIPPipeline.check_inputsc                 C   sR   |d u rt ||||d}n|j|krtd|j d| ||}||j }|S )Nr   r^   r   zUnexpected latents shape, got z, expected )r   rQ   r   rP   init_noise_sigma)r6   rQ   r   r^   r   r   r*   r9   r9   r:   prepare_latentsL  s   


z$StableUnCLIPPipeline.prepare_latentsimage_embedsr   noiser   c                 C   s   |du rt |j||j|jd}tj|g|jd  |jd}| j|j | j|}| j	j
|||d}| j|}t||jd ddd}||j}t||fd	}|S )
aG  
        Add noise to the image embeddings. The amount of noise is controlled by a `noise_level` input. A higher
        `noise_level` increases the variance in the final un-noised images.

        The noise is applied in two ways:
        1. A noise schedule is applied directly to the embeddings.
        2. A vector of sinusoidal time embeddings are appended to the output.

        In both cases, the amount of noise is controlled by the same `noise_level`.

        The embeddings are normalized before the noise is applied and un-normalized after the noise is applied.
        Nr   r   r^   )	timestepsr   rE   T)r   embedding_dimflip_sin_to_cosdownscale_freq_shiftr   )r   rQ   r^   r   rR   tensorr"   rP   scaler&   	add_noiseunscaler   r\   )r6   r   r   r   r   r9   r9   r:   noise_image_embeddingsW  s   z+StableUnCLIPPipeline.noise_image_embeddings         $@r           pilTr            @r]   r   r   num_inference_stepsguidance_scalerv   r_   r   r   output_typer   callbackr   cross_attention_kwargsprior_num_inference_stepsprior_guidance_scaleprior_latentsc           2      C   s  |p	| j jj| j }|p| j jj| j }| j||||||||d |dur-t|tr-d}n|dur;t|tr;t|}n|j	d }|| }| j
}|dk}| j||||d\}}}| jj||d | jj}| jjj}| ||f|j||	|| j}| |	|}t| |D ]V\} }!|rt|gd n|}"| j|"|!}"| j|"|!|||d	j}#|r|#d\}$}%|$||%|$   }#| jj|#f|!|d
|ddid }|dur| | dkr|| |!| q| j|}|}&|dk}'|dur|ddnd}(| j||||'||||(|d	\}}|'rt||g}| j|&||	d}&|'r#t |&}t||&g}&| j!j||d | j!j})| j jj"}*||*t#|| j t#|| j f}+| j|+|j||	|
| j!d}
| $|	|},t| |)D ]l\} }!|'rmt|
gd n|
}"| j!|"|!}"| j |"|!||&|ddd }-|'r|-d\}.}/|.||/|.   }-| j!j|-|!|
fi |,ddid }
|dur| | dkr| t%| j!dd }0||0|!|
 t&rt'(  q^|dks| j)j*|
| j)jj+ ddd }1n|
}1| j,j-|1|d}1| .  |s|1fS t/|1dS )u+  
        The call function to the pipeline for generation.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
                The height in pixels of the generated image.
            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
                The width in pixels of the generated image.
            num_inference_steps (`int`, *optional*, defaults to 20):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            guidance_scale (`float`, *optional*, defaults to 10.0):
                A higher guidance scale value encourages the model to generate images closely linked to the text
                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts to guide what to not include in image generation. If not defined, you need to
                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
                applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
            generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
                generation deterministic.
            latents (`torch.Tensor`, *optional*):
                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor is generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
                provided, text embeddings are generated from the `prompt` input argument.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
            callback (`Callable`, *optional*):
                A function that calls every `callback_steps` steps during inference. The function is called with the
                following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
            callback_steps (`int`, *optional*, defaults to 1):
                The frequency at which the `callback` function is called. If not specified, the callback is called at
                every step.
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            noise_level (`int`, *optional*, defaults to `0`):
                The amount of noise to add to the image embeddings. A higher `noise_level` increases the variance in
                the final un-noised images. See [`StableUnCLIPPipeline.noise_image_embeddings`] for more details.
            prior_num_inference_steps (`int`, *optional*, defaults to 25):
                The number of denoising steps in the prior denoising process. More denoising steps usually lead to a
                higher quality image at the expense of slower inference.
            prior_guidance_scale (`float`, *optional*, defaults to 4.0):
                A higher guidance scale value encourages the model to generate images closely linked to the text
                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
            prior_latents (`torch.Tensor`, *optional*):
                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
                embedding generation in the prior denoising process. Can be used to tweak the same generation with
                different prompts. If not provided, a latents tensor is generated by sampling using the supplied random
                `generator`.
            clip_skip (`int`, *optional*):
                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
                the output of the pre-final layer will be used for computing the prompt embeddings.
        Examples:

        Returns:
            [`~pipelines.ImagePipelineOutput`] or `tuple`:
                [`~ pipeline_utils.ImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning
                a tuple, the first element is a list with the generated images.
        )r]   r   r   r   r   rv   rh   rn   Nr   r   g      ?)r]   r^   r_   r`   r   r   )timestepproj_embeddingencoder_hidden_statesrN   )r   sampler   Fr   )	r]   r^   r_   r`   rv   rh   rn   rr   r|   )r   r   r   )rQ   r   r^   r   r   r*   )r   class_labelsr   r   orderlatentr   )r   )images)0r)   r3   sample_sizer-   r   rJ   r   rK   r2   rQ   _execution_devicerq   r%   set_timestepsr   r!   r   r   r   r   	enumerateprogress_barrR   r\   scale_model_inputpredicted_image_embeddingchunkr   post_process_latentsgetrw   r   
zeros_liker*   in_channelsr   r   r1   XLA_AVAILABLExm	mark_stepr+   r   r   r5   postprocessmaybe_free_model_hooksr   )2r6   r]   r   r   r   r   rv   r_   r   r   r   rh   rn   r   r   r   r   r   r   r   r   r   r|   ra   r^   !prior_do_classifier_free_guidanceprior_prompt_embeds prior_text_encoder_hidden_statesprior_text_maskprior_timesteps_tensorr   prior_extra_step_kwargsitlatent_model_inputr    predicted_image_embedding_uncondpredicted_image_embedding_textr   r`   text_encoder_lora_scaler   num_channels_latentsrQ   r   
noise_prednoise_pred_uncondnoise_pred_textstep_idxr   r9   r9   r:   __call__  s   h








$
 
zStableUnCLIPPipeline.__call__)NN)NNNN)NNNNN)NNN)NNNr   r   Nr   r   NNNNr   TNr   Nr   r   r   NN)+__name__
__module____qualname____doc___exclude_from_cpu_offloadmodel_cpu_offload_seqr   __annotations__r   r   r   r   r   r   r   r/   r   tuplerR   Tensorrq   r   r{   r   rw   r   r   r   r   r   	Generatorr   no_gradr   EXAMPLE_DOC_STRINGr   rK   rO   r   dictr   r   __classcell__r9   r9   r7   r:   r    F   s>  
 &	,

d	
'	

 8
=
.	
r    )4r   typingr   r   rR   transformersr   r   r   &transformers.models.clip.modeling_clipr   r5   r	   loadersr
   r   modelsr   r   r   models.embeddingsr   models.lorar   
schedulersr   utilsr   r   r   r   r   r   r   utils.torch_utilsr   pipeline_utilsr   r   r   stable_unclip_image_normalizerr   torch_xla.core.xla_modelcore	xla_modelr   r   
get_loggerr   rU   r  r    r9   r9   r9   r:   <module>   s0   $	

