import torch
import soundfile as sf
from qwen_tts import Qwen3TTSModel
import time
import os

OUTPUT_DIR = "/home/ubuntu/qwen3_hindi_results"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("Loading Qwen3-TTS-12Hz-0.6B-Base...")
t0 = time.time()
model = Qwen3TTSModel.from_pretrained(
    "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
    device_map="cuda:0",
    dtype=torch.bfloat16,
)
print(f"Model loaded in {time.time()-t0:.1f}s")

# --- TEST 1: Hindi text (Devanagari script) without voice reference ---
hindi_tests = [
    ("hindi_simple", "नमस्ते, मेरा नाम राहुल है। आज मौसम बहुत अच्छा है।"),
    ("hindi_medium", "भारत एक महान देश है। यहाँ की संस्कृति और परंपराएँ पूरे विश्व में प्रसिद्ध हैं।"),
    ("hindi_long", "आज हम बात करेंगे कृत्रिम बुद्धिमत्ता के बारे में। यह तकनीक बहुत तेज़ी से विकसित हो रही है और हमारे जीवन के हर पहलू को बदल रही है। शिक्षा, स्वास्थ्य, और उद्योग में इसका व्यापक उपयोग हो रहा है।"),
    ("hinglish", "Hello, mera naam Rahul hai. Aaj mausam bahut achha hai. Main Delhi mein rehta hoon."),
    ("hindi_numbers", "भारत की जनसंख्या एक सौ चालीस करोड़ से अधिक है।"),
]

print("\n=== TEST 1: Hindi text generation (no voice reference) ===")
for name, text in hindi_tests:
    print(f"\n--- {name} ---")
    print(f"  Text: {text[:80]}...")
    try:
        t0 = time.time()
        wavs, sr = model.generate_voice_clone(
            text=text,
            language="Chinese",  # closest available? trying different langs
        )
        dur = time.time() - t0
        outpath = os.path.join(OUTPUT_DIR, f"test1_{name}_lang_zh.wav")
        sf.write(outpath, wavs[0], sr)
        print(f"  Saved: {outpath} ({len(wavs[0])/sr:.1f}s audio, generated in {dur:.1f}s)")
    except Exception as e:
        print(f"  ERROR (lang=Chinese): {e}")

    try:
        t0 = time.time()
        wavs, sr = model.generate_voice_clone(
            text=text,
            language="English",
        )
        dur = time.time() - t0
        outpath = os.path.join(OUTPUT_DIR, f"test1_{name}_lang_en.wav")
        sf.write(outpath, wavs[0], sr)
        print(f"  Saved: {outpath} ({len(wavs[0])/sr:.1f}s audio, generated in {dur:.1f}s)")
    except Exception as e:
        print(f"  ERROR (lang=English): {e}")


# --- TEST 2: Hindi text with Modi voice reference ---
print("\n=== TEST 2: Hindi text + Modi voice reference (voice clone) ===")
ref_audio = "/home/ubuntu/qwen3_hindi_samples/modi_speech.wav"
ref_text_hindi = "हमारे देश के नागरिकों को अपने अधिकारों के बारे में जानकारी होनी चाहिए"

clone_tests = [
    ("clone_simple", "नमस्ते, मेरा नाम नरेंद्र है। आज मैं आप सभी से बात करना चाहता हूँ।"),
    ("clone_medium", "डिजिटल इंडिया अभियान ने भारत को एक नई दिशा दी है। हम प्रौद्योगिकी के माध्यम से हर नागरिक तक सेवाएँ पहुँचा रहे हैं।"),
    ("clone_english", "India is a great country with a rich cultural heritage. We are working together to build a stronger nation."),
]

for name, text in clone_tests:
    print(f"\n--- {name} ---")
    print(f"  Text: {text[:80]}...")

    for lang in ["Chinese", "English"]:
        try:
            t0 = time.time()
            wavs, sr = model.generate_voice_clone(
                text=text,
                language=lang,
                ref_audio=ref_audio,
                ref_text=ref_text_hindi,
            )
            dur = time.time() - t0
            outpath = os.path.join(OUTPUT_DIR, f"test2_{name}_lang_{lang.lower()[:2]}.wav")
            sf.write(outpath, wavs[0], sr)
            print(f"  [{lang}] Saved: {outpath} ({len(wavs[0])/sr:.1f}s, gen {dur:.1f}s)")
        except Exception as e:
            print(f"  [{lang}] ERROR: {e}")


# --- TEST 3: Supported language baseline (sanity check) ---
print("\n=== TEST 3: Supported languages baseline (Chinese + English) ===")
baseline_tests = [
    ("baseline_en", "Hello, my name is Rahul. Today the weather is very nice.", "English"),
    ("baseline_zh", "你好，我的名字叫拉胡尔。今天天气很好。", "Chinese"),
]

for name, text, lang in baseline_tests:
    try:
        t0 = time.time()
        wavs, sr = model.generate_voice_clone(
            text=text,
            language=lang,
        )
        dur = time.time() - t0
        outpath = os.path.join(OUTPUT_DIR, f"test3_{name}.wav")
        sf.write(outpath, wavs[0], sr)
        print(f"  [{name}] Saved: {outpath} ({len(wavs[0])/sr:.1f}s, gen {dur:.1f}s)")
    except Exception as e:
        print(f"  [{name}] ERROR: {e}")


print(f"\n=== All results in {OUTPUT_DIR} ===")
print("Done!")
