o
    ۷iݦ                  
   @   s  d dl Z d dlZd dlmZmZ d dlZd dlZd dlmZm	Z	 ddl
mZmZ ddlmZ ddlmZ ddlmZmZ dd	lmZ dd
lmZ ddlmZmZ ddlmZmZmZ ddlm Z  ddl!m"Z" ddl#m$Z$ e r{d dl%m&  m'Z( dZ)ndZ)e*e+Z,dZ-dd Z.				d$de/dB de0ej1B dB de2e/ dB de2e3 dB fddZ4	d%dej5dej6dB de0fd d!Z7G d"d# d#eeZ8dS )&    N)AnyCallable)T5EncoderModelT5Tokenizer   )MultiPipelineCallbacksPipelineCallback)PipelineImageInput)CogVideoXLoraLoaderMixin)AutoencoderKLCogVideoXCogVideoXTransformer3DModel)get_3d_rotary_pos_embed)DiffusionPipeline)CogVideoXDDIMSchedulerCogVideoXDPMScheduler)is_torch_xla_availableloggingreplace_example_docstring)randn_tensor)VideoProcessor   )CogVideoXPipelineOutputTFaa  
    Examples:
        ```py
        >>> import torch
        >>> from diffusers import CogVideoXImageToVideoPipeline
        >>> from diffusers.utils import export_to_video, load_image

        >>> pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
        >>> pipe.to("cuda")

        >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
        >>> image = load_image(
        ...     "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
        ... )
        >>> video = pipe(image, prompt, use_dynamic_cfg=True)
        >>> export_to_video(video.frames[0], "output.mp4", fps=8)
        ```
c                 C   s   |}|}| \}}|| }||| kr|}t t|| | }	n|}	t t|| | }t t|| d }
t t||	 d }|
|f|
| ||	 ffS )N       @)intround)src	tgt_width
tgt_heighttwthhwrresize_heightresize_widthcrop_top	crop_left r'   q/home/ubuntu/vllm_env/lib/python3.10/site-packages/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.pyget_resize_crop_region_for_gridD   s   r)   num_inference_stepsdevice	timestepssigmasc                 K   s  |dur|durt d|dur>dtt| jj v }|s(t d| j d| jd||d| | j}t	|}||fS |durpdtt| jj v }|sZt d| j d| jd||d	| | j}t	|}||fS | j|fd
|i| | j}||fS )a  
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`list[int]`, *optional*):
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.
        sigmas (`list[float]`, *optional*):
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.

    Returns:
        `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    NzYOnly one of `timesteps` or `sigmas` can be passed. Please choose one to set custom valuesr,   zThe current scheduler class zx's `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler.)r,   r+   r-   zv's `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler.)r-   r+   r+   r'   )

