o
    ۷i                  
   @   s<  d dl Z d dlmZ d dlZd dlmZmZmZmZ ddl	m
Z
mZ ddlmZmZ ddlmZ ddlmZ dd	lmZmZmZ dd
lmZ ddlmZ ddlmZ e rad dlm  mZ  dZ!ndZ!e"e#Z$dZ%dd Z&dddZ'				dde(dB de)ej*B dB de+e( dB de+e, dB fddZ-G dd deZ.dS )     N)Callable)	BertModelBertTokenizerQwen2TokenizerQwen2VLForConditionalGeneration   )MultiPipelineCallbacksPipelineCallback)AutoencoderKLMagvitEasyAnimateTransformer3DModel)DiffusionPipeline)FlowMatchEulerDiscreteScheduler)is_torch_xla_availableloggingreplace_example_docstring)randn_tensor)VideoProcessor   )EasyAnimatePipelineOutputTFa  
    Examples:
        ```python
        >>> import torch
        >>> from diffusers import EasyAnimatePipeline
        >>> from diffusers.utils import export_to_video

        >>> # Models: "alibaba-pai/EasyAnimateV5.1-12b-zh"
        >>> pipe = EasyAnimatePipeline.from_pretrained(
        ...     "alibaba-pai/EasyAnimateV5.1-7b-zh-diffusers", torch_dtype=torch.float16
        ... ).to("cuda")
        >>> prompt = (
        ...     "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
        ...     "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
        ...     "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
        ...     "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
        ...     "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
        ...     "atmosphere of this unique musical performance."
        ... )
        >>> sample_size = (512, 512)
        >>> video = pipe(
        ...     prompt=prompt,
        ...     guidance_scale=6,
        ...     negative_prompt="bad detailed",
        ...     height=sample_size[0],
        ...     width=sample_size[1],
        ...     num_inference_steps=50,
        ... ).frames[0]
        >>> export_to_video(video, "output.mp4", fps=8)
        ```
c                 C   s   |}|}| \}}|| }||| kr|}t t|| | }	n|}	t t|| | }t t|| d }
t t||	 d }|
|f|
| ||	 ffS )Ng       @)intround)src	tgt_width
tgt_heighttwthhwrresize_heightresize_widthcrop_top	crop_left r#   j/home/ubuntu/vllm_env/lib/python3.10/site-packages/diffusers/pipelines/easyanimate/pipeline_easyanimate.pyget_resize_crop_region_for_gridQ   s   r%           c                 C   sX   |j ttd|jdd}| j ttd| jdd}| ||  }|| d| |   } | S )a  
    Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
    Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
    Flawed](https://huggingface.co/papers/2305.08891).

    Args:
        noise_cfg (`torch.Tensor`):
            The predicted noise tensor for the guided diffusion process.
        noise_pred_text (`torch.Tensor`):
            The predicted noise tensor for the text-guided diffusion process.
        guidance_rescale (`float`, *optional*, defaults to 0.0):
            A rescale factor applied to the noise predictions.

    Returns:
        noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
    r   T)dimkeepdim)stdlistrangendim)	noise_cfgnoise_pred_textguidance_rescalestd_textstd_cfgnoise_pred_rescaledr#   r#   r$   rescale_noise_cfgd   s
   r3   num_inference_stepsdevice	timestepssigmasc                 K   s  |dur|durt d|dur>dtt| jj v }|s(t d| j d| jd||d| | j}t	|}||fS |durpdtt| jj v }|sZt d| j d| jd||d	| | j}t	|}||fS | j|fd
|i| | j}||fS )a  
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`list[int]`, *optional*):
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.
        sigmas (`list[float]`, *optional*):
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.

    Returns:
        `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    NzYOnly one of `timesteps` or `sigmas` can be passed. Please choose one to set custom valuesr6   zThe current scheduler class zx's `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler.)r6   r5   r7   zv's `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler.)r7   r5   r5   r#   )

ValueErrorsetinspect	signatureset_timesteps
parameterskeys	__class__r6   len)	schedulerr4   r5   r6   r7   kwargsaccepts_timestepsaccept_sigmasr#   r#   r$   retrieve_timesteps   s2   rE   c                -       sV  e Zd ZdZdZg dZdedeeB de	e
B dedef
 fd	d
