o
    ۷i                  
   @   s  d dl Z d dlmZ d dlZd dlZd dlm  mZ	 d dl
mZ d dlmZmZmZmZ ddlmZmZ ddlmZ ddlmZmZ dd	lmZ dd
lmZ ddlmZmZm Z  ddl!m"Z" ddl#m$Z$ ddl%m&Z& e rzd dl'm(  m)Z* dZ+ndZ+e,e-Z.dZ/dd Z0d&ddZ1dd Z2d'ddZ3d(ddZ4				d)de5dB de6ej7B dB d e8e5 dB d!e8e9 dB fd"d#Z:G d$d% d%eZ;dS )*    N)Callable)Image)	BertModelBertTokenizerQwen2TokenizerQwen2VLForConditionalGeneration   )MultiPipelineCallbacksPipelineCallback)VaeImageProcessor)AutoencoderKLMagvitEasyAnimateTransformer3DModel)DiffusionPipeline)FlowMatchEulerDiscreteScheduler)is_torch_xla_availableloggingreplace_example_docstring)randn_tensor)VideoProcessor   )EasyAnimatePipelineOutputTFaY  
    Examples:
        ```python
        >>> import torch
        >>> from diffusers import EasyAnimateControlPipeline
        >>> from diffusers.pipelines.easyanimate.pipeline_easyanimate_control import get_video_to_video_latent
        >>> from diffusers.utils import export_to_video, load_video

        >>> pipe = EasyAnimateControlPipeline.from_pretrained(
        ...     "alibaba-pai/EasyAnimateV5.1-12b-zh-Control-diffusers", torch_dtype=torch.bfloat16
        ... )
        >>> pipe.to("cuda")

        >>> control_video = load_video(
        ...     "https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control/blob/main/asset/pose.mp4"
        ... )
        >>> prompt = (
        ...     "In this sunlit outdoor garden, a beautiful woman is dressed in a knee-length, sleeveless white dress. "
        ...     "The hem of her dress gently sways with her graceful dance, much like a butterfly fluttering in the breeze. "
        ...     "Sunlight filters through the leaves, casting dappled shadows that highlight her soft features and clear eyes, "
        ...     "making her appear exceptionally elegant. It seems as if every movement she makes speaks of youth and vitality. "
        ...     "As she twirls on the grass, her dress flutters, as if the entire garden is rejoicing in her dance. "
        ...     "The colorful flowers around her sway in the gentle breeze, with roses, chrysanthemums, and lilies each "
        ...     "releasing their fragrances, creating a relaxed and joyful atmosphere."
        ... )
        >>> sample_size = (672, 384)
        >>> num_frames = 49

        >>> input_video, _, _ = get_video_to_video_latent(control_video, num_frames, sample_size)
        >>> video = pipe(
        ...     prompt,
        ...     num_frames=num_frames,
        ...     negative_prompt="Twisted body, limb deformities, text subtitles, comics, stillness, ugliness, errors, garbled text.",
        ...     height=sample_size[0],
        ...     width=sample_size[1],
        ...     control_video=input_video,
        ... ).frames[0]
        >>> export_to_video(video, "output.mp4", fps=8)
        ```
c                 C   s   t | tjrtjjj| d|dddd} n5t | tjr/| 	|d |d f} t
| } nt | t
jrIt| 	|d |d f} t
| } ntdt | tjsat| ddd d } | S )	zd
    Preprocess a single image (PIL.Image, numpy.ndarray, or torch.Tensor) to a resized tensor.
    r   bilinearFsizemodealign_cornersr   zKUnsupported input type. Expected PIL.Image, numpy.ndarray, or torch.Tensor.   g     o@)
