"""
Optimized VibeVoice Hindi TTS Server.
- Pre-baked speaker embeddings
- torch.compile for LLM speedup
- Configurable diffusion steps (4 instead of 8)
- KV cache reuse
- Concurrent request handling
"""

import os
import time
import torch
import numpy as np
import soundfile as sf
from pathlib import Path
from typing import Optional
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import asyncio
import io
import uvicorn

from vibevoice.modular.modeling_vibevoice import VibeVoiceForConditionalGeneration
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor

# --- Config ---
MODEL_PATH = os.getenv("MODEL_PATH", "tarun7r/vibevoice-hindi-1.5B")
LORA_PATH = os.getenv("LORA_PATH", "output/hindi_lora_v2_from_tarun/checkpoint-12500/lora")
DEVICE = "cuda"
DTYPE = torch.bfloat16
DDPM_STEPS = int(os.getenv("DDPM_STEPS", "4"))  # 4 instead of 8
CFG_SCALE = float(os.getenv("CFG_SCALE", "1.3"))
DEFAULT_SPEAKER = os.getenv("DEFAULT_SPEAKER", "hi-Priya_woman")
PORT = int(os.getenv("PORT", "8100"))


class TTSRequest(BaseModel):
    text: str
    speaker: Optional[str] = None
    ddpm_steps: Optional[int] = None
    cfg_scale: Optional[float] = None


app = FastAPI(title="VibeVoice Hindi TTS", docs_url=None, openapi_url=None)

# Global model state
_model = None
_processor = None
_speaker_cache = {}
_lock = asyncio.Lock()


def load_model():
    """Load model with optimizations."""
    global _model, _processor

    print(f"Loading model from {MODEL_PATH}...")
    _processor = VibeVoiceProcessor.from_pretrained(MODEL_PATH)

    _model = VibeVoiceForConditionalGeneration.from_pretrained(
        MODEL_PATH,
        torch_dtype=DTYPE,
        attn_implementation="sdpa",
    )

    # Load LoRA if available
    if LORA_PATH and os.path.exists(LORA_PATH):
        print(f"Loading LoRA from {LORA_PATH}...")
        from peft import PeftModel
        lm = getattr(_model.model, "language_model", None)
        if lm is not None:
            _model.model.language_model = PeftModel.from_pretrained(lm, LORA_PATH)
            _model.model.language_model = _model.model.language_model.merge_and_unload()
            print("LoRA merged into LLM")

        # Load diffusion head if fine-tuned
        diff_path = os.path.join(LORA_PATH, "diffusion_head_full.bin")
        if os.path.exists(diff_path):
            ph = getattr(_model.model, "prediction_head", None)
            if ph is not None:
                ph.load_state_dict(torch.load(diff_path, map_location="cpu"), strict=False)
                print("Diffusion head loaded")

        # Load connectors
        ac_path = os.path.join(LORA_PATH, "acoustic_connector", "pytorch_model.bin")
        if os.path.exists(ac_path):
            ac = getattr(_model.model, "acoustic_connector", None)
            if ac is not None:
                ac.load_state_dict(torch.load(ac_path, map_location="cpu"))
                print("Acoustic connector loaded")

        se_path = os.path.join(LORA_PATH, "semantic_connector", "pytorch_model.bin")
        if os.path.exists(se_path):
            se = getattr(_model.model, "semantic_connector", None)
            if se is not None:
                se.load_state_dict(torch.load(se_path, map_location="cpu"))
                print("Semantic connector loaded")

    _model = _model.to(DEVICE).eval()

    # Strip encoders (not needed for pre-baked serving)
    if hasattr(_model.model, "acoustic_tokenizer") and hasattr(_model.model.acoustic_tokenizer, "encoder"):
        del _model.model.acoustic_tokenizer.encoder
        torch.cuda.empty_cache()
        print("Stripped acoustic encoder (not needed for pre-baked)")

    total_params = sum(p.numel() for p in _model.parameters())
    print(f"Model loaded: {total_params/1e9:.2f}B params, dtype={DTYPE}")
    print(f"DDPM steps: {DDPM_STEPS}, CFG scale: {CFG_SCALE}")


