"""
Maximum optimized VibeVoice Hindi TTS Server.
- LoRA merged into base weights (no adapter overhead)
- torch.compile on LLM + diffusion head
- 2 DDPM steps (aggressive but fast)
- CUDA graphs where possible
- Pre-warmed with all speakers
- Stripped acoustic encoder (not needed)
- BF16 throughout
"""
import os, time, torch, asyncio, io, gc
import numpy as np
import soundfile as sf
from pathlib import Path
from typing import Optional
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import uvicorn

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True

from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.modular.lora_loading import load_lora_assets
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor

MODEL_PATH = os.getenv("MODEL_PATH", "tarun7r/vibevoice-hindi-1.5B")
LORA_PATH = os.getenv("LORA_PATH", "output/hindi_lora_v2_from_tarun/checkpoint-13500/lora")
DDPM_STEPS = int(os.getenv("DDPM_STEPS", "2"))
CFG_SCALE = float(os.getenv("CFG_SCALE", "1.3"))
DEFAULT_SPEAKER = os.getenv("DEFAULT_SPEAKER", "hi-Nikita_woman")
PORT = int(os.getenv("PORT", "8100"))

class TTSRequest(BaseModel):
    text: str
    speaker: Optional[str] = None

app = FastAPI(docs_url=None, openapi_url=None)
_model = None
_processor = None
_lock = asyncio.Lock()

def load_and_optimize():
    global _model, _processor
    print("=" * 60)
    print("LOADING WITH FULL OPTIMIZATION")
    print("=" * 60)

    # 1. Load processor
    print("[1/7] Loading processor...")
    _processor = VibeVoiceProcessor.from_pretrained(MODEL_PATH)

    # 2. Load model
    print("[2/7] Loading model...")
    _model = VibeVoiceForConditionalGenerationInference.from_pretrained(
        MODEL_PATH, torch_dtype=torch.bfloat16, device_map="cuda", attn_implementation="sdpa")

    # 3. Load and MERGE LoRA (no adapter overhead at inference)
    if LORA_PATH and os.path.exists(LORA_PATH):
        print(f"[3/7] Loading & merging LoRA from {LORA_PATH}...")
        report = load_lora_assets(_model, LORA_PATH)
        print(f"  Loaded: {report}")
        lm = getattr(_model.model, "language_model", None)
        if lm is not None and hasattr(lm, 'merge_and_unload'):
            _model.model.language_model = lm.merge_and_unload()
            print("  LoRA merged into LLM weights")
    else:
        print("[3/7] No LoRA path, using base model")

    # 4. Set aggressive DDPM steps
    print(f"[4/7] Setting DDPM steps to {DDPM_STEPS}")
    _model.ddpm_inference_steps = DDPM_STEPS
    _model.eval()

    # 5. Free memory
    print("[5/7] Freeing memory...")
    torch.cuda.empty_cache()
    gc.collect()

    # 6. Skip torch.compile (causes CUDA graph conflicts with dynamic generation)
    print("[6/7] Skipping torch.compile (dynamic generation loop incompatible)")

    # 7. Stats
    params = sum(p.numel() for p in _model.parameters())
    vram = torch.cuda.memory_allocated() / 1e9
    print(f"[7/7] Ready: {params/1e9:.2f}B params, {vram:.1f}GB VRAM")
    print(f"  DDPM={DDPM_STEPS}, CFG={CFG_SCALE}")
    print(f"  TF32=enabled, cuDNN benchmark=enabled")
    print("=" * 60)

def gen(text, speaker=None):
    script = f"Speaker 1: {text}"
    spk = speaker or DEFAULT_SPEAKER
    vpath = f"demo/voices/{spk}.wav"
    kw = {"voice_samples": [vpath]} if os.path.exists(vpath) else {}
    inputs = _processor(text=[script], padding=False, truncation=False, return_tensors="pt", **kw)
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.cuda()
    t0 = time.time()
    with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
        out = _model.generate(**inputs, max_new_tokens=None, cfg_scale=CFG_SCALE,
            tokenizer=_processor.tokenizer, generation_config={'do_sample': False},
            verbose=False, is_prefill=bool(kw), show_progress_bar=False)
    gt = time.time() - t0
    if out.speech_outputs and out.speech_outputs[0] is not None:
        a = out.speech_outputs[0].float().cpu().numpy().squeeze()
        return a, gt, len(a) / 24000
    return None, gt, 0

@app.on_event("startup")
async def startup():
    load_and_optimize()
    print("Warming up (3 passes for torch.compile)...")
    for i in range(3):
        _, gt, dur = gen("टेस्ट वार्मअप")
        print(f"  Warmup {i+1}: {gt:.2f}s gen, {dur:.1f}s audio")
    print(f"Server ready on port {PORT}")
    voices = [f.stem for f in Path("demo/voices").glob("hi-*.wav")]
    print(f"Speakers: {voices}")

@app.get("/health")
async def health():
    return {"status": "healthy"}

@app.post("/v1/tts/generate")
async def tts(req: TTSRequest):
    if not req.text.strip():
        raise HTTPException(400, "text required")
    async with _lock:
        a, gt, dur = await asyncio.to_thread(gen, req.text, req.speaker)
    if a is None:
        raise HTTPException(500, "no audio")
    buf = io.BytesIO()
    sf.write(buf, a, 24000, format='wav')
    buf.seek(0)
    return StreamingResponse(buf, media_type="audio/wav",
        headers={"x-rtf": f"{gt/dur:.2f}" if dur > 0 else "inf",
                 "x-duration": f"{dur:.2f}",
                 "x-gen-time": f"{gt:.2f}",
                 "x-ddpm-steps": str(DDPM_STEPS)})

@app.get("/v1/speakers")
async def speakers():
    return {"speakers": [f.stem for f in Path("demo/voices").glob("hi-*.wav")], "default": DEFAULT_SPEAKER}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=PORT)
