"""
LuxTTS Batched Inference Engine

Supports:
- Dynamic batching: collects requests and processes them in batches
- Async request queue with configurable max batch size and max wait time
- Chunked text streaming: splits long text at sentence boundaries
- FastAPI server for HTTP testing
"""

import torch
import torch.nn as nn
import time
import asyncio
import numpy as np
import threading
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Tuple
from concurrent.futures import Future
import queue
import re

# Patch swoosh before importing LuxTTS
import zipvoice.models.modules.scaling as scaling

class FastSwooshL(nn.Module):
    def forward(self, x):
        zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
        return torch.logaddexp(zero, x - 4.0) - 0.08 * x - 0.035

class FastSwooshR(nn.Module):
    def forward(self, x):
        zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
        return torch.logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687

scaling.SwooshL = FastSwooshL
scaling.SwooshR = FastSwooshR

from zipvoice.luxvoice import LuxTTS
from zipvoice.modeling_utils import generate as orig_generate


# ─── Batched Generate ───────────────────────────────────────────────

@torch.inference_mode()
def batched_generate(
    texts: List[str],
    model,
    vocoder,
    tokenizer,
    prompt_tokens,        # list of list of ints (same prompt for all)
    prompt_features,      # (1, T_prompt, F)
    prompt_features_lens, # (1,)
    prompt_rms: float,
    num_steps: int = 4,
    guidance_scale: float = 3.0,
    speed: float = 1.0,
    t_shift: float = 0.5,
    target_rms: float = 0.1,
) -> List[np.ndarray]:
    """
    Generate speech for multiple texts in a single batched forward pass.
    Returns list of numpy arrays (one per text).
    """
    batch_size = len(texts)
    device = next(model.parameters()).device
    speed_adjusted = speed * 1.3

    # Tokenize all texts
    all_tokens = tokenizer.texts_to_token_ids(texts)

    # Replicate prompt for batch
    batch_prompt_tokens = prompt_tokens * batch_size
    batch_prompt_features = prompt_features.expand(batch_size, -1, -1)
    batch_prompt_features_lens = prompt_features_lens.expand(batch_size)

    # Run model.sample with full batch
    (pred_features, pred_lens, _, _) = model.sample(
        tokens=all_tokens,
        prompt_tokens=batch_prompt_tokens,
        prompt_features=batch_prompt_features,
        prompt_features_lens=batch_prompt_features_lens,
        speed=speed_adjusted,
        t_shift=t_shift,
        duration='predict',
        num_step=num_steps,
        guidance_scale=guidance_scale,
    )

    # Vocoder: decode batch — output is (B, samples)
    pred_features_perm = pred_features.permute(0, 2, 1) / 0.1
    wav_batch = vocoder.decode(pred_features_perm).clamp(-1, 1)

    # Volume matching
    if prompt_rms < target_rms:
        wav_batch = wav_batch * (prompt_rms / target_rms)

    # Split into individual results, trimming padding
    # Vocoder upsamples: 24kHz mel (hop=256) → 48kHz audio = 512 samples per mel frame
    results = []
    for i in range(batch_size):
        num_mel_frames = int(pred_lens[i].item())
        num_audio_samples = min(num_mel_frames * 512, wav_batch.shape[1])
        wav_i = wav_batch[i, :num_audio_samples].cpu().numpy()
        results.append(wav_i)

    return results


# ─── Request Queue & Batch Scheduler ────────────────────────────────

@dataclass
class TTSRequest:
    text: str
    future: Future = field(default_factory=Future)
    timestamp: float = field(default_factory=time.time)


