o
    Giq                  
   @   s  d dl Z d dlZd dlmZmZ d dlZd dlZd dlm	Z	m
Z
mZ ddlmZmZ ddlmZmZ ddlmZmZ ddlmZ dd	lmZ dd
lmZmZmZ ddlmZ ddlm Z  ddl!m"Z" ddl#m$Z$ ddl%m&Z& ddl'm(Z( e rd dl)m*  m+Z, dZ-ndZ-e.e/Z0dZ1				d+de2de2de3de3fddZ4				d,d e2dB d!e5ej6B dB d"e7e2 dB d#e7e3 dB fd$d%Z8d-d'd(Z9G d)d* d*e"eeZ:dS ).    N)AnyCallable)Gemma3ForConditionalGenerationGemmaTokenizerGemmaTokenizerFast   )MultiPipelineCallbacksPipelineCallback)FromSingleFileMixinLTX2LoraLoaderMixin)AutoencoderKLLTX2AudioAutoencoderKLLTX2Video)LTX2VideoTransformer3DModel)FlowMatchEulerDiscreteScheduler)is_torch_xla_availableloggingreplace_example_docstring)randn_tensor)VideoProcessor   )DiffusionPipeline   )LTX2TextConnectors)LTX2PipelineOutput)LTX2VocoderTFa$  
    Examples:
        ```py
        >>> import torch
        >>> from diffusers import LTX2Pipeline
        >>> from diffusers.pipelines.ltx2.export_utils import encode_video

        >>> pipe = LTX2Pipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16)
        >>> pipe.enable_model_cpu_offload()

        >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
        >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

        >>> frame_rate = 24.0
        >>> video, audio = pipe(
        ...     prompt=prompt,
        ...     negative_prompt=negative_prompt,
        ...     width=768,
        ...     height=512,
        ...     num_frames=121,
        ...     frame_rate=frame_rate,
        ...     num_inference_steps=40,
        ...     guidance_scale=4.0,
        ...     output_type="np",
        ...     return_dict=False,
        ... )

        >>> encode_video(
        ...     video[0],
        ...     fps=frame_rate,
        ...     audio=audio[0].float().cpu(),
        ...     audio_sample_rate=pipe.vocoder.config.output_sampling_rate,  # should be 24000
        ...     output_path="video.mp4",
        ... )
        ```
            ?ffffff?base_seq_lenmax_seq_len
base_shift	max_shiftc                 C   s,   || ||  }|||  }| | | }|S N )image_seq_lenr   r    r!   r"   mbmur$   r$   Z/home/ubuntu/.local/lib/python3.10/site-packages/diffusers/pipelines/ltx2/pipeline_ltx2.pycalculate_shiftU   s   r*   num_inference_stepsdevice	timestepssigmasc                 K   s  |dur|durt d|dur>dtt| jj v }|s(t d| j d| jd||d| | j}t	|}||fS |durpdtt| jj v }|sZt d| j d| jd||d	| | j}t	|}||fS | j|fd
|i| | j}||fS )a  
    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`list[int]`, *optional*):
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.
        sigmas (`list[float]`, *optional*):
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.

    Returns:
        `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    NzYOnly one of `timesteps` or `sigmas` can be passed. Please choose one to set custom valuesr-   zThe current scheduler class zx's `set_timesteps` does not support custom timestep schedules. Please check whether you are using the correct scheduler.)r-   r,   r.   zv's `set_timesteps` does not support custom sigmas schedules. Please check whether you are using the correct scheduler.)r.   r,   r,   r$   )

ValueErrorsetinspect	signatureset_timesteps
parameterskeys	__class__r-   len)	schedulerr+   r,   r-   r.   kwargsaccepts_timestepsaccept_sigmasr$   r$   r)   retrieve_timestepsc   s2   r<           c                 C   sX   |j ttd|jdd}| j ttd| jdd}| ||  }|| d| |   } | S )a  
    Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
    Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
    Flawed](https://huggingface.co/papers/2305.08891).

    Args:
        noise_cfg (`torch.Tensor`):
            The predicted noise tensor for the guided diffusion process.
        noise_pred_text (`torch.Tensor`):
            The predicted noise tensor for the text-guided diffusion process.
        guidance_rescale (`float`, *optional*, defaults to 0.0):
            A rescale factor applied to the noise predictions.

    Returns:
        noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
    r   Tdimkeepdim)stdlistrangendim)	noise_cfgnoise_pred_textguidance_rescalestd_textstd_cfgnoise_pred_rescaledr$   r$   r)   rescale_noise_cfg   s
   rK   c                ;       s`  e Zd ZdZdZg Zg dZdedede	de
