"""
VibeVoice TTS server v2: single worker, dynamic batching.
Collects requests into batches, processes them together.
The batch size = your concurrency.

Usage:
  python serve_v2.py --test --max_batch 32 --ddpm_steps 10
  python serve_v2.py --port 8000 --max_batch 16
"""

import argparse
import asyncio
import io
import os
import queue
import threading
import time
import wave

import numpy as np
import torch

from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
from vibevoice.modular.streamer import AudioStreamer
from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor


class Request:
    __slots__ = ('text', 'created_at', 'event', 'audio', 'ttfb_ms', 'gen_time_s', 'audio_dur_s', 'error')
    def __init__(self, text):
        self.text = text if text.startswith("Speaker") else f"Speaker 1: {text}"
        self.created_at = time.perf_counter()
        self.event = threading.Event()
        self.audio = None
        self.ttfb_ms = -1
        self.gen_time_s = -1
        self.audio_dur_s = 0
        self.error = None


class TTSEngine:
    def __init__(self, model_path, voice_path, max_batch, ddpm_steps, cfg_scale, compile_model):
        self.voice_path = os.path.abspath(voice_path)
        self.max_batch = max_batch
        self.ddpm_steps = ddpm_steps
        self.cfg_scale = cfg_scale
        self.queue = queue.Queue()
        self.running = False

        print(f"Loading model: {model_path}")
        self.processor = VibeVoiceProcessor.from_pretrained(model_path)
        try:
            self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
                model_path, torch_dtype=torch.bfloat16, device_map="cuda",
                attn_implementation="flash_attention_2")
        except Exception:
            self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
                model_path, torch_dtype=torch.bfloat16, device_map="cuda",
                attn_implementation="sdpa")
        self.model.eval()
        self.model.set_ddpm_inference_steps(num_steps=ddpm_steps)

        if compile_model:
            print("Compiling LM + diffusion head...")
            self.model.model.language_model = torch.compile(self.model.model.language_model, mode="default")
            self.model.model.prediction_head = torch.compile(self.model.model.prediction_head, mode="default")
            for _ in range(2):
                self._warmup()
            print("Compile done.")

        mem = torch.cuda.memory_allocated() / 1e9
        total = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"GPU: {total:.0f}GB | Model: {mem:.1f}GB | Max batch: {max_batch} | Steps: {ddpm_steps}")

    def _warmup(self):
        inp = self.processor(
            text=["Speaker 1: warmup."], voice_samples=[[self.voice_path]],
            padding=True, return_tensors="pt", return_attention_mask=True)
        for k, v in inp.items():
            if torch.is_tensor(v): inp[k] = v.to("cuda")
        self.model.generate(**inp, max_new_tokens=None, cfg_scale=self.cfg_scale,
            tokenizer=self.processor.tokenizer, generation_config={"do_sample": False},
            verbose=False, is_prefill=True, show_progress_bar=False)

    def submit(self, text: str) -> Request:
        req = Request(text)
        self.queue.put(req)
        return req

    def start(self):
        self.running = True
        self._thread = threading.Thread(target=self._loop, daemon=True)
        self._thread.start()

    def stop(self):
        self.running = False

    def _collect_batch(self):
        batch = []
        try:
            batch.append(self.queue.get(timeout=0.1))
        except queue.Empty:
            return batch
        deadline = time.perf_counter() + 0.03
        while len(batch) < self.max_batch and time.perf_counter() < deadline:
            try:
                batch.append(self.queue.get_nowait())
            except queue.Empty:
                time.sleep(0.001)
        return batch

    def _process(self, batch):
        texts = [r.text for r in batch]
        bs = len(texts)
        voice_batch = [[self.voice_path]] * bs

        inputs = self.processor(
            text=texts, voice_samples=voice_batch,
            padding=True, return_tensors="pt", return_attention_mask=True)
        for k, v in inputs.items():
            if torch.is_tensor(v): inputs[k] = v.to("cuda")

        streamer = AudioStreamer(batch_size=bs, stop_signal=None)
        first_chunks = [None] * bs
        sample_counts = [0] * bs

        def consumer(idx):
            for chunk in streamer.get_stream(idx):
                t = time.perf_counter()
                if first_chunks[idx] is None:
                    first_chunks[idx] = t
                sample_counts[idx] += chunk.shape[-1]

        threads = [threading.Thread(target=consumer, args=(i,), daemon=True) for i in range(bs)]
        for t in threads:
            t.start()

        torch.manual_seed(42)
        torch.cuda.manual_seed_all(42)

        torch.cuda.synchronize()
        t0 = time.perf_counter()
        outputs = self.model.generate(
            **inputs, max_new_tokens=None, cfg_scale=self.cfg_scale,
            tokenizer=self.processor.tokenizer, generation_config={"do_sample": False},
            verbose=False, is_prefill=True, audio_streamer=streamer, show_progress_bar=False)
        torch.cuda.synchronize()
        gen_time = time.perf_counter() - t0

        for t in threads:
            t.join(timeout=5)

        for i, req in enumerate(batch):
            req.gen_time_s = gen_time
            if first_chunks[i] is not None:
                req.ttfb_ms = (first_chunks[i] - req.created_at) * 1000
            if outputs.speech_outputs[i] is not None:
                audio_np = outputs.speech_outputs[i].cpu().float().numpy()
                if audio_np.ndim > 1:
                    audio_np = audio_np.squeeze()
                audio_np = np.clip(audio_np, -1.0, 1.0)
                audio_int16 = (audio_np * 32767).astype(np.int16)
                buf = io.BytesIO()
                with wave.open(buf, 'wb') as wf:
                    wf.setnchannels(1)
                    wf.setsampwidth(2)
                    wf.setframerate(24000)
                    wf.writeframes(audio_int16.tobytes())
                req.audio = buf.getvalue()
                req.audio_dur_s = len(audio_int16) / 24000.0
            else:
                req.error = "No audio"
            req.event.set()

    def _loop(self):
        while self.running:
            batch = self._collect_batch()
            if batch:
                try:
                    self._process(batch)
                except Exception as e:
                    for r in batch:
                        r.error = str(e)
                        r.event.set()


