import torch
import soundfile as sf
import numpy as np
from qwen_tts import Qwen3TTSModel
import time
import os

OUTPUT_DIR = "/home/ubuntu/qwen3_hindi_results"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("Loading Qwen3-TTS-12Hz-0.6B-Base...")
t0 = time.time()
model = Qwen3TTSModel.from_pretrained(
    "Qwen/Qwen3-TTS-12Hz-0.6B-Base",
    device_map="cuda:0",
    dtype=torch.bfloat16,
)
print(f"Model loaded in {time.time()-t0:.1f}s\n")

MODI_LONG = "/home/ubuntu/qwen3_hindi_samples/modi_speech.wav"
MODI_MED = "/home/ubuntu/qwen3_hindi_samples/modi_medium.wav"

# Use English reference audio from the HF example for baseline comparison
EN_REF = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-TTS-Repo/clone.wav"
EN_REF_TEXT = "Okay. Yeah. I resent you. I love you. I respect you. But you know what? You blew it! And thanks to you."

tests = [
    # === GROUP A: Baselines with English reference (supported lang) ===
    ("A1_en_ref_en_text",
     "Hello, my name is Rahul. Today the weather is very nice and I am feeling great.",
     "English", EN_REF, EN_REF_TEXT),

    ("A2_en_ref_zh_text",
     "你好，我的名字叫拉胡尔。今天天气很好，我感觉很棒。",
     "Chinese", EN_REF, EN_REF_TEXT),

    # === GROUP B: Hindi Devanagari text with English ref (unsupported lang text) ===
    ("B1_en_ref_hindi_text_lang_en",
     "नमस्ते, मेरा नाम राहुल है। आज मौसम बहुत अच्छा है।",
     "English", EN_REF, EN_REF_TEXT),

    ("B2_en_ref_hindi_text_lang_zh",
     "नमस्ते, मेरा नाम राहुल है। आज मौसम बहुत अच्छा है।",
     "Chinese", EN_REF, EN_REF_TEXT),

    # === GROUP C: Hinglish (romanized Hindi) with English ref ===
    ("C1_en_ref_hinglish",
     "Namaste, mera naam Rahul hai. Aaj mausam bahut achha hai. Main bahut khush hoon.",
     "English", EN_REF, EN_REF_TEXT),

    # === GROUP D: Hindi text with Modi voice (Hindi ref audio) ===
    ("D1_modi_hindi_short",
     "नमस्ते, आज मैं आप सभी से बात करना चाहता हूँ।",
     "Chinese", MODI_LONG, "हमारे देश के नागरिकों को अपने अधिकारों के बारे में जानकारी होनी चाहिए"),

    ("D2_modi_hindi_medium",
     "भारत एक महान देश है। यहाँ की संस्कृति और परंपराएँ पूरे विश्व में प्रसिद्ध हैं। हम सब मिलकर इस देश को और आगे ले जाएंगे।",
     "Chinese", MODI_LONG, "हमारे देश के नागरिकों को अपने अधिकारों के बारे में जानकारी होनी चाहिए"),

    ("D3_modi_hindi_lang_en",
     "नमस्ते, आज मैं आप सभी से बात करना चाहता हूँ। भारत एक महान देश है।",
     "English", MODI_LONG, "हमारे देश के नागरिकों को अपने अधिकारों के बारे में जानकारी होनी चाहिए"),

    # === GROUP E: English text with Modi voice (cross-lingual) ===
    ("E1_modi_english",
     "India is a great country with a rich cultural heritage. We are working together to build a stronger nation.",
     "English", MODI_LONG, "हमारे देश के नागरिकों को अपने अधिकारों के बारे में जानकारी होनी चाहिए"),

    # === GROUP F: Hindi numbers & mixed content ===
    ("F1_modi_numbers",
     "भारत की जनसंख्या एक सौ चालीस करोड़ से अधिक है। दो हज़ार चौबीस में हमने बहुत प्रगति की।",
     "Chinese", MODI_MED, "हमारे देश के नागरिकों को अपने अधिकारों के बारे में जानकारी होनी चाहिए"),

    # === GROUP G: x_vector_only_mode (speaker embedding only, no in-context) ===
    ("G1_xvec_hindi",
     "नमस्ते, मेरा नाम राहुल है। आज मौसम बहुत अच्छा है।",
     "Chinese", MODI_LONG, None),  # ref_text=None triggers x_vector mode

    # === GROUP H: Hindi with Japanese/Korean language tag (will it do anything?) ===
    ("H1_hindi_lang_ja",
     "नमस्ते, मेरा नाम राहुल है।",
     "Japanese", MODI_LONG, "हमारे देश के नागरिकों को अपने अधिकारों के बारे में जानकारी होनी चाहिए"),

    ("H2_hindi_lang_ko",
     "नमस्ते, मेरा नाम राहुल है।",
     "Korean", MODI_LONG, "हमारे देश के नागरिकों को अपने अधिकारों के बारे में जानकारी होनी चाहिए"),
]

results_log = []

for name, text, lang, ref_audio, ref_text in tests:
    print(f"--- {name} ---")
    print(f"  Text: {text[:70]}{'...' if len(text)>70 else ''}")
    print(f"  Lang: {lang}, Ref: {os.path.basename(ref_audio) if isinstance(ref_audio, str) and '/' in ref_audio else 'url'}")
    try:
        t0 = time.time()
        kwargs = dict(
            text=text,
            language=lang,
            ref_audio=ref_audio,
            non_streaming_mode=True,
            max_new_tokens=2048,
        )
        if ref_text is not None:
            kwargs["ref_text"] = ref_text
        else:
            kwargs["x_vector_only_mode"] = True

        wavs, sr = model.generate_voice_clone(**kwargs)
        dur = time.time() - t0
        audio = wavs[0]
        audio_dur = len(audio) / sr
        outpath = os.path.join(OUTPUT_DIR, f"{name}.wav")
        sf.write(outpath, audio, sr)
        
        # Basic audio stats
        rms = np.sqrt(np.mean(audio**2))
        peak = np.max(np.abs(audio))
        silence_ratio = np.mean(np.abs(audio) < 0.01)
        
        result = f"  OK: {audio_dur:.1f}s audio, gen {dur:.1f}s, RMS={rms:.4f}, peak={peak:.4f}, silence={silence_ratio:.1%}"
        print(result)
        results_log.append((name, "OK", audio_dur, dur, rms, peak, silence_ratio))
    except Exception as e:
        import traceback
        print(f"  ERROR: {e}")
        traceback.print_exc()
        results_log.append((name, f"ERROR: {e}", 0, 0, 0, 0, 0))
    print()

print("\n" + "="*80)
print("SUMMARY")
print("="*80)
print(f"{'Name':<35} {'Status':>6} {'Audio':>6} {'GenTime':>8} {'RMS':>8} {'Silence':>8}")
print("-"*80)
for name, status, audio_dur, gen_dur, rms, peak, silence in results_log:
    if status == "OK":
        print(f"{name:<35} {'OK':>6} {audio_dur:>5.1f}s {gen_dur:>7.1f}s {rms:>8.4f} {silence:>7.1%}")
    else:
        print(f"{name:<35} {'FAIL':>6}")

print(f"\nAll wav files in: {OUTPUT_DIR}")