Z										dBdeee B dededeee B dB dejdB dejdB dejdB dejdB dejdB dejdB defddZdd Z						dCddZ	dDd d!Zed"d# Zed$d% Zed&d' Zed(d) Zed*d+ Ze  e!e"dd,d-d-d.d/ddd0dddddddd1ddd2gd0fdeee B d3edB d4edB d5edB d6edB d7e#dB deee B dB dedB d8e#dB d9ej$eej$ B dB d2ejdB dejdB d:ee dB dejdB dejdB dejdB d;edB d<ed=e%eegdf e&B e'B dB d>ee d?e#f*d@dAZ(  Z)S )EEasyAnimatePipelinea  
    Pipeline for text-to-video generation using EasyAnimate.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.

    Args:
        vae ([`AutoencoderKLMagvit`]):
            Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
        text_encoder (`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel` | None):
            EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
        tokenizer (`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer` | None):
            A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
        transformer ([`EasyAnimateTransformer3DModel`]):
            The EasyAnimate model designed by EasyAnimate Team.
        scheduler ([`FlowMatchEulerDiscreteScheduler`]):
            A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
    ztext_encoder->transformer->vae)latentsprompt_embedsnegative_prompt_embedsvaetext_encoder	tokenizertransformerrA   c                    s   t    | j|||||d t| dd d ur| jjjnd| _t| dd d ur+| jjnd| _	t| dd d ur:| jj
nd| _t| j	d| _d S )N)rJ   rK   rL   rM   rA   rM   TrJ         )vae_scale_factor)super__init__register_modulesgetattrrM   configenable_text_attention_maskrJ   spatial_compression_ratiovae_spatial_compression_ratiotemporal_compression_ratiovae_temporal_compression_ratior   video_processor)selfrJ   rK   rL   rM   rA   r?   r#   r$   rR      s"   
	
zEasyAnimatePipeline.__init__r   TN   promptnum_images_per_promptdo_classifier_free_guidancenegative_promptrH   rI   prompt_attention_masknegative_prompt_attention_maskr5   dtypemax_sequence_lengthc              	      sZ  |
p j j}
|	p j j}	|durt|trd}n|dur&t|tr&t|}n|jd }|du rt|tr?dd|dgdg}ndd	 |D } fd
d	|D } j|d|ddddd}|	 j j}|j
}|j} jrw j ||ddjd }ntd||d}|j	|
|	d}|j\}}}|d|d}||| |d}|j	|	d}|r|du r|durt|trdd|dgdg}ndd	 |D } fdd	|D } j|d|ddddd}|	 j j}|j
}|j} jr j ||ddjd }ntd||d}|r'|jd }|j	|
|	d}|d|d}||| |d}|j	|	d}||||fS )a[  
        Encodes the prompt into text encoder hidden states.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                prompt to be encoded
            device: (`torch.device`):
                torch device
            dtype (`torch.dtype`):
                torch dtype
            num_images_per_prompt (`int`):
                number of images that should be generated per prompt
            do_classifier_free_guidance (`bool`):
                whether to use classifier free guidance or not
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            prompt_attention_mask (`torch.Tensor`, *optional*):
                Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
            negative_prompt_attention_mask (`torch.Tensor`, *optional*):
                Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
            max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
        Nr   r   usertexttyperh   rolecontentc                 S      g | ]}d d|dgdqS rg   rh   ri   rk   r#   ).0_promptr#   r#   r$   