class BatchedTTSEngine:
    """
    Batched TTS engine with dynamic batching.

    Collects requests and processes them in batches for maximum GPU utilization.
    """

    def __init__(
        self,
        max_batch_size: int = 32,
        max_wait_ms: float = 50.0,
        num_steps: int = 4,
        guidance_scale: float = 3.0,
        speed: float = 1.0,
        t_shift: float = 0.5,
        device: str = 'cuda',
    ):
        self.max_batch_size = max_batch_size
        self.max_wait_s = max_wait_ms / 1000.0
        self.num_steps = num_steps
        self.guidance_scale = guidance_scale
        self.speed = speed
        self.t_shift = t_shift

        # Request queue
        self._queue: queue.Queue[TTSRequest] = queue.Queue()
        self._running = False

        # Stats
        self.total_requests = 0
        self.total_batches = 0
        self.total_audio_seconds = 0.0
        self.total_gpu_seconds = 0.0

        # Load model
        print("Loading LuxTTS model...")
        t0 = time.time()
        self._tts = LuxTTS(device=device)
        self._model = self._tts.model
        self._vocoder = self._tts.vocos
        self._tokenizer = self._tts.tokenizer
        self._device = device
        print(f"Model loaded in {time.time()-t0:.1f}s")

        mem = torch.cuda.memory_allocated() / 1024**3
        print(f"GPU memory: {mem:.2f} GB")

    def encode_prompt(self, audio_path: str, duration: float = 5) -> dict:
        """Encode a reference audio prompt."""
        encoded = self._tts.encode_prompt(audio_path, duration=duration)
        self._prompt_tokens = encoded['prompt_tokens']
        self._prompt_features = encoded['prompt_features']
        self._prompt_features_lens = encoded['prompt_features_lens']
        self._prompt_rms = encoded['prompt_rms']
        print("Prompt encoded.")
        return encoded

    def submit(self, text: str) -> Future:
        """Submit a TTS request. Returns a Future with the audio numpy array."""
        req = TTSRequest(text=text)
        self._queue.put(req)
        return req.future

    def start(self):
        """Start the batch processing thread."""
        self._running = True
        self._thread = threading.Thread(target=self._batch_loop, daemon=True)
        self._thread.start()
        print(f"Batch engine started (max_batch={self.max_batch_size}, max_wait={self.max_wait_s*1000:.0f}ms)")

    def stop(self):
        """Stop the batch processing thread."""
        self._running = False
        self._thread.join(timeout=5)

    def _batch_loop(self):
        """Main loop: collect requests into batches and process."""
        while self._running:
            batch: List[TTSRequest] = []

            # Wait for first request (blocking)
            try:
                first = self._queue.get(timeout=0.1)
                batch.append(first)
            except queue.Empty:
                continue

            # Collect more requests up to max_batch or max_wait
            deadline = time.time() + self.max_wait_s
            while len(batch) < self.max_batch_size:
                remaining = deadline - time.time()
                if remaining <= 0:
                    break
                try:
                    req = self._queue.get(timeout=remaining)
                    batch.append(req)
                except queue.Empty:
                    break

            # Process batch
            self._process_batch(batch)

    def _process_batch(self, batch: List[TTSRequest]):
        """Process a batch of requests."""
        texts = [r.text for r in batch]
        bs = len(texts)

        try:
            torch.cuda.synchronize()
            t0 = time.time()

            results = batched_generate(
                texts=texts,
                model=self._model,
                vocoder=self._vocoder,
                tokenizer=self._tokenizer,
                prompt_tokens=self._prompt_tokens,
                prompt_features=self._prompt_features,
                prompt_features_lens=self._prompt_features_lens,
                prompt_rms=self._prompt_rms,
                num_steps=self.num_steps,
                guidance_scale=self.guidance_scale,
                speed=self.speed,
                t_shift=self.t_shift,
            )

            torch.cuda.synchronize()
            elapsed = time.time() - t0

            # Stats
            self.total_batches += 1
            self.total_requests += bs
            total_audio = sum(len(r) / 48000 for r in results)
            self.total_audio_seconds += total_audio
            self.total_gpu_seconds += elapsed

            # Resolve futures
            for req, result in zip(batch, results):
                req.future.set_result(result)

        except Exception as e:
            for req in batch:
                req.future.set_exception(e)

    def stats(self) -> dict:
        return {
            "total_requests": self.total_requests,
            "total_batches": self.total_batches,
            "avg_batch_size": self.total_requests / max(1, self.total_batches),
            "total_audio_s": self.total_audio_seconds,
            "total_gpu_s": self.total_gpu_seconds,
            "avg_rtf": self.total_gpu_seconds / max(0.001, self.total_audio_seconds),
            "effective_speed": self.total_audio_seconds / max(0.001, self.total_gpu_seconds),
            "throughput_req_per_s": self.total_requests / max(0.001, self.total_gpu_seconds),
            "peak_vram_gb": torch.cuda.max_memory_allocated() / 1024**3,
        }


# ─── Sentence Chunker (for streaming-like behavior) ────────────────

def split_sentences(text: str) -> List[str]:
    """Split text at sentence boundaries."""
    parts = re.split(r'(?<=[.!?])\s+', text.strip())
    return [p for p in parts if p.strip()]


# ─── CLI Benchmark ──────────────────────────────────────────────────

