o
    ۷iΉ                  
   @   sT  d dl Z d dlmZmZ d dlZd dlZd dlmZm	Z	 ddl
mZmZ ddlmZmZ ddlmZmZ ddlmZ dd	lmZ dd
lmZmZmZ ddlmZ ddlmZ e red dlm   m!Z" dZ#ndZ#e$e%Z&dZ'			d"de(de)de)de)fddZ*				d#de(dB de+ej,B dB de-e( dB de-e) dB fddZ.G d d! d!eZ/dS )$    N)AnyCallable)AutoTokenizerGlmModel   )MultiPipelineCallbacksPipelineCallback)PipelineImageInputVaeImageProcessor)AutoencoderKLCogView4Transformer2DModel)DiffusionPipeline)FlowMatchEulerDiscreteScheduler)is_torch_xla_availableloggingreplace_example_docstring)randn_tensor   )CogView4PipelineOutputTFaw  
    Examples:
        ```python
        >>> import torch
        >>> from diffusers import CogView4ControlPipeline

        >>> pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B-Control", torch_dtype=torch.bfloat16)
        >>> control_image = load_image(
        ...     "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
        ... )
        >>> prompt = "A bird in space"
        >>> image = pipe(prompt, control_image=control_image, height=1024, width=1024, guidance_scale=3.5).images[0]
        >>> image.save("cogview4-control.png")
        ```
         ?      ?base_seq_len
base_shift	max_shiftreturnc                 C   s   | | d }|| | }|S )Ng      ? )image_seq_lenr   r   r   mmur   r   l/home/ubuntu/vllm_env/lib/python3.10/site-packages/diffusers/pipelines/cogview4/pipeline_cogview4_control.pycalculate_shift<   s   r!   num_inference_stepsdevice	timestepssigmasc                 K   sT  dt t| jj v }dt t| jj v }|durF|durF|s/|s/td| j d| jd|||d| | j}t	|}||fS |duro|du ro|sYtd| j d| jd||d| | j}t	|}||fS |du r|dur|std| j d	| jd||d
| | j}t	|}||fS | j|fd|i| | j}||fS )a  
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`list[int]`, *optional*):
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.
        sigmas (`list[float]`, *optional*):
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.

    Returns:
        `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    r$   r%   NzThe current scheduler class z's `set_timesteps` does not support custom timestep or sigma schedules. Please check whether you are using the correct scheduler.)r$   r%   r#   zx's `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler.)r$   r#   zv's `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler.)r%   r#   r#   r   )
setinspect	signatureset_timesteps
parameterskeys
ValueError	__class__r$   len)	schedulerr"   r#   r$   r%   kwargsaccepts_timestepsaccepts_sigmasr   r   r    retrieve_timestepsH   s@   r3   c                2       s  e Zd ZdZg ZdZg dZdedede	de
def
 fd	d
Z				dFdeee B dedejdB dejdB fddZ								dGdeee B deee B dB dededejdB dejdB dejdB dejdB defddZdHddZ		dIdd Z		dJd!d"Zed#d$ Zed%d& Zed'd( Zed)d* Zed+d, Zed-d. Z e! e"e#dddddd/ddd0ddddddd1d2dddd3gdfdeee B dB deee B dB d4e$d5edB d6edB d7ed8ee dB d9ee% dB d:e%ded;ej&eej& B dB d3ej'dB dej'dB dej'dB d<e(eef dB d=e(eef d>ed?ed@e)ee*f dB dAe+eegdf e,B e-B dB dBee dedCe.e(B f.dDdEZ/  Z0S )KCogView4ControlPipelineaR  
    Pipeline for text-to-image generation using CogView4.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    Args:
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        text_encoder ([`GLMModel`]):
            Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf).
        tokenizer (`PreTrainedTokenizer`):
            Tokenizer of class
            [PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer).
        transformer ([`CogView4Transformer2DModel`]):
            A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
    ztext_encoder->transformer->vae)latentsprompt_embedsnegative_prompt_embeds	tokenizertext_encodervaetransformerr/   c                    sX   t    | j|||||d t| dd r dt| jjjd  nd| _t	| jd| _
