o
    ۷i                     @   sN  d dl Z d dlmZ d dlZd dlZd dlZd dlm	Z
 d dlmZmZmZmZ ddlmZmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZ ddlmZmZm Z m!Z! ddl"m#Z# ddl$m%Z% ddl&m'Z' e r{d dl(m)  m*Z+ dZ,ndZ,e -e.Z/e rd dl0Z0e -e.Z/dZ1dd Z2dd Z3dd Z4G dd de%eZ5dS )    N)Callable)
functional)CLIPTextModelCLIPTokenizer"Qwen2_5_VLForConditionalGenerationQwen2VLProcessor   )MultiPipelineCallbacksPipelineCallback)VaeImageProcessor)KandinskyLoraLoaderMixin)AutoencoderKL)Kandinsky5Transformer3DModel)FlowMatchEulerDiscreteScheduler)is_ftfy_availableis_torch_xla_availableloggingreplace_example_docstring)randn_tensor   )DiffusionPipeline   )KandinskyImagePipelineOutputTFa<  
    Examples:

        ```python
        >>> import torch
        >>> from diffusers import Kandinsky5T2IPipeline

        >>> # Available models:
        >>> # kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers
        >>> # kandinskylab/Kandinsky-5.0-T2I-Lite-pretrain-Diffusers

        >>> model_id = "kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers"
        >>> pipe = Kandinsky5T2IPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
        >>> pipe = pipe.to("cuda")

        >>> prompt = "A cat and a dog baking a cake together in a kitchen."

        >>> output = pipe(
        ...     prompt=prompt,
        ...     negative_prompt="",
        ...     height=1024,
        ...     width=1024,
        ...     num_inference_steps=50,
        ...     guidance_scale=3.5,
        ... ).frames[0]
        ```
c                 C   s(   t  rt| } tt| } |  S )z
    Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py

    Clean text using ftfy if available and unescape HTML entities.
    )r   ftfyfix_texthtmlunescapestriptext r    k/home/ubuntu/vllm_env/lib/python3.10/site-packages/diffusers/pipelines/kandinsky5/pipeline_kandinsky_t2i.pybasic_cleanW   s   
r"   c                 C   s   t dd| } |  } | S )z
    Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py

    Normalize whitespace in text by replacing multiple spaces with single space.
    z\s+ )resubr   r   r    r    r!   whitespace_cleanc   s   r&   c                 C   s   t t| } | S )z
    Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py

    Apply both basic cleaning and whitespace normalization to prompts.
    )r&   r"   r   r    r    r!   prompt_cleann   s   r'   c                +       s  e Zd ZdZdZg dZdededede	de
d	ed
ef fddZ			dAdee dejdB dedejdB fddZ		dBdeee B dejdB dejdB fddZ				dCdeee B dededejdB dejdB f
ddZ								dDddZ							dEded ed!ed"edejdB dejdB d#ejeej B dB d$ejdB d%ejfd&d'Zed(d) Zed*d+ Zed,d- Ze e e!ddddd.d/dddddddddd0d1dd$gdfdeee B d2eee B dB d!ed"ed3ed4e"dedB d#ejeej B dB d$ejdB d5ejdB d6ejdB d7ejdB d8ejdB d9ejdB d:ejdB d;edB d<e#d=e$eedge%e&B f dB d>ee def(d?d@Z'  Z(S )FKandinsky5T2IPipelinea  
    Pipeline for text-to-image generation using Kandinsky 5.0.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).

    Args:
        transformer ([`Kandinsky5Transformer3DModel`]):
            Conditional Transformer to denoise the encoded image latents.
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder Model [black-forest-labs/FLUX.1-dev
            (vae)](https://huggingface.co/black-forest-labs/FLUX.1-dev) to encode and decode videos to and from latent
            representations.
        text_encoder ([`Qwen2_5_VLForConditionalGeneration`]):
            Frozen text-encoder [Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct).
        tokenizer ([`AutoProcessor`]):
            Tokenizer for Qwen2.5-VL.
        text_encoder_2 ([`CLIPTextModel`]):
            Frozen [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel),
            specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
        tokenizer_2 ([`CLIPTokenizer`]):
            Tokenizer for CLIP.
        scheduler ([`FlowMatchEulerDiscreteScheduler`]):
            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
    z.text_encoder->text_encoder_2->transformer->vae)latentsprompt_embeds_qwenprompt_embeds_clipnegative_prompt_embeds_qwennegative_prompt_embeds_cliptransformervaetext_encoder	tokenizertext_encoder_2tokenizer_2	schedulerc              	      sP   t    | j|||||||d d| _d| _d| _t| jd| _g d| _d S )N)r.   r/   r0   r1   r2   r3   r4   z<|im_start|>system
