"""
LuxTTS Benchmark v2 - Focused on longer texts for meaningful RTF measurements
and proper concurrency scaling analysis on A100.
"""
import time
import torch
import numpy as np
import soundfile as sf
import sys
import concurrent.futures

sys.path.insert(0, '/home/ubuntu/LuxTTS')

REFERENCE_AUDIO = '/home/ubuntu/test_reference.wav'

LONG_TEXTS = [
    "In the heart of the bustling city, there lies a hidden garden where flowers bloom year round. The old gardener tends to each plant with care, knowing that every petal and leaf tells a story of seasons past. Visitors often pause in amazement at the beauty that thrives in this unlikely urban oasis.",
    "The ancient library stood at the crossroads of two great civilizations, its walls lined with scrolls and manuscripts that chronicled thousands of years of human achievement. Scholars traveled from distant lands to study within its hallowed halls, seeking wisdom that could illuminate the mysteries of the universe.",
    "Technology has fundamentally transformed every aspect of modern life, from the way we communicate and work to how we entertain ourselves and connect with others. The rapid pace of innovation shows no signs of slowing down, with breakthroughs in artificial intelligence, quantum computing, and biotechnology reshaping our understanding of what is possible.",
    "The mountain trail wound its way through ancient forests and across crystal clear streams, revealing breathtaking vistas at every turn. Hikers who made the challenging journey to the summit were rewarded with a panoramic view that stretched from the distant ocean to the snow capped peaks of the neighboring range.",
    "Music has been a fundamental part of human culture since the dawn of civilization. From the earliest drums and flutes to modern electronic compositions, our ability to create and appreciate music reflects something deeply profound about the human experience and our need to express emotions beyond what words alone can convey.",
    "The scientific method represents one of humanity's greatest intellectual achievements, providing a systematic framework for understanding the natural world through observation, hypothesis, experimentation, and analysis. This approach has led to countless discoveries that have improved human health, expanded our knowledge, and driven technological progress.",
    "Education is the cornerstone of a thriving society, empowering individuals with the knowledge and skills they need to contribute meaningfully to their communities. A well-designed educational system nurtures curiosity, encourages critical thinking, and helps students develop the resilience and adaptability they will need throughout their lives.",
    "The ocean covers more than seventy percent of the earth's surface, yet we have explored less than five percent of its depths. Marine biologists continue to discover new species in the deep sea, revealing an astonishing diversity of life that has adapted to extreme pressure, darkness, and cold temperatures.",
    "Artificial intelligence is reshaping industries across the globe, from healthcare and finance to transportation and entertainment. While the potential benefits are enormous, including more accurate medical diagnoses, efficient resource allocation, and personalized education, the technology also raises important questions about privacy, employment, and the nature of human intelligence.",
    "The art of storytelling has been passed down through countless generations, evolving from oral traditions around ancient campfires to the written word and now to digital media. Great stories have the power to transport us to different worlds, help us understand perspectives unlike our own, and ultimately remind us of what it means to be human.",
    "Climate change represents one of the most significant challenges facing humanity today. Rising temperatures, shifting weather patterns, and increasing frequency of extreme events demand urgent action from governments, businesses, and individuals worldwide. The transition to renewable energy sources and sustainable practices is not just an environmental imperative but also an economic opportunity.",
    "The human brain contains approximately eighty six billion neurons, each forming thousands of connections with other neurons, creating an incredibly complex network that gives rise to consciousness, creativity, and all forms of human thought. Neuroscientists are only beginning to understand how this remarkable organ processes information and generates experience.",
    "Architecture is more than just the design of buildings; it is a reflection of culture, values, and aspirations. From the ancient pyramids of Egypt to the soaring skyscrapers of modern cities, the structures we build tell the story of who we are and what we believe is important. Great architecture inspires wonder and creates spaces where people can live, work, and gather.",
    "The history of space exploration is a testament to human curiosity and ingenuity. From the first satellite launches in the nineteen fifties to the Mars rovers and deep space probes of today, our journey beyond Earth has expanded our understanding of the cosmos and inspired generations of scientists, engineers, and dreamers to push the boundaries of what is possible.",
    "Good nutrition is essential for maintaining health and preventing disease. A balanced diet rich in fruits, vegetables, whole grains, lean proteins, and healthy fats provides the body with the nutrients it needs to function optimally. Understanding the relationship between food and health empowers individuals to make informed choices that can significantly improve their quality of life.",
    "The global economy is an interconnected web of trade, finance, and commerce that links nations and communities around the world. Economic policies, technological innovation, and geopolitical events in one region can have ripple effects across the entire system, making international cooperation and sound governance more important than ever before.",
]

