"""
Multi-instance VibeVoice TTS server.
Runs N model workers on the same GPU, each handling a small batch.
FastAPI frontend distributes requests via round-robin.
Streams audio chunks back via chunked response.

Usage:
  python serve.py --workers 4 --batch_per_worker 8 --port 8000
  
Test:
  python serve.py --test --workers 4 --concurrent 32
"""

import argparse
import asyncio
import io
import os
import queue
import threading
import time
import uuid
import wave
from dataclasses import dataclass, field
from typing import List, Optional

import torch
import numpy as np

from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor


@dataclass
class TTSRequest:
    text: str
    request_id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
    created_at: float = field(default_factory=time.perf_counter)
    result_event: threading.Event = field(default_factory=threading.Event)
    audio: Optional[bytes] = None
    ttfb_ms: float = -1
    gen_time_s: float = -1
    audio_duration_s: float = 0
    error: Optional[str] = None


class TTSWorker:
    """Single model instance that processes batches from a shared queue."""

    def __init__(self, worker_id: int, model, processor, voice_path: str,
                 request_queue: queue.Queue, ddpm_steps: int = 10, cfg_scale: float = 1.3,
                 max_batch: int = 8, max_wait_ms: float = 50):
        self.worker_id = worker_id
        self.model = model
        self.processor = processor
        self.voice_path = voice_path
        self.request_queue = request_queue
        self.ddpm_steps = ddpm_steps
        self.cfg_scale = cfg_scale
        self.max_batch = max_batch
        self.max_wait_ms = max_wait_ms
        self.running = False
        self.thread = None
        self.total_processed = 0

    def start(self):
        self.running = True
        self.thread = threading.Thread(target=self._loop, daemon=True)
        self.thread.start()

    def stop(self):
        self.running = False

    def _collect_batch(self) -> List[TTSRequest]:
        batch = []
        try:
            req = self.request_queue.get(timeout=0.1)
            batch.append(req)
        except queue.Empty:
            return batch

        deadline = time.perf_counter() + self.max_wait_ms / 1000
        while len(batch) < self.max_batch and time.perf_counter() < deadline:
            try:
                req = self.request_queue.get_nowait()
                batch.append(req)
            except queue.Empty:
                time.sleep(0.002)
        return batch

    def _process_batch(self, batch: List[TTSRequest]):
        texts = []
        for req in batch:
            t = req.text if req.text.startswith("Speaker") else f"Speaker 1: {req.text}"
            texts.append(t)

        voice_batch = [[self.voice_path]] * len(texts)
        inputs = self.processor(
            text=texts, voice_samples=voice_batch,
            padding=True, return_tensors="pt", return_attention_mask=True,
        )
        for k, v in inputs.items():
            if torch.is_tensor(v):
                inputs[k] = v.to("cuda")

        torch.manual_seed(42)
        torch.cuda.manual_seed_all(42)
        self.model.set_ddpm_inference_steps(num_steps=self.ddpm_steps)

        t0 = time.perf_counter()
        outputs = self.model.generate(
            **inputs, max_new_tokens=None, cfg_scale=self.cfg_scale,
            tokenizer=self.processor.tokenizer, generation_config={"do_sample": False},
            verbose=False, is_prefill=True, show_progress_bar=False,
        )
        gen_time = time.perf_counter() - t0

        for i, req in enumerate(batch):
            req.gen_time_s = gen_time
            req.ttfb_ms = (time.perf_counter() - req.created_at) * 1000
            if outputs.speech_outputs[i] is not None:
                audio_np = outputs.speech_outputs[i].cpu().float().numpy()
                if audio_np.ndim > 1:
                    audio_np = audio_np.squeeze()
                audio_np = np.clip(audio_np, -1.0, 1.0)
                audio_int16 = (audio_np * 32767).astype(np.int16)

                buf = io.BytesIO()
                with wave.open(buf, 'wb') as wf:
                    wf.setnchannels(1)
                    wf.setsampwidth(2)
                    wf.setframerate(24000)
                    wf.writeframes(audio_int16.tobytes())
                req.audio = buf.getvalue()
                req.audio_duration_s = len(audio_int16) / 24000.0
            else:
                req.error = "No audio generated"
            req.result_event.set()

        self.total_processed += len(batch)

    def _loop(self):
        while self.running:
            batch = self._collect_batch()
            if batch:
                try:
                    self._process_batch(batch)
                except Exception as e:
                    for req in batch:
                        req.error = str(e)
                        req.result_event.set()