isinstancetorchTensornn
functionalinterpolate	unsqueezesqueezer   resizenparrayndarray	fromarray
ValueError
from_numpypermutefloat)imagesample_size r0   r/home/ubuntu/vllm_env/lib/python3.10/site-packages/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.pypreprocess_image]   s    r2   c                    s0  | d urw fdd| D } t | d | } | ddddd} |d ur\t| d}t |dk d	d
}|ddg dd}t |dd|  d ddg}|| j	| j
}nt | d d d df }d
|d d d d d d f< nd\} }|d urt| d}|ddddd}nd }| ||fS )Nc                    s   g | ]}t | d qS )r/   )r2   ).0framer3   r0   r1   
<listcomp>{   s    z-get_video_to_video_latent.<locals>.<listcomp>r   r   r   r   )r   g?           )r   r   r   r   NN)r   stackr,   r#   r2   wheretiler   todevicedtype
zeros_like)input_video
num_framesr/   validation_video_mask	ref_imageinput_video_maskr0   r3   r1   get_video_to_video_latentx   s$    
rG   c                 C   s   |}|}| \}}|| }||| kr|}t t|| | }	n|}	t t|| | }t t|| d }
t t||	 d }|
|f|
| ||	 ffS )Ng       @)intround)src	tgt_width
tgt_heighttwthhwrresize_heightresize_widthcrop_top	crop_leftr0   r0   r1   get_resize_crop_region_for_grid   s   rV   r7   c                 C   sX   |j ttd|jdd}| j ttd| jdd}| ||  }|| d| |   } | S )a  
    Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
    Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
    Flawed](https://huggingface.co/papers/2305.08891).

    Args:
        noise_cfg (`torch.Tensor`):
            The predicted noise tensor for the guided diffusion process.
        noise_pred_text (`torch.Tensor`):
            The predicted noise tensor for the text-guided diffusion process.
        guidance_rescale (`float`, *optional*, defaults to 0.0):
            A rescale factor applied to the noise predictions.

    Returns:
        noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
    r   T)dimkeepdim)stdlistrangendim)	noise_cfgnoise_pred_textguidance_rescalestd_textstd_cfgnoise_pred_rescaledr0   r0   r1   rescale_noise_cfg   s
   rc   c                 C   s   |  }|rkt|dd  }d|d< tj| d d d d ddd d d d f |ddd}t|dd  }|d d |d< |d dkrgtj| d d d d dd d d d d f |ddd}tj||gdd}|S |}|S t|dd  }tj| |ddd}|S )Nr   r   r   	trilinearFr   rW   )r   rZ   Fr"   r   cat)masklatentprocess_first_frame_onlylatent_sizetarget_sizefirst_frame_resizedremaining_frames_resizedresized_maskr0   r0   r1   resize_mask   s(   **rp   num_inference_stepsr?   	timestepssigmasc                 K   s  |dur|durt d|dur>dtt| jj v }|s(t d| j d| jd||d| | j}t	|}||fS |durpdtt| jj v }|sZt d| j d| jd||d	| | j}t	|}||fS | j|fd
|i| | j}||fS )a  
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`list[int]`, *optional*):
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.
        sigmas (`list[float]`, *optional*):
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.

    Returns:
        `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    NzYOnly one of `timesteps` or `sigmas` can be passed. Please choose one to set custom valuesrr   zThe current scheduler class zx's `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler.)rr   r?   rs   zv's `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler.)rs   r?   r?   r0   )
r*   setinspect	signatureset_timesteps
parameterskeys	__class__rr   len)	schedulerrq   r?   rr   rs   kwargsaccepts_timestepsaccept_sigmasr0   r0   r1   retrieve_timesteps   s2   r   c                4       sv  e Zd ZdZdZg dZdedeeB de	e
B dedef
 fd	d
Z										dGdeee B dededeee B dB dejdB dejdB dejdB dejdB dejdB dejdB defddZdd Z						dHddZ	dId d!Zd"d# Zed$d% Zed&d' Zed(d) Zed*d+ Zed,d- Z e! e"e#dd.d/d/dddd0d1ddd2ddddddd3ddd4gd2dfdeee B d5edB d6edB d7edB d8ej$d9ej$d:ej$d;edB d<e%dB deee B dB dedB d=e%dB d>ej&eej& B dB d4ejdB dejdB dejdB dejdB dejdB d?edB d@edAe'eegdf e(B e)B dB dBee dCe%dDee dB f0dEdFZ*  Z+S )JEasyAnimateControlPipelinea  
    Pipeline for text-to-video generation using EasyAnimate.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.

    Args:
        vae ([`AutoencoderKLMagvit`]):
            Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
        text_encoder (`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel` | None):
            EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
        tokenizer (`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer` | None):
            A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
        transformer ([`EasyAnimateTransformer3DModel`]):
            The EasyAnimate model designed by EasyAnimate Team.
        scheduler ([`FlowMatchEulerDiscreteScheduler`]):
            A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
    ztext_encoder->transformer->vae)latentsprompt_embedsnegative_prompt_embedsvaetext_encoder	tokenizertransformerr|   c                    s   t    | j|||||d t| dd d ur| jjjnd| _t| dd d ur+| jjnd| _	t| dd d ur:| jj
