o
    Gif                  
   @   s0  d dl Z d dlmZmZ d dlZd dlmZmZ ddlm	Z	 ddl
mZmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZ ddlmZ ddlmZ eeZdZ				d"de de de!de!fddZ"				d#de dB de#ej$B dB de%e  dB de%e! dB fddZ&G d d! d!eeeZ'dS )$    N)AnyCallable)AutoTokenizerPreTrainedModel   )VaeImageProcessor)FromSingleFileMixinZImageLoraLoaderMixin)AutoencoderKL)ZImageTransformer2DModel)DiffusionPipeline)FlowMatchEulerDiscreteScheduler)loggingreplace_example_docstring)randn_tensor   )ZImagePipelineOutputut  
    Examples:
        ```py
        >>> import torch
        >>> from diffusers import ZImagePipeline

        >>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
        >>> pipe.to("cuda")

        >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch.
        >>> # (1) Use flash attention 2
        >>> # pipe.transformer.set_attention_backend("flash")
        >>> # (2) Use flash attention 3
        >>> # pipe.transformer.set_attention_backend("_flash_3")

        >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化：一辆复古蒸汽小火车化身为巨大的拉链头，正拉开厚厚的冬日积雪，展露出一个生机盎然的春天。"
        >>> image = pipe(
        ...     prompt,
        ...     height=1024,
        ...     width=1024,
        ...     num_inference_steps=9,
        ...     guidance_scale=0.0,
        ...     generator=torch.Generator("cuda").manual_seed(42),
        ... ).images[0]
        >>> image.save("zimage.png")
        ```
            ?ffffff?base_seq_lenmax_seq_len