<listcomp>1      
z5EasyAnimatePipeline.encode_prompt.<locals>.<listcomp>c                        g | ]} j j|gd ddqS FT)tokenizeadd_generation_promptrL   apply_chat_templaterp   mr\   r#   r$   rr   8      
max_lengthTrightpt)rh   paddingr~   
truncationreturn_attention_maskpadding_sidereturn_tensors)	input_idsattention_maskoutput_hidden_stateszLLM needs attention_mask)re   r5   r5   c                 S   rn   ro   r#   )rp   _negative_promptr#   r#   r$   rr   d  rs   c                    rt   ru   rx   rz   r|   r#   r$   rr   k  r}   )rK   re   r5   
isinstancestrr*   r@   shaperL   tor   r   rV   hidden_statesr8   repeatview)r\   r_   r`   ra   rb   rH   rI   rc   rd   r5   re   rf   
batch_sizemessagesrh   text_inputstext_input_idsbs_embedseq_len_r#   r|   r$   encode_prompt   s   -




	

	
z!EasyAnimatePipeline.encode_promptc                 C   sX   dt t| jjj v }i }|r||d< dt t| jjj v }|r*||d< |S )Neta	generator)r9   r:   r;   rA   stepr=   r>   )r\   r   r   accepts_etaextra_step_kwargsaccepts_generatorr#   r#   r$   prepare_extra_step_kwargs  s   z-EasyAnimatePipeline.prepare_extra_step_kwargsc
           
         st  |d dks|d dkrt d| d| d|	d ur8t fdd|	D s8t d j d	 fd
d|	D  |d urK|d urKt d| d| d|d u rW|d u rWt d|d urnt|tsnt|tsnt dt| |d urz|d u rzt d|d ur|d urt d| d| d|d ur|d u rt d|d ur|d ur|j|jkrt d|j d|j dd S d S d S )N   r   z8`height` and `width` have to be divisible by 16 but are z and .c                 3   s    | ]}| j v V  qd S N_callback_tensor_inputsrp   kr|   r#   r$   	<genexpr>  s    

z3EasyAnimatePipeline.check_inputs.<locals>.<genexpr>z2`callback_on_step_end_tensor_inputs` has to be in z, but found c                    s   g | ]	}| j vr|qS r#   r   r   r|   r#   r$   rr     s    z4EasyAnimatePipeline.check_inputs.<locals>.<listcomp>zCannot forward both `prompt`: z and `prompt_embeds`: z2. Please make sure to only forward one of the two.zeProvide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.z2`prompt` has to be of type `str` or `list` but is zEMust provide `prompt_attention_mask` when specifying `prompt_embeds`.z'Cannot forward both `negative_prompt`: z and `negative_prompt_embeds`: zWMust provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.zu`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but got: `prompt_embeds` z != `negative_prompt_embeds` )r8   allr   r   r   r*   rj   r   )
r\   r_   heightwidthrb   rH   rI   rc   rd   "callback_on_step_end_tensor_inputsr#   r|   r$   check_inputs  sN   z EasyAnimatePipeline.check_inputsc
                 C   s   |	d ur|	j ||dS |||d | j d || j || j f}
t|tr7t||kr7tdt| d| dt|
|||d}	t| j	drK|	| j	j
 }	|	S )N)r5   re   r   z/You have passed a list of generators of length z+, but requested an effective batch size of z@. Make sure the batch size matches the length of the generators.)r   r5   re   init_noise_sigma)r   rZ   rX   r   r*   r@   r8   r   hasattrrA   r   )r\   r   num_channels_latents
num_framesr   r   re   r5   r   rG   r   r#   r#   r$   prepare_latents  s$   z#EasyAnimatePipeline.prepare_latentsc                 C      | j S r   _guidance_scaler|   r#   r#   r$   guidance_scale     z"EasyAnimatePipeline.guidance_scalec                 C   r   r   )_guidance_rescaler|   r#   r#   r$   r/     r   z$EasyAnimatePipeline.guidance_rescalec                 C   s
   | j dkS )Nr   r   r|   r#   r#   r$   ra      s   
