o
    Giϲ                  
   @   s  d dl Z d dlmZmZ d dlZd dlZd dlmZm	Z	 ddl
mZmZ ddlmZ ddlmZmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZ ddlmZ ddlmZ ddl m!Z! ddl"m#Z# e rud dl$m%  m&Z' dZ(ndZ(e)e*Z+dZ,				d0de-de-de.de.fddZ/				d1de-dB d e0ej1B dB d!e2e- dB d"e2e. dB fd#d$Z3	%d2d&ej4d'ej5dB d(e0fd)d*Z6d3d,d-Z7G d.d/ d/e!eeZ8dS )4    N)AnyCallable)T5EncoderModelT5TokenizerFast   )MultiPipelineCallbacksPipelineCallback)PipelineImageInput)FromSingleFileMixinLTXVideoLoraLoaderMixin)AutoencoderKLLTXVideo)LTXVideoTransformer3DModel)FlowMatchEulerDiscreteScheduler)is_torch_xla_availableloggingreplace_example_docstring)randn_tensor)VideoProcessor   )DiffusionPipeline   )LTXPipelineOutputTFaS  
    Examples:
        ```py
        >>> import torch
        >>> from diffusers import LTXImageToVideoPipeline
        >>> from diffusers.utils import export_to_video, load_image

        >>> pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
        >>> pipe.to("cuda")

        >>> image = load_image(
        ...     "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png"
        ... )
        >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background. Flames engulf the structure, with smoke billowing into the air. Firefighters in protective gear rush to the scene, a fire truck labeled '38' visible behind them. The girl's neutral expression contrasts sharply with the chaos of the fire, creating a poignant and emotionally charged scene."
        >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

        >>> video = pipe(
        ...     image=image,
        ...     prompt=prompt,
        ...     negative_prompt=negative_prompt,
        ...     width=704,
        ...     height=480,
        ...     num_frames=161,
        ...     num_inference_steps=50,
        ... ).frames[0]
        >>> export_to_video(video, "output.mp4", fps=24)
        ```
            ?ffffff?base_seq_lenmax_seq_len
base_shift	max_shiftc                 C   s,   || ||  }|||  }| | | }|S N )image_seq_lenr   r   r   r   mbmur!   r!   d/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/pipelines/ltx/pipeline_ltx_image2video.pycalculate_shiftK   s   r'   num_inference_stepsdevice	timestepssigmasc                 K   s  |dur|durt d|dur>dtt| jj v }|s(t d| j d| jd||d| | j}t	|}||fS |durpdtt| jj v }|sZt d| j d| jd||d	| | j}t	|}||fS | j|fd
|i| | j}||fS )a  
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`list[int]`, *optional*):
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.
        sigmas (`list[float]`, *optional*):
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.

    Returns:
        `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    NzYOnly one of `timesteps` or `sigmas` can be passed. Please choose one to set custom valuesr*   zThe current scheduler class zx's `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler.)r*   r)   r+   zv's `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler.)r+   r)   r)   r!   )

ValueErrorsetinspect	signatureset_timesteps
parameterskeys	__class__r*   len)	schedulerr(   r)   r*   r+   kwargsaccepts_timestepsaccept_sigmasr!   r!   r&   retrieve_timestepsY   s2   r9   sampleencoder_output	generatorsample_modec                 C   sR   t | dr|dkr| j|S t | dr|dkr| j S t | dr%| jS td)Nlatent_distr:   argmaxlatentsz3Could not access latents of provided encoder_output)hasattrr>   r:   moder@   AttributeError)r;   r<   r=   r!   r!   r&   retrieve_latents   s   

rD           c                 C   sX   |j ttd|jdd}| j ttd| jdd}| ||  }|| d| |   } | S )a  
    Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
    Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
    Flawed](https://huggingface.co/papers/2305.08891).

    Args:
        noise_cfg (`torch.Tensor`):
            The predicted noise tensor for the guided diffusion process.
        noise_pred_text (`torch.Tensor`):
            The predicted noise tensor for the text-guided diffusion process.
        guidance_rescale (`float`, *optional*, defaults to 0.0):
            A rescale factor applied to the noise predictions.

    Returns:
        noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
    r   T)dimkeepdim)stdlistrangendim)	noise_cfgnoise_pred_textguidance_rescalestd_textstd_cfgnoise_pred_rescaledr!   r!   r&   rescale_noise_cfg   s
   rR   c                7       s  e Zd ZdZdZg Zg dZdedede	de
