"""
Benchmark LuxTTS with swoosh fixed for inference + fp16 + torch.compile
"""
import torch
import torch.nn as nn
import time
import sys
import numpy as np
import soundfile as sf

# Patch swoosh BEFORE importing LuxTTS — replace slow autograd versions with simple forward-only
import zipvoice.models.modules.scaling as scaling

class FastSwooshL(nn.Module):
    def forward(self, x):
        zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
        return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035

class FastSwooshR(nn.Module):
    def forward(self, x):
        zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
        return torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687

scaling.SwooshL = FastSwooshL
scaling.SwooshR = FastSwooshR

from zipvoice.luxvoice import LuxTTS
from zipvoice.modeling_utils import generate as orig_generate

REF_AUDIO = "/home/ubuntu/LuxTTS/ref_audio.wav"
SHORT_TEXT = "Hello, this is a concurrency test."
LONG_TEXT = "The quick brown fox jumps over the lazy dog. This is a test of the text to speech system for benchmarking purposes on the A100 GPU."
VERY_LONG = "Ladies and gentlemen, welcome to the annual technology conference. Today we will discuss the latest advances in artificial intelligence, machine learning, and natural language processing. These fields have seen remarkable progress over the past few years, and we expect continued innovation in the years to come."

def bench(tts, encoded, text, label, n=10):
    times = []
    adurs = []
    for i in range(n):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        wav = orig_generate(
            encoded['prompt_tokens'], encoded['prompt_features_lens'],
            encoded['prompt_features'], encoded['prompt_rms'],
            text, tts.model, tts.vocos, tts.tokenizer,
            num_step=4, guidance_scale=3.0, t_shift=0.5, speed=1.0
        )
        torch.cuda.synchronize()
        elapsed = time.perf_counter() - t0
        audio_dur = len(wav[0]) / 48000
        times.append(elapsed)
        adurs.append(audio_dur)
    # skip first warmup
    avg_t = np.mean(times[1:])
    avg_a = np.mean(adurs[1:])
    rtf = avg_t / avg_a
    print(f"  {label:20s}: {avg_t*1000:.1f}ms wall, {avg_a:.2f}s audio, RTF={rtf:.5f}, {1/rtf:.0f}x realtime")
    return avg_t, avg_a

def bench_fp16(tts, encoded, text, label, n=10):
    times = []
    adurs = []
    for i in range(n):
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        with torch.amp.autocast('cuda', dtype=torch.float16):
            wav = orig_generate(
                encoded['prompt_tokens'], encoded['prompt_features_lens'],
                encoded['prompt_features'], encoded['prompt_rms'],
                text, tts.model, tts.vocos, tts.tokenizer,
                num_step=4, guidance_scale=3.0, t_shift=0.5, speed=1.0
            )
        torch.cuda.synchronize()
        elapsed = time.perf_counter() - t0
        audio_dur = len(wav[0]) / 48000
        times.append(elapsed)
        adurs.append(audio_dur)
    avg_t = np.mean(times[1:])
    avg_a = np.mean(adurs[1:])
    rtf = avg_t / avg_a
    print(f"  {label:20s}: {avg_t*1000:.1f}ms wall, {avg_a:.2f}s audio, RTF={rtf:.5f}, {1/rtf:.0f}x realtime")
    return avg_t, avg_a

def main():
    print("=" * 70)
    print("LuxTTS Speed Benchmark — A100 80GB — Fast Swoosh Patched")
    print("=" * 70)

    tts = LuxTTS(device='cuda')
    encoded = tts.encode_prompt(REF_AUDIO, duration=5)
    mem = torch.cuda.memory_allocated() / 1024**3
    print(f"Model VRAM: {mem:.2f} GB")

    # Count params
    total = sum(p.numel() for p in tts.model.parameters())
    print(f"Params: {total/1e6:.1f}M")

    # --- A: fp32 baseline with fast swoosh ---
    print("\n--- [A] fp32 + fast swoosh (no autograd) ---")
    bench(tts, encoded, SHORT_TEXT, "short (~2s audio)")
    bench(tts, encoded, LONG_TEXT, "long (~8s audio)")
    bench(tts, encoded, VERY_LONG, "very long (~15s audio)")

    # --- B: fp16 ---
    print("\n--- [B] fp16 + fast swoosh ---")
    bench_fp16(tts, encoded, SHORT_TEXT, "short (~2s audio)")
    bench_fp16(tts, encoded, LONG_TEXT, "long (~8s audio)")
    bench_fp16(tts, encoded, VERY_LONG, "very long (~15s audio)")

    # --- C: model.half() direct ---
    print("\n--- [C] model.half() + fast swoosh ---")
    tts.model = tts.model.half()
    tts.vocos = tts.vocos.half()
    # Re-encode in fp16
    encoded_h = {}
    for k, v in encoded.items():
        if isinstance(v, torch.Tensor) and v.is_floating_point():
            encoded_h[k] = v.half()
        else:
            encoded_h[k] = v

    bench(tts, encoded_h, SHORT_TEXT, "short (~2s audio)")
    bench(tts, encoded_h, LONG_TEXT, "long (~8s audio)")
    bench(tts, encoded_h, VERY_LONG, "very long (~15s audio)")

    peak = torch.cuda.max_memory_allocated() / 1024**3
    print(f"\nPeak VRAM: {peak:.2f} GB")

    # --- D: torch.compile on fp16 model ---
    print("\n--- [D] torch.compile + model.half() + fast swoosh ---")
    try:
        tts.model.fm_decoder = torch.compile(tts.model.fm_decoder, mode="reduce-overhead")
        print("  Warming up compile (3 runs)...")
        for _ in range(3):
            orig_generate(
                encoded_h['prompt_tokens'], encoded_h['prompt_features_lens'],
                encoded_h['prompt_features'], encoded_h['prompt_rms'],
                SHORT_TEXT, tts.model, tts.vocos, tts.tokenizer,
                num_step=4, guidance_scale=3.0, t_shift=0.5, speed=1.0
            )
        bench(tts, encoded_h, SHORT_TEXT, "short (~2s audio)")
        bench(tts, encoded_h, LONG_TEXT, "long (~8s audio)")
        bench(tts, encoded_h, VERY_LONG, "very long (~15s audio)")
    except Exception as e:
        print(f"  torch.compile failed: {e}")

    peak = torch.cuda.max_memory_allocated() / 1024**3
    print(f"\nFinal Peak VRAM: {peak:.2f} GB")

if __name__ == "__main__":
    main()
