import time
import torch
import numpy as np
import soundfile as sf
import sys
import os
import threading
import concurrent.futures
from dataclasses import dataclass

sys.path.insert(0, '/home/ubuntu/LuxTTS')

REFERENCE_AUDIO = '/home/ubuntu/test_reference.wav'

TEST_TEXTS = [
    "Hey, what's up? I'm feeling really great if you ask me honestly!",
    "The quick brown fox jumps over the lazy dog near the riverbank.",
    "Technology has transformed the way we communicate with each other across the globe.",
    "In the heart of the city, there's a small café that serves the best espresso you'll ever taste.",
    "Scientists have discovered a new species of deep sea fish that glows in the dark.",
    "The weather forecast predicts heavy rain throughout the weekend, so bring an umbrella.",
    "She walked through the garden, admiring the beautiful roses that bloomed in spring.",
    "Artificial intelligence is changing every industry from healthcare to transportation.",
    "The old library on the corner has thousands of books dating back to the eighteenth century.",
    "Music has the power to bring people together regardless of their background or beliefs.",
    "The mountain trail was steep and challenging but the view from the top was absolutely breathtaking.",
    "Every morning he would wake up early, brew a cup of coffee, and read the newspaper.",
    "The conference will feature speakers from over twenty different countries this year.",
    "Children learn best when they are engaged and having fun with the material.",
    "The stock market experienced significant volatility during the third quarter of the fiscal year.",
    "A balanced diet combined with regular exercise is essential for maintaining good health.",
    "The documentary explores the impact of climate change on arctic wildlife populations.",
    "Innovation requires both creativity and the willingness to take calculated risks.",
    "The restaurant recently updated its menu to include more plant based options for customers.",
    "Learning a new language opens doors to understanding different cultures and perspectives.",
    "The spacecraft successfully completed its orbit around Mars and sent back stunning images.",
    "Public transportation systems need significant investment to meet growing urban demands.",
    "The novel won several literary awards and was translated into thirty five languages.",
    "Renewable energy sources like solar and wind power are becoming increasingly cost effective.",
    "The team worked tirelessly through the night to meet the project deadline on time.",
    "Historical preservation is crucial for maintaining our connection to the past.",
    "The film festival showcased independent movies from emerging directors worldwide.",
    "Effective communication skills are vital in both personal and professional relationships.",
    "The research paper presents compelling evidence for a new approach to treating the disease.",
    "Volunteers spent the entire weekend cleaning up the beach and planting native vegetation.",
    "The symphony orchestra performed a magnificent rendition of Beethoven's ninth symphony.",
    "Digital literacy is becoming an essential skill for success in the modern workforce.",
]

@dataclass
class BenchmarkResult:
    text: str
    generation_time: float
    audio_duration: float
    rtf: float  # realtime factor (audio_duration / generation_time)


def print_gpu_info():
    if torch.cuda.is_available():
        gpu = torch.cuda.get_device_properties(0)
        print(f"GPU: {gpu.name}")
        print(f"VRAM: {gpu.total_memory / 1024**3:.1f} GB")
        print(f"CUDA Version: {torch.version.cuda}")
        print(f"PyTorch Version: {torch.__version__}")
        mem_allocated = torch.cuda.memory_allocated(0) / 1024**2
        mem_reserved = torch.cuda.memory_reserved(0) / 1024**2
        print(f"GPU Memory Allocated: {mem_allocated:.1f} MB")
        print(f"GPU Memory Reserved: {mem_reserved:.1f} MB")


def run_single_inference(lux_tts, encoded_prompt, text, warmup=False):
    """Run a single inference and return timing + audio info."""
    torch.cuda.synchronize()
    start = time.perf_counter()
    final_wav = lux_tts.generate_speech(text, encoded_prompt, num_steps=4)
    torch.cuda.synchronize()
    elapsed = time.perf_counter() - start

    final_wav_np = final_wav.numpy().squeeze()
    audio_duration = len(final_wav_np) / 48000.0
    rtf = audio_duration / elapsed

    if not warmup:
        return BenchmarkResult(
            text=text,
            generation_time=elapsed,
            audio_duration=audio_duration,
            rtf=rtf,
        )
    return None


