"""
LTX-2 Concurrency Benchmark
Tests single vs concurrent video generation on the same GPU.
Uses the DistilledPipeline (fastest: 8 steps stage1, 4 steps stage2).
"""

import gc
import os
import sys
import time
import threading
import torch
import logging

logging.getLogger().setLevel(logging.INFO)

# Must be set before any CUDA allocation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import ltx_core.loader  # must import before quantization to avoid circular import
from ltx_core.quantization import QuantizationPolicy
from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
from ltx_pipelines.distilled import DistilledPipeline
from ltx_pipelines.utils.media_io import encode_video

# ── Paths ──
MODEL_DIR = "/home/ubuntu/ltx2-models"
DISTILLED_CKPT = f"{MODEL_DIR}/ltx-2.3-22b-distilled-fp8.safetensors"
SPATIAL_UP = f"{MODEL_DIR}/ltx-2.3-spatial-upscaler-x2-1.0.safetensors"
GEMMA_ROOT = f"{MODEL_DIR}/gemma3"

# ── Test prompts ──
PROMPTS = [
    "A golden retriever running through a field of wildflowers at sunset, cinematic lighting",
    "A futuristic cityscape at night with neon lights reflecting on wet streets, aerial drone shot",
    "Ocean waves crashing against rocky cliffs, slow motion, dramatic sky with storm clouds",
    "A chef preparing sushi in a traditional Japanese restaurant, close-up hands working, warm lighting",
]


def gpu_mem_mb():
    return torch.cuda.memory_allocated() / 1024**2


def gpu_mem_reserved_mb():
    return torch.cuda.memory_reserved() / 1024**2


def print_gpu_stats(label=""):
    alloc = gpu_mem_mb()
    reserved = gpu_mem_reserved_mb()
    total = torch.cuda.get_device_properties(0).total_memory / 1024**2
    print(f"[GPU {label}] Allocated: {alloc:.0f}MB | Reserved: {reserved:.0f}MB | Total: {total:.0f}MB | Free(est): {total - reserved:.0f}MB")


def run_single_generation(pipeline, prompt, seed, output_path, height, width, num_frames, frame_rate, label=""):
    """Run a single video generation and return timing info."""
    tiling_config = TilingConfig.default()
    video_chunks = get_video_chunks_number(num_frames, tiling_config)

    print(f"[{label}] Starting generation: {prompt[:50]}...")
    print_gpu_stats(label)
    t0 = time.time()

    video, audio = pipeline(
        prompt=prompt,
        seed=seed,
        height=height,
        width=width,
        num_frames=num_frames,
        frame_rate=frame_rate,
        images=[],
        tiling_config=tiling_config,
    )

    t_gen = time.time() - t0
    print(f"[{label}] Generation done in {t_gen:.1f}s")
    print_gpu_stats(label)

    t0_enc = time.time()
    encode_video(video=video, fps=frame_rate, audio=audio, output_path=output_path, video_chunks_number=video_chunks)
    t_enc = time.time() - t0_enc
    print(f"[{label}] Encoding done in {t_enc:.1f}s")

    return {"gen_time": t_gen, "enc_time": t_enc, "total": t_gen + t_enc}