deeB d	ed
edef fddZe			d}dejdejdeejB dedededejfddZ					d~deee B dedededejdB d ejdB fd!d"Z		#									ddeee B d$eee B dB d%eded&ejdB d'ejdB d(ejdB d)ejdB dededejdB d ejdB fd*d+Z					dd,d-Zedd.ejd/ed0edejfd1d2Ze	dd.ejd3ed4ed5ed/ed0edejfd6d7Z e	8dd.ejd9ejd:ejd;edejf
d<d=Z!e	8dd.ejd9ejd:ejd;edejf
d>d?Z"ed.ejd9ejd:ejfd@dAZ#ed.ejd9ejd:ejfdBdCZ$e	dd.ejdDeejB dEej%dB fdFdGZ&e	dd.ejd/edB d0edB dejfdHdIZ'e		dd.ejdJedKed/edB d0edB dejfdLdMZ(		N	O	P	Q	R				ddSedTed4ed5ed3edDed ejdB dejdB dEej%dB d.ejdB dejfdUdVZ)				W	R				ddSedTedXedKedDed ejdB dejdB dEej%dB d.ejdB dejfdYdZZ*e+d[d\ Z,e+d]d^ Z-e+d_d` Z.e+dadb Z/e+dcdd Z0e+dedf Z1e+dgdh Z2e3 e4e5dddOdPdQdidjdddkdRdRdddddddddRddld#ddd.gdfdeee B d$eee B dB d4ed5ed3edmednedoee dB dpee dqedredDededEej%eej% B dB d.ejdB dsejdB d&ejdB d(ejdB d'ejdB d)ejdB dteee B dueee B dB dvedwedxe6ee7f dB dye8eegdf dB dzee def8d{d|Z9  Z:S )LTX2PipelineaK  
    Pipeline for text-to-video generation.

    Reference: https://github.com/Lightricks/LTX-Video

    Args:
        transformer ([`LTXVideoTransformer3DModel`]):
            Conditional Transformer architecture to denoise the encoded video latents.
        scheduler ([`FlowMatchEulerDiscreteScheduler`]):
            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
        vae ([`AutoencoderKLLTXVideo`]):
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        text_encoder ([`T5EncoderModel`]):
            [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
            the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
        tokenizer (`CLIPTokenizer`):
            Tokenizer of class
            [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
        tokenizer (`T5TokenizerFast`):
            Second Tokenizer of class
            [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
        connectors ([`LTX2TextConnectors`]):
            Text connector stack used to adapt text encoder hidden states for the video and audio branches.
    z>text_encoder->connectors->transformer->vae->audio_vae->vocoder)latentsprompt_embedsnegative_prompt_embedsr8   vae	audio_vaetext_encoder	tokenizer
connectorstransformervocoderc	           	   
      sP  t    | j||||||||d t| dd d ur| jjnd| _t| dd d ur-| jjnd| _t| dd d ur<| j	j
nd| _t| dd d urK| j	jnd| _t| dd d ur[| jjjnd| _t| dd urj| jjjnd| _t| dd d urz| j	jjnd	| _t| dd d ur| j	jjnd
| _t| jd| _t| dd d ur| jj| _d S d| _d S )N)rP   rQ   rR   rS   rT   rU   rV   r8   rP          rQ      rU   r   i>     )vae_scale_factorrS      )super__init__register_modulesgetattrrP   spatial_compression_ratiovae_spatial_compression_ratiotemporal_compression_ratiovae_temporal_compression_ratiorQ   mel_compression_ratioaudio_vae_mel_compression_ratio$audio_vae_temporal_compression_ratiorU   config
patch_sizetransformer_spatial_patch_sizepatch_size_ttransformer_temporal_patch_sizesample_rateaudio_sampling_ratemel_hop_lengthaudio_hop_lengthr   video_processorrS   model_max_lengthtokenizer_max_length)	selfr8   rP   rQ   rR   rS   rT   rU   rV   r6   r$   r)   r^      s@   
zLTX2Pipeline.__init__leftrX   ư>text_hidden_statessequence_lengthsr,   padding_sidescale_factorepsreturnc                 C   s^  | j \}}}}	| j}
tj||dd}|dkr#||dddf k }n|dkr6||dddf  }||k}ntd| |ddddddf }| | d}|| |ddd}|jd	d
d||  }| | t	dj
d	d
d}| | t	djd	d
d}| | || |  }|| }|d}|ddd||	 }|| d}|j|
d}|S )a&  
        Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
        per-layer in a masked fashion (only over non-padded positions).

        Args:
            text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
                Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
            sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
                The number of valid (non-padded) tokens for each batch instance.
            device: (`str` or `torch.device`, *optional*):
                torch device to place the resulting embeddings on
            padding_side: (`str`, *optional*, defaults to `"left"`):
                Whether the text tokenizer performs padding on the `"left"` or `"right"`.
            scale_factor (`int`, *optional*, defaults to `8`):
                Scaling factor to multiply the normalized hidden states by.
            eps (`float`, *optional*, defaults to `1e-6`):
                A small positive value for numerical stability when performing normalization.

        Returns:
            `torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
                Normed and flattened text encoder hidden states.
        )r,   r   rightNrv   z,padding_side must be 'left' or 'right', got r=   r   )r   r   Tr>   infz-infr   dtype)shaper   torcharange	unsqueezer/   masked_fillviewsumfloataminamaxflattensqueezeexpandto)rx   ry   r,   rz   r{   r|   
batch_sizeseq_len
hidden_dim
num_layersoriginal_dtypetoken_indicesmaskstart_indicesmasked_text_hidden_statesnum_valid_positionsmasked_meanx_minx_maxnormalized_hidden_states	mask_flatr$   r$   r)   _pack_text_embeds  s,   

zLTX2Pipeline._pack_text_embedsr   r\   Npromptnum_videos_per_promptmax_sequence_lengthr   c                 C   sF  |p| j }|p
| jj}t|tr|gn|}t|}t| dddur1d| j_| jj	du r1| jj
| j_	dd |D }| j|d|dddd	}|j}	|j}
|	|}	|
|}
| j|	|
dd
}|j}tj|dd}|
jdd}| j|||| jj|d}|j|d}|j\}}}|d|d}||| |d}|
|d}
|
|d}
||
fS )a  
        Encodes the prompt into text encoder hidden states.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                prompt to be encoded
            device: (`str` or `torch.device`):
                torch device to place the resulting embeddings on
            dtype: (`torch.dtype`):
                torch dtype to cast the prompt embeds to
            max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt.
        rS   Nrv   c                 S   s   g | ]}|  qS r$   )strip).0pr$   r$   r)   
<listcomp>s  s    z9LTX2Pipeline._get_gemma_prompt_embeds.<locals>.<listcomp>
max_lengthTpt)paddingr   
truncationadd_special_tokensreturn_tensors)	input_idsattention_maskoutput_hidden_statesr   r?   )r,   rz   r{   r   r   )_execution_devicerR   r   
isinstancestrr7   r`   rS   rz   	pad_token	eos_tokenr   r   r   hidden_statesr   stackr   r   r   repeatr   )rt   r   r   r   r{   r,   r   r   text_inputstext_input_idsprompt_attention_masktext_encoder_outputstext_encoder_hidden_statesry   rN   _r   r$   r$   r)   _get_gemma_prompt_embedsR  sR   


