"""
Final benchmark: test longer texts to see if we approach 150x,
and profile where time is spent.
"""
import torch
import torch.nn as nn
import time
import numpy as np

# Patch swoosh before import
import zipvoice.models.modules.scaling as scaling

class FastSwooshL(nn.Module):
    def forward(self, x):
        zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
        return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035

class FastSwooshR(nn.Module):
    def forward(self, x):
        zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
        return torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687

scaling.SwooshL = FastSwooshL
scaling.SwooshR = FastSwooshR

from zipvoice.luxvoice import LuxTTS
from zipvoice.modeling_utils import generate as orig_generate

REF_AUDIO = "/home/ubuntu/LuxTTS/ref_audio.wav"

TEXTS = {
    "1-sentence": "Hello, this is a test.",
    "2-sentences": "The quick brown fox jumps over the lazy dog. This is a benchmarking test.",
    "paragraph": "Ladies and gentlemen, welcome to the annual technology conference. Today we will discuss the latest advances in artificial intelligence, machine learning, and natural language processing. These fields have seen remarkable progress over the past few years.",
    "long-paragraph": "Ladies and gentlemen, welcome to the annual technology conference. Today we will discuss the latest advances in artificial intelligence, machine learning, and natural language processing. These fields have seen remarkable progress over the past few years, and we expect continued innovation in the years to come. Our keynote speakers include leading researchers from around the world who will share their insights on the future of computing, robotics, and automated reasoning systems.",
    "2-paragraphs": "Ladies and gentlemen, welcome to the annual technology conference. Today we will discuss the latest advances in artificial intelligence, machine learning, and natural language processing. These fields have seen remarkable progress over the past few years, and we expect continued innovation in the years to come. Our keynote speakers include leading researchers from around the world who will share their insights on the future of computing, robotics, and automated reasoning systems. In addition to our keynote presentations, we have organized workshops, panel discussions, and networking events to help attendees connect with peers and explore new collaboration opportunities across disciplines and industries.",
}

def main():
    print("=" * 70)
    print("LuxTTS Speed vs Text Length — A100 80GB")
    print("=" * 70)

    tts = LuxTTS(device='cuda')
    encoded = tts.encode_prompt(REF_AUDIO, duration=5)

    print(f"\n{'Text':<20} {'Words':>6} {'Wall(ms)':>10} {'Audio(s)':>10} {'RTF':>10} {'Speed':>8}")
    print("-" * 70)

    for label, text in TEXTS.items():
        words = len(text.split())
        times = []
        adurs = []
        for i in range(8):
            torch.cuda.synchronize()
            t0 = time.perf_counter()
            wav = orig_generate(
                encoded['prompt_tokens'], encoded['prompt_features_lens'],
                encoded['prompt_features'], encoded['prompt_rms'],
                text, tts.model, tts.vocos, tts.tokenizer,
                num_step=4, guidance_scale=3.0, t_shift=0.5, speed=1.0
            )
            torch.cuda.synchronize()
            elapsed = time.perf_counter() - t0
            audio_dur = len(wav[0]) / 48000
            times.append(elapsed)
            adurs.append(audio_dur)

        avg_t = np.mean(times[2:])  # skip 2 warmup
        avg_a = np.mean(adurs[2:])
        rtf = avg_t / avg_a
        print(f"{label:<20} {words:>6} {avg_t*1000:>10.1f} {avg_a:>10.2f} {rtf:>10.5f} {1/rtf:>7.0f}x")

    # --- Profile breakdown for long text ---
    print("\n--- Time Breakdown (long-paragraph) ---")
    text = TEXTS["long-paragraph"]

    # Tokenize
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    tokens = tts.tokenizer.texts_to_token_ids([text])
    torch.cuda.synchronize()
    tok_time = time.perf_counter() - t0
    print(f"  Tokenization: {tok_time*1000:.1f}ms")

    # Full generate but instrument model.sample
    import zipvoice.models.zipvoice as zv
    orig_sample = tts.model.sample

    # Time the model.sample call
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    wav = orig_generate(
        encoded['prompt_tokens'], encoded['prompt_features_lens'],
        encoded['prompt_features'], encoded['prompt_rms'],
        text, tts.model, tts.vocos, tts.tokenizer,
        num_step=4, guidance_scale=3.0, t_shift=0.5, speed=1.0
    )
    torch.cuda.synchronize()
    total = time.perf_counter() - t0
    print(f"  Total generate: {total*1000:.1f}ms")

    # --- Num steps comparison ---
    print("\n--- Steps vs Speed (long-paragraph) ---")
    for steps in [1, 2, 3, 4, 6, 8]:
        times = []
        for i in range(6):
            torch.cuda.synchronize()
            t0 = time.perf_counter()
            wav = orig_generate(
                encoded['prompt_tokens'], encoded['prompt_features_lens'],
                encoded['prompt_features'], encoded['prompt_rms'],
                text, tts.model, tts.vocos, tts.tokenizer,
                num_step=steps, guidance_scale=3.0, t_shift=0.5, speed=1.0
            )
            torch.cuda.synchronize()
            elapsed = time.perf_counter() - t0
            audio_dur = len(wav[0]) / 48000
            times.append(elapsed)

        avg_t = np.mean(times[1:])
        rtf = avg_t / audio_dur
        print(f"  steps={steps}: {avg_t*1000:.1f}ms, {1/rtf:.0f}x realtime")

    peak = torch.cuda.max_memory_allocated() / 1024**3
    print(f"\nPeak VRAM: {peak:.2f} GB")

if __name__ == "__main__":
    main()