You are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>
<|im_start|>user
{}<|im_end|>)      )vae_scale_factor))   r8   )    )r:   r9   )      )r<   r;   )    )r>   r=   )	super__init__register_modulesprompt_template prompt_template_encode_start_idxvae_scale_factor_spatialr   image_processorresolutions)selfr.   r/   r0   r1   r2   r3   r4   	__class__r    r!   r@      s   


zKandinsky5T2IPipeline.__init__N   promptdevicemax_sequence_lengthdtypec              	      sn  |p j }|p
 jj} fdd|D } j| } j|dddddd }|jd |krft|D ]5\}}	||  jd	 }
 j|
|d
 d }t|dkre|	dt|  ||< t	
d| d|  q0 j|dd|dddd|} j|d dddd d dd jdf }|d dd jdf }tj|ddd}tj|dddjtjd}|||fS )aI  
        Encode prompt using Qwen2.5-VL text encoder.

        This method processes the input prompt through the Qwen2.5-VL model to generate text embeddings suitable for
        image generation.

        Args:
            prompt list[str]: Input list of prompts
            device (torch.device): Device to run encoding on
            max_sequence_length (int): Maximum sequence length for tokenization
            dtype (torch.dtype): Data type for embeddings

        Returns:
            tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths
        c                    s   g | ]} j |qS r    )rB   format.0prG   r    r!   
<listcomp>   s    z=Kandinsky5T2IPipeline._encode_prompt_qwen.<locals>.<listcomp>Nptlongest)r   imagesvideosreturn_tensorspadding	input_idsr   r   zXThe following part of your input was truncated because `max_sequence_length` is set to  z	 tokens: T)r   rW   rX   
max_length
truncationrY   rZ   )r[   return_dictoutput_hidden_stateshidden_statesattention_maskr   )dim)r   r   )value)rN   )_execution_devicer0   rN   rC   r1   shape	enumeratedecodelenloggerwarningtotorchcumsumsumFpadint32)rG   rK   rL   rM   rN   
full_textsmax_allowed_lenuntruncated_idsir   tokensremoved_textinputsembedsrc   
cu_seqlensr    rS   r!   _encode_prompt_qwen   sj   


z)Kandinsky5T2IPipeline._encode_prompt_qwenc                 C   sP   |p| j }|p
| jj}| j|dddddd|}| jdi |d }||S )	a  
        Encode prompt using CLIP text encoder.

        This method processes the input prompt through the CLIP model to generate pooled embeddings that capture
        semantic information.

        Args:
            prompt (str | list[str]): Input prompt or list of prompts
            device (torch.device): Device to run encoding on
            dtype (torch.dtype): Data type for embeddings

        Returns:
            torch.Tensor: Pooled text embeddings from CLIP
        M   Tr^   rU   )r^   r_   add_special_tokensrZ   rY   pooler_outputNr    )rf   r2   rN   r3   rm   )rG   rK   rL   rN   rz   pooled_embedr    r    r!   _encode_prompt_clip   s   
	