d S )N)r8   r9   r:   r;   r/   r:      r      )vae_scale_factor)super__init__register_modulesgetattrr.   r:   configblock_out_channelsr>   r
   image_processor)selfr8   r9   r:   r;   r/   r-   r   r    r@      s   

(z CogView4ControlPipeline.__init__N   promptmax_sequence_lengthr#   dtypec                 C   s2  |p| j }|p
| jj}t|tr|gn|}| j|d|dddd}|j}| j|dddj}|jd |jd krWt	||sW| j
|d d |d df }td| d	|  |jd }	d
|	d
  d
 }
|
dkrtj|jd |
f| jj|j|jd}tj||gdd}| j||ddjd }|j||d}|S )NlongestTpt)padding
max_length
truncationadd_special_tokensreturn_tensors)rN   rR   r   zXThe following part of your input was truncated because `max_sequence_length` is set to  z	 tokens:    r   )
fill_valuerK   r#   dim)output_hidden_statesrK   r#   )_execution_devicer9   rK   
isinstancestrr8   	input_idsshapetorchequalbatch_decodeloggerwarningfullpad_token_idr#   cattohidden_states)rF   rI   rJ   r#   rK   text_inputstext_input_idsuntruncated_idsremoved_textcurrent_length
pad_lengthpad_idsr6   r   r   r    _get_glm_embeds   sF   
  
z'CogView4ControlPipeline._get_glm_embedsTr   negative_promptdo_classifier_free_guidancenum_images_per_promptr6   r7   c
              
   C   sR  |p| j }t|tr|gn|}|durt|}
n|jd }
|du r)| ||	||}|d}|d|d}||
| |d}|r|du r|pGd}t|trR|
|g n|}|durot	|t	|urot
dt	| dt	| d|
t|krtd	| d
t| d| d
|
 d	| ||	||}|d}|d|d}||
| |d}||fS )a  
        Encodes the prompt into text encoder hidden states.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                prompt to be encoded
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
                Whether to use classifier free guidance or not.
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                Number of images that should be generated per prompt. torch device to place the resulting embeddings on
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            device: (`torch.device`, *optional*):
                torch device
            dtype: (`torch.dtype`, *optional*):
                torch dtype
            max_sequence_length (`int`, defaults to `1024`):
                Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
        Nr   r   rS    z?`negative_prompt` should be the same type to `prompt`, but got z != .z`negative_prompt`: z has batch size z, but `prompt`: zT. Please make sure that passed `negative_prompt` matches the batch size of `prompt`.)r[   r\   r]   r.   r_   rq   sizerepeatviewtype	TypeErrorr,   )rF   rI   rr   rs   rt   r6   r7   r#   rK   rJ   
batch_sizeseq_lenr   r   r    encode_prompt   s@   
(



z%CogView4ControlPipeline.encode_promptc	           
      C   sv   |d ur	| |S ||t|| j t|| j f}	t|tr1t||kr1tdt| d| dt|	|||d}|S )Nz/You have passed a list of generators of length z+, but requested an effective batch size of z@. Make sure the batch size matches the length of the generators.)	generatorr#   rK   )rh   intr>   r\   listr.   r,   r   )
rF   r|   num_channels_latentsheightwidthrK   r#   r   r5   r_   r   r   r    prepare_latents0  s   
z'CogView4ControlPipeline.prepare_latentsFc
                 C   s   t |tjrn	| jj|||d}|jd }
|
dkr|}n|}|j|d|jd | d}|j||d}|r>|	s>t|gd }|S )N)r   r   r   r   )rW   output_size)r#   rK   r<   )	r\   r`   TensorrE   
preprocessr_   repeat_interleaverh   rg   )rF   imager   r   r|   rt   r#   rK   rs   
guess_modeimage_batch_size	repeat_byr   r   r    prepare_imageB  s   
z%CogView4ControlPipeline.prepare_imagec                    sj  |d dks|d dkrt d| d| d|d ur8t fdd|D s8t d j d	 fd
d|D  |d urK|d urKt d| d| d|d u rW|d u rWt d|d urnt|tsnt|tsnt dt| |d ur|d urt d| d| d|d ur|d urt d| d| d|d ur|d ur|j|jkrt d|j d|j dd S d S d S )NrT   r   z8`height` and `width` have to be divisible by 16 but are z and rv   c                 3   s    | ]}| j v V  qd S N_callback_tensor_inputs.0krF   r   r    	<genexpr>q  s    