nd| _t| j	d| _t| j	dddd	| _t| j	d| _d S )
N)r   r   r   r   r|   r   Tr         )vae_scale_factorF)r   do_normalizedo_binarizedo_convert_grayscale)super__init__register_modulesgetattrr   configenable_text_attention_maskr   spatial_compression_ratiovae_spatial_compression_ratiotemporal_compression_ratiovae_temporal_compression_ratior   image_processormask_processorr   video_processor)selfr   r   r   r   r|   rz   r0   r1   r   9  s0   


z#EasyAnimateControlPipeline.__init__r   TN   promptnum_images_per_promptdo_classifier_free_guidancenegative_promptr   r   prompt_attention_masknegative_prompt_attention_maskr?   r@   max_sequence_lengthc              	      sZ  |
p j j}
|	p j j}	|durt|trd}n|dur&t|tr&t|}n|jd }|du rt|tr?dd|dgdg}ndd	 |D } fd
d	|D } j|d|ddddd}|	 j j}|j
}|j} jrw j ||ddjd }ntd||d}|j	|
|	d}|j\}}}|d|d}||| |d}|j	|	d}|r|du r|durt|trdd|dgdg}ndd	 |D } fdd	|D } j|d|ddddd}|	 j j}|j
}|j} jr j ||ddjd }ntd||d}|r'|jd }|j	|
|	d}|d|d}||| |d}|j	|	d}||||fS )a[  
        Encodes the prompt into text encoder hidden states.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                prompt to be encoded
            device: (`torch.device`):
                torch device
            dtype (`torch.dtype`):
                torch dtype
            num_images_per_prompt (`int`):
                number of images that should be generated per prompt
            do_classifier_free_guidance (`bool`):
                whether to use classifier free guidance or not
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            prompt_attention_mask (`torch.Tensor`, *optional*):
                Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
            negative_prompt_attention_mask (`torch.Tensor`, *optional*):
                Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
            max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
        Nr   r   usertexttyper   rolecontentc                 S      g | ]}d d|dgdqS r   r   r   r   r0   )r4   _promptr0   r0   r1   r6         
z<EasyAnimateControlPipeline.encode_prompt.<locals>.<listcomp>c                        g | ]} j j|gd ddqS FT)tokenizeadd_generation_promptr   apply_chat_templater4   mr   r0   r1   r6         
max_lengthTrightpt)r   paddingr   
truncationreturn_attention_maskpadding_sidereturn_tensors)	input_idsattention_maskoutput_hidden_stateszLLM needs attention_mask)r@   r?   r9   r?   c                 S   r   r   r0   )r4   _negative_promptr0   r0   r1   r6     r   c                    r   r   r   r   r   r0   r1   r6     r   )r   r@   r?   r   strrZ   r{   shaper   r>   r   r   r   hidden_statesr*   repeatview)r   r   r   r   r   r   r   r   r   r?   r@   r   
batch_sizemessagesr   text_inputstext_input_idsbs_embedseq_len_r0   r   r1   encode_prompt`  s   -




	

	
z(EasyAnimateControlPipeline.encode_promptc                 C   sX   dt t| jjj v }i }|r||d< dt t| jjj v }|r*||d< |S )Neta	generator)rt   ru   rv   r|   steprx   ry   )r   r   r   accepts_etaextra_step_kwargsaccepts_generatorr0   r0   r1   prepare_extra_step_kwargs  s   z4EasyAnimateControlPipeline.prepare_extra_step_kwargsc
           
         st  |d dks|d dkrt d| d| d|	d ur8t fdd|	D s8t d j d	 fd
d|	D  |d urK|d urKt d| d| d|d u rW|d u rWt d|d urnt|tsnt|tsnt dt| |d urz|d u rzt d|d ur|d urt d| d| d|d ur|d u rt d|d ur|d ur|j|jkrt d|j d|j dd S d S d S )N   r   z8`height` and `width` have to be divisible by 16 but are z and .c                 3   s    | ]}| j v V  qd S N_callback_tensor_inputsr4   kr   r0   r1   	<genexpr>#  s    

z:EasyAnimateControlPipeline.check_inputs.<locals>.<genexpr>z2`callback_on_step_end_tensor_inputs` has to be in z, but found c                    s   g | ]	}| j vr|qS r0   r   r   r   r0   r1   r6   '  s    z;EasyAnimateControlPipeline.check_inputs.<locals>.<listcomp>zCannot forward both `prompt`: z and `prompt_embeds`: z2. Please make sure to only forward one of the two.zeProvide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.z2`prompt` has to be of type `str` or `list` but is zEMust provide `prompt_attention_mask` when specifying `prompt_embeds`.z'Cannot forward both `negative_prompt`: z and `negative_prompt_embeds`: zWMust provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.zu`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but got: `prompt_embeds` z != `negative_prompt_embeds` )r*   allr   r   r   rZ   r   r   )
r   r   heightwidthr   r   r   r   r   "callback_on_step_end_tensor_inputsr0   r   r1   check_inputs  sN   z'EasyAnimateControlPipeline.check_inputsc
                 C   s   |	d ur|	j ||dS |||d | j d || j || j f}