def benchmark_sequential(lux_tts, encoded_prompt, n_runs=10):
    """Benchmark sequential single-request inference."""
    print("\n" + "=" * 70)
    print("SEQUENTIAL BENCHMARK (single request at a time)")
    print("=" * 70)

    # Warmup
    print("Warming up (3 runs)...")
    for i in range(3):
        run_single_inference(lux_tts, encoded_prompt, TEST_TEXTS[i], warmup=True)

    print(f"Running {n_runs} sequential inferences...")
    results = []
    for i in range(n_runs):
        text = TEST_TEXTS[i % len(TEST_TEXTS)]
        result = run_single_inference(lux_tts, encoded_prompt, text)
        results.append(result)
        print(f"  [{i+1}/{n_runs}] gen={result.generation_time:.4f}s, audio={result.audio_duration:.2f}s, RTF={result.rtf:.1f}x")

    gen_times = [r.generation_time for r in results]
    audio_durs = [r.audio_duration for r in results]
    rtfs = [r.rtf for r in results]

    print(f"\n--- Sequential Results ---")
    print(f"  Avg generation time: {np.mean(gen_times):.4f}s (std: {np.std(gen_times):.4f}s)")
    print(f"  Avg audio duration:  {np.mean(audio_durs):.2f}s")
    print(f"  Avg RTF:             {np.mean(rtfs):.1f}x realtime")
    print(f"  Min RTF:             {np.min(rtfs):.1f}x")
    print(f"  Max RTF:             {np.max(rtfs):.1f}x")
    print(f"  P50 gen time:        {np.percentile(gen_times, 50):.4f}s")
    print(f"  P95 gen time:        {np.percentile(gen_times, 95):.4f}s")
    print(f"  P99 gen time:        {np.percentile(gen_times, 99):.4f}s")
    print(f"  Total audio produced: {sum(audio_durs):.2f}s in {sum(gen_times):.2f}s")
    print(f"  Throughput:          {sum(audio_durs)/sum(gen_times):.1f}x realtime")

    return results


def benchmark_concurrent(lux_tts, encoded_prompt, concurrency_levels=[1, 2, 4, 8, 16, 32], runs_per_level=16):
    """Benchmark concurrent inference using ThreadPoolExecutor."""
    print("\n" + "=" * 70)
    print("CONCURRENCY BENCHMARK")
    print("=" * 70)

    all_results = {}

    for concurrency in concurrency_levels:
        actual_runs = max(runs_per_level, concurrency)
        print(f"\n--- Concurrency={concurrency}, total_requests={actual_runs} ---")

        results = []
        errors = []

        def worker(idx):
            text = TEST_TEXTS[idx % len(TEST_TEXTS)]
            try:
                torch.cuda.synchronize()
                start = time.perf_counter()
                final_wav = lux_tts.generate_speech(text, encoded_prompt, num_steps=4)
                torch.cuda.synchronize()
                elapsed = time.perf_counter() - start

                final_wav_np = final_wav.numpy().squeeze()
                audio_duration = len(final_wav_np) / 48000.0
                rtf = audio_duration / elapsed

                return BenchmarkResult(
                    text=text,
                    generation_time=elapsed,
                    audio_duration=audio_duration,
                    rtf=rtf,
                )
            except Exception as e:
                return e

        torch.cuda.synchronize()
        wall_start = time.perf_counter()

        with concurrent.futures.ThreadPoolExecutor(max_workers=concurrency) as executor:
            futures = [executor.submit(worker, i) for i in range(actual_runs)]
            for f in concurrent.futures.as_completed(futures):
                res = f.result()
                if isinstance(res, Exception):
                    errors.append(res)
                else:
                    results.append(res)

        torch.cuda.synchronize()
        wall_time = time.perf_counter() - wall_start

        if errors:
            print(f"  ERRORS: {len(errors)}")
            for e in errors[:3]:
                print(f"    {e}")

        if results:
            gen_times = [r.generation_time for r in results]
            audio_durs = [r.audio_duration for r in results]
            rtfs = [r.rtf for r in results]
            total_audio = sum(audio_durs)

            print(f"  Completed:       {len(results)}/{actual_runs}")
            print(f"  Wall clock time: {wall_time:.2f}s")
            print(f"  Total audio:     {total_audio:.2f}s")
            print(f"  Throughput:      {total_audio/wall_time:.1f}x realtime (audio/wall)")
            print(f"  Avg gen time:    {np.mean(gen_times):.4f}s")
            print(f"  P50 gen time:    {np.percentile(gen_times, 50):.4f}s")
            print(f"  P95 gen time:    {np.percentile(gen_times, 95):.4f}s")
            print(f"  Avg RTF:         {np.mean(rtfs):.1f}x (per-request)")
            print(f"  Min RTF:         {np.min(rtfs):.1f}x")

            mem_allocated = torch.cuda.memory_allocated(0) / 1024**2
            mem_reserved = torch.cuda.memory_reserved(0) / 1024**2
            print(f"  GPU Mem Alloc:   {mem_allocated:.0f} MB")
            print(f"  GPU Mem Reserve: {mem_reserved:.0f} MB")

            all_results[concurrency] = {
                'wall_time': wall_time,
                'total_audio': total_audio,
                'throughput_rtx': total_audio / wall_time,
                'avg_gen_time': np.mean(gen_times),
                'p50_gen_time': np.percentile(gen_times, 50),
                'p95_gen_time': np.percentile(gen_times, 95),
                'avg_rtf': np.mean(rtfs),
                'min_rtf': np.min(rtfs),
                'completed': len(results),
                'errors': len(errors),
            }

    return all_results