z7CogView4ControlPipeline.check_inputs.<locals>.<genexpr>z2`callback_on_step_end_tensor_inputs` has to be in z, but found c                    s   g | ]	}| j vr|qS r   r   r   r   r   r    
<listcomp>u  s    z8CogView4ControlPipeline.check_inputs.<locals>.<listcomp>zCannot forward both `prompt`: z and `prompt_embeds`: z2. Please make sure to only forward one of the two.zeProvide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.z2`prompt` has to be of type `str` or `list` but is z and `negative_prompt_embeds`: z'Cannot forward both `negative_prompt`: zu`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but got: `prompt_embeds` z != `negative_prompt_embeds` )r,   allr   r\   r]   r   rz   r_   )rF   rI   r   r   rr   "callback_on_step_end_tensor_inputsr6   r7   r   r   r    check_inputsd  sR   
z$CogView4ControlPipeline.check_inputsc                 C      | j S r   _guidance_scaler   r   r   r    guidance_scale     z&CogView4ControlPipeline.guidance_scalec                 C   s
   | j dkS )Nr   r   r   r   r   r    rs     s   
z3CogView4ControlPipeline.do_classifier_free_guidancec                 C   r   r   )_num_timestepsr   r   r   r    num_timesteps  r   z%CogView4ControlPipeline.num_timestepsc                 C   r   r   )_attention_kwargsr   r   r   r    attention_kwargs  r   z(CogView4ControlPipeline.attention_kwargsc                 C   r   r   )_current_timestepr   r   r   r    current_timestep  r   z(CogView4ControlPipeline.current_timestepc                 C   r   r   )
_interruptr   r   r   r    	interrupt  r   z!CogView4ControlPipeline.interrupt2   g      @)r   r   pilr5   control_imager   r   r"   r$   r%   r   r   original_sizecrops_coords_top_leftoutput_typereturn_dictr   callback_on_step_endr   r   c           -      C   s  t |ttfr
|j}|p| jjj| j }|p| jjj| j }|p#||f}||f}| ||||||| |	| _	|| _
d| _d| _|durKt |trKd}n|durYt |trYt|}n|jd }| j}| j||| j|
||||d\}}| jjjd }| j|||||
 |
|| jjd}|jdd \}}d}| j|j }|| | jjj }| ||
 |||tj|||}tj|g|j|d	}tj|g|j|d	}tj|g|j|d	}| ||
 d}| ||
 d}| ||
 d}|| j || j  | jjj!d  }|du r