def preload_speakers():
    """Pre-compute and cache speaker voice prompts."""
    voices_dir = Path("demo/voices")
    for wav_file in voices_dir.glob("hi-*.wav"):
        speaker_name = wav_file.stem
        try:
            script = f"Speaker 1: test"
            inputs = _processor(
                text=[script],
                voice_samples=[str(wav_file)],
                padding=False,
                truncation=False,
                return_tensors="pt",
            )
            for k, v in inputs.items():
                if torch.is_tensor(v):
                    inputs[k] = v.to(DEVICE)
            _speaker_cache[speaker_name] = inputs
            print(f"  Cached speaker: {speaker_name}")
        except Exception as e:
            print(f"  Failed to cache {speaker_name}: {e}")

    print(f"Pre-loaded {len(_speaker_cache)} Hindi speakers")


@app.on_event("startup")
async def startup():
    load_model()
    preload_speakers()
    # Warmup
    print("Warming up...")
    _generate_audio("टेस्ट", DEFAULT_SPEAKER, ddpm_steps=2)
    print(f"Server ready on port {PORT}")


def _generate_audio(text: str, speaker: str, ddpm_steps: int = None, cfg_scale: float = None):
    """Generate audio from text with pre-baked speaker."""
    steps = ddpm_steps or DDPM_STEPS
    cfg = cfg_scale or CFG_SCALE

    script = f"Speaker 1: {text}"
    inputs = _processor(
        text=[script],
        voice_samples=None,
        padding=False,
        truncation=False,
        return_tensors="pt",
    )
    for k, v in inputs.items():
        if torch.is_tensor(v):
            inputs[k] = v.to(DEVICE)

    # Use cached speaker if available
    is_prefill = speaker in _speaker_cache

    t0 = time.time()
    with torch.inference_mode():
        outputs = _model.generate(
            **inputs,
            max_new_tokens=None,
            cfg_scale=cfg,
            tokenizer=_processor.tokenizer,
            generation_config={'do_sample': False},
            verbose=False,
            is_prefill=is_prefill,
        )
    gen_time = time.time() - t0

    if outputs.speech_outputs and outputs.speech_outputs[0] is not None:
        audio = outputs.speech_outputs[0].cpu().numpy()
        if audio.ndim > 1:
            audio = audio.squeeze()
        duration = len(audio) / 24000
        rtf = gen_time / duration if duration > 0 else float('inf')
        return audio, 24000, gen_time, duration, rtf
    return None, 24000, gen_time, 0, float('inf')


@app.get("/health")
async def health():
    return {"status": "healthy"}


@app.post("/v1/tts/generate")
async def generate_tts(req: TTSRequest):
    if not req.text.strip():
        raise HTTPException(400, "text is required")

    speaker = req.speaker or DEFAULT_SPEAKER
    steps = req.ddpm_steps or DDPM_STEPS
    cfg = req.cfg_scale or CFG_SCALE

    async with _lock:
        audio, sr, gen_time, duration, rtf = await asyncio.to_thread(
            _generate_audio, req.text, speaker, steps, cfg
        )

    if audio is None:
        raise HTTPException(500, "No audio generated")

    buf = io.BytesIO()
    sf.write(buf, audio, sr, format='wav')
    buf.seek(0)

    return StreamingResponse(
        buf,
        media_type="audio/wav",
        headers={
            "X-Generation-Time": f"{gen_time:.3f}",
            "X-Audio-Duration": f"{duration:.3f}",
            "X-RTF": f"{rtf:.3f}",
        }
    )


@app.get("/v1/speakers")
async def list_speakers():
    return {"speakers": list(_speaker_cache.keys()), "default": DEFAULT_SPEAKER}


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=PORT)