def run_benchmark():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--ref-audio", default="/home/ubuntu/LuxTTS/ref_audio.wav")
    parser.add_argument("--max-batch", type=int, default=32)
    parser.add_argument("--max-wait-ms", type=float, default=50)
    parser.add_argument("--num-steps", type=int, default=4)
    parser.add_argument("--concurrency", type=int, nargs="+", default=[1, 8, 16, 32, 64, 128, 256, 500])
    args = parser.parse_args()

    engine = BatchedTTSEngine(
        max_batch_size=args.max_batch,
        max_wait_ms=args.max_wait_ms,
        num_steps=args.num_steps,
    )
    engine.encode_prompt(args.ref_audio)
    engine.start()

    TEST_TEXTS = [
        "Hello, this is a test of the batched inference system.",
        "The quick brown fox jumps over the lazy dog near the riverbank.",
        "Welcome to the annual technology conference on artificial intelligence.",
        "Machine learning has transformed how we approach complex problems.",
        "Natural language processing enables computers to understand human speech.",
        "Deep learning models continue to improve in both speed and accuracy.",
        "The future of computing lies in efficient parallel processing architectures.",
        "Voice synthesis technology has made remarkable progress in recent years.",
        "Today we demonstrate the power of batched inference for text to speech.",
        "High concurrency serving requires careful optimization of GPU resources.",
    ]

    print("\n" + "=" * 70)
    print("BATCHED INFERENCE BENCHMARK")
    print("=" * 70)
    print(f"Config: max_batch={args.max_batch}, max_wait={args.max_wait_ms}ms, steps={args.num_steps}")

    # Warmup
    print("\nWarmup...")
    f = engine.submit("Warmup request.")
    f.result(timeout=30)
    time.sleep(0.5)

    for conc in args.concurrency:
        # Reset stats
        engine.total_requests = 0
        engine.total_batches = 0
        engine.total_audio_seconds = 0.0
        engine.total_gpu_seconds = 0.0
        torch.cuda.reset_peak_memory_stats()

        print(f"\n--- Concurrency: {conc} ---")
        submit_times = []
        futures = []
        t0 = time.time()

        # Submit all requests at once
        for i in range(conc):
            text = TEST_TEXTS[i % len(TEST_TEXTS)]
            submit_times.append(time.time())
            futures.append(engine.submit(text))

        # Wait for all to complete — track per-request TTFB and completion
        ttfbs = []
        latencies = []
        audio_durations = []
        for i, f in enumerate(futures):
            result = f.result(timeout=120)
            done_time = time.time()
            ttfb = done_time - submit_times[i]  # time from submit to result ready
            ttfbs.append(ttfb)
            latencies.append(done_time - t0)
            audio_durations.append(len(result) / 48000)

        total_wall = time.time() - t0
        stats = engine.stats()

        # TTFB stats (per-request: time from that request's submission to its completion)
        ttfb_p50 = np.percentile(ttfbs, 50) * 1000
        ttfb_p95 = np.percentile(ttfbs, 95) * 1000
        ttfb_p99 = np.percentile(ttfbs, 99) * 1000
        ttfb_min = min(ttfbs) * 1000
        ttfb_max = max(ttfbs) * 1000
        ttfb_avg = np.mean(ttfbs) * 1000

        # Audio duration stats
        avg_audio = np.mean(audio_durations)

        print(f"  Wall time:     {total_wall:.3f}s")
        print(f"  Batches:       {stats['total_batches']}")
        print(f"  Avg batch sz:  {stats['avg_batch_size']:.1f}")
        print(f"  Throughput:    {conc/total_wall:.1f} req/s")
        print(f"  GPU time:      {stats['total_gpu_s']:.3f}s")
        print(f"  Audio gen:     {stats['total_audio_s']:.1f}s total (avg {avg_audio:.2f}s/req)")
        print(f"  Effective RTF: {stats['avg_rtf']:.5f} ({stats['effective_speed']:.0f}x realtime)")
        print(f"  Peak VRAM:     {stats['peak_vram_gb']:.2f} GB")
        print(f"  --- TTFB (time to first byte / full audio) ---")
        print(f"  TTFB min:      {ttfb_min:.0f}ms")
        print(f"  TTFB p50:      {ttfb_p50:.0f}ms")
        print(f"  TTFB p95:      {ttfb_p95:.0f}ms")
        print(f"  TTFB p99:      {ttfb_p99:.0f}ms")
        print(f"  TTFB max:      {ttfb_max:.0f}ms")
        print(f"  TTFB avg:      {ttfb_avg:.0f}ms")

        # Small gap between tests
        time.sleep(0.3)

    engine.stop()
    print("\nDone.")


if __name__ == "__main__":
    run_benchmark()