def
 fd	d
Z					d]deee B dededejdB dejdB f
ddZ										d^deee B deee B dB dededejdB dejdB dejdB dejdB dedejdB dejdB fddZ					d_ddZed`d ejd!ed"ed#ejfd$d%Ze	d`d ejd&ed'ed(ed!ed"ed#ejfd)d*Ze	+dad ejd,ejd-ejd.ed#ejf
d/d0Ze	+dad ejd,ejd-ejd.ed#ejf
d1d2Z				3	4	5				dbd6ejdB d7ed8ed'ed(ed&edejdB dejdB d9ejdB d ejdB d#ejfd:d;Ze d<d= Z!e d>d? Z"e d@dA Z#e dBdC Z$e dDdE Z%e dFdG Z&e dHdI Z'e( e)e*dddd3d4d5dJdKddLdMddddddddMddNdddd gdfd6e+deee B deee B dB d'ed(ed&edOedPedQee dRedSededB d9ejeej B dB d ejdB dejdB dejdB dejdB dejdB dTeee B dUeee B dB dVedB dWedXe,ee-f dB dYe.eegdf dB dZee def4d[d\Z/  Z0S )cLTXImageToVideoPipelinea  
    Pipeline for image-to-video generation.

    Reference: https://github.com/Lightricks/LTX-Video

    Args:
        transformer ([`LTXVideoTransformer3DModel`]):
            Conditional Transformer architecture to denoise the encoded video latents.
        scheduler ([`FlowMatchEulerDiscreteScheduler`]):
            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
        vae ([`AutoencoderKLLTXVideo`]):
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        text_encoder ([`T5EncoderModel`]):
            [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
            the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
        tokenizer (`CLIPTokenizer`):
            Tokenizer of class
            [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
        tokenizer (`T5TokenizerFast`):
            Second Tokenizer of class
            [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
    ztext_encoder->transformer->vae)r@   prompt_embedsnegative_prompt_embedsr5   vaetext_encoder	tokenizertransformerc                    s   t    | j|||||d t| dd d ur| jjnd| _t| dd d ur*| jjnd| _t| dd d ur:| j	j
jnd| _t| dd urI| j	j
jnd| _t| jd| _t| dd d ur_| jjnd	| _d
| _d| _d| _d S )N)rV   rW   rX   rY   r5   rV          rY   r   )vae_scale_factorrX           y   )super__init__register_modulesgetattrrV   spatial_compression_ratiovae_spatial_compression_ratiotemporal_compression_ratiovae_temporal_compression_ratiorY   config
patch_sizetransformer_spatial_patch_sizepatch_size_ttransformer_temporal_patch_sizer   video_processorrX   model_max_lengthtokenizer_max_lengthdefault_heightdefault_widthdefault_frames)selfr5   rV   rW   rX   rY   r3   r!   r&   rb      s,   
	
z LTXImageToVideoPipeline.__init__Nr   r]   promptnum_videos_per_promptmax_sequence_lengthr)   dtypec                 C   s8  |p| j }|p
| jj}t|tr|gn|}t|}| j|d|dddd}|j}|j}	|		 
|}	| j|dddj}
|
jd |jd kret||
se| j|
d d |d df }td	| d
|  | |
|d }|j
||d}|j\}}}|d|d}||| |d}|	|d}	|	|d}	||	fS )N
max_lengthTpt)paddingrz   
truncationadd_special_tokensreturn_tensorslongest)r|   r   r   zXThe following part of your input was truncated because `max_sequence_length` is set to  z	 tokens: r   )ry   r)   )_execution_devicerW   ry   
isinstancestrr4   rX   	input_idsattention_maskbooltoshapetorchequalbatch_decodeloggerwarningrepeatview)rt   rv   rw   rx   r)   ry   
batch_sizetext_inputstext_input_idsprompt_attention_maskuntruncated_idsremoved_textrT   _seq_lenr!   r!   r&   _get_t5_prompt_embeds  sB   
  z-LTXImageToVideoPipeline._get_t5_prompt_embedsTnegative_promptdo_classifier_free_guidancerT   rU   r   negative_prompt_attention_maskc              
   C   s  |
