"""
LTX-2 Persistent Inference Server
Loads all models ONCE into GPU memory, reuses across generations.
Avoids the default pipeline behavior of rebuilding models on every call.
"""

import gc
import os
import time
import logging

os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")

import torch
import ltx_core.loader  # break circular import

from ltx_core.batch_split import BatchSplitAdapter
from ltx_core.components.diffusion_steps import EulerDiffusionStep
from ltx_core.components.noisers import GaussianNoiser
from ltx_core.components.patchifiers import AudioPatchifier, VideoLatentPatchifier
from ltx_core.loader.registry import DummyRegistry
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder
from ltx_core.model.audio_vae import (
    AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
    VOCODER_COMFY_KEYS_FILTER,
    AudioDecoderConfigurator,
    VocoderConfigurator,
    decode_audio as vae_decode_audio,
)
from ltx_core.model.transformer import (
    LTXV_MODEL_COMFY_RENAMING_MAP,
    LTXModelConfigurator,
    X0Model,
)
from ltx_core.model.upsampler import LatentUpsamplerConfigurator, upsample_video
from ltx_core.model.video_vae import (
    VAE_DECODER_COMFY_KEYS_FILTER,
    VAE_ENCODER_COMFY_KEYS_FILTER,
    TilingConfig,
    VideoDecoderConfigurator,
    VideoEncoderConfigurator,
    get_video_chunks_number,
)
from ltx_core.quantization import QuantizationPolicy
from ltx_core.text_encoders.gemma import (
    EMBEDDINGS_PROCESSOR_KEY_OPS,
    GEMMA_LLM_KEY_OPS,
    GEMMA_MODEL_OPS,
    EmbeddingsProcessorConfigurator,
    GemmaTextEncoderConfigurator,
    module_ops_from_gemma_root,
)
from ltx_core.tools import AudioLatentTools, VideoLatentTools
from ltx_core.types import AudioLatentShape, LatentState, VideoLatentShape, VideoPixelShape
from ltx_core.utils import find_matching_file
from ltx_core.loader import SDOps

from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
from ltx_pipelines.utils.helpers import (
    assert_resolution,
    combined_image_conditionings,
    create_noised_state,
)
from ltx_pipelines.utils.samplers import euler_denoising_loop
from ltx_pipelines.utils.denoisers import SimpleDenoiser
from ltx_pipelines.utils.types import ModalitySpec
from ltx_pipelines.utils.media_io import encode_video

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
logger = logging.getLogger(__name__)