t"#| j$jj%d
|nt"&|}|'t"j('t"j}|du r%|| j$jj% n|}t)|| j$j*dd| j$j*dd| j$j*dd}t+rFd}n|}t,| j$|||||d\}}t|| _-| jj}t.t||| j$j/  d} | j0|d}!t1|D ]\}"}#| j2rqv|#| _tj3||gdd4|}$|#5|jd }%| j|$||%||||ddd }&| jr| j|$||%||||ddd }'|'| j6|&|'   }(n|&}(| j$j7|(|#|ddd }|duri })|D ]
}*t8 |* |)|*< q|| |"| j$j9|" |)}+|+:d|}|+:d|}|+:d|}|"t|d ks|"d | kr!|"d | j$j/ dkr!|!;  t+r(t<=  qvW d   n	1 s5w   Y  d| _|dksZ|4| jj| jjj }| jj>|d|dd },n|},| j?j@|,|d},| A  |sn|,fS tB|,dS )a@  
        Function invoked when calling the pipeline for generation.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
                The height in pixels of the generated image. If not provided, it is set to 1024.
            width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
                The width in pixels of the generated image. If not provided it is set to 1024.
            num_inference_steps (`int`, *optional*, defaults to `50`):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            timesteps (`list[int]`, *optional*):
                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
                passed will be used. Must be in descending order.
            sigmas (`list[float]`, *optional*):
                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
                will be used.
            guidance_scale (`float`, *optional*, defaults to `5.0`):
                Guidance scale as defined in [Classifier-Free Diffusion
                Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
                of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
                `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
                the text `prompt`, usually at the expense of lower image quality.
            num_images_per_prompt (`int`, *optional*, defaults to `1`):
                The number of images to generate per prompt.
            generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
            latents (`torch.FloatTensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will be generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            original_size (`tuple[int]`, *optional*, defaults to (1024, 1024)):
                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
                explained in section 2.2 of
                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
            crops_coords_top_left (`tuple[int]`, *optional*, defaults to (0, 0)):
                `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.pipeline_CogView4.CogView4PipelineOutput`] instead of a plain
                tuple.
            attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            callback_on_step_end (`Callable`, *optional*):
                A function that calls at the end of each denoising steps during the inference. The function is called
                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
                `callback_on_step_end_tensor_inputs`.
            callback_on_step_end_tensor_inputs (`list`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
                `._callback_tensor_inputs` attribute of your pipeline class.
            max_sequence_length (`int`, defaults to `224`):
                Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
        Examples:

        Returns:
            [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`:
            [`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a
            `tuple`. When returning a tuple, the first element is a list with the generated images.
        NFr   r   )rt   r6   r7   rJ   r#   r<   )r   r   r   r|   rt   r#   rK   rY   rZ   g      ?base_image_seq_lenr   r   r   r   r   cpu)r   )totalrV   )ri   encoder_hidden_statestimestepr   target_sizecrop_coordsr   r   )r   r5   r6   r7   latent)r   r   )r   )images)Cr\   r   r   tensor_inputsr;   rC   sample_sizer>   r   r   r   r   r   r]   r   r.   r_   r[   r~   rs   in_channelsr   r:   rK   encodelatent_distsamplescaling_factorr   r`   float32tensorrx   
patch_sizenplinspacer/   num_train_timestepsarrayastypeint64r!   getXLA_AVAILABLEr3   r   maxorderprogress_bar	enumerater   rg   rh   expandr   steplocalsr%   popupdatexm	mark_stepdecoderE   postprocessmaybe_free_model_hooksr   )-rF   rI   rr   r   r   r   r"   r$   r%   r   rt   r   r5   r6   r7   r   r   r   r   r   r   r   rJ   r   r|   r#   latent_channelsvae_shift_factorr   r   timestep_devicetransformer_dtypenum_warmup_stepsr   itlatent_model_inputr   noise_pred_condnoise_pred_uncond
noise_predcallback_kwargsr   callback_outputsr   r   r   r    __call__  s*  q	


	


		
68

z CogView4ControlPipeline.__call__)NrH   NN)NTr   NNNNrH   r   )FF)NN)1__name__
__module____qualname____doc___optional_componentsmodel_cpu_offload_seqr   r   r   r   r   r   r@   r]   r   r   r`   r#   rK   rq   boolr   r~   r   r   r   propertyr   rs   r   r   r   r   no_gradr   EXAMPLE_DOC_STRINGr	   float	GeneratorFloatTensortupledictr   r   r   r   r   r   __classcell__r   r   rG   r    r4      s$   

/
	


O
)
3







	

r4   )r   r   r   )NNNN)0r'   typingr   r   numpyr   r`   transformersr   r   	callbacksr   r   rE   r	   r
   modelsr   r   pipelines.pipeline_utilsr   
schedulersr   utilsr   r   r   utils.torch_utilsr   pipeline_outputr   torch_xla.core.xla_modelcore	xla_modelr   r   
get_loggerr   rc   r   r   r   r!   r]   r#   r   r3   r4   r   r   r   r    <module>   sZ   




C