def load_test(engine, concurrencies, text):
    for n in concurrencies:
        print(f"\n{'='*80}")
        print(f"LOAD TEST: {n} concurrent requests (max_batch={engine.max_batch})")
        print(f"{'='*80}")

        reqs = [engine.submit(text) for _ in range(n)]
        for r in reqs:
            r.event.wait(timeout=120)

        ok = [r for r in reqs if not r.error]
        errs = len(reqs) - len(ok)

        if ok:
            ttfbs = sorted([r.ttfb_ms for r in ok])
            durs = [r.audio_dur_s for r in ok]
            total_audio = sum(durs)
            wall = max(r.gen_time_s for r in ok)
            print(f"  OK: {len(ok)}/{n}  Errors: {errs}")
            print(f"  Wall time:   {wall:.2f}s")
            print(f"  Total audio: {total_audio:.1f}s")
            print(f"  Throughput:  {total_audio/wall:.1f} audio-s/wall-s")
            print(f"  TTFB min:    {ttfbs[0]:.0f}ms")
            print(f"  TTFB avg:    {sum(ttfbs)/len(ttfbs):.0f}ms")
            print(f"  TTFB p50:    {ttfbs[len(ttfbs)//2]:.0f}ms")
            print(f"  TTFB p95:    {ttfbs[min(int(len(ttfbs)*0.95), len(ttfbs)-1)]:.0f}ms")
            print(f"  TTFB max:    {ttfbs[-1]:.0f}ms")
            avg_dur = sum(durs) / len(durs)
            rtf = wall / avg_dur if avg_dur > 0 else 0
            print(f"  RTF/user:    {rtf:.3f}x")
            can_stream = "YES" if rtf < 1.0 else "NO"
            print(f"  Streamable:  {can_stream}")
        else:
            print(f"  All {errs} failed!")
    print(f"\n{'='*80}")


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--max_batch", type=int, default=16)
    p.add_argument("--ddpm_steps", type=int, default=10)
    p.add_argument("--cfg_scale", type=float, default=1.3)
    p.add_argument("--port", type=int, default=8000)
    p.add_argument("--test", action="store_true")
    p.add_argument("--no_compile", action="store_true")
    args = p.parse_args()

    engine = TTSEngine(
        model_path="microsoft/VibeVoice-1.5B",
        voice_path="demo/voices/modi.wav",
        max_batch=args.max_batch,
        ddpm_steps=args.ddpm_steps,
        cfg_scale=args.cfg_scale,
        compile_model=not args.no_compile,
    )
    engine.start()

    if args.test:
        text = "Speaker 1: मेरे प्यारे देशवासियों, आज मैं आपके साथ कुछ बहुत ज़रूरी बातें करना चाहता हूँ."

        # Warmup the worker loop
        print("\nWarmup...")
        r = engine.submit("Speaker 1: warmup test.")
        r.event.wait(timeout=30)

        load_test(engine, [1, 4, 8, 16, 32, 48, 64], text)
        engine.stop()
    else:
        from fastapi import FastAPI
        from fastapi.responses import Response
        import uvicorn

        app = FastAPI(title="Modi TTS")

        @app.post("/tts")
        async def tts(text: str = "नमस्ते देशवासियों."):
            req = engine.submit(text)
            await asyncio.get_event_loop().run_in_executor(None, req.event.wait, 60)
            if req.error:
                return {"error": req.error}
            return Response(
                content=req.audio, media_type="audio/wav",
                headers={"X-TTFB-Ms": f"{req.ttfb_ms:.0f}",
                         "X-Audio-Duration": f"{req.audio_dur_s:.2f}"})

        @app.get("/health")
        async def health():
            return {"queue": engine.queue.qsize(), "max_batch": engine.max_batch}

        uvicorn.run(app, host="0.0.0.0", port=args.port)


if __name__ == "__main__":
    main()
