"""
Benchmark LuxTTS with optimizations: fp16, torch.compile, k2 swoosh
"""
import torch
import time
import numpy as np
from zipvoice.luxvoice import LuxTTS
import soundfile as sf

REF_AUDIO = "/home/ubuntu/LuxTTS/ref_audio.wav"
SHORT_TEXT = "Hello, this is a concurrency test."
LONG_TEXT = "The quick brown fox jumps over the lazy dog. This is a test of the text to speech system for benchmarking purposes on the A100 GPU."

def profile_generate(tts, text, encoded, label="", n=5):
    """Run n generations and report stats."""
    times = []
    audio_durs = []
    for i in range(n):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        wav = tts.generate_speech(text, encoded, num_steps=4)
        torch.cuda.synchronize()
        elapsed = time.perf_counter() - t0
        audio_dur = len(wav[0]) / 48000
        times.append(elapsed)
        audio_durs.append(audio_dur)

    avg_time = np.mean(times[1:])  # skip first (warmup)
    avg_audio = np.mean(audio_durs[1:])
    rtf = avg_time / avg_audio
    print(f"  {label}: {avg_time:.4f}s wall, {avg_audio:.2f}s audio, RTF={rtf:.5f}, {1/rtf:.0f}x realtime")
    return avg_time, avg_audio

def main():
    print("=" * 60)
    print("LuxTTS Optimized Benchmark — A100 80GB")
    print("=" * 60)

    # Check k2
    try:
        import k2
        print(f"k2 version: {k2.__version__} (swoosh ENABLED)")
    except:
        print("k2 NOT available (swoosh DISABLED — slower)")

    # === BASELINE: fp32, no compile ===
    print("\n--- [A] BASELINE: fp32, no compile ---")
    tts = LuxTTS(device='cuda')
    encoded = tts.encode_prompt(REF_AUDIO, duration=5)

    mem = torch.cuda.memory_allocated() / 1024**3
    print(f"  Model VRAM: {mem:.2f} GB")

    profile_generate(tts, SHORT_TEXT, encoded, "short", n=5)
    profile_generate(tts, LONG_TEXT, encoded, "long", n=5)

    # === FP16 AUTOCAST ===
    print("\n--- [B] FP16 autocast ---")

    # Monkey-patch generate to use autocast
    from zipvoice.modeling_utils import generate as orig_generate
    def generate_fp16(*args, **kwargs):
        with torch.amp.autocast('cuda', dtype=torch.float16):
            return orig_generate(*args, **kwargs)

    import zipvoice.luxvoice as lux_module
    import zipvoice.modeling_utils as mod_utils
    # Patch at module level
    orig_gen = mod_utils.generate
    mod_utils.generate = generate_fp16

    # Need to re-import or patch the reference in luxvoice
    tts2 = LuxTTS(device='cuda')
    encoded2 = tts2.encode_prompt(REF_AUDIO, duration=5)

    # Direct call with autocast
    times_short = []
    times_long = []
    for i in range(5):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        with torch.amp.autocast('cuda', dtype=torch.float16):
            wav = orig_gen(
                encoded2['prompt_tokens'], encoded2['prompt_features_lens'],
                encoded2['prompt_features'], encoded2['prompt_rms'],
                SHORT_TEXT, tts2.model, tts2.vocos, tts2.tokenizer,
                num_step=4, guidance_scale=3.0, t_shift=0.5, speed=1.0
            )
        torch.cuda.synchronize()
        elapsed = time.perf_counter() - t0
        audio_dur = len(wav[0]) / 48000
        times_short.append((elapsed, audio_dur))

    avg = np.mean([t[0] for t in times_short[1:]])
    avg_a = np.mean([t[1] for t in times_short[1:]])
    print(f"  short fp16: {avg:.4f}s wall, {avg_a:.2f}s audio, {avg_a/avg:.0f}x realtime")

    for i in range(5):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        with torch.amp.autocast('cuda', dtype=torch.float16):
            wav = orig_gen(
                encoded2['prompt_tokens'], encoded2['prompt_features_lens'],
                encoded2['prompt_features'], encoded2['prompt_rms'],
                LONG_TEXT, tts2.model, tts2.vocos, tts2.tokenizer,
                num_step=4, guidance_scale=3.0, t_shift=0.5, speed=1.0
            )
        torch.cuda.synchronize()
        elapsed = time.perf_counter() - t0
        audio_dur = len(wav[0]) / 48000
        times_long.append((elapsed, audio_dur))

    avg = np.mean([t[0] for t in times_long[1:]])
    avg_a = np.mean([t[1] for t in times_long[1:]])
    print(f"  long fp16:  {avg:.4f}s wall, {avg_a:.2f}s audio, {avg_a/avg:.0f}x realtime")

    peak = torch.cuda.max_memory_allocated() / 1024**3
    print(f"  Peak VRAM: {peak:.2f} GB")

    # === TORCH.COMPILE on fm_decoder ===
    print("\n--- [C] torch.compile(fm_decoder) + fp16 ---")
    try:
        tts3 = LuxTTS(device='cuda')
        encoded3 = tts3.encode_prompt(REF_AUDIO, duration=5)

        print("  Compiling fm_decoder...")
        tts3.model.fm_decoder = torch.compile(tts3.model.fm_decoder, mode="reduce-overhead")

        # Warmup compile
        print("  Warmup (first call triggers compilation)...")
        for i in range(3):
            with torch.amp.autocast('cuda', dtype=torch.float16):
                wav = orig_gen(
                    encoded3['prompt_tokens'], encoded3['prompt_features_lens'],
                    encoded3['prompt_features'], encoded3['prompt_rms'],
                    SHORT_TEXT, tts3.model, tts3.vocos, tts3.tokenizer,
                    num_step=4, guidance_scale=3.0, t_shift=0.5, speed=1.0
                )

        # Benchmark
        for text, label in [(SHORT_TEXT, "short"), (LONG_TEXT, "long")]:
            times = []
            for i in range(5):
                torch.cuda.synchronize()
                t0 = time.perf_counter()
                with torch.amp.autocast('cuda', dtype=torch.float16):
                    wav = orig_gen(
                        encoded3['prompt_tokens'], encoded3['prompt_features_lens'],
                        encoded3['prompt_features'], encoded3['prompt_rms'],
                        text, tts3.model, tts3.vocos, tts3.tokenizer,
                        num_step=4, guidance_scale=3.0, t_shift=0.5, speed=1.0
                    )
                torch.cuda.synchronize()
                elapsed = time.perf_counter() - t0
                audio_dur = len(wav[0]) / 48000
                times.append((elapsed, audio_dur))
            avg = np.mean([t[0] for t in times])
            avg_a = np.mean([t[1] for t in times])
            print(f"  {label} compiled+fp16: {avg:.4f}s wall, {avg_a:.2f}s audio, {avg_a/avg:.0f}x realtime")
    except Exception as e:
        print(f"  torch.compile failed: {e}")

    # === Model parameter count ===
    print("\n--- Model Info ---")
    total_params = sum(p.numel() for p in tts.model.parameters())
    print(f"  Total params: {total_params/1e6:.1f}M")
    for name, module in tts.model.named_children():
        params = sum(p.numel() for p in module.parameters())
        if params > 0:
            print(f"    {name}: {params/1e6:.1f}M")

if __name__ == "__main__":
    main()