base_shift	max_shiftc                 C   s,   || ||  }|||  }| | | }|S N )image_seq_lenr   r   r   r   mbmur   r   `/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/pipelines/z_image/pipeline_z_image.pycalculate_shift@   s   r"   num_inference_stepsdevice	timestepssigmasc                 K   s  |dur|durt d|dur>dtt| jj v }|s(t d| j d| jd||d| | j}t	|}||fS |durpdtt| jj v }|sZt d| j d| jd||d	| | j}t	|}||fS | j|fd
|i| | j}||fS )a  
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`list[int]`, *optional*):
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.
        sigmas (`list[float]`, *optional*):
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.

    Returns:
        `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    NzYOnly one of `timesteps` or `sigmas` can be passed. Please choose one to set custom valuesr%   zThe current scheduler class zx's `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler.)r%   r$   r&   zv's `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler.)r&   r$   r$   r   )

ValueErrorsetinspect	signatureset_timesteps
parameterskeys	__class__r%   len)	schedulerr#   r$   r%   r&   kwargsaccepts_timestepsaccept_sigmasr   r   r!   retrieve_timestepsN   s2   r4   c                +       s8  e Zd ZdZg ZddgZdededede	de
f
 fd	d
Z						d;deee B dejdB dedeee B dB deej dB dejdB defddZ			d<deee B dejdB deej dB dedeej f
ddZ	d=ddZedd Zedd Zedd  Zed!d" Zed#d$ Ze eedddd%dd&d'd(dd)ddddd*ddddgdfdeee B d+edB d,edB d-ed.ee dB d/ed0ed1edeee B dB d2edB d3ej eej  B dB dejdB deej dB deej dB d4edB d5ed6e!ee"f dB d7e#eegdf dB d8ee def(d9d:Z$  Z%S )>ZImagePipelineztext_encoder->transformer->vaelatentsprompt_embedsr0   vaetext_encoder	tokenizertransformerc                    sd   t    | j|||||d t| dr$| jd ur$dt| jjjd  nd| _t	| jd d| _
d S )N)r8   r9   r:   r0   r;   r8      r      )vae_scale_factor)super__init__register_moduleshasattrr8   r/   configblock_out_channelsr>   r   image_processor)selfr0   r8   r9   r:   r;   r.   r   r!   r@      s   
,zZImagePipeline.__init__NT   promptr$   do_classifier_free_guidancenegative_promptnegative_prompt_embedsmax_sequence_lengthc                 C   s   t |tr|gn|}| j||||d}|rB|d u r!dd |D }n
t |tr)|gn|}t|t|ks5J | j||||d}||fS g }||fS )N)rI   r$   r7   rM   c                 S   s   g | ]}d qS ) r   ).0_r   r   r!   
<listcomp>   s    z0ZImagePipeline.encode_prompt.<locals>.<listcomp>)
isinstancestr_encode_promptr/   )rF   rI   r$   rJ   rK   r7   rL   rM   r   r   r!   encode_prompt   s*   
zZImagePipeline.encode_promptreturnc                 C   s   |p| j }|d ur|S t|tr|g}t|D ]\}}d|dg}| jj|dddd}|||< q| j|d|ddd}|j|}	|j|	 }
| j
|	|
dd	jd
 }g }tt|D ]}||| |
|   q[|S )Nuser)rolecontentFT)tokenizeadd_generation_promptenable_thinking
max_lengthpt)paddingr]   
truncationreturn_tensors)	input_idsattention_maskoutput_hidden_states)_execution_devicerR   rS   	enumerater:   apply_chat_templaterb   torc   boolr9   hidden_statesranger/   append)rF   rI   r$   r7   rM   iprompt_itemmessagestext_inputstext_input_idsprompt_masksembeddings_listr   r   r!   rT      sF   


zZImagePipeline._encode_promptc	           
      C   s   dt || jd   }dt || jd   }||||f}	|d u r*t|	|||d}|S |j|	kr:td|j d|	 ||}|S )Nr<   )	generatorr$   dtypezUnexpected latents shape, got z, expected )intr>   r   shaper'   ri   )
rF   
batch_sizenum_channels_latentsheightwidthrv   r$   ru   r6   rx   r   r   r!   prepare_latents   s   

zZImagePipeline.prepare_latentsc                 C      | j S r   _guidance_scalerF   r   r   r!   guidance_scale     zZImagePipeline.guidance_scalec                 C   s
   | j dkS )Nr   r   r   r   r   r!   rJ     s   
z*ZImagePipeline.do_classifier_free_guidancec                 C   r~   r   )_joint_attention_kwargsr   r   r   r!   joint_attention_kwargs  r   z%ZImagePipeline.joint_attention_kwargsc                 C   r~   r   )_num_timestepsr   r   r   r!   num_timesteps  r   zZImagePipeline.num_timestepsc                 C   r~   r   )
_interruptr   r   r   r!   	interrupt!  r   zZImagePipeline.interrupt2   g      @Fg      ?r   pilr{   r|   r#   r&   r   cfg_normalizationcfg_truncationnum_images_per_promptru   output_typereturn_dictr   callback_on_step_end"callback_on_step_end_tensor_inputsc           :   
      s  |pd}|pd}| j d }|| dkr!td| d| d| d|| dkr5td| d| d	| d| j}|| _|| _d
| _|| _|| _|durSt|t	rSd}n|durat|t
rat|}nt|}|dury|du ry| jrx|du rxtdn| j||	| j||||d\}}| jj}| |  |||tj|||} dkr fdd|D }| jr|r fdd|D }|  }|jd d |jd d  }t|| jjdd| jjdd| jjdd| jjdd}d| j_d|i}t| j||fd|i|\}}tt||| jj  d}t|| _| j|dn}t|D ]`\} }!| jr(q|! |jd }"d|" d }"|"d ! }#| j"}$| jrY| jdurYt#| jdkrY|#| jkrYd}$| jo`|$dk}%|%r}|$| jj%}&|&&dddd}'|| }(|"&d})n|$| jj%}'|}(|"})|''d}'t
|'j(dd }*| j|*|)|(d
d!d }+|%r|+d| },|+|d }-g }.t)|D ]F}/|,|/ # }0|-|/ # }1|0|$|0|1   }2| jrt#| jdkrtj*+|0}3tj*+|2}4|3t#| j }5|4|5kr|2|5|4  }2|.,|2 qtj-|.dd }.ntj-d"d |+D dd }.|..d}.|. }.| jj/|.$tj|!|d
d!d }|j%tjks1J |dur^i }6|D ]
}7t0 |7 |6|7< q:|| | |!|6}8|81d#|}|81d$|}|81d%|}| t|d ksy| d |kr}| d | jj dkr}|2  qW d   n	1 sw   Y  |d&kr|}9n%|$| j3j%}|| j3jj4 | j3jj5 }| j3j6|d
d!d }9| j7j8|9|d'}9| 9  |s|9fS t:|9d(S ))a  
        Function invoked when calling the pipeline for generation.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
                instead.
            height (`int`, *optional*, defaults to 1024):
                The height in pixels of the generated image.
            width (`int`, *optional*, defaults to 1024):
                The width in pixels of the generated image.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            sigmas (`list[float]`, *optional*):
                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
                will be used.
            guidance_scale (`float`, *optional*, defaults to 5.0):
                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
                `guidance_scale` is defined as `w` of equation 2. of [Imagen
                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
                usually at the expense of lower image quality.
            cfg_normalization (`bool`, *optional*, defaults to False):
                Whether to apply configuration normalization.
            cfg_truncation (`float`, *optional*, defaults to 1.0):
                The truncation value for configuration.
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
            latents (`torch.FloatTensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will be generated by sampling using the supplied random `generator`.
            prompt_embeds (`list[torch.FloatTensor]`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`list[torch.FloatTensor]`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
                tuple.
            joint_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            callback_on_step_end (`Callable`, *optional*):
                A function that calls at the end of each denoising steps during the inference. The function is called
                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
                `callback_on_step_end_tensor_inputs`.
            callback_on_step_end_tensor_inputs (`List`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
                `._callback_tensor_inputs` attribute of your pipeline class.
            max_sequence_length (`int`, *optional*, defaults to 512):
                Maximum sequence length to use with the `prompt`.

        Examples:

        Returns:
            [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
            `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
            generated images.
        i   r<   r   zHeight must be divisible by z (got z-). Please adjust the height to a multiple of .zWidth must be divisible by z,). Please adjust the width to a multiple of FNr   zWhen `prompt_embeds` is provided without `prompt`, `negative_prompt_embeds` must also be provided for classifier-free guidance.)rI   rK   rJ   r7   rL   r$   rM   c                       g | ]}t  D ]}|qqS r   rl   )rO   perP   r   r   r!   rQ         z+ZImagePipeline.__call__.<locals>.<listcomp>c                    r   r   r   )rO   nperP   r   r   r!   rQ     r   r   base_image_seq_lenr   max_image_seq_lenr   r   r   r   r   g        r    r&   )totali  )dim)r   c                 S   s   g | ]}|  qS r   )float)rO   tr   r   r!   rQ   +  s    r6   r7   rL   latent)r   )images);r>   r'   rf   r   r   r   _cfg_normalization_cfg_truncationrR   rS   listr/   rJ   rU   r;   in_channelsr}   torchfloat32rx   r"   r0   rC   get	sigma_minr4   maxorderr   progress_barrg   r   expanditemr   r   ri   rv   repeat	unsqueezeunbindrl   linalgvector_normrm   stacksqueezesteplocalspopupdater8   scaling_factorshift_factordecoderE   postprocessmaybe_free_model_hooksr   ):rF   rI   r{   r|   r#   r&   r   r   r   rK   r   ru   r6   r7   rL   r   r   r   r   r   rM   	vae_scaler$   ry   rz   actual_batch_sizer   r    scheduler_kwargsr%   num_warmup_stepsr   rn   r   timestept_normcurrent_guidance_scale	apply_cfglatents_typedlatent_model_inputprompt_embeds_model_inputtimestep_model_inputlatent_model_input_listmodel_out_listpos_outneg_out
noise_predjposnegpredori_pos_normnew_pos_normmax_new_normcallback_kwargskcallback_outputsimager   r   r!   __call__%  s6  f







 
6
X
zZImagePipeline.__call__)NTNNNrH   )NNrH   r   )&__name__
__module____qualname__model_cpu_offload_seq_optional_components_callback_tensor_inputsr   r
   r   r   r   r@   rS   r   r   r$   rj   FloatTensorrw   rU   rT   r}   propertyr   rJ   r   r   r   no_gradr   EXAMPLE_DOC_STRINGr   	Generatordictr   r   r   __classcell__r   r   rG   r!   r5      s    

%

<







	
r5   )r   r   r   r   )NNNN)(r)   typingr   r   r   transformersr   r   rE   r   loadersr   r	   models.autoencodersr
   models.transformersr   pipelines.pipeline_utilsr   
schedulersr   utilsr   r   utils.torch_utilsr   pipeline_outputr   
get_loggerr   loggerr   rw   r   r"   rS   r$   r   r4   r5   r   r   r   r!   <module>   sT   
 



;