"""
True TTFB benchmark: measure actual per-request processing time (no queue wait).
Then calculate how many GPUs needed for 500 concurrency with ms-level TTFB.
"""
import torch
import torch.nn as nn
import time
import numpy as np

import zipvoice.models.modules.scaling as scaling
class FSL(nn.Module):
    def forward(self, x):
        z = torch.tensor(0.0, dtype=x.dtype, device=x.device)
        return torch.logaddexp(z, x - 4.0) - 0.08 * x - 0.035
class FSR(nn.Module):
    def forward(self, x):
        z = torch.tensor(0.0, dtype=x.dtype, device=x.device)
        return torch.logaddexp(z, x - 1.0) - 0.08 * x - 0.313261687
scaling.SwooshL = FSL
scaling.SwooshR = FSR

from zipvoice.luxvoice import LuxTTS
from zipvoice.modeling_utils import generate as orig_generate
from batched_inference import batched_generate

REF_AUDIO = "/home/ubuntu/LuxTTS/ref_audio.wav"

TEXTS = [
    "Hello, this is a test.",
    "The quick brown fox jumps over the lazy dog near the riverbank.",
    "Welcome to the annual technology conference on artificial intelligence and machine learning.",
    "Ladies and gentlemen, welcome to the annual technology conference. Today we will discuss the latest advances in artificial intelligence, machine learning, and natural language processing. These fields have seen remarkable progress.",
]
LABELS = ["short (5w)", "medium (12w)", "long (14w)", "very long (30w)"]