def benchmark_text_lengths(lux_tts, encoded_prompt):
    """Benchmark how text length affects generation time."""
    print("\n" + "=" * 70)
    print("TEXT LENGTH BENCHMARK")
    print("=" * 70)

    length_texts = [
        ("short ~10w", "Hey there, how are you doing today my friend?"),
        ("medium ~25w", "The quick brown fox jumps over the lazy dog. This is a medium length sentence that should produce a decent amount of audio output."),
        ("long ~50w", "In the heart of the bustling city, there lies a hidden garden where flowers bloom year round. The old gardener tends to each plant with care, knowing that every petal and leaf tells a story of seasons past. Visitors often pause in amazement at the beauty that thrives in this unlikely urban oasis."),
        ("very long ~100w", "The ancient library stood at the crossroads of two great civilizations, its walls lined with scrolls and manuscripts that chronicled thousands of years of human achievement. Scholars traveled from distant lands to study within its hallowed halls, seeking wisdom that could illuminate the mysteries of the universe. The head librarian, an elderly woman with silver hair and piercing blue eyes, guided each visitor with patience and care. She believed that knowledge was the most precious gift one could share, and she dedicated her life to preserving these treasures for future generations who would carry the torch of learning forward into an uncertain but hopeful future."),
    ]

    for label, text in length_texts:
        times = []
        audio_durs = []
        for _ in range(3):
            result = run_single_inference(lux_tts, encoded_prompt, text)
            times.append(result.generation_time)
            audio_durs.append(result.audio_duration)

        avg_time = np.mean(times)
        avg_audio = np.mean(audio_durs)
        avg_rtf = avg_audio / avg_time
        print(f"  {label:15s}: gen={avg_time:.4f}s, audio={avg_audio:.2f}s, RTF={avg_rtf:.1f}x")


def print_summary(seq_results, conc_results):
    """Print final summary table."""
    print("\n" + "=" * 70)
    print("FINAL SUMMARY - LuxTTS on A100-SXM4-80GB")
    print("=" * 70)

    if seq_results:
        rtfs = [r.rtf for r in seq_results]
        gen_times = [r.generation_time for r in seq_results]
        print(f"\nSingle Request Performance:")
        print(f"  Average RTF:        {np.mean(rtfs):.1f}x realtime")
        print(f"  Average latency:    {np.mean(gen_times)*1000:.1f}ms")
        print(f"  P95 latency:        {np.percentile(gen_times, 95)*1000:.1f}ms")

    if conc_results:
        print(f"\nConcurrency Scaling:")
        print(f"  {'Conc':>6s} | {'Throughput':>12s} | {'Avg Latency':>12s} | {'P95 Latency':>12s} | {'Avg RTF/req':>12s} | {'Errors':>6s}")
        print(f"  {'-'*6} | {'-'*12} | {'-'*12} | {'-'*12} | {'-'*12} | {'-'*6}")
        for conc in sorted(conc_results.keys()):
            r = conc_results[conc]
            print(f"  {conc:>6d} | {r['throughput_rtx']:>10.1f}x  | {r['avg_gen_time']*1000:>10.1f}ms | {r['p95_gen_time']*1000:>10.1f}ms | {r['avg_rtf']:>10.1f}x  | {r['errors']:>6d}")

    mem_allocated = torch.cuda.max_memory_allocated(0) / 1024**2
    mem_reserved = torch.cuda.max_memory_reserved(0) / 1024**2
    print(f"\nPeak GPU Memory:")
    print(f"  Max Allocated: {mem_allocated:.0f} MB")
    print(f"  Max Reserved:  {mem_reserved:.0f} MB")


def main():
    print("=" * 70)
    print("LuxTTS BENCHMARK SUITE - A100")
    print("=" * 70)
    print_gpu_info()

    print("\nLoading model...")
    load_start = time.perf_counter()
    from zipvoice.luxvoice import LuxTTS
    lux_tts = LuxTTS('YatharthS/LuxTTS', device='cuda')
    load_time = time.perf_counter() - load_start
    print(f"Model loaded in {load_time:.2f}s")

    print_gpu_info()

    print("\nEncoding reference prompt...")
    encode_start = time.perf_counter()
    encoded_prompt = lux_tts.encode_prompt(REFERENCE_AUDIO, rms=0.01)
    encode_time = time.perf_counter() - encode_start
    print(f"Prompt encoded in {encode_time:.2f}s")

    # Basic sanity test
    print("\nRunning sanity check...")
    test_wav = lux_tts.generate_speech("Hello world, this is a test.", encoded_prompt, num_steps=4)
    test_wav_np = test_wav.numpy().squeeze()
    sf.write('/home/ubuntu/luxtts_test_output.wav', test_wav_np, 48000)
    print(f"Sanity check passed. Output: {len(test_wav_np)/48000:.2f}s audio at 48kHz")

    # Run benchmarks
    seq_results = benchmark_sequential(lux_tts, encoded_prompt, n_runs=15)
    benchmark_text_lengths(lux_tts, encoded_prompt)
    conc_results = benchmark_concurrent(lux_tts, encoded_prompt,
                                         concurrency_levels=[1, 2, 4, 8, 16, 32],
                                         runs_per_level=16)

    print_summary(seq_results, conc_results)


if __name__ == '__main__':
    main()