z%LTX2Pipeline._get_gemma_prompt_embedsTnegative_promptdo_classifier_free_guidancerN   rO   r   negative_prompt_attention_maskc              
   C   s  |p| j }t|tr|gn|}|durt|}n|jd }|du r.| j|||	|
||d\}}|r|du r|p7d}t|trB||g n|}|dur_t|t|ur_tdt| dt| d|t|krxtd| d	t| d
| d	| d	| j|||	|
||d\}}||||fS )a"  
        Encodes the prompt into text encoder hidden states.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                prompt to be encoded
            negative_prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
                Whether to use classifier free guidance or not.
            num_videos_per_prompt (`int`, *optional*, defaults to 1):
                Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            device: (`torch.device`, *optional*):
                torch device
            dtype: (`torch.dtype`, *optional*):
                torch dtype
        Nr   )r   r   r   r{   r,   r    z?`negative_prompt` should be the same type to `prompt`, but got z != .z`negative_prompt`: z has batch size z, but `prompt`: zT. Please make sure that passed `negative_prompt` matches the batch size of `prompt`.)	r   r   r   r7   r   r   type	TypeErrorr/   )rt   r   r   r   r   rN   rO   r   r   r   r{   r,   r   r   r$   r$   r)   encode_prompt  sP   
)


	
	zLTX2Pipeline.encode_promptc	           	         st  |d dks|d dkrt d| d| d|d ur8t fdd|D s8t d j d	 fd