z)Kandinsky5T2IPipeline._encode_prompt_clipr   num_images_per_promptc                 C   s   |p| j }|p
| jj}t|ts|g}t|}dd |D }| j||||d\}}| j|||d}	|d|d}|	|| d|j
d }|	d|d}	|		|| d}	| }
|
|}ttjdg|tjd|dg}||	|fS )	a  
        Encodes a single prompt (positive or negative) into text encoder hidden states.

        This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text
        representations for image generation.

        Args:
            prompt (`str` or `list[str]`):
                Prompt to be encoded.
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                Number of images to generate per prompt.
            max_sequence_length (`int`, *optional*, defaults to 512):
                Maximum sequence length for text encoding. Must be less than 1024
            device (`torch.device`, *optional*):
                Torch device.
            dtype (`torch.dtype`, *optional*):
                Torch dtype.

        Returns:
            tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
                - Qwen text embeddings of shape (batch_size * num_images_per_prompt, sequence_length, embedding_dim)
                - CLIP pooled embeddings of shape (batch_size * num_images_per_prompt, clip_embedding_dim)
                - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size *
                  num_images_per_prompt + 1,)
        c                 S   s   g | ]}t |qS r    )r'   rP   r    r    r!   rT   J  s    z7Kandinsky5T2IPipeline.encode_prompt.<locals>.<listcomp>)rK   rL   rM   rN   )rK   rL   rN   r   r\   r   rL   rN   )rf   r0   rN   
isinstancelistrj   r}   r   repeatviewrg   diffrepeat_interleavern   cattensorrs   ro   )rG   rK   r   rM   rL   rN   
batch_sizer*   prompt_cu_seqlensr+   original_lengthsrepeated_lengthsrepeated_cu_seqlensr    r    r!   encode_prompt!  sF   
!

		
z#Kandinsky5T2IPipeline.encode_promptc              	      s  |dur|dkrt d||f jvr-ddd  jD }td| d| d	| d
 |durNt fdd|D sNt d j d fdd|D  |dusZ|dusZ|	durj|du sf|du sf|	du rjt d|dusv|dusv|
dur|du s|du s|
du rt d|du r|du rt d|durt|tst|t	st dt
| |durt|tst|t	st dt
| dS dS dS )al  
        Validate input parameters for the pipeline.

        Args:
            prompt: Input prompt
            negative_prompt: Negative prompt for guidance
            height: Image height
            width: Image width
            prompt_embeds_qwen: Pre-computed Qwen prompt embeddings
            prompt_embeds_clip: Pre-computed CLIP prompt embeddings
            negative_prompt_embeds_qwen: Pre-computed Qwen negative prompt embeddings
            negative_prompt_embeds_clip: Pre-computed CLIP negative prompt embeddings
            prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen positive prompt
            negative_prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen negative prompt
            callback_on_step_end_tensor_inputs: Callback tensor inputs

        Raises:
            ValueError: If inputs are invalid
        Nr8   z*max_sequence_length must be less than 1024,c                 S   s"   g | ]\}}d | d| dqS )(r   )r    )rQ   whr    r    r!   rT     s   " z6Kandinsky5T2IPipeline.check_inputs.<locals>.<listcomp>z'`height` and `width` have to be one of z
, but are z and z(. Dimensions will be resized accordinglyc                 3   s    | ]}| j v V  qd S )N_callback_tensor_inputsrQ   krS   r    r!   	<genexpr>  s    

