import torch
import re
import time
import soundfile as sf
from transformers import AutoTokenizer, AutoModelForCausalLM
from neucodec import NeuCodec

torch.set_float32_matmul_precision('high')

MODEL_NAME = "Scicom-intl/Multilingual-Expressive-TTS-0.6B"

print("Loading models...")
codec = NeuCodec.from_pretrained("neuphonic/neucodec").eval().to('cuda')
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16).to('cuda')
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

samples = [
    {"lang": "hindi", "text": "नमस्ते, मेरा नाम ग्रेस है।"},
    {"lang": "hindi_long", "text": "भारत एक विविधताओं से भरा देश है जहाँ अनेक भाषाएँ बोली जाती हैं।"},
    {"lang": "english", "text": "Hello, my name is Grace and I am here to help you."},
    {"lang": "tamil", "text": "வணக்கம், என் பெயர் கிரேஸ்."},
]

speaker = "multilingual-tts_audio_Grace"

# Warmup
print("Warmup run...")
prompt = f"<|im_start|>{speaker}: Hello<|speech_start|>"
inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to('cuda')
with torch.no_grad():
    _ = model.generate(**inputs, max_new_tokens=50, do_sample=True, temperature=0.7)
torch.cuda.synchronize()

print("\n" + "="*70)
print(f"{'Lang':<15} {'TTFB(ms)':<12} {'Total(ms)':<12} {'Audio(s)':<10} {'Tokens':<8} {'RTF':<8}")
print("="*70)

for s in samples:
    prompt = f"<|im_start|>{speaker}: {s['text']}<|speech_start|>"
    inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to('cuda')
    input_len = inputs['input_ids'].shape[1]

    torch.cuda.synchronize()

    # Measure TTFB: time to generate first token
    t_start = time.perf_counter()
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=1,
            do_sample=True,
            temperature=0.7,
        )
    torch.cuda.synchronize()
    t_first_token = time.perf_counter()
    ttfb_ms = (t_first_token - t_start) * 1000

    # Now measure full generation
    torch.cuda.synchronize()
    t_start_full = time.perf_counter()
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=1024,
            do_sample=True,
            temperature=0.7,
            repetition_penalty=1.15,
        )
    torch.cuda.synchronize()
    t_end_full = time.perf_counter()
    total_ms = (t_end_full - t_start_full) * 1000

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
    audio_tokens = re.findall(r'<\|s_(\d+)\|>', generated_text.split('<|speech_start|>')[-1])
    audio_tokens = [int(t) for t in audio_tokens]
    n_tokens = len(audio_tokens)
    audio_dur = n_tokens / 50.0
    rtf = (total_ms / 1000) / audio_dur if audio_dur > 0 else 0

    print(f"{s['lang']:<15} {ttfb_ms:<12.1f} {total_ms:<12.1f} {audio_dur:<10.2f} {n_tokens:<8} {rtf:<8.3f}")

# Also measure TTFB for first playable audio chunk (need ~20 tokens for codec)
print("\n" + "="*70)
print("TTFB to first playable audio chunk (~20 tokens = 0.4s audio):")
print("="*70)

for s in samples[:2]:
    prompt = f"<|im_start|>{speaker}: {s['text']}<|speech_start|>"
    inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to('cuda')

    torch.cuda.synchronize()
    t_start = time.perf_counter()

    # Generate 20 speech tokens (enough for first audio chunk)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=20,
            do_sample=True,
            temperature=0.7,
        )
    torch.cuda.synchronize()
    t_20tok = time.perf_counter()

    # Decode those tokens to audio
    generated_text = tokenizer.decode(out[0], skip_special_tokens=False)
    audio_tokens = re.findall(r'<\|s_(\d+)\|>', generated_text.split('<|speech_start|>')[-1])
    audio_tokens = [int(t) for t in audio_tokens]

    if len(audio_tokens) > 0:
        audio_codes = torch.tensor(audio_tokens)[None, None]
        with torch.no_grad():
            _ = codec.decode_code(audio_codes.cuda())
    torch.cuda.synchronize()
    t_decoded = time.perf_counter()

    gen_ms = (t_20tok - t_start) * 1000
    decode_ms = (t_decoded - t_20tok) * 1000
    total_ttfb = (t_decoded - t_start) * 1000

    print(f"{s['lang']:<15} LLM={gen_ms:.0f}ms + Codec={decode_ms:.0f}ms = TTFB {total_ttfb:.0f}ms  ({len(audio_tokens)} tokens)")