d|D  |d urK|d urKt d| d| d|d u rW|d u rWt d|d urnt|tsnt|tsnt dt| |d urz|d u rzt d|d ur|d u rt d|d ur|d ur|j|jkrt d|j d|j d|j|jkrt d|j d|j dd S d S d S )NrW   r   z8`height` and `width` have to be divisible by 32 but are z and r   c                 3   s    | ]}| j v V  qd S r#   _callback_tensor_inputsr   krt   r$   r)   	<genexpr>  s    

z,LTX2Pipeline.check_inputs.<locals>.<genexpr>z2`callback_on_step_end_tensor_inputs` has to be in z, but found c                    s   g | ]	}| j vr|qS r$   r   r   r   r$   r)   r     s    z-LTX2Pipeline.check_inputs.<locals>.<listcomp>zCannot forward both `prompt`: z and `prompt_embeds`: z2. Please make sure to only forward one of the two.zeProvide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.z2`prompt` has to be of type `str` or `list` but is zEMust provide `prompt_attention_mask` when specifying `prompt_embeds`.zWMust provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.zu`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but got: `prompt_embeds` z != `negative_prompt_embeds` z`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but got: `prompt_attention_mask` z% != `negative_prompt_attention_mask` )r/   allr   r   r   rB   r   r   )	rt   r   heightwidth"callback_on_step_end_tensor_inputsrN   rO   r   r   r$   r   r)   check_inputs  sR   zLTX2Pipeline.check_inputsrM   ri   rk   c              
   C   sl   | j \}}}}}|| }|| }	|| }
| |d|||	||
|} | dddddddd	dd	dd} | S )
Nr   r   r   rY      r   r         )r   reshapepermuter   )rM   ri   rk   r   num_channels
num_framesr   r   post_patch_num_framespost_patch_heightpost_patch_widthr$   r$   r)   _pack_latents&  s    (
zLTX2Pipeline._pack_latentsr   r   r   c              
   C   sV   |  d}| ||||d|||} | dddddddd	dd	dddd} | S )
Nr   r   rY   r   r   r   r   r   r   )sizer   r   r   )rM   r   r   r   ri   rk   r   r$   r$   r)   _unpack_latents=  s   
0zLTX2Pipeline._unpack_latents      ?latents_meanlatents_stdscaling_factorc                 C   sP   | ddddd| j| j}| ddddd| j| j}| | | | } | S Nr   r   r   r   r,   r   rM   r   r   r   r$   r$   r)   _normalize_latentsI  s   zLTX2Pipeline._normalize_latentsc                 C   sP   | ddddd| j| j}| ddddd| j| j}| | | | } | S r   r   r   r$   r$   r)   _denormalize_latentsT  s   z!LTX2Pipeline._denormalize_latentsc                 C   s,   | | j| j}| | j| j}| | | S r#   r   r,   r   rM   r   r   r$   r$   r)   _normalize_audio_latents^     z%LTX2Pipeline._normalize_audio_latentsc                 C   s,   | | j| j}| | j| j}| | | S r#   r   r   r$   r$   r)   _denormalize_audio_latentsd  r   z'LTX2Pipeline._denormalize_audio_latentsnoise_scale	generatorc                 C   s.   t | j|| j| jd}|| d| |   }|S )Nr   r,   r   r   )r   r   r,   r   )rM   r   r   noisenoised_latentsr$   r$   r)   _create_noised_statej  s   z!LTX2Pipeline._create_noised_statec           	      C   s   |d ur5|d ur5| j \}}}}|| }|| }| |d||||} | dddddddddd} | S | dddd} | S )Nr   r   r   rY   r   r   r   )r   r   r   r   	transpose)	rM   ri   rk   r   r   latent_lengthlatent_mel_binspost_patch_latent_lengthpost_patch_mel_binsr$   r$   r)   _pack_audio_latentsr  s   $z LTX2Pipeline._pack_audio_latentsr   num_mel_binsc                 C   sr   |d ur+|d ur+|  d}| |||d||} | dddddddddd} | S | dd|fdd} | S )Nr   r   r   r   rY   r   r   )r   r   r   r   	unflattenr   )rM   r   r   ri   rk   r   r$   r$   r)   _unpack_audio_latents  s   