p| j }
t|tr|gn|}|durt|}n|jd }|du r-| j|||	|
|d\}}|r|du r|p6d}t|trA||g n|}|dur^t|t|ur^tdt| dt| d|t|krwtd| d	t| d
| d	| d	| j|||	|
|d\}}||||fS )a"  
        Encodes the prompt into text encoder hidden states.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                prompt to be encoded
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
                Whether to use classifier free guidance or not.
            num_videos_per_prompt (`int`, *optional*, defaults to 1):
                Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            device: (`torch.device`, *optional*):
                torch device
            dtype: (`torch.dtype`, *optional*):
                torch dtype
        Nr   )rv   rw   rx   r)   ry    z?`negative_prompt` should be the same type to `prompt`, but got z != .z`negative_prompt`: z has batch size z, but `prompt`: zT. Please make sure that passed `negative_prompt` matches the batch size of `prompt`.)	r   r   r   r4   r   r   type	TypeErrorr,   )rt   rv   r   r   rw   rT   rU   r   r   rx   r)   ry   r   r!   r!   r&   encode_prompt2  sL   
(



z%LTXImageToVideoPipeline.encode_promptc	           	         st  |d dks|d dkrt d| d| d|d ur8t fdd|D s8t d j d	 fd
d|D  |d urK|d urKt d| d| d|d u rW|d u rWt d|d urnt|tsnt|tsnt dt| |d urz|d u rzt d|d ur|d u rt d|d ur|d ur|j|jkrt d|j d|j d|j|jkrt d|j d|j dd S d S d S )NrZ   r   z8`height` and `width` have to be divisible by 32 but are z and r   c                 3   s    | ]}| j v V  qd S r    _callback_tensor_inputs.0krt   r!   r&   	<genexpr>  s    

z7LTXImageToVideoPipeline.check_inputs.<locals>.<genexpr>z2`callback_on_step_end_tensor_inputs` has to be in z, but found c                    s   g | ]	}| j vr|qS r!   r   r   r   r!   r&   
<listcomp>  s    z8LTXImageToVideoPipeline.check_inputs.<locals>.<listcomp>zCannot forward both `prompt`: z and `prompt_embeds`: z2. Please make sure to only forward one of the two.zeProvide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.z2`prompt` has to be of type `str` or `list` but is zEMust provide `prompt_attention_mask` when specifying `prompt_embeds`.zWMust provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.zu`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but got: `prompt_embeds` z != `negative_prompt_embeds` z`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but got: `prompt_attention_mask` z% != `negative_prompt_attention_mask` )r,   allr   r   r   rI   r   r   )	rt   rv   heightwidth"callback_on_step_end_tensor_inputsrT   rU   r   r   r!   r   r&   check_inputs  sR   z$LTXImageToVideoPipeline.check_inputsr@   rj   rl   returnc              
   C   sl   | j \}}}}}|| }|| }	|| }
| |d|||	||
|} | dddddddd	dd	dd} | S )
Nr   r   r         r   r         )r   reshapepermuteflatten)r@   rj   rl   r   num_channels
num_framesr   r   post_patch_num_framespost_patch_heightpost_patch_widthr!   r!   r&   _pack_latents  s    (
z%LTXImageToVideoPipeline._pack_latentsr   r   r   c              
   C   sV   |  d}| ||||d|||} | dddddddd	dd	dddd} | S )