ValueErrorsetinspect	signatureset_timesteps
parameterskeys	__class__r,   len)	schedulerr*   r+   r,   r-   kwargsaccepts_timestepsaccept_sigmasr'   r'   r(   retrieve_timestepsW   s2   r;   sampleencoder_output	generatorsample_modec                 C   sR   t | dr|dkr| j|S t | dr|dkr| j S t | dr%| jS td)Nlatent_distr<   argmaxlatentsz3Could not access latents of provided encoder_output)hasattrr@   r<   moderB   AttributeError)r=   r>   r?   r'   r'   r(   retrieve_latents   s   

rF   c                2       s   e Zd ZdZg ZdZg dZdedede	de
deeB f
 fd	d
Z					dUdeee B dededejdB dejdB f
ddZ								dVdeee B deee B dB dededejdB dejdB dedejdB dejdB fddZ									dWd ejd!ed"ed#ed$ed%edejdB dejdB d&ejdB d'ejdB fd(d)Zd'ejd*ejfd+d,Zd-d. Zd/d0 Z			dXd1d2ZdYd3d4ZdYd5d6Zd$ed%ed#edejd*e ejejf f
d7d8Z!e"d9d: Z#e"d;d< Z$e"d=d> Z%e"d?d@ Z&e"dAdB Z'e( e)e*dddddCdDddEdFddGdddddHdddd'gdfd e+deee B dB deee B dB d$edB d%edB d#edIedJee dB dKe,dLededMe,d&ejeej B dB d'ej-dB dej-dB dej-dB dNedOedPe.ee/f dB dQe0eegdf e1B e2B dB dRee ded*e3e B f.dSdTZ4  Z5S )ZCogVideoXImageToVideoPipelinea  
    Pipeline for image-to-video generation using CogVideoX.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    Args:
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
        text_encoder ([`T5EncoderModel`]):
            Frozen text-encoder. CogVideoX uses
            [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
            [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
        tokenizer (`T5Tokenizer`):
            Tokenizer of class
            [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
        transformer ([`CogVideoXTransformer3DModel`]):
            A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
    ztext_encoder->transformer->vae)rB   prompt_embedsnegative_prompt_embeds	tokenizertext_encodervaetransformerr7   c                    s   t    | j|||||d t| dd r dt| jjjd  nd| _t| dd r.| jjj	nd| _
t| dd r<| jjjnd| _t| jd| _d S )	N)rJ   rK   rL   rM   r7   rL      r         gffffff?)vae_scale_factor)super__init__register_modulesgetattrr6   rL   configblock_out_channelsvae_scale_factor_spatialtemporal_compression_ratiovae_scale_factor_temporalscaling_factorvae_scaling_factor_imager   video_processor)selfrJ   rK   rL   rM   r7   r5   r'   r(   rS      s   
$z&CogVideoXImageToVideoPipeline.__init__Nr      promptnum_videos_per_promptmax_sequence_lengthr+   dtypec                 C   s  |p| j }|p
| jj}t|tr|gn|}t|}| j|d|dddd}|j}| j|dddj}	|	jd |jd kr[t	
||	s[| j|	d d |d df }
td	| d
|
  | ||d }|j||d}|j\}}}|d|d}||| |d}|S )N
max_lengthTpt)paddingre   
truncationadd_special_tokensreturn_tensorslongest)rg   rj   r   zXThe following part of your input was truncated because `max_sequence_length` is set to  z	 tokens: r   )rd   r+   )_execution_devicerK   rd   
isinstancestrr6   rJ   	input_idsshapetorchequalbatch_decodeloggerwarningtorepeatview)r^   ra   rb   rc   r+   rd   
batch_sizetext_inputstext_input_idsuntruncated_idsremoved_textrH   _seq_lenr'   r'   r(   _get_t5_prompt_embeds   s:   
  z3CogVideoXImageToVideoPipeline._get_t5_prompt_embedsTnegative_promptdo_classifier_free_guidancerH   rI   c
              
   C   s  |p| j }t|tr|gn|}|durt|}
n|jd }
|du r+| j|||||	d}|r|du r|p4d}t|tr?|
|g n|}|dur\t|t|ur\tdt| dt| d|
t|krutd| d	t| d
| d	|
 d	| j|||||	d}||fS )a"  
        Encodes the prompt into text encoder hidden states.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                prompt to be encoded
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
                Whether to use classifier free guidance or not.
            num_videos_per_prompt (`int`, *optional*, defaults to 1):
                Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            device: (`torch.device`, *optional*):
                torch device
            dtype: (`torch.dtype`, *optional*):
                torch dtype
        Nr   )ra   rb   rc   r+   rd    z?`negative_prompt` should be the same type to `prompt`, but got z != .z`negative_prompt`: z has batch size z, but `prompt`: zT. Please make sure that passed `negative_prompt` matches the batch size of `prompt`.)	rm   rn   ro   r6   rq   r   type	TypeErrorr.   )r^   ra   r   r   rb   rH   rI   rc   r+   rd   rz   r'   r'   r(   encode_prompt  sL   
&

z+CogVideoXImageToVideoPipeline.encode_prompt      <   Z   imagerz   num_channels_latents
num_framesheightwidthr>   rB   c                    s  t  trt |krtdt  d| d|d j d }||||j |j f}jjjd urO|d d |d |d jjj  f |dd   }	dt  trg fddt
|D }n
 fddD }tj|d	d