z5Kandinsky5T2IPipeline.check_inputs.<locals>.<genexpr>z2`callback_on_step_end_tensor_inputs` has to be in z, but found c                    s   g | ]	}| j vr|qS r    r   r   rS   r    r!   rT     s    zuIf any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, all three must be provided.zIf any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, all three must be provided.zProvide either `prompt` or `prompt_embeds_qwen` (and corresponding `prompt_embeds_clip` and `prompt_cu_seqlens`). Cannot leave all undefined.z2`prompt` has to be of type `str` or `list` but is z;`negative_prompt` has to be of type `str` or `list` but is )
ValueErrorrF   joinrk   rl   allr   r   strr   type)rG   rK   negative_promptheightwidthr*   r+   r,   r-   r   negative_prompt_cu_seqlens"callback_on_step_end_tensor_inputsrM   resolutions_strr    rS   r!   check_inputs|  sP   #z"Kandinsky5T2IPipeline.check_inputs   r8   r   num_channels_latentsr   r   	generatorr)   returnc	           
      C   s|   |dur|j ||dS |dt|| j t|| j |f}	t|tr4t||kr4tdt| d| dt|	|||d}|S )a  
        Prepare initial latent variables for text-to-image generation.

        This method creates random noise latents

        Args:
            batch_size (int): Number of images to generate
            num_channels_latents (int): Number of channels in latent space
            height (int): Height of generated image
            width (int): Width of generated image
            dtype (torch.dtype): Data type for latents
            device (torch.device): Device to create latents on
            generator (torch.Generator): Random number generator
            latents (torch.Tensor): Pre-existing latents to use

        Returns:
            torch.Tensor: Prepared latent tensor
        Nr   r   z/You have passed a list of generators of length z+, but requested an effective batch size of z@. Make sure the batch size matches the length of the generators.)r   rL   rN   )rm   intrD   r   r   rj   r   r   )