t|tr7t||kr7tdt| d| dt|
|||d}	t| j	drK|	| j	j
 }	|	S )Nr?   r@   r   z/You have passed a list of generators of length z+, but requested an effective batch size of z@. Make sure the batch size matches the length of the generators.)r   r?   r@   init_noise_sigma)r>   r   r   r   rZ   r{   r*   r   hasattrr|   r   )r   r   num_channels_latentsrC   r   r   r@   r?   r   r   r   r0   r0   r1   prepare_latentsJ  s$   z*EasyAnimateControlPipeline.prepare_latentsc
                 C   s  |d urB|j ||d}d}
g }td|jd |
D ]}||||
  }| j|d }| }|| qtj|dd}|| jj	j
 }|d ur|j ||d}d}
g }td|jd |
D ]}||||
  }| j|d }| }|| qZtj|dd}|| jj	j
 }||fS d }||fS )Nr   r   r   re   )r>   r[   r   r   encoder   appendr   rg   r   scaling_factor)r   controlcontrol_imager   r   r   r@   r?   r   r   bsnew_controli
control_bsnew_control_pixel_valuescontrol_pixel_values_bscontrol_image_latentsr0   r0   r1   prepare_control_latentsd  s2   z2EasyAnimateControlPipeline.prepare_control_latentsc                 C      | j S r   _guidance_scaler   r0   r0   r1   guidance_scale     z)EasyAnimateControlPipeline.guidance_scalec                 C   r   r   )_guidance_rescaler   r0   r0   r1   r_     r   z+EasyAnimateControlPipeline.guidance_rescalec                 C   s
   | j dkS )Nr   r   r   r0   r0   r1   r     s   