class LTX2PersistentEngine:
    """Load all models once, keep in GPU memory, run inference fast."""

    def __init__(
        self,
        distilled_checkpoint_path: str,
        spatial_upsampler_path: str,
        gemma_root: str,
        quantization: QuantizationPolicy | None = None,
        device: torch.device | None = None,
    ):
        self.device = device or torch.device("cuda")
        self.dtype = torch.bfloat16
        registry = DummyRegistry()

        t0 = time.time()

        # ── 1. Text encoder (Gemma 3) ──
        logger.info("Loading Gemma text encoder...")
        t1 = time.time()
        module_ops = module_ops_from_gemma_root(gemma_root)
        model_folder = find_matching_file(gemma_root, "model*.safetensors").parent
        weight_paths = [str(p) for p in model_folder.rglob("*.safetensors")]
        text_enc_builder = Builder(
            model_path=tuple(weight_paths),
            model_class_configurator=GemmaTextEncoderConfigurator,
            model_sd_ops=GEMMA_LLM_KEY_OPS,
            module_ops=(GEMMA_MODEL_OPS, *module_ops),
            registry=registry,
        )
        self.text_encoder = text_enc_builder.build(device=self.device, dtype=self.dtype).eval()
        logger.info(f"  Text encoder loaded in {time.time()-t1:.1f}s | VRAM: {torch.cuda.memory_allocated()/1024**3:.1f}GB")

        # ── 2. Embeddings processor ──
        logger.info("Loading embeddings processor...")
        t1 = time.time()
        emb_builder = Builder(
            model_path=distilled_checkpoint_path,
            model_class_configurator=EmbeddingsProcessorConfigurator,
            model_sd_ops=EMBEDDINGS_PROCESSOR_KEY_OPS,
            registry=registry,
        )
        self.embeddings_processor = emb_builder.build(device=self.device, dtype=self.dtype).to(self.device).eval()
        logger.info(f"  Embeddings processor loaded in {time.time()-t1:.1f}s | VRAM: {torch.cuda.memory_allocated()/1024**3:.1f}GB")

        # ── 3. Video encoder (for image conditioning + upsampling) ──
        logger.info("Loading video encoder...")
        t1 = time.time()
        enc_builder = Builder(
            model_path=distilled_checkpoint_path,
            model_class_configurator=VideoEncoderConfigurator,
            model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER,
            registry=registry,
        )
        self.video_encoder = enc_builder.build(device=self.device, dtype=self.dtype).to(self.device).eval()
        logger.info(f"  Video encoder loaded in {time.time()-t1:.1f}s | VRAM: {torch.cuda.memory_allocated()/1024**3:.1f}GB")

        # ── 4. Spatial upsampler ──
        logger.info("Loading spatial upsampler...")
        t1 = time.time()
        up_builder = Builder(
            model_path=spatial_upsampler_path,
            model_class_configurator=LatentUpsamplerConfigurator,
            registry=registry,
        )
        self.upsampler = up_builder.build(device=self.device, dtype=self.dtype).to(self.device).eval()
        logger.info(f"  Upsampler loaded in {time.time()-t1:.1f}s | VRAM: {torch.cuda.memory_allocated()/1024**3:.1f}GB")

        # ── 5. Transformer (the big one) ──
        logger.info("Loading transformer...")
        t1 = time.time()
        trans_builder = Builder(
            model_path=distilled_checkpoint_path,
            model_class_configurator=LTXModelConfigurator,
            model_sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP,
            loras=(),
            registry=registry,
        )
        # Apply quantization sd_ops if needed
        sd_ops = trans_builder.model_sd_ops
        module_ops_trans = trans_builder.module_ops
        if quantization is not None:
            module_ops_trans = (*module_ops_trans, *quantization.module_ops)
            sd_ops = SDOps(
                name=f"chain_{sd_ops.name}+{quantization.sd_ops.name}",
                mapping=(*sd_ops.mapping, *quantization.sd_ops.mapping),
            )
        trans_builder = trans_builder.with_module_ops(module_ops_trans).with_sd_ops(sd_ops)
        self.transformer = X0Model(trans_builder.build(device=self.device)).to(self.device).eval()
        logger.info(f"  Transformer loaded in {time.time()-t1:.1f}s | VRAM: {torch.cuda.memory_allocated()/1024**3:.1f}GB")

        # ── 6. Video decoder ──
        logger.info("Loading video decoder...")
        t1 = time.time()
        dec_builder = Builder(
            model_path=distilled_checkpoint_path,
            model_class_configurator=VideoDecoderConfigurator,
            model_sd_ops=VAE_DECODER_COMFY_KEYS_FILTER,
            registry=registry,
        )
        self.video_decoder = dec_builder.build(device=self.device, dtype=self.dtype).to(self.device).eval()
        logger.info(f"  Video decoder loaded in {time.time()-t1:.1f}s | VRAM: {torch.cuda.memory_allocated()/1024**3:.1f}GB")

        # ── 7. Audio decoder + vocoder ──
        logger.info("Loading audio decoder + vocoder...")
        t1 = time.time()
        audio_dec_builder = Builder(
            model_path=distilled_checkpoint_path,
            model_class_configurator=AudioDecoderConfigurator,
            model_sd_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
            registry=registry,
        )
        self.audio_decoder = audio_dec_builder.build(device=self.device, dtype=self.dtype).to(self.device).eval()

        vocoder_builder = Builder(
            model_path=distilled_checkpoint_path,
            model_class_configurator=VocoderConfigurator,
            model_sd_ops=VOCODER_COMFY_KEYS_FILTER,
            registry=registry,
        )
        self.vocoder = vocoder_builder.build(device=self.device, dtype=self.dtype).to(self.device).eval()
        logger.info(f"  Audio decoder+vocoder loaded in {time.time()-t1:.1f}s | VRAM: {torch.cuda.memory_allocated()/1024**3:.1f}GB")

        total_load = time.time() - t0
        logger.info(f"ALL MODELS LOADED in {total_load:.1f}s")
        logger.info(f"Total VRAM: {torch.cuda.memory_allocated()/1024**3:.2f} GB / {torch.cuda.get_device_properties(0).total_memory/1024**3:.1f} GB")

    @torch.inference_mode()
    def generate(
        self,
        prompt: str,
        seed: int = 42,
        height: int = 512,
        width: int = 768,
        num_frames: int = 41,
        frame_rate: float = 24.0,
        output_path: str | None = None,
    ) -> dict:
        """Generate a video. Returns timing dict."""
        assert_resolution(height=height, width=width, is_two_stage=True)
        timings = {}
        t_total = time.time()

        generator = torch.Generator(device=self.device).manual_seed(seed)
        noiser = GaussianNoiser(generator=generator)

        # ── Encode prompt ──
        t0 = time.time()
        raw_hs, mask = self.text_encoder.encode(prompt)
        ctx = self.embeddings_processor.process_hidden_states(raw_hs, mask)
        video_context, audio_context = ctx.video_encoding, ctx.audio_encoding
        timings["prompt_encode"] = time.time() - t0

        # ── Stage 1: Low-res generation ──
        t0 = time.time()
        stage_1_sigmas = torch.Tensor(DISTILLED_SIGMA_VALUES).to(self.device)
        s1_w, s1_h = width // 2, height // 2

        # Image conditioning (none for text-to-video)
        stage_1_conditionings = combined_image_conditionings(
            images=[], height=s1_h, width=s1_w,
            video_encoder=self.video_encoder, dtype=self.dtype, device=self.device,
        )

        pixel_shape = VideoPixelShape(batch=1, frames=num_frames, height=s1_h, width=s1_w, fps=frame_rate)
        v_shape = VideoLatentShape.from_pixel_shape(pixel_shape)
        video_tools = VideoLatentTools(VideoLatentPatchifier(patch_size=1), v_shape, frame_rate)
        a_shape = AudioLatentShape.from_video_pixel_shape(pixel_shape)
        audio_tools = AudioLatentTools(AudioPatchifier(patch_size=1), a_shape)

        video_spec = ModalitySpec(context=video_context, conditionings=stage_1_conditionings)
        audio_spec = ModalitySpec(context=audio_context)

        video_state = create_noised_state(
            tools=video_tools, conditionings=video_spec.conditionings,
            noiser=noiser, dtype=self.dtype, device=self.device,
        )
        audio_state = create_noised_state(
            tools=audio_tools, conditionings=audio_spec.conditionings,
            noiser=noiser, dtype=self.dtype, device=self.device,
        )

        denoiser = SimpleDenoiser(video_context, audio_context)
        stepper = EulerDiffusionStep()
        transformer = BatchSplitAdapter(self.transformer, max_batch_size=1)

        video_state, audio_state = euler_denoising_loop(
            sigmas=stage_1_sigmas,
            video_state=video_state,
            audio_state=audio_state,
            stepper=stepper,
            transformer=transformer,
            denoiser=denoiser,
        )

        video_state = video_tools.clear_conditioning(video_state)
        video_state = video_tools.unpatchify(video_state)
        audio_state = audio_tools.clear_conditioning(audio_state)
        audio_state = audio_tools.unpatchify(audio_state)
        timings["stage1_denoise"] = time.time() - t0

        # ── Upsample ──
        t0 = time.time()
        upscaled_video_latent = upsample_video(
            latent=video_state.latent[:1],
            video_encoder=self.video_encoder,
            upsampler=self.upsampler,
        )
        timings["upsample"] = time.time() - t0

        # ── Stage 2: Refine at full resolution ──
        t0 = time.time()
        stage_2_sigmas = torch.Tensor(STAGE_2_DISTILLED_SIGMA_VALUES).to(self.device)
        stage_2_conditionings = combined_image_conditionings(
            images=[], height=height, width=width,
            video_encoder=self.video_encoder, dtype=self.dtype, device=self.device,
        )

        pixel_shape_2 = VideoPixelShape(batch=1, frames=num_frames, height=height, width=width, fps=frame_rate)
        v_shape_2 = VideoLatentShape.from_pixel_shape(pixel_shape_2)
        video_tools_2 = VideoLatentTools(VideoLatentPatchifier(patch_size=1), v_shape_2, frame_rate)
        a_shape_2 = AudioLatentShape.from_video_pixel_shape(pixel_shape_2)
        audio_tools_2 = AudioLatentTools(AudioPatchifier(patch_size=1), a_shape_2)

        video_spec_2 = ModalitySpec(
            context=video_context, conditionings=stage_2_conditionings,
            noise_scale=stage_2_sigmas[0].item(), initial_latent=upscaled_video_latent,
        )
        audio_spec_2 = ModalitySpec(
            context=audio_context,
            noise_scale=stage_2_sigmas[0].item(), initial_latent=audio_state.latent,
        )

        video_state_2 = create_noised_state(
            tools=video_tools_2, conditionings=video_spec_2.conditionings,
            noiser=noiser, dtype=self.dtype, device=self.device,
            noise_scale=video_spec_2.noise_scale, initial_latent=video_spec_2.initial_latent,
        )
        audio_state_2 = create_noised_state(
            tools=audio_tools_2, conditionings=audio_spec_2.conditionings,
            noiser=noiser, dtype=self.dtype, device=self.device,
            noise_scale=audio_spec_2.noise_scale, initial_latent=audio_spec_2.initial_latent,
        )

        denoiser_2 = SimpleDenoiser(video_context, audio_context)
        video_state_2, audio_state_2 = euler_denoising_loop(
            sigmas=stage_2_sigmas,
            video_state=video_state_2,
            audio_state=audio_state_2,
            stepper=stepper,
            transformer=transformer,
            denoiser=denoiser_2,
        )

        video_state_2 = video_tools_2.clear_conditioning(video_state_2)
        video_state_2 = video_tools_2.unpatchify(video_state_2)
        audio_state_2 = audio_tools_2.clear_conditioning(audio_state_2)
        audio_state_2 = audio_tools_2.unpatchify(audio_state_2)
        timings["stage2_denoise"] = time.time() - t0

        # ── Decode ──
        t0 = time.time()
        tiling_config = TilingConfig.default()
        video_chunks = self.video_decoder.decode_video(video_state_2.latent, tiling_config, generator)
        audio = vae_decode_audio(audio_state_2.latent, self.audio_decoder, self.vocoder)
        timings["decode"] = time.time() - t0

        # ── Encode to file ──
        if output_path:
            t0 = time.time()
            video_chunks_number = get_video_chunks_number(num_frames, tiling_config)
            encode_video(video=video_chunks, fps=frame_rate, audio=audio,
                         output_path=output_path, video_chunks_number=video_chunks_number)
            timings["file_encode"] = time.time() - t0

        timings["total"] = time.time() - t_total
        return timings


