"""Quick single-request test to verify Qwen3-TTS performance."""
import time, torch
from faster_qwen3_tts import FasterQwen3TTS

ref_audio = "/home/ubuntu/vibevoice/demo/voices/modi.wav"
ref_text = "मेरे प्यारे देशवासियों, मुझे सीतापुर के ओजस्वी ने लिखा है कि अमृत महोत्सव से जुड़ी चर्चाएं उन्हें खूब पसंद आ रही हैं।"
test_text = "मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ. हमारा देश एक नये दौर में प्रवेश कर रहा है."

print("Loading model...")
model = FasterQwen3TTS.from_pretrained("Qwen/Qwen3-TTS-12Hz-1.7B-Base")
print(f"VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB")

# Warmup
print("Warmup...")
audio, sr = model.generate_voice_clone(
    text="Warmup test.", language="Auto",
    ref_audio=ref_audio, ref_text=ref_text,
)
print(f"Warmup done. SR={sr}")

# Non-streaming test
print("\n--- Non-streaming ---")
torch.cuda.synchronize()
t0 = time.perf_counter()
audio_list, sr = model.generate_voice_clone(
    text=test_text, language="Auto",
    ref_audio=ref_audio, ref_text=ref_text,
)
torch.cuda.synchronize()
gen = time.perf_counter() - t0

a = audio_list[0]
if isinstance(a, torch.Tensor):
    a = a.cpu().float().numpy()
if a.ndim > 1:
    a = a.squeeze()
dur = len(a) / sr
print(f"Duration: {dur:.2f}s, GenTime: {gen:.2f}s, RTF: {gen/dur:.3f}x")

import soundfile as sf
sf.write("/home/ubuntu/qwen3_single_test.wav", a, sr)

# Streaming test
print("\n--- Streaming ---")
chunks = []
torch.cuda.synchronize()
t0 = time.perf_counter()
ttfb = None
for chunk, chunk_sr, timing in model.generate_voice_clone_streaming(
    text=test_text, language="Auto",
    ref_audio=ref_audio, ref_text=ref_text,
    chunk_size=8,
):
    if ttfb is None:
        ttfb = time.perf_counter() - t0
    chunks.append(chunk)
torch.cuda.synchronize()
total = time.perf_counter() - t0

if chunks:
    full = torch.cat(chunks, dim=-1) if isinstance(chunks[0], torch.Tensor) else chunks[0]
    if isinstance(full, torch.Tensor):
        full = full.cpu().float().numpy()
    if full.ndim > 1:
        full = full.squeeze()
    dur = len(full) / chunk_sr
    print(f"Duration: {dur:.2f}s, GenTime: {total:.2f}s, RTF: {total/dur:.3f}x, TTFB: {ttfb*1000:.0f}ms, Chunks: {len(chunks)}")

print(f"\nPeak VRAM: {torch.cuda.max_memory_allocated()/1e9:.1f}GB")
