"""
Qwen3-TTS concurrency benchmark.
Tests 1, 2, 4, 8 concurrent streaming requests to find max concurrency.
"""
import os, time, threading, torch, subprocess
import soundfile as sf
from faster_qwen3_tts import FasterQwen3TTS

os.makedirs("/home/ubuntu/qwen3_conc_samples", exist_ok=True)

ref_audio = "/home/ubuntu/vibevoice/demo/voices/modi.wav"
ref_text = "मेरे प्यारे देशवासियों, मुझे सीतापुर के ओजस्वी ने लिखा है कि अमृत महोत्सव से जुड़ी चर्चाएं उन्हें खूब पसंद आ रही हैं।"
test_text = "मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है."

print("Loading Qwen3-TTS 1.7B...")
model = FasterQwen3TTS.from_pretrained("Qwen/Qwen3-TTS-12Hz-1.7B-Base")
print(f"VRAM after load: {torch.cuda.memory_allocated()/1e9:.1f}GB")

# Warmup (2 runs to trigger CUDA graph capture)
for i in range(2):
    for chunk, sr, timing in model.generate_voice_clone_streaming(
        text="Warmup test.", language="Auto",
        ref_audio=ref_audio, ref_text=ref_text, chunk_size=8,
    ):
        pass
print(f"Warmup done. VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB")

def gpu_monitor(stop_event, samples_list):
    while not stop_event.is_set():
        r = subprocess.run(
            ['nvidia-smi', '--query-gpu=utilization.gpu', '--format=csv,noheader,nounits'],
            capture_output=True, text=True
        )
        try:
            samples_list.append(float(r.stdout.strip()))
        except:
            pass
        time.sleep(0.1)

def run_single_request(model, idx, results):
    """Run a single streaming TTS request and record metrics."""
    chunks = []
    ttfb = None
    t0 = time.perf_counter()
    try:
        for audio_chunk, chunk_sr, timing in model.generate_voice_clone_streaming(
            text=test_text, language="Auto",
            ref_audio=ref_audio, ref_text=ref_text,
            chunk_size=8,
        ):
            if ttfb is None:
                ttfb = time.perf_counter() - t0
            chunks.append(audio_chunk)
        total_time = time.perf_counter() - t0

        if chunks:
            import numpy as np
            if isinstance(chunks[0], torch.Tensor):
                full = torch.cat(chunks, dim=-1).cpu().float().numpy()
            else:
                full = np.concatenate(chunks, axis=-1)
            if full.ndim > 1:
                full = full.squeeze()
            dur = len(full) / chunk_sr
            rtf = total_time / dur if dur > 0 else 999
            results[idx] = {
                "duration": dur, "gen_time": total_time,
                "rtf": rtf, "ttfb": ttfb * 1000 if ttfb else -1,
                "chunks": len(chunks), "sr": chunk_sr, "audio": full,
            }
        else:
            results[idx] = {"error": "no audio"}
    except Exception as e:
        results[idx] = {"error": str(e), "gen_time": time.perf_counter() - t0}

print(f"\n{'='*80}")
print(f"{'Batch':>6} {'AvgDur':>8} {'AvgGen':>8} {'AvgRTF':>8} {'AvgTTFB':>9} {'Throughput':>11} {'GPU%':>6} {'Stream?':>8}")
print("-" * 80)

for batch_size in [1, 2, 4, 8, 12, 16]:
    results = [None] * batch_size

    gpu_samples = []
    stop_gpu = threading.Event()
    gpu_t = threading.Thread(target=gpu_monitor, args=(stop_gpu, gpu_samples), daemon=True)
    gpu_t.start()

    threads = []
    t_start = time.perf_counter()
    for i in range(batch_size):
        t = threading.Thread(target=run_single_request, args=(model, i, results))
        threads.append(t)
        t.start()

    for t in threads:
        t.join(timeout=300)

    wall_time = time.perf_counter() - t_start
    stop_gpu.set()
    time.sleep(0.2)

    valid = [r for r in results if r and "error" not in r]
    errors = [r for r in results if r and "error" in r]

    if valid:
        avg_dur = sum(r["duration"] for r in valid) / len(valid)
        avg_gen = sum(r["gen_time"] for r in valid) / len(valid)
        avg_rtf = sum(r["rtf"] for r in valid) / len(valid)
        avg_ttfb = sum(r["ttfb"] for r in valid) / len(valid)
        throughput = sum(r["duration"] for r in valid) / wall_time
        avg_gpu = sum(gpu_samples) / len(gpu_samples) if gpu_samples else 0
        streamable = "YES" if avg_rtf < 1.0 else "NO"

        print(f"{batch_size:>6} {avg_dur:>7.2f}s {avg_gen:>7.2f}s {avg_rtf:>7.3f}x {avg_ttfb:>8.0f}ms {throughput:>9.1f} s/s {avg_gpu:>5.0f}% {streamable:>8}")

        if errors:
            print(f"       ({len(errors)} errors: {errors[0].get('error', '')[:60]})")

        # Save one sample
        sf.write(f"/home/ubuntu/qwen3_conc_samples/batch{batch_size}.wav",
                 valid[0]["audio"], valid[0]["sr"])
    else:
        print(f"{batch_size:>6} ALL FAILED - {errors[0].get('error', '') if errors else 'unknown'}")

    if valid and avg_rtf > 2.0:
        print(f"       Stopping -- RTF > 2.0x, not viable")
        break

print(f"\nPeak VRAM: {torch.cuda.max_memory_allocated()/1e9:.1f}GB")