def main():
    MODEL_DIR = "/home/ubuntu/ltx2-models"
    DISTILLED_CKPT = f"{MODEL_DIR}/ltx-2.3-22b-distilled-fp8.safetensors"
    SPATIAL_UP = f"{MODEL_DIR}/ltx-2.3-spatial-upscaler-x2-1.0.safetensors"
    GEMMA_ROOT = f"{MODEL_DIR}/gemma3"

    engine = LTX2PersistentEngine(
        distilled_checkpoint_path=DISTILLED_CKPT,
        spatial_upsampler_path=SPATIAL_UP,
        gemma_root=GEMMA_ROOT,
        quantization=QuantizationPolicy.fp8_cast(),
    )

    prompts = [
        "A golden retriever running through a field of wildflowers at sunset, cinematic lighting",
        "A futuristic cityscape at night with neon lights reflecting on wet streets, aerial drone shot",
        "Ocean waves crashing against rocky cliffs, slow motion, dramatic sky with storm clouds",
    ]

    os.makedirs("/home/ubuntu/ltx2_bench_output", exist_ok=True)

    # Warmup
    logger.info("=" * 60)
    logger.info("WARMUP RUN")
    logger.info("=" * 60)
    t = engine.generate(
        prompt=prompts[0], seed=42, num_frames=41,
        output_path="/home/ubuntu/ltx2_bench_output/persistent_warmup.mp4",
    )
    logger.info(f"Warmup: {t}")

    # Timed runs
    for i, prompt in enumerate(prompts):
        logger.info("=" * 60)
        logger.info(f"RUN {i+1}")
        logger.info("=" * 60)
        t = engine.generate(
            prompt=prompt, seed=42 + i, num_frames=41,
            output_path=f"/home/ubuntu/ltx2_bench_output/persistent_{i}.mp4",
        )
        logger.info(f"Run {i+1} timings: {t}")
        for k, v in t.items():
            logger.info(f"  {k}: {v:.2f}s")

    logger.info(f"Peak VRAM: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB")


if __name__ == "__main__":
    main()