@torch.inference_mode()
def main():
    # ── Configuration ──
    # Small resolution for faster testing: 512x768 final (256x384 stage1)
    HEIGHT = 512
    WIDTH = 768
    NUM_FRAMES = 41  # 8*5+1, ~1.7s at 24fps — short for benchmarking
    FRAME_RATE = 24.0
    # Using pre-quantized FP8 checkpoint — no runtime quantization needed
    print("=" * 70)
    print("LTX-2 Concurrency Benchmark")
    print("=" * 70)
    print(f"Resolution: {WIDTH}x{HEIGHT} (stage1: {WIDTH//2}x{HEIGHT//2})")
    print(f"Frames: {NUM_FRAMES} ({NUM_FRAMES/FRAME_RATE:.1f}s at {FRAME_RATE}fps)")
    print(f"Checkpoint: FP8 pre-quantized")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print()

    # ── Load pipeline ──
    print("Loading pipeline...")
    t0 = time.time()
    pipeline = DistilledPipeline(
        distilled_checkpoint_path=DISTILLED_CKPT,
        spatial_upsampler_path=SPATIAL_UP,
        gemma_root=GEMMA_ROOT,
        loras=[],
        quantization=QuantizationPolicy.fp8_cast(),  # handles upcast during inference for FP8 weights
    )
    load_time = time.time() - t0
    print(f"Pipeline loaded in {load_time:.1f}s")
    print_gpu_stats("after load")
    print()

    # ── Test 1: Single generation (warmup) ──
    print("=" * 70)
    print("TEST 1: Single generation (warmup)")
    print("=" * 70)
    os.makedirs("/home/ubuntu/ltx2_bench_output", exist_ok=True)
    r1 = run_single_generation(
        pipeline, PROMPTS[0], seed=42,
        output_path="/home/ubuntu/ltx2_bench_output/single_warmup.mp4",
        height=HEIGHT, width=WIDTH, num_frames=NUM_FRAMES, frame_rate=FRAME_RATE,
        label="warmup"
    )
    print()

    # ── Test 2: Single generation (timed) ──
    print("=" * 70)
    print("TEST 2: Single generation (timed)")
    print("=" * 70)
    torch.cuda.synchronize()
    r2 = run_single_generation(
        pipeline, PROMPTS[1], seed=43,
        output_path="/home/ubuntu/ltx2_bench_output/single_timed.mp4",
        height=HEIGHT, width=WIDTH, num_frames=NUM_FRAMES, frame_rate=FRAME_RATE,
        label="single"
    )
    print()

    # ── Test 3: Sequential 2 generations ──
    print("=" * 70)
    print("TEST 3: Sequential 2 generations")
    print("=" * 70)
    torch.cuda.synchronize()
    t0_seq = time.time()
    for i in range(2):
        run_single_generation(
            pipeline, PROMPTS[i], seed=42 + i,
            output_path=f"/home/ubuntu/ltx2_bench_output/seq_{i}.mp4",
            height=HEIGHT, width=WIDTH, num_frames=NUM_FRAMES, frame_rate=FRAME_RATE,
            label=f"seq-{i}"
        )
    torch.cuda.synchronize()
    seq_total = time.time() - t0_seq
    print(f"Sequential 2x total: {seq_total:.1f}s")
    print()

    # ── Test 4: Concurrent 2 generations (threads) ──
    print("=" * 70)
    print("TEST 4: Concurrent 2 generations (same pipeline, threads)")
    print("=" * 70)
    results_concurrent = [None, None]
    errors = [None, None]

    def gen_thread(idx):
        try:
            results_concurrent[idx] = run_single_generation(
                pipeline, PROMPTS[idx], seed=42 + idx,
                output_path=f"/home/ubuntu/ltx2_bench_output/concurrent_{idx}.mp4",
                height=HEIGHT, width=WIDTH, num_frames=NUM_FRAMES, frame_rate=FRAME_RATE,
                label=f"conc-{idx}"
            )
        except Exception as e:
            errors[idx] = e
            print(f"[conc-{idx}] ERROR: {e}")

    torch.cuda.synchronize()
    gc.collect()
    torch.cuda.empty_cache()
    print_gpu_stats("before concurrent")

    t0_conc = time.time()
    threads = [threading.Thread(target=gen_thread, args=(i,)) for i in range(2)]
    for t in threads:
        t.start()
    for t in threads:
        t.join()
    torch.cuda.synchronize()
    conc_total = time.time() - t0_conc
    print(f"Concurrent 2x total wall time: {conc_total:.1f}s")
    print()

    # ── Test 5: Concurrent with CUDA streams ──
    print("=" * 70)
    print("TEST 5: Concurrent 2 generations (separate CUDA streams)")
    print("=" * 70)
    streams = [torch.cuda.Stream() for _ in range(2)]
    results_streams = [None, None]
    errors_streams = [None, None]

    def gen_stream_thread(idx):
        try:
            with torch.cuda.stream(streams[idx]):
                results_streams[idx] = run_single_generation(
                    pipeline, PROMPTS[idx], seed=42 + idx,
                    output_path=f"/home/ubuntu/ltx2_bench_output/stream_{idx}.mp4",
                    height=HEIGHT, width=WIDTH, num_frames=NUM_FRAMES, frame_rate=FRAME_RATE,
                    label=f"stream-{idx}"
                )
        except Exception as e:
            errors_streams[idx] = e
            print(f"[stream-{idx}] ERROR: {e}")

    torch.cuda.synchronize()
    gc.collect()
    torch.cuda.empty_cache()
    print_gpu_stats("before streams")

    t0_streams = time.time()
    threads = [threading.Thread(target=gen_stream_thread, args=(i,)) for i in range(2)]
    for t in threads:
        t.start()
    for t in threads:
        t.join()
    torch.cuda.synchronize()
    streams_total = time.time() - t0_streams
    print(f"CUDA streams 2x total wall time: {streams_total:.1f}s")
    print()

    # ── Summary ──
    print("=" * 70)
    print("SUMMARY")
    print("=" * 70)
    print(f"Pipeline load time:          {load_time:.1f}s")
    print(f"Single gen (warmup):         {r1['gen_time']:.1f}s gen, {r1['total']:.1f}s total")
    print(f"Single gen (timed):          {r2['gen_time']:.1f}s gen, {r2['total']:.1f}s total")
    print(f"Sequential 2x:               {seq_total:.1f}s total")
    print(f"Concurrent 2x (threads):     {conc_total:.1f}s total")
    print(f"CUDA streams 2x:             {streams_total:.1f}s total")
    print()
    if conc_total > 0:
        print(f"Speedup (concurrent vs seq): {seq_total/conc_total:.2f}x")
    if streams_total > 0:
        print(f"Speedup (streams vs seq):    {seq_total/streams_total:.2f}x")
    print()

    # Check for OOM errors
    for i, e in enumerate(errors):
        if e:
            print(f"Concurrent thread {i} error: {e}")
    for i, e in enumerate(errors_streams):
        if e:
            print(f"Stream thread {i} error: {e}")

    # Peak memory
    print(f"Peak GPU memory allocated: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB")
    print(f"Peak GPU memory reserved:  {torch.cuda.max_memory_reserved()/1024**3:.2f} GB")


if __name__ == "__main__":
    main()
