o
    Gi                  
   @   s<  d dl Z d dlmZmZ d dlZd dlmZmZ ddlm	Z	m
Z
 ddlmZ ddlmZ ddlmZmZ dd	lmZ dd
lmZmZmZmZmZmZmZ ddlmZ ddlmZm Z  e rid dl!m"  m#Z$ dZ%ndZ%e&e'Z(dZ)				dde*dB de+ej,B dB de-e* dB de-e. dB fddZ/G dd deeZ0dS )    N)AnyCallable)T5TokenizerUMT5EncoderModel   )MultiPipelineCallbacksPipelineCallback)VaeImageProcessor)AuraFlowLoraLoaderMixin)AuraFlowTransformer2DModelAutoencoderKL)FlowMatchEulerDiscreteScheduler)USE_PEFT_BACKEND	deprecateis_torch_xla_availableloggingreplace_example_docstringscale_lora_layersunscale_lora_layers)randn_tensor   )DiffusionPipelineImagePipelineOutputTFa  
    Examples:
        ```py
        >>> import torch
        >>> from diffusers import AuraFlowPipeline

        >>> pipe = AuraFlowPipeline.from_pretrained("fal/AuraFlow", torch_dtype=torch.float16)
        >>> pipe = pipe.to("cuda")
        >>> prompt = "A cat holding a sign that says hello world"
        >>> image = pipe(prompt).images[0]
        >>> image.save("aura_flow.png")
        ```
num_inference_stepsdevice	timestepssigmasc                 K   s  |dur|durt d|dur>dtt| jj v }|s(t d| j d| jd||d| | j}t	|}||fS |durpdtt| jj v }|sZt d| j d| jd||d	| | j}t	|}||fS | j|fd
|i| | j}||fS )a  
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`list[int]`, *optional*):
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.
        sigmas (`list[float]`, *optional*):
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.

    Returns:
        `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    NzYOnly one of `timesteps` or `sigmas` can be passed. Please choose one to set custom valuesr   zThe current scheduler class zx's `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler.)r   r   r   zv's `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler.)r   r   r    )

ValueErrorsetinspect	signatureset_timesteps
parameterskeys	__class__r   len)	schedulerr   r   r   r   kwargsaccepts_timestepsaccept_sigmasr   r   d/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/pipelines/aura_flow/pipeline_aura_flow.pyretrieve_timesteps@   s2   r,   c                .       s(  e Zd ZdZg ZdZddgZdedede	de
d	ef
 fd
dZ					d:ddZ										d;deee B deee B dededejdB dejdB dejdB dejdB dejdB dededB fddZ	d<ddZd d! Zed"d# Zed$d% Zed&d' Ze eeddd(dd)dd*d*dddddddd+ddddgfdeee B deee B d,ed-ee d.ededB d/edB d0edB d1ej eej  B dB dejdB dejdB dejdB dejdB dejdB ded2edB d3ed4e!ee"f dB d5e#eegdf e$B e%B dB d6ee d7e&e'B f*d8d9Z(  Z)S )=AuraFlowPipelinea  
    Args:
        tokenizer (`T5TokenizerFast`):
            Tokenizer of class
            [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
        text_encoder ([`T5EncoderModel`]):
            Frozen text-encoder. AuraFlow uses
            [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
            [EleutherAI/pile-t5-xl](https://huggingface.co/EleutherAI/pile-t5-xl) variant.
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        transformer ([`AuraFlowTransformer2DModel`]):
            Conditional Transformer (MMDiT and DiT) architecture to denoise the encoded image latents.
        scheduler ([`FlowMatchEulerDiscreteScheduler`]):
            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
    ztext_encoder->transformer->vaelatentsprompt_embeds	tokenizertext_encodervaetransformerr'   c                    sX   t    | j|||||d t| dd r dt| jjjd  nd| _t	| jd| _
d S )N)r0   r1   r2   r3   r'   r2   r         )vae_scale_factor)super__init__register_modulesgetattrr&   r2   configblock_out_channelsr6   r	   image_processor)selfr0   r1   r2   r3   r'   r%   r   r+   r8      s   

(zAuraFlowPipeline.__init__Nc
           
         s  | j d  dks| j d  dkr#td j d  d| d| d|	d urDt fdd|	D sDtd	 j d
 fdd|	D  |d urW|d urWtd| d| d|d u rc|d u rctd|d urzt|tszt|tsztdt| |d ur|d urtd| d| d|d ur|d urtd| d| d|d ur|d u rtd|d ur|d u rtd|d ur|d ur|j|jkrtd|j d|j d|j|jkrtd|j d|j dd S d S d S )Nr   r   z-`height` and `width` have to be divisible by z	 but are z and .c                 3   s    | ]}| j v V  qd S N_callback_tensor_inputs.0kr>   r   r+   	<genexpr>   s    