z/EasyAnimatePipeline.do_classifier_free_guidancec                 C   r   r   )_num_timestepsr|   r#   r#   r$   num_timesteps  r   z!EasyAnimatePipeline.num_timestepsc                 C   r   r   )
_interruptr|   r#   r#   r$   	interrupt  r   zEasyAnimatePipeline.interrupt1   i   2   g      @r&   pilrG   r   r   r   r4   r   r   r   r6   output_typereturn_dictcallback_on_step_endr   r/   c           *      C   s  t |ttfr
|j}t|d d }t|d d }| |||||||||	 || _|| _d| _|dur<t |t	r<d}n|durJt |t
rJt|}n|jd }| j}| jdur\| jj}n| jj}| j||||| j|||||d
\}}}}tryd}n|}t | jtrt| j|||dd\}}n
t| j|||\}}| jjj}| || |||||||
|	}| |
|	}| jrt||g}t||g}|j|d	}|j|d	}t||| jj  }t|| _| j |d
}t!|D ]\}}| j"rq| jrt|gd n|} t#| jdr| j$| |} tj%|g| jd  |d	j| jd}!| j| |!|ddd }"|"& d | j'jj(kr>|"j)ddd\}"}#| jrQ|")d\}$}%|$||%|$   }"| jra|dkrat*|"|%|d}"| jj+|"||fi |ddid }|duri }&|D ]
}'t, |' |&|'< q||| |||&}(|(-d|}|(-d|}|(-d|}|t|d ks|d |kr|d | jj dkr|.  trt/0  qW d   n	1 sw   Y  |dksd| j'jj1 | }| j'j2|ddd })| j3j4|)|d})n|})| 5  |s|)fS t6|)dS )aK  
        Generates images or video using the EasyAnimate pipeline based on the provided prompts.

        Examples:
            prompt (`str` or `list[str]`, *optional*):
                Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
            num_frames (`int`, *optional*):
                Length of the generated video (in frames).
            height (`int`, *optional*):
                Height of the generated image in pixels.
            width (`int`, *optional*):
                Width of the generated image in pixels.
            num_inference_steps (`int`, *optional*, defaults to 50):
                Number of denoising steps during generation. More steps generally yield higher quality images but slow
                down inference.
            guidance_scale (`float`, *optional*, defaults to 5.0):
                Encourages the model to align outputs with prompts. A higher value may decrease image quality.
            negative_prompt (`str` or `list[str]`, *optional*):
                Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                Number of images to generate for each prompt.
            eta (`float`, *optional*, defaults to 0.0):
                Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
            generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
                A generator to ensure reproducibility in image generation.
            latents (`torch.Tensor`, *optional*):
                Predefined latent tensors to condition generation.
            prompt_embeds (`torch.Tensor`, *optional*):
                Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Embeddings for negative prompts. Overrides string inputs if defined.
            prompt_attention_mask (`torch.Tensor`, *optional*):
                Attention mask for the primary prompt embeddings.
            negative_prompt_attention_mask (`torch.Tensor`, *optional*):
                Attention mask for negative prompt embeddings.
            output_type (`str`, *optional*, defaults to "latent"):
                Format of the generated output, either as a PIL image or as a NumPy array.
            return_dict (`bool`, *optional*, defaults to `True`):
                If `True`, returns a structured output. Otherwise returns a simple tuple.
            callback_on_step_end (`Callable`, *optional*):
                Functions called at the end of each denoising step.
            callback_on_step_end_tensor_inputs (`list[str]`, *optional*):
                Tensor names to be included in callback function calls.
            guidance_rescale (`float`, *optional*, defaults to 0.0):
                Adjusts noise levels based on guidance scale.
            original_size (`tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
                Original dimensions of the output.
            target_size (`tuple[int, int]`, *optional*):
                Desired output dimensions for calculations.
            crops_coords_top_left (`tuple[int, int]`, *optional*, defaults to `(0, 0)`):
                Coordinates for cropping.

        Returns:
            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
                otherwise a `tuple` is returned where the first element is a list with the generated images and the
                second element is a list of `bool`s indicating whether the corresponding generated image contains
                "not-safe-for-work" (nsfw) content.
        r   FNr   r   )
r_   r5   re   r`   ra   rb   rH   rI   rc   rd   cpu)mur   )total   scale_model_input)re   )encoder_hidden_statesr   )r'   r&   )r/   r   rG   rH   rI   latent)r   )videor   )frames)7r   r	   r   tensor_inputsr   r   r   r   r   r   r*   r@   r   _execution_devicerK   re   rM   r   ra   XLA_AVAILABLErA   r   rE   rU   in_channelsr   r   torchcatr   orderr   progress_bar	enumerater   r   r   tensorsizerJ   latent_channelschunkr3   r   localspopupdatexm	mark_stepscaling_factordecoder[   postprocess_videomaybe_free_model_hooksr   )*r\   r_   r   r   r   r4   r   rb   r`   r   r   rG   rH   r6   rI   rc   rd   r   r   r   r   r/   r   r5   re   timestep_devicer   r   num_warmup_stepsr   itlatent_model_inputt_expand
noise_predr   noise_pred_uncondr.   callback_kwargsr   callback_outputsr   r#   r#   r$   __call__  s   V






$
6
6
zEasyAnimatePipeline.__call__)
r   TNNNNNNNr^   )NNNNNNr   )*__name__
__module____qualname____doc__model_cpu_offload_seqr   r
   r   r   r   r   r   r   rR   r   r*   r   boolr   Tensorr5   re   r   r   r   r   propertyr   r/   ra   r   r   no_gradr   EXAMPLE_DOC_STRINGfloat	Generatorr   r	   r   r   __classcell__r#   r#   r]   r$   rF      s   !
	

 $
7






	

rF   )r&   )NNNN)/r:   typingr   r   transformersr   r   r   r   	callbacksr   r	   modelsr
   r   pipelines.pipeline_utilsr   
schedulersr   utilsr   r   r   utils.torch_utilsr   r[   r   pipeline_outputr   torch_xla.core.xla_modelcore	xla_modelr   r   
get_loggerr   loggerr   r%   r3   r   r   r5   r*   r   rE   rF   r#   r#   r#   r$   <module>   sD   
"



;