$z"LTX2Pipeline._unpack_audio_latents         y   r=   r   num_channels_latentsc                 C   s  |
d ur=|
j dkr!| |
| jj| jj| jjj}
| |
| j| j	}
|
j dkr/t
d|
j d| |
||	}
|
j||dS || j }|| j }|d | j d }|||||f}t|	trot|	|krot
dt|	 d| d	t||	||d
}
| |
| j| j	}
|
S )Nr   r   $Provided `latents` tensor has shape @, but the expected shape is [batch_size, num_seq, num_features].r,   r   r   /You have passed a list of generators of length +, but requested an effective batch size of @. Make sure the batch size matches the length of the generators.r   )rD   r   rP   r   r   rh   r   r   rj   rl   r/   r   r   r   rb   rd   r   rB   r7   r   )rt   r   r  r   r   r   r   r   r,   r   rM   r   r$   r$   r)   prepare_latents  s:   





zLTX2Pipeline.prepare_latents@   audio_latent_lengthc
                 C   s   |	d ur5|	j dkr| |	}	|	j dkrtd|	j d| |	| jj| jj}	| |	||}	|	j	||dS || j
 }
||||
f}t|trXt||krXtdt| d| dt||||d	}	| |	}	|	S )
NrY   r   r  r  r  r  r  r  r   )rD   r   r/   r   r   rQ   r   r   r   r   rf   r   rB   r7   r   )rt   r   r  r  r   r   r   r,   r   rM   r   r   r$   r$   r)   prepare_audio_latents  s*   




z"LTX2Pipeline.prepare_audio_latentsc                 C      | j S r#   _guidance_scaler   r$   r$   r)   guidance_scale     zLTX2Pipeline.guidance_scalec                 C   r  r#   )_guidance_rescaler   r$   r$   r)   rG     r  zLTX2Pipeline.guidance_rescalec                 C   s
   | j dkS )Nr   r  r   r$   r$   r)   r     s   
