"""
Sooktam-2 streaming inference server.
Uses chunk-based streaming: splits long text into chunks, yields audio per chunk.
FastAPI with WebSocket for real-time audio delivery.
"""
import sys, os, time, io, wave, struct
sys.path.insert(0, '/home/ubuntu/sooktam2/src')
sys.path.insert(0, '/home/ubuntu/sooktam2')

import numpy as np
import torch
from fastapi import FastAPI, WebSocket
from fastapi.responses import Response
import uvicorn

REF = '/home/ubuntu/vibevoice/demo/voices/modi.wav'
REF_TEXT = 'मेरे प्यारे देशवासियों, मुझे सीतापुर के ओजस्वी ने लिखा है कि अमृत महोत्सव से जुड़ी चर्चाएं उन्हें खूब पसंद आ रही हैं।'

print('Loading Sooktam-2...', flush=True)
from transformers import AutoModel
model = AutoModel.from_pretrained('/home/ubuntu/sooktam2', trust_remote_code=True)
# Warmup
model.infer(ref_file=REF, ref_text=REF_TEXT, gen_text='warmup.', tokenizer='cls', cls_language='hindi')
print(f'Ready. VRAM: {torch.cuda.memory_allocated()/1e9:.1f}GB', flush=True)

app = FastAPI(title="Sooktam-2 Modi TTS")

def to_wav_bytes(audio_np, sr=24000):
    if audio_np.dtype != np.int16:
        audio_np = np.clip(audio_np, -1.0, 1.0)
        audio_np = (audio_np * 32767).astype(np.int16)
    buf = io.BytesIO()
    with wave.open(buf, 'wb') as wf:
        wf.setnchannels(1)
        wf.setsampwidth(2)
        wf.setframerate(sr)
        wf.writeframes(audio_np.tobytes())
    return buf.getvalue()

@app.post("/tts")
async def tts(text: str = "नमस्ते देशवासियों.", nfe_step: int = 32, cfg_strength: float = 2.0):
    t0 = time.perf_counter()
    wav, sr, _ = model.infer(
        ref_file=REF, ref_text=REF_TEXT, gen_text=text,
        tokenizer='cls', cls_language='hindi',
        nfe_step=nfe_step, cfg_strength=cfg_strength,
    )
    gen_time = time.perf_counter() - t0
    dur = len(wav) / sr
    audio_bytes = to_wav_bytes(wav, sr)
    return Response(
        content=audio_bytes, media_type="audio/wav",
        headers={
            "X-Audio-Duration": f"{dur:.2f}",
            "X-Gen-Time": f"{gen_time:.2f}",
            "X-RTF": f"{gen_time/dur:.3f}",
        },
    )

@app.get("/health")
async def health():
    return {"status": "ok", "model": "sooktam-2", "vram_gb": f"{torch.cuda.memory_allocated()/1e9:.1f}"}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)