z6EasyAnimateControlPipeline.do_classifier_free_guidancec                 C   r   r   )_num_timestepsr   r0   r0   r1   num_timesteps  r   z(EasyAnimateControlPipeline.num_timestepsc                 C   r   r   )
_interruptr   r0   r0   r1   	interrupt  r   z$EasyAnimateControlPipeline.interrupt1   i   2   g      @r7   pilr   rC   r   r   control_videocontrol_camera_videorE   rq   r   r   r   output_typereturn_dictcallback_on_step_endr   r_   rr   c           4      C   s  t |ttfr
|j}t|d d }t|d d }| ||||
|||||	 |	| _|| _d| _|dur<t |t	r<d}n|durJt |t
rJt|}n|jd }| j}| jdur\| jj}n| jj}| j||||| j|
||||dd\}}}}trzd}n|}t | jtrt| j|||dd\}}n
t| j|||\}}| jj}| jjj}| || ||||||||	}|durt||d	d
}|d }| jrt|gd n|||}nz|dur1|j\}} }}!}"| j j!|"ddddd#|| | |!|"||d}|jtj$d}|#||| ||"ddddd}| %d|||||||| j	d }| jr*t|gd n|||}nt&|||}| jrFt|gd n|||}|dur|j\}} }}!}"| j j!|"ddddd#|| | |!|"||d}|jtj$d}|#||| ||"ddddd}| %d|||||j||| j	d }#t&|}$|' d dkr|#|$ddddddf< | jrt|$gd n|$||}$tj||$gdd}n t&|}$| jrt|$gd n|$||}$tj||$gdd}| (||}%| jrt||g}t||g}|j|d}|j|d}t||| jj)  }&t|| _*| j+|d}'t,|D ]\}(})| j-r8q.| jrDt|gd n|}*t.| jdrT| j/|*|)}*tj0|)g|*jd  |dj|*jd}+| j|*|+||ddd },|,' d | jjjkr|,j1ddd\},}-| jr|,1d\}.}/|.|	|/|.   },| jr|dkrt2|,|/|d},| jj3|,|)|fi |%ddid }|duri }0|D ]
}1t4 |1 |0|1< q|| |(|)|0}2|25d|}|25d|}|25d|}|(t|d ks|(d |&kr|(d | jj) dkr|'6  trt78  q.W d   n	1 sw   Y  |dks4| 9|}3| j:j;|3|d}3n|}3| <  |s@|3fS t=|3dS )a  
        Generates images or video using the EasyAnimate pipeline based on the provided prompts.

        Examples:
            prompt (`str` or `list[str]`, *optional*):
                Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
            num_frames (`int`, *optional*):
                Length of the generated video (in frames).
            height (`int`, *optional*):
                Height of the generated image in pixels.
            width (`int`, *optional*):
                Width of the generated image in pixels.
            num_inference_steps (`int`, *optional*, defaults to 50):
                Number of denoising steps during generation. More steps generally yield higher quality images but slow
                down inference.
            guidance_scale (`float`, *optional*, defaults to 5.0):
                Encourages the model to align outputs with prompts. A higher value may decrease image quality.
            negative_prompt (`str` or `list[str]`, *optional*):
                Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                Number of images to generate for each prompt.
            eta (`float`, *optional*, defaults to 0.0):
                Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
            generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
                A generator to ensure reproducibility in image generation.
            latents (`torch.Tensor`, *optional*):
                Predefined latent tensors to condition generation.
            prompt_embeds (`torch.Tensor`, *optional*):
                Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Embeddings for negative prompts. Overrides string inputs if defined.
            prompt_attention_mask (`torch.Tensor`, *optional*):
                Attention mask for the primary prompt embeddings.
            negative_prompt_attention_mask (`torch.Tensor`, *optional*):
                Attention mask for negative prompt embeddings.
            output_type (`str`, *optional*, defaults to "latent"):
                Format of the generated output, either as a PIL image or as a NumPy array.
            return_dict (`bool`, *optional*, defaults to `True`):
                If `True`, returns a structured output. Otherwise returns a simple tuple.
            callback_on_step_end (`Callable`, *optional*):
                Functions called at the end of each denoising step.
            callback_on_step_end_tensor_inputs (`list[str]`, *optional*):
                Tensor names to be included in callback function calls.
            guidance_rescale (`float`, *optional*, defaults to 0.0):
                Adjusts noise levels based on guidance scale.

        Returns:
            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
                otherwise a `tuple` is returned where the first element is a list with the generated images and the
                second element is a list of `bool`s indicating whether the corresponding generated image contains
                "not-safe-for-work" (nsfw) content.
        r   FNr   r   )r   r?   r@   r   r   r   r   r   r   r   text_encoder_indexcpu)muT)rj      r   r   r   )r   r   )r@   re   r   )totalscale_model_input)encoder_hidden_statescontrol_latentsr  r7   )r_   r  r   r   r   ri   )videor  )frames)>r   r
   r	   tensor_inputsrH   r   r   r  r  r   rZ   r{   r   _execution_devicer   r@   r   r   r   XLA_AVAILABLEr|   r   r   rr   r   r   latent_channelsr   rp   r   rg   r>   r   
preprocessr,   reshapefloat32r   rA   r   r   orderr  progress_bar	enumerater  r   r  tensorchunkrc   r   localspopupdatexm	mark_stepdecode_latentsr   postprocess_videomaybe_free_model_hooksr   )4r   r   rC   r   r   r	  r
  rE   rq   r   r   r   r   r   r   r   r   r   r   r  r  r  r   r_   rr   r   r?   r@   timestep_devicer   control_video_latentsr  channelsheight_videowidth_videoref_image_latentsref_image_latents_conv_inr   num_warmup_stepsr   r   tlatent_model_inputt_expand
noise_predr   noise_pred_uncondr^   callback_kwargsr   callback_outputsr  r0   r0   r1   __call__  s  S









  



$
6
6

z#EasyAnimateControlPipeline.__call__)
r   TNNNNNNNr   )NNNNNNr   ),__name__
__module____qualname____doc__model_cpu_offload_seqr   r   r   r   r   r   r   r   r   r   rZ   rH   boolr   r   r?   r@   r   r   r   r   r   propertyr   r_   r   r  r  no_gradr   EXAMPLE_DOC_STRINGFloatTensorr-   	Generatorr   r
   r	   r;  __classcell__r0   r0   r   r1   r      s"   *
	

 $
7
#





	

r   r:   )r7   )T)NNNN)<ru   typingr   numpyr&   r   torch.nn.functionalr    r!   rf   PILr   transformersr   r   r   r   	callbacksr	   r
   r   r   modelsr   r   pipelines.pipeline_utilsr   
schedulersr   utilsr   r   r   utils.torch_utilsr   r   r   pipeline_outputr   torch_xla.core.xla_modelcore	xla_modelr'  r  
get_loggerr<  loggerrD  r2   rG   rV   rc   rp   rH   r   r?   rZ   r-   r   r   r0   r0   r0   r1   <module>   sR   
*
%




;