def main():
    print("=" * 70)
    print("TRUE TTFB BENCHMARK — A100 80GB")
    print("=" * 70)

    tts = LuxTTS(device='cuda')
    enc = tts.encode_prompt(REF_AUDIO, duration=5)
    model = tts.model
    vocoder = tts.vocos
    tokenizer = tts.tokenizer

    # Warmup
    for _ in range(3):
        orig_generate(enc['prompt_tokens'], enc['prompt_features_lens'],
                      enc['prompt_features'], enc['prompt_rms'],
                      "warmup", model, vocoder, tokenizer, num_step=4)

    # ─── Part 1: Single request TTFB (B=1, no queue) ───
    print("\n╔══════════════════════════════════════════════════════════════════╗")
    print("║  PART 1: TRUE TTFB — Single Request (no queue, no contention)  ║")
    print("╚══════════════════════════════════════════════════════════════════╝")

    print(f"\n{'Text':<16} {'Steps':>5} {'TTFB(ms)':>10} {'Audio(s)':>9} {'Speed':>8}")
    print("-" * 55)

    for text, label in zip(TEXTS, LABELS):
        for steps in [1, 2, 4]:
            times = []
            adur = 0
            for _ in range(8):
                torch.cuda.synchronize()
                t0 = time.perf_counter()
                wav = orig_generate(
                    enc['prompt_tokens'], enc['prompt_features_lens'],
                    enc['prompt_features'], enc['prompt_rms'],
                    text, model, vocoder, tokenizer, num_step=steps)
                torch.cuda.synchronize()
                times.append(time.perf_counter() - t0)
                adur = wav.shape[-1] / 48000
            avg = np.mean(times[2:]) * 1000
            speed = adur / (avg / 1000)
            print(f"{label:<16} {steps:>5} {avg:>10.1f} {adur:>9.2f} {speed:>7.0f}x")
        print()

    # ─── Part 2: Batched TTFB (processing time per item in a batch) ───
    print("╔══════════════════════════════════════════════════════════════════╗")
    print("║  PART 2: BATCHED TTFB — Per-item time when batched             ║")
    print("╚══════════════════════════════════════════════════════════════════╝")

    text = TEXTS[1]  # medium text
    print(f"\nText: \"{text[:50]}...\"")
    print(f"\n{'Batch':>5} {'Steps':>5} {'Total(ms)':>10} {'TTFB/req(ms)':>13} {'Audio(s)':>9} {'Speed':>8} {'VRAM':>7}")
    print("-" * 65)

    for steps in [1, 2, 4]:
        for bs in [1, 8, 16, 32, 64]:
            texts = [text] * bs
            times = []
            adur = 0
            torch.cuda.reset_peak_memory_stats()
            for _ in range(5):
                torch.cuda.synchronize()
                t0 = time.perf_counter()
                results = batched_generate(
                    texts, model, vocoder, tokenizer,
                    enc['prompt_tokens'], enc['prompt_features'],
                    enc['prompt_features_lens'], enc['prompt_rms'],
                    num_steps=steps)
                torch.cuda.synchronize()
                times.append(time.perf_counter() - t0)
                adur = len(results[0]) / 48000
            avg_total = np.mean(times[1:]) * 1000
            ttfb = avg_total  # all items in batch complete together
            per_item = avg_total / bs
            speed = adur / (per_item / 1000)
            vram = torch.cuda.max_memory_allocated() / 1024**3
            print(f"{bs:>5} {steps:>5} {avg_total:>10.1f} {ttfb:>13.1f} {adur:>9.2f} {speed:>7.0f}x {vram:>6.1f}G")
        print()

    # ─── Part 3: What it takes for 500 concurrency with <500ms TTFB ───
    print("╔══════════════════════════════════════════════════════════════════╗")
    print("║  PART 3: 500 CONCURRENCY — GPU requirement calculator          ║")
    print("╚══════════════════════════════════════════════════════════════════╝")

    # Use measured numbers
    configs = [
        ("4 steps, B=32", 32, 4, None),
        ("4 steps, B=64", 64, 4, None),
        ("2 steps, B=32", 32, 2, None),
        ("2 steps, B=64", 64, 2, None),
        ("1 step,  B=64", 64, 1, None),
        ("1 step,  B=128", 128, 1, None),
    ]

    # Measure each config
    text = TEXTS[1]
    measured = []
    for label, bs, steps, _ in configs:
        texts = [text] * bs
        times = []
        for _ in range(4):
            torch.cuda.synchronize()
            t0 = time.perf_counter()
            results = batched_generate(
                texts, model, vocoder, tokenizer,
                enc['prompt_tokens'], enc['prompt_features'],
                enc['prompt_features_lens'], enc['prompt_rms'],
                num_steps=steps)
            torch.cuda.synchronize()
            times.append(time.perf_counter() - t0)
        batch_time_ms = np.mean(times[1:]) * 1000
        req_per_s = bs / (batch_time_ms / 1000)
        measured.append((label, bs, steps, batch_time_ms, req_per_s))

    print(f"\n{'Config':<20} {'Batch ms':>10} {'req/s':>8} {'TTFB(ms)':>10} {'GPUs for':>10} {'TTFB @500':>10}")
    print(f"{'':20} {'':>10} {'1 GPU':>8} {'1 req':>10} {'500 req/s':>10} {'per GPU':>10}")
    print("-" * 75)

    for label, bs, steps, batch_ms, rps in measured:
        ttfb_single = batch_ms  # TTFB for 1 request = batch time
        gpus_needed = int(np.ceil(500 / rps))
        # With enough GPUs, each GPU handles 500/gpus requests
        reqs_per_gpu = 500 / gpus_needed
        batches_per_gpu = int(np.ceil(reqs_per_gpu / bs))
        ttfb_at_500 = batches_per_gpu * batch_ms
        print(f"{label:<20} {batch_ms:>10.0f} {rps:>8.1f} {ttfb_single:>10.0f} {gpus_needed:>10} {ttfb_at_500:>10.0f}")

    print(f"\nNote: TTFB = time until audio is fully generated (non-streaming).")
    print(f"For streaming first-chunk delivery, split text into sentences:")
    print(f"  → First sentence TTFB would be ~{measured[0][3]/3:.0f}ms (1/3 the full text)")


if __name__ == "__main__":
    main()