z0AuraFlowPipeline.check_inputs.<locals>.<genexpr>z2`callback_on_step_end_tensor_inputs` has to be in z, but found c                    s   g | ]	}| j vr|qS r   rB   rD   rG   r   r+   
<listcomp>   s    z1AuraFlowPipeline.check_inputs.<locals>.<listcomp>zCannot forward both `prompt`: z and `prompt_embeds`: z2. Please make sure to only forward one of the two.zeProvide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.z2`prompt` has to be of type `str` or `list` but is z and `negative_prompt_embeds`: z'Cannot forward both `negative_prompt`: zEMust provide `prompt_attention_mask` when specifying `prompt_embeds`.zWMust provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.zu`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but got: `prompt_embeds` z != `negative_prompt_embeds` z`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but got: `prompt_attention_mask` z% != `negative_prompt_attention_mask` )	r6   r   allrC   
isinstancestrlisttypeshape)
r>   promptheightwidthnegative_promptr/   negative_prompt_embedsprompt_attention_masknegative_prompt_attention_mask"callback_on_step_end_tensor_inputsr   rG   r+   check_inputs   sn   $zAuraFlowPipeline.check_inputsTr4      rP   rS   do_classifier_free_guidancenum_images_per_promptr   rT   rU   rV   max_sequence_length
lora_scalec                    s  |durt | tr|| _| jdurtrt| j|  du r | j |dur,t |tr,d}n|dur:t |tr:t	|}n|j
d }|
}|du r| j|d|ddd}|d }| j|d	dd
j}|j
d |j
d krt||s| j|dd|d df }td| d|   fdd| D }| jdi |d }|d d|j
}|| }| jdur| jj}n| jdur| jj}nd}|j| d}|j
\}}}|d|d}||| |d}||d}||d}|r6|du r6|pd}t |tr|g| n|}|j
d }| j|d|ddd} fdd| D }| jdi |d }|d d|j
}	||	 }|rb|j
d }|j| d}|d|d}||| |d}|	|d}	|	|d}	nd}d}	| jdur{t | tr{tr{t| j| ||||	fS )a
  
        Encodes the prompt into text encoder hidden states.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                prompt to be encoded
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
                instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
                whether to use classifier free guidance or not
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                number of images that should be generated per prompt
            device: (`torch.device`, *optional*):
                torch device to place the resulting embeddings on
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            prompt_attention_mask (`torch.Tensor`, *optional*):
                Pre-generated attention mask for text embeddings.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings.
            negative_prompt_attention_mask (`torch.Tensor`, *optional*):
                Pre-generated attention mask for negative text embeddings.
            max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt.
            lora_scale (`float`, *optional*):
                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
        Nr4   r   T
max_lengthpt)
truncationr^   paddingreturn_tensors	input_idslongest)ra   rb   zZThe following part of your input was truncated because T5 can only handle sequences up to z	 tokens: c                       i | ]
\}}||  qS r   torE   rF   vr   r   r+   
<dictcomp>9      z2AuraFlowPipeline.encode_prompt.<locals>.<dictcomp>attention_mask)dtyper    c                    rf   r   rg   ri   rk   r   r+   rl   Z  rm   r   )rK   r
   _lora_scaler1   r   r   _execution_devicerL   rM   r&   rO   r0   rc   torchequalbatch_decodeloggerwarningitems	unsqueezeexpandro   r3   rh   repeatviewreshaper   )r>   rP   rS   rZ   r[   r   r/   rT   rU   rV   r\   r]   
batch_sizer^   text_inputstext_input_idsuntruncated_idsremoved_textro   bs_embedseq_len_uncond_tokensuncond_inputr   rk   r+   encode_prompt   s   ,

 





zAuraFlowPipeline.encode_promptc	           
      C   sz   |d ur|j ||dS ||t|| j t|| j f}	t|tr3t||kr3tdt| d| dt|	|||d}|S )N)r   ro   z/You have passed a list of generators of length z+, but requested an effective batch size of z@. Make sure the batch size matches the length of the generators.)	generatorr   ro   )rh   intr6   rK   rM   r&   r   r   )