rG   r   r   r   r   rN   rL   r   r)   rg   r    r    r!   prepare_latents  s    z%Kandinsky5T2IPipeline.prepare_latentsc                 C      | j S )z%Get the current guidance scale value.)_guidance_scalerS   r    r    r!   guidance_scale     z$Kandinsky5T2IPipeline.guidance_scalec                 C   r   )z&Get the number of denoising timesteps.)_num_timestepsrS   r    r    r!   num_timesteps  r   z#Kandinsky5T2IPipeline.num_timestepsc                 C   r   )z)Check if generation has been interrupted.)
_interruptrS   r    r    r!   	interrupt  r   zKandinsky5T2IPipeline.interrupt2   g      @pilTr   num_inference_stepsr   r*   r+   r,   r-   r   r   output_typer`   callback_on_step_endr   c           *         s  t |ttfr
|j}| j|| |
|||||||d  f| jvr5| jt fdd| jD  \ || _d| _	| j
}| jj}|durQt |trQd}|g}n|dur_t |tr_t|}n|
jd }|
du ru| j|||||d\}
}}| jd	kr|du rd
}t |tr|dur|gt| n|g}nt|t|krtdt| dt| d|du r| j|||||d\}}}| jj||d | jj}| jjj}| j|| | ||||	d}	tjd|dtj | j d |dtj| j d |dg}tj|   |d}|durtj|   |dnd}g d}d}t||| jj   }t|| _!| j"|d} t#|D ]\}!}"| j$rEq;|"%d&|| }#| j|	'||
'||'||#'|||||dd	j(}$| jd	kr|dur| j|	'||'||'||#'|||||dd	j(}%|%||$|%   }$| jj)|$ddddf |"|	ddd }	|duri }&|D ]
}'t* |' |&|'< q|| |!|"|&}(|(+d|	}	|(+d|
}
|(+d|}|(+d|}|(+d|}|!t|d ks |!d |kr|!d | jj  dkr| ,  t-rt./  q;W d   n	1 sw   Y  |	ddddddddd|f }	|dkr}|	'| j0j}	|	1||d | j | j |}	|	2dddddd}	|	1|| | | j | j }	|	| j0jj3 }	| j04|	j(})| j5j6|)|d})n|	})| 7  |s|)fS t8|)d S )!aX  
        The call function to the pipeline for text-to-image generation.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
                instead. Ignored when not using guidance (`guidance_scale` < `1`).
            height (`int`, defaults to `1024`):
                The height in pixels of the generated image.
            width (`int`, defaults to `1024`):
                The width in pixels of the generated image.
            num_inference_steps (`int`, defaults to `50`):
                The number of denoising steps.
            guidance_scale (`float`, defaults to `5.0`):
                Guidance scale as defined in classifier-free guidance.
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
                A torch generator to make generation deterministic.
            latents (`torch.Tensor`, *optional*):
                Pre-generated noisy latents.
            prompt_embeds_qwen (`torch.Tensor`, *optional*):
                Pre-generated Qwen text embeddings.
            prompt_embeds_clip (`torch.Tensor`, *optional*):
                Pre-generated CLIP text embeddings.
            negative_prompt_embeds_qwen (`torch.Tensor`, *optional*):
                Pre-generated Qwen negative text embeddings.
            negative_prompt_embeds_clip (`torch.Tensor`, *optional*):
                Pre-generated CLIP negative text embeddings.
            prompt_cu_seqlens (`torch.Tensor`, *optional*):
                Pre-generated cumulative sequence lengths for Qwen positive prompt.
            negative_prompt_cu_seqlens (`torch.Tensor`, *optional*):
                Pre-generated cumulative sequence lengths for Qwen negative prompt.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generated image.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`KandinskyImagePipelineOutput`].
            callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
                A function that is called at the end of each denoising step.
            callback_on_step_end_tensor_inputs (`List`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function.
            max_sequence_length (`int`, defaults to `512`):
                The maximum sequence length for text encoding.

        Examples:

        Returns:
            [`~KandinskyImagePipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`KandinskyImagePipelineOutput`] is returned, otherwise a `tuple` is
                returned where the first element is a list with the generated images.
        )rK   r   r   r   r*   r+   r,   r-   r   r   r   rM   c                    s(   g | ]}t |d  |d     qS )r   r   )abs)rQ   rw   r   r   r    r!   rT   v  s   ( z2Kandinsky5T2IPipeline.__call__.<locals>.<listcomp>FNr   r   )rK   r   rM   rL   rN         ? z9`negative_prompt` must have same length as `prompt`. Got z vs .)rL   )r   r   r   r   rN   rL   r   r)   r   )r   r   r   )totalT)	rb   encoder_hidden_statespooled_projectionstimestepvisual_rope_postext_rope_posscale_factorsparse_paramsr`   )r`   r)   r*   r+   r,   r-   latent   r      )r   )image)9r   r
   r	   tensor_inputsr   rF   npargminr   r   rf   r.   rN   r   r   rj   rg   r   r   r   r4   set_timesteps	timestepsconfigin_visual_dimr   rn   arangerD   r   maxitemorderr   progress_barrh   r   	unsqueezer   rm   samplesteplocalspopupdateXLA_AVAILABLExm	mark_stepr/   reshapepermutescaling_factorri   rE   postprocessmaybe_free_model_hooksr   )*rG   rK   r   r   r   r   r   r   r   r)   r*   r+   r,   r-   r   r   r   r`   r   r   rM   rL   rN   r   r   r   r   r   negative_text_rope_posr   r   num_warmup_stepsr   rw   tr   pred_velocityuncond_pred_velocitycallback_kwargsr   callback_outputsr   r    r   r!   __call__  sF  N








(
6&<

zKandinsky5T2IPipeline.__call__)NrJ   N)NN)r   rJ   NN)NNNNNNNN)r   r8   r8   NNNN))__name__
__module____qualname____doc__model_cpu_offload_seqr   r   r   r   r   r   r   r   r@   r   r   rn   rL   r   rN   r}   r   r   r   	GeneratorTensorr   propertyr   r   r   no_gradr   EXAMPLE_DOC_STRINGfloatboolr   r
   r	   r   __classcell__r    r    rH   r!   r(   x   s>    
G

'

a
\	

2



	
r(   )6r   typingr   numpyr   regexr$   rn   torch.nnr   rq   transformersr   r   r   r   	callbacksr	   r
   rE   r   loadersr   modelsr   models.transformersr   
schedulersr   utilsr   r   r   r   utils.torch_utilsr   pipeline_utilsr   pipeline_outputr   torch_xla.core.xla_modelcore	xla_modelr   r   
get_loggerr   rk   r   r   r"   r&   r'   r(   r    r    r    r!   <module>   s<   