|d	dddd}jjjsj| }ndj | }||d ||j |j f}tj|||d}tj||gdd
}jjjd ur|d d d |djjj df }tj||gdd
}|
d u rt| ||d}
n|
|}
|
jj }
|
|fS )Nz/You have passed a list of generators of length z+, but requested an effective batch size of z@. Make sure the batch size matches the length of the generators.r   rN   c                    s,   g | ]}t j| d  | qS r   rF   rL   encode	unsqueeze).0ir>   r   r^   r'   r(   
<listcomp>{  s     zACogVideoXImageToVideoPipeline.prepare_latents.<locals>.<listcomp>c                    s$   g | ]}t j|d  qS r   r   )r   img)r>   r^   r'   r(   r     s   $ r   dimr   rP   )r+   rd   .)r>   r+   rd   )rn   listr6   r.   rZ   rX   rM   rV   patch_size_tr   rangerr   catrw   permuterL   invert_scale_latentsr\   zerossizer   r7   init_noise_sigma)r^   r   rz   r   r   r   r   rd   r+   r>   rB   rq   image_latentspadding_shapelatent_paddingfirst_framer'   r   r(   prepare_latentsX  sR   	4

"
&
z-CogVideoXImageToVideoPipeline.prepare_latentsreturnc                 C   s2   | ddddd}d| j | }| j|j}|S )Nr   rN   r   r   rP   )r   r\   rL   decoder<   )r^   rB   framesr'   r'   r(   decode_latents  s   z,CogVideoXImageToVideoPipeline.decode_latentsc                 C   s@   t t|| |}t|| d}||| jj d  }||| fS )Nr   )minr   maxr7   order)r^   r*   r,   strengthr+   init_timestept_startr'   r'   r(   get_timesteps  s   z+CogVideoXImageToVideoPipeline.get_timestepsc                 C   sX   dt t| jjj v }i }|r||d< dt t| jjj v }|r*||d< |S )Netar>   )r/   r0   r1   r7   stepr3   r4   )r^   r>   r   accepts_etaextra_step_kwargsaccepts_generatorr'   r'   r(   prepare_extra_step_kwargs  s   z7CogVideoXImageToVideoPipeline.prepare_extra_step_kwargsc
           
         s  t |tjst |tjjst |tstdt| |d dks'|d dkr2td| d| d|d urSt fdd|D sStd	 j	 d
 fdd|D  |d urf|d urftd| d| d|d u rr|d u rrtd|d urt |t
st |tstdt| |d ur|	d urtd| d|	 d|d ur|	d urtd| d|	 d|d ur|	d ur|j|	jkrtd|j d|	j dd S d S d S )Nz``image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `list[PIL.Image.Image]` but is rO   r   z7`height` and `width` have to be divisible by 8 but are z and r   c                 3   s    | ]}| j v V  qd S N_callback_tensor_inputsr   kr^   r'   r(   	<genexpr>  s    

