#!/usr/bin/env python3
import httpx, base64, time, os, struct

API_BASE = "http://localhost:8091"
REF_AUDIO = "/home/ubuntu/modi_processed/segments/001_ERW9i1lwnBw/seg_0001.wav"

def encode_audio(path):
    with open(path, "rb") as f:
        data = f.read()
    return f"data:audio/wav;base64,{base64.b64encode(data).decode()}"

ref_b64 = encode_audio(REF_AUDIO)
print(f"Reference audio: {REF_AUDIO}")
print(f"Encoded size: {len(ref_b64) // 1024} KB")

tests = [
    {
        "name": "Hindi Modi Speech",
        "text": "मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ।",
        "language": "Auto",
    },
    {
        "name": "Hindi Modi Medium",
        "text": "डिजिटल इंडिया के माध्यम से हम गाँव गाँव तक तकनीक पहुँचा रहे हैं। यह एक क्रांति है।",
        "language": "Auto",
    },
]

for test in tests:
    print(f"\n{'='*60}")
    print(f"Test: {test['name']}")
    print(f"Text: {test['text']}")

    payload = {
        "model": "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
        "input": test["text"],
        "voice": "clone",
        "task_type": "Base",
        "language": test["language"],
        "ref_audio": ref_b64,
        "ref_text": "मेरे प्यारे देशवासियों",
        "response_format": "wav",
    }

    t0 = time.time()
    try:
        resp = httpx.post(f"{API_BASE}/v1/audio/speech", json=payload, timeout=120.0)
        elapsed = time.time() - t0

        if resp.status_code != 200:
            print(f"  ERROR {resp.status_code}: {resp.text[:500]}")
            continue

        try:
            err = resp.json()
            if "error" in err:
                print(f"  JSON Error: {err}")
                continue
        except:
            pass

        out_file = f"/home/ubuntu/vllm_tts_{test['name'].replace(' ', '_').lower()}.wav"
        with open(out_file, "wb") as f:
            f.write(resp.content)

        size_kb = len(resp.content) / 1024
        audio_dur = 0
        if resp.content[:4] == b'RIFF':
            try:
                data_start = resp.content.find(b'data') + 4
                data_size = struct.unpack('<I', resp.content[data_start:data_start+4])[0]
                sr_offset = 24
                sr = struct.unpack('<I', resp.content[sr_offset:sr_offset+4])[0]
                channels = struct.unpack('<H', resp.content[22:24])[0]
                bps = struct.unpack('<H', resp.content[34:36])[0]
                audio_dur = data_size / (sr * channels * bps // 8)
            except:
                pass

        print(f"  Time: {elapsed:.2f}s")
        print(f"  Size: {size_kb:.1f} KB")
        print(f"  Audio duration: {audio_dur:.2f}s")
        if audio_dur > 0:
            rtf = elapsed / audio_dur
            print(f"  RTF: {rtf:.2f}x")
        print(f"  Saved: {out_file}")

    except Exception as e:
        elapsed = time.time() - t0
        print(f"  EXCEPTION ({elapsed:.2f}s): {e}")

print("\n\n=== Streaming test ===")
payload_stream = {
    "model": "Qwen/Qwen3-TTS-12Hz-1.7B-Base",
    "input": "मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ।",
    "voice": "clone",
    "task_type": "Base",
    "language": "Auto",
    "ref_audio": ref_b64,
    "ref_text": "मेरे प्यारे देशवासियों",
    "response_format": "pcm",
    "stream": True,
}

t0 = time.time()
ttfb = None
total_bytes = 0
chunks = 0
try:
    with httpx.stream("POST", f"{API_BASE}/v1/audio/speech", json=payload_stream, timeout=120.0) as resp:
        for chunk in resp.iter_bytes():
            if ttfb is None:
                ttfb = time.time() - t0
            total_bytes += len(chunk)
            chunks += 1
    total_time = time.time() - t0
    audio_dur = total_bytes / (24000 * 2)  # 24kHz, 16-bit mono
    print(f"  TTFB: {ttfb*1000:.0f}ms")
    print(f"  Total time: {total_time:.2f}s")
    print(f"  Audio duration: {audio_dur:.2f}s")
    print(f"  RTF: {total_time/audio_dur:.2f}x")
    print(f"  Chunks: {chunks}")
    print(f"  Total bytes: {total_bytes}")
except Exception as e:
    print(f"  Stream EXCEPTION: {e}")
