"""
Test Qwen3-TTS with Modi voice cloning.
Generate same Hindi texts as VibeVoice for comparison.
Measure RTF, TTFB, GPU utilization.
"""

import os
import time
import subprocess
import threading
import torch
import soundfile as sf
from faster_qwen3_tts import FasterQwen3TTS

os.makedirs("/home/ubuntu/qwen3_samples", exist_ok=True)

ref_audio = "/home/ubuntu/vibevoice/demo/voices/modi.wav"

texts = {
    "short": "नमस्ते, मेरे प्यारे देशवासियों.",
    "medium": "मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है.",
    "long": "मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है, जहाँ टेक्नोलॉजी और इनोवेशन हमारी ताकत बन रही है. डिजिटल इंडिया ने देश की तस्वीर बदल दी है. गाँव गाँव में इंटरनेट पहुँच रहा है.",
    "speech": "भारत आज दुनिया की पाँचवीं सबसे बड़ी अर्थव्यवस्था है. हमारे युवाओं की ऊर्जा, हमारे वैज्ञानिकों की प्रतिभा, और हमारे किसानों की मेहनत, यही हमारी असली ताकत है. आने वाले समय में भारत और भी ऊँचाइयों को छूएगा.",
}

ref_text = "मेरे प्यारे देशवासियों, मुझे सीतापुर के ओजस्वी ने लिखा है कि अमृत महोत्सव से जुड़ी चर्चाएं उन्हें खूब पसंद आ रही हैं।"

# Load 1.7B model
print("Loading Qwen3-TTS 1.7B...")
model = FasterQwen3TTS.from_pretrained("Qwen/Qwen3-TTS-12Hz-1.7B-Base")
print(f"Model loaded. GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB")

# Warmup
print("\nWarming up...")
audio_list, sr = model.generate_voice_clone(
    text="Hello warmup test.", language="Hindi",
    ref_audio=ref_audio, ref_text=ref_text,
)
print(f"Warmup done. Sample rate: {sr}")

# Generate samples
print(f"\n{'Label':<10} {'Duration':>9} {'GenTime':>9} {'RTF':>7} {'TTFB':>8} {'GPU%':>6}")
print("-" * 60)

for label, text in texts.items():
    # GPU monitor
    gpu_samples = []
    stop_mon = [False]
    def mon():
        while not stop_mon[0]:
            r = subprocess.run(['nvidia-smi','--query-gpu=utilization.gpu,power.draw','--format=csv,noheader,nounits'],
                               capture_output=True, text=True)
            parts = r.stdout.strip().split(', ')
            if len(parts) >= 2:
                gpu_samples.append(float(parts[0]))
            time.sleep(0.15)
    mt = threading.Thread(target=mon, daemon=True)
    gpu_samples.clear()
    stop_mon[0] = False
    mt.start()

    # Streaming generation
    first_chunk_time = [None]
    all_chunks = []

    torch.cuda.synchronize()
    t0 = time.perf_counter()

    for audio_chunk, chunk_sr, timing in model.generate_voice_clone_streaming(
        text=text, language="Hindi",
        ref_audio=ref_audio, ref_text=ref_text,
        chunk_size=8,
    ):
        if first_chunk_time[0] is None:
            first_chunk_time[0] = time.perf_counter()
        all_chunks.append(audio_chunk)

    torch.cuda.synchronize()
    gen_time = time.perf_counter() - t0

    stop_mon[0] = True
    time.sleep(0.2)

    # Concatenate and save
    if all_chunks:
        full_audio = torch.cat(all_chunks, dim=-1) if isinstance(all_chunks[0], torch.Tensor) else all_chunks[0]
        if isinstance(full_audio, torch.Tensor):
            full_audio = full_audio.cpu().float().numpy()
        if full_audio.ndim > 1:
            full_audio = full_audio.squeeze()

        dur = len(full_audio) / chunk_sr
        rtf = gen_time / dur if dur > 0 else 999
        ttfb = (first_chunk_time[0] - t0) * 1000 if first_chunk_time[0] else -1
        avg_gpu = sum(gpu_samples) / len(gpu_samples) if gpu_samples else 0

        path = f"/home/ubuntu/qwen3_samples/modi_{label}.wav"
        sf.write(path, full_audio, chunk_sr)

        print(f"{label:<10} {dur:>8.2f}s {gen_time:>8.2f}s {rtf:>6.3f}x {ttfb:>7.0f}ms {avg_gpu:>5.0f}%  -> {path}")
    else:
        print(f"{label:<10} No audio generated")

# Also generate non-streaming for comparison
print("\n--- Non-streaming (full quality) ---")
for label in ["medium", "speech"]:
    text = texts[label]
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    audio_list, sr = model.generate_voice_clone(
        text=text, language="Hindi",
        ref_audio=ref_audio, ref_text=ref_text,
    )
    torch.cuda.synchronize()
    gen_time = time.perf_counter() - t0

    if audio_list:
        audio = audio_list[0]
        if isinstance(audio, torch.Tensor):
            audio = audio.cpu().float().numpy()
        if audio.ndim > 1:
            audio = audio.squeeze()
        dur = len(audio) / sr
        rtf = gen_time / dur
        path = f"/home/ubuntu/qwen3_samples/modi_{label}_full.wav"
        sf.write(path, audio, sr)
        print(f"{label:<10} {dur:>8.2f}s {gen_time:>8.2f}s {rtf:>6.3f}x  -> {path}")

peak_vram = torch.cuda.max_memory_allocated() / 1e9
print(f"\nPeak VRAM: {peak_vram:.1f}GB")
print("Done. Compare files in /home/ubuntu/qwen3_samples/ with /home/ubuntu/vibevoice/samples_final/")
