import os
import torch
import re
import soundfile as sf
from transformers import AutoTokenizer, AutoModelForCausalLM
from neucodec import NeuCodec

torch.set_float32_matmul_precision('high')

MODEL_NAME = "Scicom-intl/Multilingual-Expressive-TTS-0.6B"

print("Loading codec...")
codec = NeuCodec.from_pretrained("neuphonic/neucodec")
codec = codec.eval().to('cuda')

print(f"Loading model: {MODEL_NAME}...")
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16).to('cuda')
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Indic language test samples
samples = [
    # Hindi
    {"lang": "hindi", "speaker": "multilingual-tts_audio_Grace", "text": "नमस्ते, मेरा नाम ग्रेस है और मैं आपकी मदद करने के लिए यहाँ हूँ।"},
    {"lang": "hindi", "speaker": "multilingual-tts_audio_Grace", "text": "भारत एक विविधताओं से भरा देश है जहाँ अनेक भाषाएँ बोली जाती हैं।"},
    # Tamil
    {"lang": "tamil", "speaker": "multilingual-tts_audio_Grace", "text": "வணக்கம், என் பெயர் கிரேஸ், நான் உங்களுக்கு உதவ இங்கே இருக்கிறேன்."},
    # Bengali
    {"lang": "bengali", "speaker": "multilingual-tts_audio_Grace", "text": "নমস্কার, আমার নাম গ্রেস এবং আমি আপনাকে সাহায্য করতে এখানে আছি।"},
    # Telugu
    {"lang": "telugu", "speaker": "multilingual-tts_audio_Grace", "text": "నమస్కారం, నా పేరు గ్రేస్ మరియు మీకు సహాయం చేయడానికి నేను ఇక్కడ ఉన్నాను."},
    # Marathi
    {"lang": "marathi", "speaker": "multilingual-tts_audio_Grace", "text": "नमस्कार, माझं नाव ग्रेस आहे आणि मी तुम्हाला मदत करण्यासाठी इथे आहे."},
    # Urdu
    {"lang": "urdu", "speaker": "multilingual-tts_audio_Grace", "text": "السلام علیکم، میرا نام گریس ہے اور میں آپ کی مدد کے لیے یہاں ہوں۔"},
    # Malayalam
    {"lang": "malayalam", "speaker": "multilingual-tts_audio_Grace", "text": "നമസ്കാരം, എന്റെ പേര് ഗ്രേസ് ആണ്, നിങ്ങളെ സഹായിക്കാൻ ഞാൻ ഇവിടെയുണ്ട്."},
]

os.makedirs("indic_tts_outputs", exist_ok=True)

for i, s in enumerate(samples):
    print(f"\n[{i+1}/{len(samples)}] Generating {s['lang']}...")
    print(f"  Text: {s['text'][:60]}...")

    prompt = f"<|im_start|>{s['speaker']}: {s['text']}<|speech_start|>"
    inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=1024,
            do_sample=True,
            temperature=0.7,
            repetition_penalty=1.15,
        )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
    audio_tokens = re.findall(r'<\|s_(\d+)\|>', generated_text.split('<|speech_start|>')[-1])
    audio_tokens = [int(t) for t in audio_tokens]

    if len(audio_tokens) == 0:
        print(f"  WARNING: No audio tokens generated!")
        continue

    print(f"  Generated {len(audio_tokens)} audio tokens ({len(audio_tokens)/50:.1f}s)")

    audio_codes = torch.tensor(audio_tokens)[None, None]
    with torch.no_grad():
        audio_waveform = codec.decode_code(audio_codes.cuda())

    filename = f"indic_tts_outputs/{s['lang']}_{i+1}.wav"
    sf.write(filename, audio_waveform[0, 0].cpu().numpy(), 24000)
    print(f"  Saved: {filename}")

print("\nDone! All outputs in indic_tts_outputs/")
