import torch
import soundfile as sf
import numpy as np
from qwen_tts import Qwen3TTSModel
import time
import os

OUTPUT_DIR = "/home/ubuntu/qwen3_hindi_results"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("Loading Qwen3-TTS-12Hz-0.6B-Base...")
t0 = time.time()
model = Qwen3TTSModel.from_pretrained(
    "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
    device_map="cuda:0",
    dtype=torch.bfloat16,
)
print(f"Model loaded in {time.time()-t0:.1f}s\n")

REF_AUDIO = "/home/ubuntu/qwen3_hindi_samples/modi_speech.wav"

tests = [
    # (name, text, language, ref_audio, ref_text)
    # -- Baselines in supported languages --
    ("01_baseline_en", "Hello, my name is Rahul. Today the weather is very nice and I am happy.", "English", None, None),
    ("02_baseline_zh", "你好，我的名字叫拉胡尔。今天天气很好。", "Chinese", None, None),
    # -- Hindi text without voice ref --
    ("03_hindi_noref_en", "नमस्ते, मेरा नाम राहुल है। आज मौसम बहुत अच्छा है।", "English", None, None),
    ("04_hindi_noref_zh", "नमस्ते, मेरा नाम राहुल है। आज मौसम बहुत अच्छा है।", "Chinese", None, None),
    # -- Hinglish (romanized Hindi) --
    ("05_hinglish_en", "Namaste, mera naam Rahul hai. Aaj mausam bahut achha hai.", "English", None, None),
    # -- Hindi with Modi voice clone --
    ("06_hindi_clone", "नमस्ते, मेरा नाम नरेंद्र है। आज मैं आप सभी से बात करना चाहता हूँ।", "Chinese",
     REF_AUDIO, "हमारे देश के नागरिकों को अपने अधिकारों के बारे में जानकारी होनी चाहिए"),
    # -- English text with Modi voice clone (cross-lingual) --
    ("07_english_clone_modi", "India is a great country. We are building a digital future for every citizen.", "English",
     REF_AUDIO, "हमारे देश के नागरिकों को अपने अधिकारों के बारे में जानकारी होनी चाहिए"),
]

for name, text, lang, ref_audio, ref_text in tests:
    print(f"--- {name} ---")
    print(f"  Text: {text[:70]}{'...' if len(text)>70 else ''}")
    print(f"  Lang: {lang}, Ref: {'yes' if ref_audio else 'no'}")
    try:
        t0 = time.time()
        kwargs = dict(
            text=text,
            language=lang,
            non_streaming_mode=True,
            max_new_tokens=2048,
        )
        if ref_audio:
            kwargs["ref_audio"] = ref_audio
            kwargs["ref_text"] = ref_text

        wavs, sr = model.generate_voice_clone(**kwargs)
        dur = time.time() - t0
        audio = wavs[0]
        audio_dur = len(audio) / sr
        outpath = os.path.join(OUTPUT_DIR, f"{name}.wav")
        sf.write(outpath, audio, sr)
        print(f"  OK: {audio_dur:.1f}s audio, generated in {dur:.1f}s, saved to {outpath}")
    except Exception as e:
        import traceback
        print(f"  ERROR: {e}")
        traceback.print_exc()
    print()

print(f"\nAll results in: {OUTPUT_DIR}")
print("Listen to the files to compare what works vs what breaks for Hindi!")