def print_gpu_info():
    if torch.cuda.is_available():
        gpu = torch.cuda.get_device_properties(0)
        print(f"GPU: {gpu.name}")
        print(f"VRAM: {gpu.total_memory / 1024**3:.1f} GB")
        mem_alloc = torch.cuda.memory_allocated(0) / 1024**2
        mem_res = torch.cuda.memory_reserved(0) / 1024**2
        print(f"GPU Memory: {mem_alloc:.0f} MB allocated / {mem_res:.0f} MB reserved")


def run_inference(lux_tts, encoded_prompt, text):
    torch.cuda.synchronize()
    start = time.perf_counter()
    final_wav = lux_tts.generate_speech(text, encoded_prompt, num_steps=4)
    torch.cuda.synchronize()
    elapsed = time.perf_counter() - start
    wav_np = final_wav.numpy().squeeze()
    audio_dur = len(wav_np) / 48000.0
    return elapsed, audio_dur


def main():
    print("=" * 80)
    print("LuxTTS BENCHMARK v2 - A100-SXM4-80GB - Long Text Focus")
    print("=" * 80)
    print_gpu_info()

    print("\nLoading LuxTTS model...")
    from zipvoice.luxvoice import LuxTTS
    lux_tts = LuxTTS('YatharthS/LuxTTS', device='cuda')
    print_gpu_info()

    print("\nEncoding reference prompt...")
    encoded_prompt = lux_tts.encode_prompt(REFERENCE_AUDIO, rms=0.01)

    # Warmup
    print("\nWarming up (5 runs)...")
    for i in range(5):
        run_inference(lux_tts, encoded_prompt, LONG_TEXTS[i])
    print("Warmup complete.")

    # ========== SEQUENTIAL BENCHMARK (long texts) ==========
    print("\n" + "=" * 80)
    print("SEQUENTIAL BENCHMARK - Long Texts (~50-80 words each)")
    print("=" * 80)

    seq_results = []
    for i in range(16):
        text = LONG_TEXTS[i % len(LONG_TEXTS)]
        gen_time, audio_dur = run_inference(lux_tts, encoded_prompt, text)
        rtf = audio_dur / gen_time
        seq_results.append((gen_time, audio_dur, rtf))
        words = len(text.split())
        print(f"  [{i+1:2d}/16] gen={gen_time:.4f}s  audio={audio_dur:.2f}s  RTF={rtf:.1f}x  ({words}w)")

    gen_times = [r[0] for r in seq_results]
    audio_durs = [r[1] for r in seq_results]
    rtfs = [r[2] for r in seq_results]

    print(f"\n--- Sequential Summary ---")
    print(f"  Avg generation time:   {np.mean(gen_times):.4f}s")
    print(f"  Avg audio duration:    {np.mean(audio_durs):.2f}s")
    print(f"  Avg RTF per request:   {np.mean(rtfs):.1f}x realtime")
    print(f"  Max RTF:               {np.max(rtfs):.1f}x")
    print(f"  Min RTF:               {np.min(rtfs):.1f}x")
    print(f"  P50 latency:           {np.percentile(gen_times, 50)*1000:.1f}ms")
    print(f"  P95 latency:           {np.percentile(gen_times, 95)*1000:.1f}ms")
    print(f"  Aggregate throughput:  {sum(audio_durs)/sum(gen_times):.1f}x realtime")

    # ========== CONCURRENCY BENCHMARK ==========
    print("\n" + "=" * 80)
    print("CONCURRENCY BENCHMARK - Long Texts")
    print("=" * 80)

    concurrency_levels = [1, 2, 4, 8, 16, 32]
    conc_summary = {}

    for conc in concurrency_levels:
        n_requests = max(32, conc * 2)
        print(f"\n--- Concurrency={conc}, requests={n_requests} ---")

        results = []
        errors = []

        def worker(idx):
            text = LONG_TEXTS[idx % len(LONG_TEXTS)]
            try:
                torch.cuda.synchronize()
                start = time.perf_counter()
                wav = lux_tts.generate_speech(text, encoded_prompt, num_steps=4)
                torch.cuda.synchronize()
                elapsed = time.perf_counter() - start
                wav_np = wav.numpy().squeeze()
                audio_dur = len(wav_np) / 48000.0
                return (elapsed, audio_dur)
            except Exception as e:
                return e

        torch.cuda.synchronize()
        wall_start = time.perf_counter()

        with concurrent.futures.ThreadPoolExecutor(max_workers=conc) as pool:
            futures = [pool.submit(worker, i) for i in range(n_requests)]
            for f in concurrent.futures.as_completed(futures):
                r = f.result()
                if isinstance(r, Exception):
                    errors.append(r)
                else:
                    results.append(r)

        torch.cuda.synchronize()
        wall_time = time.perf_counter() - wall_start

        if errors:
            print(f"  ERRORS: {len(errors)}")
            for e in errors[:3]:
                print(f"    {type(e).__name__}: {e}")

        if results:
            gen_times = [r[0] for r in results]
            audio_durs = [r[1] for r in results]
            total_audio = sum(audio_durs)
            throughput = total_audio / wall_time

            print(f"  Completed:         {len(results)}/{n_requests}")
            print(f"  Wall clock time:   {wall_time:.2f}s")
            print(f"  Total audio gen'd: {total_audio:.1f}s")
            print(f"  THROUGHPUT:        {throughput:.1f}x realtime")
            print(f"  Avg latency:       {np.mean(gen_times)*1000:.0f}ms")
            print(f"  P50 latency:       {np.percentile(gen_times, 50)*1000:.0f}ms")
            print(f"  P95 latency:       {np.percentile(gen_times, 95)*1000:.0f}ms")
            print(f"  P99 latency:       {np.percentile(gen_times, 99)*1000:.0f}ms")
            print(f"  Avg audio/request: {np.mean(audio_durs):.2f}s")
            print(f"  Avg RTF/request:   {np.mean([a/g for g,a in results]):.1f}x")

            mem_alloc = torch.cuda.memory_allocated(0) / 1024**2
            mem_res = torch.cuda.memory_reserved(0) / 1024**2
            print(f"  GPU Mem:           {mem_alloc:.0f} MB alloc / {mem_res:.0f} MB reserved")

            conc_summary[conc] = {
                'wall_time': wall_time,
                'total_audio': total_audio,
                'throughput': throughput,
                'avg_latency_ms': np.mean(gen_times) * 1000,
                'p50_latency_ms': np.percentile(gen_times, 50) * 1000,
                'p95_latency_ms': np.percentile(gen_times, 95) * 1000,
                'p99_latency_ms': np.percentile(gen_times, 99) * 1000,
                'avg_rtf': np.mean([a/g for g,a in results]),
                'errors': len(errors),
                'n_completed': len(results),
            }

    # ========== FINAL SUMMARY ==========
    print("\n" + "=" * 80)
    print("FINAL RESULTS - LuxTTS on NVIDIA A100-SXM4-80GB")
    print("=" * 80)

    print(f"\nModel VRAM usage: ~{torch.cuda.max_memory_allocated(0)/1024**2:.0f} MB peak allocated")
    print(f"                  ~{torch.cuda.max_memory_reserved(0)/1024**2:.0f} MB peak reserved")

    print(f"\nSingle-Request (sequential):")
    seq_rtfs = [r[2] for r in seq_results]
    seq_gens = [r[0] for r in seq_results]
    print(f"  RTF:     {np.mean(seq_rtfs):.1f}x avg  |  {np.max(seq_rtfs):.1f}x max  |  {np.min(seq_rtfs):.1f}x min")
    print(f"  Latency: {np.mean(seq_gens)*1000:.0f}ms avg  |  {np.percentile(seq_gens, 95)*1000:.0f}ms p95")

    if conc_summary:
        print(f"\nConcurrency Scaling Table:")
        print(f"  {'Conc':>5} | {'Throughput':>12} | {'Avg Lat':>10} | {'P50 Lat':>10} | {'P95 Lat':>10} | {'P99 Lat':>10} | {'RTF/req':>8} | {'Err':>4}")
        print(f"  {'-'*5}-+-{'-'*12}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}-+-{'-'*10}-+-{'-'*8}-+-{'-'*4}")
        for conc in sorted(conc_summary.keys()):
            s = conc_summary[conc]
            print(f"  {conc:>5} | {s['throughput']:>10.1f}x  | {s['avg_latency_ms']:>8.0f}ms | {s['p50_latency_ms']:>8.0f}ms | {s['p95_latency_ms']:>8.0f}ms | {s['p99_latency_ms']:>8.0f}ms | {s['avg_rtf']:>6.1f}x  | {s['errors']:>4}")

    print("\nDone.")


if __name__ == '__main__':
    main()