z(LTX2Pipeline.do_classifier_free_guidancec                 C   r  r#   )_num_timestepsr   r$   r$   r)   num_timesteps  r  zLTX2Pipeline.num_timestepsc                 C   r  r#   )_current_timestepr   r$   r$   r)   current_timestep   r  zLTX2Pipeline.current_timestepc                 C   r  r#   )_attention_kwargsr   r$   r$   r)   attention_kwargs  r  zLTX2Pipeline.attention_kwargsc                 C   r  r#   )
_interruptr   r$   r$   r)   	interrupt  r  zLTX2Pipeline.interruptg      8@(   g      @pil
frame_rater+   r.   r-   r  rG   audio_latentsdecode_timestepdecode_noise_scaleoutput_typereturn_dictr  callback_on_step_endr   c           H      C   sZ  t |ttfr
|j}| j||||||||d |
| _|| _|| _d| _d| _	|dur2t |t
r2d}n|dur@t |tr@t|}n|jd }| j}| j||| j|||||||d
\}}}}| jrqtj||gdd}tj||gdd}d||j d }| j||d	d
\} }!}"|d | j d }#|| j }$|| j }%|dur|jdkrtd |j\}&}&}#}$}%n|jdkrtd|j d n	td|j d|#|$ |% }'| jjj}(|  || |(||||tj!|||
}|| })| j"| j# t$| j% }*t&|)|* }+|dur,|jdkrtd |j\}&}&}+}&n|jdkr#td|j d n	td|j dt'| dddur:| j(jj)nd},|,| j* }-t'| dddurO| j(jj+nd}.| j,|| |.|+|,|tj!|||d	}|du rpt-.dd| |n|}t/|'| j0j1dd| j0j1dd | j0j1d!d"| j0j1d#d$}/t23| j0}0t4|0|||	||/d%\}&}&t4| j0|||	||/d%\}	}t5t|	|| j0j6  d}1t|	| _7| j| | j| jf}2| jj8j9|jd |#|$|%|j:|d&}3| jj;<|jd |+|j:}4| jr|3=d'd(|3jd   }3|4=d'd(|4jd   }4| j>|d)A}5t?|	D ]3\}6}7| j@rq|7| _	| jr-t|gd* n|}8|8|j}8| jrAt|gd* n|}9|9|j}9|7A|8jd }:| jBd+@ | jdFi d,|8d-|9d.| d/|!d0|:d1|"d2|"d3|#d4|$d5|%d6|d7|+d8|3d9|4d:|d;d\};}<W d   n	1 sw   Y  |;$ };|<$ }<| jr|;Cd*\}=}>|=| jD|>|=   };|<Cd*\}?}@|?| jD|@|?   }<| jEdkrtF|;|>| jEd<};tF|<|@| jEd<}<| j0jG|;|7|dd=d }|0jG|<|7|dd=d }|dur i }A|D ]
}BtH |B |A|B< q|| |6|7|A}C|CId>|}|CId?|}|6t|	d ks;|6d |1kr?|6d | j0j6 dkr?|5J  tKrFtLM  qW d   n	1 sSw   Y  | N||#|$|%| jO| jP}| Q|| jRjS| jRjT| jRjjU}| V|| j(jS| j(jT}| jW||+|-d@}|dAkr|}D|}En||j}| jRjjXsd}:nMtY|j|||jdB}Ft |ts|g| }|du r|}nt |ts|g| }tjZ|||jdC}:tjZ|||jdCddddddf }d| | ||F  }|| jRj}| jRj[||:dd=d }D| j\j]|D|dD}D|| j(j}| j(j[|dd=d }G| ^|G}E| _  |s'|D|EfS t`|D|EdES )Gu=  
        Function invoked when calling the pipeline for generation.

        Args:
            prompt (`str` or `list[str]`, *optional*):
                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
                instead.
            height (`int`, *optional*, defaults to `512`):
                The height in pixels of the generated image. This is set to 480 by default for the best results.
            width (`int`, *optional*, defaults to `768`):
                The width in pixels of the generated image. This is set to 848 by default for the best results.
            num_frames (`int`, *optional*, defaults to `121`):
                The number of video frames to generate
            frame_rate (`float`, *optional*, defaults to `24.0`):
                The frames per second (FPS) of the generated video.
            num_inference_steps (`int`, *optional*, defaults to 40):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference.
            sigmas (`List[float]`, *optional*):
                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
                will be used.
            timesteps (`list[int]`, *optional*):
                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
                passed will be used. Must be in descending order.
            guidance_scale (`float`, *optional*, defaults to `4.0`):
                Guidance scale as defined in [Classifier-Free Diffusion
                Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
                of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
                `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
                the text `prompt`, usually at the expense of lower image quality.
            guidance_rescale (`float`, *optional*, defaults to 0.0):
                Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
                Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
                [Common Diffusion Noise Schedules and Sample Steps are
                Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
                using zero terminal SNR.
            noise_scale (`float`, *optional*, defaults to `0.0`):
                The interpolation factor between random noise and denoised latents at each timestep. Applying noise to
                the `latents` and `audio_latents` before continue denoising.
            num_videos_per_prompt (`int`, *optional*, defaults to 1):
                The number of videos to generate per prompt.
            generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
                to make generation deterministic.
            latents (`torch.Tensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will be generated by sampling using the supplied random `generator`.
            audio_latents (`torch.Tensor`, *optional*):
                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio
                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
                tensor will be generated by sampling using the supplied random `generator`.
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            prompt_attention_mask (`torch.Tensor`, *optional*):
                Pre-generated attention mask for text embeddings.
            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
                Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
                provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
            negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
                Pre-generated attention mask for negative text embeddings.
            decode_timestep (`float`, defaults to `0.0`):
                The timestep at which generated video is decoded.
            decode_noise_scale (`float`, defaults to `None`):
                The interpolation factor between random noise and denoised latents at the decode timestep.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generate image. Choose between
                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.ltx.LTX2PipelineOutput`] instead of a plain tuple.
            attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
                `self.processor` in
                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            callback_on_step_end (`Callable`, *optional*):
                A function that calls at the end of each denoising steps during the inference. The function is called
                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
                `callback_on_step_end_tensor_inputs`.
            callback_on_step_end_tensor_inputs (`List`, *optional*, defaults to `["latents"]`):
                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
                `._callback_tensor_inputs` attribute of your pipeline class.
            max_sequence_length (`int`, *optional*, defaults to `1024`):
                Maximum sequence length to use with the `prompt`.

        Examples:

        Returns:
            [`~pipelines.ltx.LTX2PipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.ltx.LTX2PipelineOutput`] is returned, otherwise a `tuple` is
                returned where the first element is a list with the generated images.
        )r   r   r   r   rN   rO   r   r   FNr   r   )
r   r   r   r   rN   rO   r   r   r   r,   r   g    .T)additive_maskr   zGot latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred.r   z,You have supplied packed `latents` of shape zp, so the latent dims cannot be inferred. Make sure the supplied `height`, `width`, and `num_frames` are correct.r  z, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width].rY   zsGot audio_latents of shape [batch_size, num_channels, audio_length, mel_bins], `audio_num_frames` will be inferred.z2You have supplied packed `audio_latents` of shape zj, so the latent dims cannot be inferred. Make sure the supplied `num_frames` and `frame_rate` are correct.z*Provided `audio_latents` tensor has shape z}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, num_channels, audio_length, mel_bins].rQ   r
  rX   )r  r  r   r   r   r,   r   rM   r   base_image_seq_lenr\   max_image_seq_lenr   r!   gffffff?r"   gffffff @)r.   r(   )fps)r   )r   )totalr   cond_uncondr   audio_hidden_statesencoder_hidden_statesaudio_encoder_hidden_statestimestepencoder_attention_maskaudio_encoder_attention_maskr   r   r   r'  audio_num_framesvideo_coordsaudio_coordsr  r"  )rG   )r"  rM   rN   )r   latentr   r  )r!  )framesaudior$   )ar   r	   r   tensor_inputsr   r  r  r  r  r  r   rB   r7   r   r   r   r   r   catr   r   rT   rd   rb   rD   loggerinfowarningr/   rU   rh   in_channelsr	  float32rn   rp   r   rg   roundr`   rQ   mel_binsrf   latent_channelsr  nplinspacer*   r8   getcopydeepcopyr<   maxorderr  ropeprepare_video_coordsr,   
audio_ropeprepare_audio_coordsr   progress_bar	enumerater  r   cache_contextchunkr  rG   rK   steplocalspopupdateXLA_AVAILABLExm	mark_stepr   rj   rl   r   rP   r   r   r   r   r   timestep_conditioningr   tensordecoderq   postprocess_videorV   maybe_free_model_hooksr   )Hrt   r   r   r   r   r   r  r+   r.   r-   r  rG   r   r   r   rM   r  rN   r   rO   r   r  r   r!  r"  r  r#  r   r   r   r,   additive_attention_maskconnector_prompt_embedsconnector_audio_prompt_embedsconnector_attention_masklatent_num_frameslatent_heightlatent_widthr   video_sequence_lengthr  
duration_saudio_latents_per_secondr0  r   r   num_channels_latents_audior(   audio_schedulernum_warmup_stepsrope_interpolation_scaler1  r2  rK  itlatent_model_inputaudio_latent_model_inputr-  noise_pred_videonoise_pred_audionoise_pred_video_uncondnoise_pred_video_textnoise_pred_audio_uncondnoise_pred_audio_textcallback_kwargsr   callback_outputsvideor5  r   generated_mel_spectrogramsr$   r$   r)   __call__  s"   







 
 



	