Nr   r   r   r   r   r   r   r   r   )sizer   r   r   )r@   r   r   r   rj   rl   r   r!   r!   r&   _unpack_latents  s   
0z'LTXImageToVideoPipeline._unpack_latents      ?latents_meanlatents_stdscaling_factorc                 C   sP   | ddddd| j| j}| ddddd| j| j}| | | | } | S Nr   r   r   r   r)   ry   r@   r   r   r   r!   r!   r&   _normalize_latents     z*LTXImageToVideoPipeline._normalize_latentsc                 C   sP   | ddddd| j| j}| ddddd| j| j}| | | | } | S r   r   r   r!   r!   r&   _denormalize_latents  r   z,LTXImageToVideoPipeline._denormalize_latentsr^   r_      imager   num_channels_latentsr<   c                    s  |j  }|j  }|d j d }|||||f}|d|||f}|
d uri|
|}d|d d d d df< |jjd}|
jdksP|
jd d |jkr`t	d|
j d|j|f  d	|
j
||d
|fS t trt |krt	dt  d| d fddt|D }n
 fddD }tj|dd
|}|jjjj}|dd|dd}tj|||d
}d|d d d d df< t| ||d}|| |d|   }
|jjd}|
jj}
|
|fS )Nr   r   r   r   r   r   z$Provided `latents` tensor has shape z, but the expected shape is r   r)   ry   z/You have passed a list of generators of length z+, but requested an effective batch size of z@. Make sure the batch size matches the length of the generators.c                    s2   g | ]}t j| d d | qS r   r   rD   rV   encode	unsqueeze)r   ir<   r   rt   r!   r&   r     s    $z;LTXImageToVideoPipeline.prepare_latents.<locals>.<listcomp>c                    s*   g | ]}t j|d d qS r   r   )r   img)r<   rt   r!   r&   r   "  s    rF   r<   r)   ry   )rf   rh   	new_zerosr   rk   rm   squeezerK   r   r,   r   r   rI   r4   rJ   r   catr   rV   r   r   r   zerosr   )rt   r   r   r   r   r   r   ry   r)   r<   r@   r   
mask_shapeconditioning_maskinit_latentsnoiser!   r   r&   prepare_latents  s^   






z'LTXImageToVideoPipeline.prepare_latentsc                 C      | j S r    _guidance_scaler   r!   r!   r&   guidance_scale8     z&LTXImageToVideoPipeline.guidance_scalec                 C   r   r    )_guidance_rescaler   r!   r!   r&   rN   <  r   z(LTXImageToVideoPipeline.guidance_rescalec                 C   s
   | j dkS )Nr   r   r   r!   r!   r&   r   @  s   
z3LTXImageToVideoPipeline.do_classifier_free_guidancec                 C   r   r    )_num_timestepsr   r!   r!   r&   num_timestepsD  r   z%LTXImageToVideoPipeline.num_timestepsc                 C   r   r    )_current_timestepr   r!   r!   r&   current_timestepH  r   z(LTXImageToVideoPipeline.current_timestepc                 C   r   r    )_attention_kwargsr   r!   r!   r&   attention_kwargsL  r   z(LTXImageToVideoPipeline.attention_kwargsc                 C   r   r    )
_interruptr   r!   r!   r&   	interruptP  r   z!LTXImageToVideoPipeline.interrupt   2   r   rE   pil
frame_rater(   r*   r   rN   decode_timestepdecode_noise_scaleoutput_typereturn_dictr   callback_on_step_endr   c           8      C   s  t |ttfr
|j}| j||||||||d |
| _|| _|| _d| _d| _	|dur2t |t
r2d}n|dur@t |tr@t|}n|jd }| j}| j||| j|||||||d
\}}}}| jrqtj||gdd}tj||gdd}|du r| jj|||d}|j||jd	}| jjj}| ||| ||||tj|||
\}}| jrt||g}|d | j d }|| j } || j }!||  |! }"td
d| |}#t |"| j!j"dd| j!j"dd| j!j"dd| j!j"dd}$t#rd}%n|}%t$| j!||%|	|#|$d\}	}t%t|	|| j!j&  d}&t|	| _'| j| | j| jf}'| j(|d>}(t)|	D ]0\})}*| j*r1q&|*| _	| jr@t|gd n|}+|+|j}+|*+|+jd },|,,dd|  },| j-d | j|+||,||| |!|'|dd
d }-W d   n	1 s|w   Y  |-. }-| jr|-/d\}.}/|.| j0|/|.   }-|,/d\},}0| j1dkrt2|-|/| j1d}-| 3|-|| |!| j4| j5}-| 3||| |!| j4| j5}|-ddddddf }-|ddddddf }1| j!j6|-|*|1ddd }2tj|ddddddf |2gdd}| 7|| j4| j5}|dur0i }3|D ]
}4t8 |4 |3|4< q|| |)|*|3}5|59d|}|59d|}|)t|	d ksK|)d |&krO|)d | j!j& dkrO|(:  t#rVt;<  q&W d   n	1 scw   Y  |dkrp|}6n| 3||| |!| j4| j5}| =|| j>j?| j>j@| j>jjA}||j}| j>jjBsd},nNtjC|j|||jd}7t |ts|g| }|du r|}nt |ts|g| }tjD|||jd	},tjD|||jd	ddddddf }d| | ||7  }| j>jE||,ddd }6| jjF|6|d }6| G  |s|6fS tH|6d!S )"u?  
        Function invoked when calling the pipeline for generation.

        Args:
            image (`PipelineImageInput`):
                The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
            prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
                instead.
            height (`int`, defaults to `512`):
                The height in pixels of the generated image. This is set to 480 by default for the best results.
            width (`int`, defaults to `704`):
                The width in pixels of the generated image. This is set to 848 by default for the best results.
            num_frames (`int`, defaults to `161`):
                The number of video frames to generate
            num_inference_steps (`int`, *optional*, defaults to 50):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            timesteps (`list[int]`, *optional*):
                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
                passed will be used. Must be in descending order.
            guidance_scale (`float`, defaults to `3 `):
                Guidance scale as defined in [Classifier-Free Diffusion
                Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
                of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
                `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
                the text `prompt`, usually at the expense of lower image quality.
            guidance_rescale (`float`, *optional*, defaults to 0.0):
                Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
                Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
                [Common Diffusion Noise Schedules and Sample Steps are
                Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
                using zero terminal SNR.
            num_videos_per_prompt (`int`, *optional*, defaults to 1):
                The number of videos to generate per prompt.
            generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
            latents (`torch.Tensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will be generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            prompt_attention_mask (`torch.Tensor`, *optional*):
                Pre-generated attention mask for text embeddings.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
                provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
            negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
                Pre-generated attention mask for negative text embeddings.
            decode_timestep (`float`, defaults to `0.0`):
                The timestep at which generated video is decoded.
            decode_noise_scale (`float`, defaults to `None`):
                The interpolation factor between random noise and denoised latents at the decode timestep.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
            attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            callback_on_step_end (`Callable`, *optional*):
                A function that calls at the end of each denoising steps during the inference. The function is called
                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
                `callback_on_step_end_tensor_inputs`.
            callback_on_step_end_tensor_inputs (`list`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
                `._callback_tensor_inputs` attribute of your pipeline class.
            max_sequence_length (`int` defaults to `128 `):
                Maximum sequence length to use with the `prompt`.

        Examples:

        Returns:
            [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
                returned where the first element is a list with the generated images.
        )rv   r   r   r   rT   rU   r   r   FNr   r   )
rv   r   r   rw   rT   rU   r   r   rx   r)   r   )r   r   r   r   base_image_seq_lenr   max_image_seq_lenr   r   r   r   r   cpu)r+   r%   )totalr   r   cond_uncond)
hidden_statesencoder_hidden_statestimestepencoder_attention_maskr   r   r   rope_interpolation_scaler   r   )rN   )r   r@   rT   latentr   )r   )frames)Ir   r   r   tensor_inputsr   r   r   r   r   r   r   rI   r4   r   r   r   r   r   r   rn   
preprocessr   ry   rY   ri   in_channelsr   float32rh   rf   nplinspacer'   r5   getXLA_AVAILABLEr9   maxorderr   progress_bar	enumerater   expandr   cache_contextfloatchunkr   rN   rR   r   rk   rm   stepr   localspopupdatexm	mark_stepr   rV   r   r   r   timestep_conditioningrandntensordecodepostprocess_videomaybe_free_model_hooksr   )8rt   r   rv   r   r   r   r   r   r(   r*   r   rN   rw   r<   r@   rT   r   rU   r   r   r   r   r   r   r   r   rx   r   r)   r   r   latent_num_frameslatent_heightlatent_widthvideo_sequence_lengthr+   r%   timestep_devicenum_warmup_stepsr   r  r   tlatent_model_inputr   
noise_prednoise_pred_uncondrM   r   noise_latentspred_latentscallback_kwargsr   callback_outputsvideor   r!   r!   r&   __call__T  s  u






	(

6
S



z LTXImageToVideoPipeline.__call__)Nr   r]   NN)
NTr   NNNNr]   NN)NNNNN)r   r   )r   )
Nr   r]   r^   r_   r   NNNN)1__name__
__module____qualname____doc__model_cpu_offload_seq_optional_componentsr   r   r   r   r   r   rb   r   rI   intr   r)   ry   r   r   Tensorr   r   staticmethodr   r   r  r   r   	Generatorr   propertyr   rN   r   r   r   r   r   no_gradr   EXAMPLE_DOC_STRINGr	   dictr   r   r-  __classcell__r!   r!   ru   r&   rS      s   *

4
	

Y
5"		

B







	

rS   )r   r   r   r   )NNNN)Nr:   )rE   )9r.   typingr   r   numpyr  r   transformersr   r   	callbacksr   r   image_processorr	   loadersr
   r   models.autoencodersr   models.transformersr   
schedulersr   utilsr   r   r   utils.torch_utilsr   rn   r   pipeline_utilsr   pipeline_outputr   torch_xla.core.xla_modelcore	xla_modelr  r	  
get_loggerr.  r   r:  r4  r  r'   r   r)   rI   r9   r5  r7  rD   rR   rS   r!   r!   r!   r&   <module>   st   
!



=