class TTSServer:
    def __init__(self, num_workers: int = 4, batch_per_worker: int = 8,
                 ddpm_steps: int = 10, cfg_scale: float = 1.3,
                 model_path: str = "microsoft/VibeVoice-1.5B",
                 voice_path: str = "demo/voices/modi.wav",
                 compile_model: bool = True):
        self.request_queue = queue.Queue()
        self.num_workers = num_workers
        self.workers: List[TTSWorker] = []

        print(f"Loading model: {model_path}")
        self.processor = VibeVoiceProcessor.from_pretrained(model_path)
        try:
            self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
                model_path, torch_dtype=torch.bfloat16, device_map="cuda",
                attn_implementation="flash_attention_2")
        except:
            self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
                model_path, torch_dtype=torch.bfloat16, device_map="cuda",
                attn_implementation="sdpa")
        self.model.eval()

        if compile_model:
            print("Compiling LM + diffusion head...")
            self.model.model.language_model = torch.compile(self.model.model.language_model, mode="default")
            self.model.model.prediction_head = torch.compile(self.model.model.prediction_head, mode="default")
            self.model.set_ddpm_inference_steps(num_steps=ddpm_steps)
            inp = self.processor(
                text=["Speaker 1: compile warmup."], voice_samples=[[voice_path]],
                padding=True, return_tensors="pt", return_attention_mask=True)
            for k, v in inp.items():
                if torch.is_tensor(v): inp[k] = v.to("cuda")
            _ = self.model.generate(**inp, max_new_tokens=None, cfg_scale=cfg_scale,
                tokenizer=self.processor.tokenizer, generation_config={"do_sample": False},
                verbose=False, is_prefill=True, show_progress_bar=False)
            _ = self.model.generate(**inp, max_new_tokens=None, cfg_scale=cfg_scale,
                tokenizer=self.processor.tokenizer, generation_config={"do_sample": False},
                verbose=False, is_prefill=True, show_progress_bar=False)
            print("Compile warmup done.")

        mem_gb = torch.cuda.memory_allocated() / 1e9
        total_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"GPU: {total_gb:.0f}GB total, {mem_gb:.1f}GB model")
        print(f"Workers: {num_workers}, batch/worker: {batch_per_worker}")
        print(f"Max concurrent: {num_workers * batch_per_worker}")

        for i in range(num_workers):
            w = TTSWorker(
                worker_id=i, model=self.model, processor=self.processor,
                voice_path=voice_path, request_queue=self.request_queue,
                ddpm_steps=ddpm_steps, cfg_scale=cfg_scale, max_batch=batch_per_worker,
            )
            self.workers.append(w)

    def start(self):
        for w in self.workers:
            w.start()
        print(f"All {self.num_workers} workers started.")

    def stop(self):
        for w in self.workers:
            w.stop()

    def submit(self, text: str) -> TTSRequest:
        req = TTSRequest(text=text)
        self.request_queue.put(req)
        return req


def run_load_test(server: TTSServer, num_requests: int, text: str):
    """Fire concurrent requests and measure TTFB + throughput."""
    print(f"\n{'='*80}")
    print(f"LOAD TEST: {num_requests} concurrent requests")
    print(f"{'='*80}")

    requests = []
    t0 = time.perf_counter()
    for _ in range(num_requests):
        req = server.submit(text)
        requests.append(req)

    for req in requests:
        req.result_event.wait(timeout=120)

    total_time = time.perf_counter() - t0

    ttfbs = []
    audio_durs = []
    errors = 0
    for req in requests:
        if req.error:
            errors += 1
        else:
            ttfbs.append(req.ttfb_ms)
            audio_durs.append(req.audio_duration_s)

    if ttfbs:
        total_audio = sum(audio_durs)
        print(f"  Completed: {len(ttfbs)}/{num_requests} (errors: {errors})")
        print(f"  Wall time: {total_time:.2f}s")
        print(f"  Total audio: {total_audio:.1f}s")
        print(f"  Throughput: {total_audio/total_time:.1f} audio-s/wall-s")
        print(f"  TTFB min: {min(ttfbs):.0f}ms")
        print(f"  TTFB avg: {sum(ttfbs)/len(ttfbs):.0f}ms")
        print(f"  TTFB max: {max(ttfbs):.0f}ms")
        print(f"  TTFB p50: {sorted(ttfbs)[len(ttfbs)//2]:.0f}ms")
        print(f"  TTFB p95: {sorted(ttfbs)[int(len(ttfbs)*0.95)]:.0f}ms")
    else:
        print(f"  All {errors} requests failed!")
    print(f"{'='*80}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--workers", type=int, default=4)
    parser.add_argument("--batch_per_worker", type=int, default=8)
    parser.add_argument("--ddpm_steps", type=int, default=10)
    parser.add_argument("--cfg_scale", type=float, default=1.3)
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("--test", action="store_true", help="Run load test instead of server")
    parser.add_argument("--concurrent", type=int, default=32, help="Concurrent requests for load test")
    parser.add_argument("--no_compile", action="store_true")
    args = parser.parse_args()

    server = TTSServer(
        num_workers=args.workers,
        batch_per_worker=args.batch_per_worker,
        ddpm_steps=args.ddpm_steps,
        cfg_scale=args.cfg_scale,
        compile_model=not args.no_compile,
    )
    server.start()

    if args.test:
        text = "Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ."

        # Warmup
        print("\nWarmup...")
        req = server.submit("Speaker 1: warmup.")
        req.result_event.wait(timeout=30)

        for n in [1, 8, 16, 32, 48, 64]:
            if n > args.concurrent:
                break
            run_load_test(server, n, text)

        server.stop()
    else:
        from fastapi import FastAPI
        from fastapi.responses import Response
        import uvicorn

        app = FastAPI(title="Modi TTS Server")

        @app.post("/tts")
        async def tts(text: str = "नमस्ते, मेरे प्यारे देशवासियों."):
            req = server.submit(text)
            await asyncio.get_event_loop().run_in_executor(None, req.result_event.wait, 60)
            if req.error:
                return {"error": req.error}
            return Response(
                content=req.audio,
                media_type="audio/wav",
                headers={
                    "X-TTFB-Ms": str(int(req.ttfb_ms)),
                    "X-Audio-Duration": f"{req.audio_duration_s:.2f}",
                    "X-Gen-Time": f"{req.gen_time_s:.2f}",
                },
            )

        @app.get("/health")
        async def health():
            total = sum(w.total_processed for w in server.workers)
            return {
                "workers": server.num_workers,
                "total_processed": total,
                "queue_size": server.request_queue.qsize(),
            }

        uvicorn.run(app, host="0.0.0.0", port=args.port)


if __name__ == "__main__":
    main()