6R




zLTX2Pipeline.__call__)rv   rX   rw   )r   r\   rX   NN)NTr   NNNNr\   rX   NN)NNNNN)r   r   )r   r#   )NN)
r   r   r   r   r  r=   NNNN)	r   rX   r   r
  r=   NNNN);__name__
__module____qualname____doc__model_cpu_offload_seq_optional_componentsr   r   r   r   r   r   r   r   r   r   r^   staticmethodr   Tensorr   r,   intr   r   rB   r   r   boolr   r   r   r   r   r   r   r   	Generatorr   r   r   r	  r  propertyr  rG   r   r  r  r  r  no_gradr   EXAMPLE_DOC_STRINGdictr   r   rw  __classcell__r$   r$   ru   r)   rL      s   	8E

L
	

[
5"			

1	

'








	

rL   )r   r   r   r   )NNNN)r=   );rC  r1   typingr   r   numpyr@  r   transformersr   r   r   	callbacksr   r	   loadersr
   r   models.autoencodersr   r   models.transformersr   
schedulersr   utilsr   r   r   utils.torch_utilsr   rq   r   pipeline_utilsr   rT   r   pipeline_outputr   rV   r   torch_xla.core.xla_modelcore	xla_modelrT  rS  
get_loggerrx  r8  r  r  r   r*   r   r,   rB   r<   rK   rL   r$   r$   r$   r)   <module>   sh   
)




<