"""
Benchmark LuxTTS concurrency on A100.
Tests: single request latency, then concurrent throughput.
"""
import torch
import time
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed
from zipvoice.luxvoice import LuxTTS
import soundfile as sf
import os

# --- Create a short reference audio if none exists ---
REF_AUDIO = "/home/ubuntu/LuxTTS/ref_audio.wav"
TEST_TEXT = "The quick brown fox jumps over the lazy dog. This is a test of the text to speech system for benchmarking purposes."
SHORT_TEXT = "Hello, this is a concurrency test."

def create_ref_audio():
    """Create a simple sine wave as reference audio."""
    if os.path.exists(REF_AUDIO):
        return
    sr = 24000
    duration = 5
    t = np.linspace(0, duration, sr * duration)
    # Simple tone
    audio = 0.3 * np.sin(2 * np.pi * 440 * t).astype(np.float32)
    sf.write(REF_AUDIO, audio, sr)
    print(f"Created reference audio: {REF_AUDIO}")

def benchmark():
    create_ref_audio()

    print("=" * 60)
    print("LuxTTS Concurrency Benchmark — A100 80GB")
    print("=" * 60)

    # Load model
    print("\n[1] Loading model...")
    t0 = time.time()
    tts = LuxTTS(device='cuda')
    load_time = time.time() - t0
    print(f"    Model loaded in {load_time:.1f}s")

    # Check GPU memory after load
    mem_used = torch.cuda.memory_allocated() / 1024**3
    mem_reserved = torch.cuda.memory_reserved() / 1024**3
    print(f"    GPU memory: {mem_used:.2f} GB allocated, {mem_reserved:.2f} GB reserved")

    # Encode prompt
    print("\n[2] Encoding reference prompt...")
    t0 = time.time()
    encoded = tts.encode_prompt(REF_AUDIO, duration=5)
    encode_time = time.time() - t0
    print(f"    Encoded in {encode_time:.2f}s")

    # --- Single request latency ---
    print("\n[3] Single request latency (warmup)...")
    for i in range(2):
        t0 = time.time()
        wav = tts.generate_speech(SHORT_TEXT, encoded, num_steps=4)
        elapsed = time.time() - t0
        audio_dur = len(wav[0]) / 48000
        rtf = elapsed / audio_dur
        print(f"    Run {i+1}: {elapsed:.3f}s wall, {audio_dur:.2f}s audio, RTF={rtf:.4f} ({1/rtf:.0f}x realtime)")

    # Save a sample
    sf.write("/home/ubuntu/LuxTTS/benchmark_sample.wav", wav[0].numpy(), 48000)

    # --- Longer text latency ---
    print("\n[4] Longer text latency...")
    t0 = time.time()
    wav = tts.generate_speech(TEST_TEXT, encoded, num_steps=4)
    elapsed = time.time() - t0
    audio_dur = len(wav[0]) / 48000
    rtf = elapsed / audio_dur
    print(f"    {elapsed:.3f}s wall, {audio_dur:.2f}s audio, RTF={rtf:.4f} ({1/rtf:.0f}x realtime)")

    mem_peak = torch.cuda.max_memory_allocated() / 1024**3
    print(f"    Peak GPU memory: {mem_peak:.2f} GB")

    # --- Sequential throughput baseline ---
    print("\n[5] Sequential throughput (10 requests)...")
    t0 = time.time()
    for _ in range(10):
        tts.generate_speech(SHORT_TEXT, encoded, num_steps=4)
    seq_time = time.time() - t0
    print(f"    10 requests in {seq_time:.2f}s = {10/seq_time:.1f} req/s")

    # --- Concurrent throughput with threads ---
    # Note: Python GIL + CUDA means threads serialize on GPU,
    # but this simulates real server concurrency patterns

    def do_request(idx):
        t0 = time.time()
        wav = tts.generate_speech(SHORT_TEXT, encoded, num_steps=4)
        elapsed = time.time() - t0
        audio_dur = len(wav[0]) / 48000
        return idx, elapsed, audio_dur

    for concurrency in [1, 5, 10, 25, 50, 100]:
        print(f"\n[6] Concurrent requests: {concurrency}...")
        torch.cuda.reset_peak_memory_stats()

        t0 = time.time()
        results = []
        with ThreadPoolExecutor(max_workers=concurrency) as executor:
            futures = [executor.submit(do_request, i) for i in range(concurrency)]
            for f in as_completed(futures):
                results.append(f.result())

        total_time = time.time() - t0
        latencies = [r[1] for r in results]
        audio_durs = [r[2] for r in results]
        total_audio = sum(audio_durs)
        peak_mem = torch.cuda.max_memory_allocated() / 1024**3

        print(f"    Total wall time: {total_time:.2f}s")
        print(f"    Throughput: {concurrency/total_time:.1f} req/s")
        print(f"    Audio generated: {total_audio:.1f}s total")
        print(f"    Latency — p50: {np.median(latencies):.3f}s, p95: {np.percentile(latencies, 95):.3f}s, max: {max(latencies):.3f}s")
        print(f"    Peak GPU mem: {peak_mem:.2f} GB")

    # --- Summary ---
    print("\n" + "=" * 60)
    print("SUMMARY")
    print("=" * 60)
    print(f"Model VRAM: {mem_used:.2f} GB")
    print(f"A100 total: 80 GB")
    print(f"Theoretical model copies: {int(80 / mem_used)} (if memory-bound)")
    print(f"Sequential req/s: {10/seq_time:.1f}")
    print("Note: True batched inference (not just threaded) would")
    print("need model code changes to batch the forward pass.")

if __name__ == "__main__":
    benchmark()