r>   r~   num_channels_latentsrQ   rR   ro   r   r   r.   rO   r   r   r+   prepare_latentsx  s   z AuraFlowPipeline.prepare_latentsc                 C   s    t ddd | jjtjd d S )N
upcast_vaez1.0.0z`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`. For more details, please refer to: https://github.com/huggingface/diffusers/pull/12619#issue-3606633695.ro   )r   r2   rh   rs   float32rG   r   r   r+   r     s   zAuraFlowPipeline.upcast_vaec                 C      | j S rA   )_guidance_scalerG   r   r   r+   guidance_scale     zAuraFlowPipeline.guidance_scalec                 C   r   rA   )_attention_kwargsrG   r   r   r+   attention_kwargs  r   z!AuraFlowPipeline.attention_kwargsc                 C   r   rA   )_num_timestepsrG   r   r   r+   num_timesteps  r   zAuraFlowPipeline.num_timesteps2   g      @i   pilr   r   r   rQ   rR   r   output_typereturn_dictr   callback_on_step_endrW   returnc           *      C   s|  |p	| j jj| j }|p| j jj| j }| j|||||||||d	 || _|| _|dur4t|tr4d}n|durBt|t	rBt
|}n|jd }| j}| jdurV| jddnd}|dk}| j|||||||||||d\}}}}|r{tj||gdd}trd	}n|}t| j|||d
\}}| j jj}| || ||||j||	|
}
tt
||| jj  d}t
|| _| j|d}t|D ]\}}|rt|
gd n|
} t|d g| jd }!|!j|
j |
jd}!| j | ||!d| jdd }"|r|"!d\}#}$|#||$|#   }"| jj"|"||
ddd }
|dur7i }%|D ]
}&t# |& |%|&< q|| |||%}'|'$d|
}
|'$d|}|t
|d ksR|d |krV|d | jj dkrV|%  tr]t&'  qW d   n	1 siw   Y  |dkrv|
}(n9| j(jtj)ko| j(jj*})|)r| +  |
t,t-| j(j./ j}
| j(j0|
| j(jj1 ddd }(| j2j3|(|d}(| 4  |s|(fS t5|(dS )a{  
        Function invoked when calling the pipeline for generation.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
                instead.
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
                The height in pixels of the generated image. This is set to 1024 by default for best results.
            width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
                The width in pixels of the generated image. This is set to 1024 by default for best results.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            sigmas (`list[float]`, *optional*):
                Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
                `num_inference_steps` and `timesteps` must be `None`.
            guidance_scale (`float`, *optional*, defaults to 5.0):
                Guidance scale as defined in [Classifier-Free Diffusion
                Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
                of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
                `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
                the text `prompt`, usually at the expense of lower image quality.
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
            latents (`torch.FloatTensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will be generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            prompt_attention_mask (`torch.Tensor`, *optional*):
                Pre-generated attention mask for text embeddings.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            negative_prompt_attention_mask (`torch.Tensor`, *optional*):
                Pre-generated attention mask for negative text embeddings.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
                of a plain tuple.
            attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            callback_on_step_end (`Callable`, *optional*):
                A function that calls at the end of each denoising steps during the inference. The function is called
                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
                `callback_on_step_end_tensor_inputs`.
            callback_on_step_end_tensor_inputs (`list`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
                `._callback_tensor_inputs` attribute of your pipeline class.
            max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.

        Examples:

        Returns: [`~pipelines.ImagePipelineOutput`] or `tuple`:
            If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned
            where the first element is a list with the generated images.
        )rW   Nr4   r   scaleg      ?)rP   rS   rZ   r[   r   r/   rT   rU   rV   r\   r]   )dimcpu)r   )totalr   i  r   F)encoder_hidden_statestimestepr   r   )r   r.   r/   latent)r   )images)6r3   r;   sample_sizer6   rX   r   r   rK   rL   rM   r&   rO   rr   r   getr   rs   catXLA_AVAILABLEr,   r'   in_channelsr   ro   maxorderr   progress_bar	enumeratetensorrz   rh   r   chunksteplocalspopupdatexm	mark_stepr2   float16force_upcastr   nextiterpost_quant_convr#   decodescaling_factorr=   postprocessmaybe_free_model_hooksr   )*r>   rP   rS   r   r   r   r[   rQ   rR   r   r.   r/   rU   rT   rV   r\   r   r   r   r   rW   r~   r   r]   rZ   timestep_devicer   latent_channelsnum_warmup_stepsr   itlatent_model_inputr   
noise_prednoise_pred_uncondnoise_pred_textcallback_kwargsrF   callback_outputsimageneeds_upcastingr   r   r+   __call__  s   d





	
6
+
zAuraFlowPipeline.__call__)NNNNN)
NTr4   NNNNNrY   NrA   )*__name__
__module____qualname____doc___optional_componentsmodel_cpu_offload_seqrC   r   r   r   r   r   r8   rX   rL   rM   boolr   rs   r   Tensorfloatr   r   r   propertyr   r   r   no_gradr   EXAMPLE_DOC_STRING	Generatordictr   r   r   r   r   tupler   __classcell__r   r   r?   r+   r-   {   s   
F

	

 
 




	
r-   )NNNN)1r    typingr   r   rs   transformersr   r   	callbacksr   r   r=   r	   loadersr
   modelsr   r   
schedulersr   utilsr   r   r   r   r   r   r   utils.torch_utilsr   pipeline_utilsr   r   torch_xla.core.xla_modelcore	xla_modelr   r   
get_loggerr   rv   r   r   rL   r   rM   r   r,   r-   r   r   r   r+   <module>   s@   $	



;