z=CogVideoXImageToVideoPipeline.check_inputs.<locals>.<genexpr>z2`callback_on_step_end_tensor_inputs` has to be in z, but found c                    s   g | ]	}| j vr|qS r'   r   r   r   r'   r(   r     s    z>CogVideoXImageToVideoPipeline.check_inputs.<locals>.<listcomp>zCannot forward both `prompt`: z and `prompt_embeds`: z2. Please make sure to only forward one of the two.zeProvide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.z2`prompt` has to be of type `str` or `list` but is z and `negative_prompt_embeds`: z'Cannot forward both `negative_prompt`: zu`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but got: `prompt_embeds` z != `negative_prompt_embeds` )rn   rr   TensorPILImager   r.   r   allr   ro   rq   )
r^   r   ra   r   r   r   "callback_on_step_end_tensor_inputsrB   rH   rI   r'   r   r(   check_inputs  sh   
z*CogVideoXImageToVideoPipeline.check_inputsc                 C   s   d| _ | j  dS )zEnables fused QKV projections.TN)fusing_transformerrM   fuse_qkv_projectionsr   r'   r'   r(   r     s   z2CogVideoXImageToVideoPipeline.fuse_qkv_projectionsc                 C   s(   | j s
td dS | j  d| _ dS )z)Disable QKV projection fusion if enabled.zKThe Transformer was not initially fused for QKV projections. Doing nothing.FN)r   ru   rv   rM   unfuse_qkv_projectionsr   r'   r'   r(   r     s   

z4CogVideoXImageToVideoPipeline.unfuse_qkv_projectionsc              	   C   s   || j | jjj  }|| j | jjj  }| jjj}| jjj}| jjj| }	| jjj| }
|d u rLt||f|	|
}t| jjj	|||f||d\}}||fS || d | }t| jjj	d ||f|d|
|	f|d\}}||fS )N)	embed_dimcrops_coords	grid_sizetemporal_sizer+   r   slice)r   r   r   r   	grid_typemax_sizer+   )
rX   rM   rV   
patch_sizer   sample_widthsample_heightr)   r   attention_head_dim)r^   r   r   r   r+   grid_height
grid_widthpp_tbase_size_widthbase_size_heightgrid_crops_coords	freqs_cos	freqs_sinbase_num_framesr'   r'   r(   %_prepare_rotary_positional_embeddings  s:   





zCCogVideoXImageToVideoPipeline._prepare_rotary_positional_embeddingsc                 C      | j S r   )_guidance_scaler   r'   r'   r(   guidance_scaleB     z,CogVideoXImageToVideoPipeline.guidance_scalec                 C   r   r   )_num_timestepsr   r'   r'   r(   num_timestepsF  r   z+CogVideoXImageToVideoPipeline.num_timestepsc                 C   r   r   )_attention_kwargsr   r'   r'   r(   attention_kwargsJ  r   z.CogVideoXImageToVideoPipeline.attention_kwargsc                 C   r   r   )_current_timestepr   r'   r'   r(   current_timestepN  r   z.CogVideoXImageToVideoPipeline.current_timestepc                 C   r   r   )
_interruptr   r'   r'   r(   	interruptR  r   z'CogVideoXImageToVideoPipeline.interrupt1   2      Fg        pilr*   r,   r   use_dynamic_cfgr   output_typereturn_dictr   callback_on_step_endr   c           2      C   s0  t |ttfr
|j}|p| jjj| j }|p| jjj| j }|p$| jjj	}d}| j
|||||||||d	 |	| _d| _|| _d| _|durMt |trMd}n|dur[t |tr[t|}n|jd }| j}|	dk}| j||||||||d\}}|rtj||gdd}trd	}n|}t| j|||\}}t|| _|d | j d }| jjj}d}|dur|| dkr|||  }||| j 7 }| jj|||d
j||j d}| jjj!d }| "||| |||||j |||
\}}| #||} | jjj$r| %|||&d|nd}!| jjj'du r	dn|j(ddd}"t)t||| jj*  d}#| j+|d:}$d}%t,|D ]*\}&}'| j-r6q+|'| _|rDt|gd n|}(| j.|(|'}(|rXt|gd n|})tj|(|)gdd}(|'/|(jd }*| j0d | j|(||*|"|!|ddd }+W d   n	1 sw   Y  |+1 }+|
rd|	dt23t2j4||'5  | d   d   | _|r|+6d\},}-|,| j7|-|,   }+t | jt8s| jj9|+|'|fi | ddid }n| jj9|+|%|'|&dkr||&d  nd|fi | ddi\}}%||j }|dur/i }.|D ]
}/t: |/ |.|/< q|| |&|'|.}0|0;d|}|0;d|}|0;d|}|&t|d ksJ|&d |#krN|&d | jj* dkrN|$<  trUt=>  q+W d   n	1 sbw   Y  d| _|dks|dd|df }| ?|}1| jj@|1|d}1n|}1| A  |s|1fS tB|1dS )a7  
        Function invoked when calling the pipeline for generation.

        Args:
            image (`PipelineImageInput`):
                The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
            prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
                instead.
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
                The height in pixels of the generated image. This is set to 480 by default for the best results.
            width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
                The width in pixels of the generated image. This is set to 720 by default for the best results.
            num_frames (`int`, defaults to `48`):
                Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
                contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
                num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
                needs to be satisfied is that of divisibility mentioned above.
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            timesteps (`list[int]`, *optional*):
                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
                passed will be used. Must be in descending order.
            guidance_scale (`float`, *optional*, defaults to 7.0):
                Guidance scale as defined in [Classifier-Free Diffusion
                Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
                of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
                `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
                the text `prompt`, usually at the expense of lower image quality.
            num_videos_per_prompt (`int`, *optional*, defaults to 1):
                The number of videos to generate per prompt.
            generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
            latents (`torch.FloatTensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will be generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
                of a plain tuple.
            attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            callback_on_step_end (`Callable`, *optional*):
                A function that calls at the end of each denoising steps during the inference. The function is called
                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
                `callback_on_step_end_tensor_inputs`.
            callback_on_step_end_tensor_inputs (`list`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
                `._callback_tensor_inputs` attribute of your pipeline class.
            max_sequence_length (`int`, defaults to `226`):
                Maximum sequence length in encoded prompt. Must be consistent with
                `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.

        Examples:

        Returns:
            [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`:
            [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
            `tuple`. When returning a tuple, the first element is a list with the generated images.
        r   )	r   ra   r   r   r   r   rB   rH   rI   NFr   g      ?)ra   r   r   rb   rH   rI   rc   r+   r   cpu)r   r   )rd   rN   )r   r   )
fill_value)totalcond_uncond)hidden_statesencoder_hidden_statestimestepofsimage_rotary_embr   r   g      @r   rB   rH   rI   latent)videor   )r   )Crn   r   r   tensor_inputsrM   rV   r   rX   r   sample_framesr   r   r   r   r   ro   r   r6   rq   rm   r   rr   r   XLA_AVAILABLEr;   r7   r   rZ   r   r]   
preprocessrw   rd   in_channelsr   r    use_rotary_positional_embeddingsr   r   ofs_embed_dimnew_fullr   r   progress_bar	enumerater   scale_model_inputexpandcache_contextfloatmathcospiitemchunkr   r   r   localspopupdatexm	mark_stepr   postprocess_videomaybe_free_model_hooksr   )2r^   r   ra   r   r   r   r   r*   r,   r   r   rb   r   r>   rB   rH   rI   r   r   r   r   r   rc   rz   r+   r   timestep_devicelatent_framesr   additional_frameslatent_channelsr   r   r   ofs_embnum_warmup_stepsr	  old_pred_original_sampler   tlatent_model_inputlatent_image_inputr   
noise_prednoise_pred_uncondnoise_pred_textcallback_kwargsr   callback_outputsr   r'   r'   r(   __call__V  s&  m






"
&&	
6G


z&CogVideoXImageToVideoPipeline.__call__)Nr   r`   NN)NTr   NNr`   NN)	r   r   r   r   r   NNNN)NNN)r   N)6__name__
__module____qualname____doc___optional_componentsmodel_cpu_offload_seqr   r   r   r   r   r   r   rS   ro   r   r   rr   r+   rd   r   boolr   r   	Generatorr   r   r   r   r   r   r   tupler   propertyr   r   r   r   r   no_gradr   EXAMPLE_DOC_STRINGr	   r  FloatTensordictr   r   r   r   r   r*  __classcell__r'   r'   r_   r(   rG      sz   

.
	

T	

L


@
	
,





	
rG   )NNNN)Nr<   )9r0   r  typingr   r   r   rr   transformersr   r   	callbacksr   r   image_processorr	   loadersr
   modelsr   r   models.embeddingsr   pipelines.pipeline_utilsr   
schedulersr   r   utilsr   r   r   utils.torch_utilsr   r]   r   pipeline_outputr   torch_xla.core.xla_modelcore	xla_modelr  r  
get_loggerr+  ru   r6  r)   r   ro   r+   r   r  r;   r   r2  rF   rG   r'   r'   r'   r(   <module>   s\   



=
