#!/usr/bin/env python3
"""
Irodori-TTS Concurrency Deep-Dive Benchmark

The problem: diffusion models can't be served like autoregressive models (vLLM-style)
because:
  1. Each request needs N sequential forward passes (Euler steps) — no KV-cache pipelining
  2. The runtime holds a global lock — only 1 request at a time
  3. No continuous batching — can't merge/split in-flight requests
  4. Fixed compute cost regardless of output length

This script benchmarks 4 strategies to maximize throughput:
  A. Baseline: single runtime, serial requests (current behavior)
  B. Request batching: pack multiple texts into one batched call
  C. Multi-instance: N model copies on one GPU, true parallel via processes
  D. Multi-instance + batching: combine B and C

Then reports throughput, latency, VRAM, and RTF for each.
"""
from __future__ import annotations

import gc
import math
import multiprocessing as mp
import os
import queue
import sys
import time
import traceback
from dataclasses import dataclass, field
from pathlib import Path

import torch

OUT = Path("bench_outputs")
OUT.mkdir(exist_ok=True)

HF_REPO = "Aratako/Irodori-TTS-500M"


# ──────────────────────────────────────────────────────────────────────
# Shared test data
# ──────────────────────────────────────────────────────────────────────
TEST_TEXTS = [
    "今日はいい天気ですね。散歩に行きましょう。",
    "新しいプロジェクトが始まりました。頑張りましょう。",
    "電車が遅れているようです。別のルートを探しましょう。",
    "この映画はとても感動的でした。また見たいです。",
    "来週の会議の準備はできていますか？",
    "日本の四季はとても美しいです。特に桜の季節が好きです。",
    "人工知能の発展は目覚ましいものがあります。",
    "音楽を聴くとリラックスできます。クラシックが好きです。",
    "東京タワーは日本で最も有名な観光地の一つです。",
    "おはようございます。今日も一日頑張りましょう。",
    "夕焼けがとても綺麗ですね。写真を撮りましょう。",
    "明日の天気予報では、関東地方は晴れのち曇りでしょう。",
    "新しいレストランがオープンしました。行ってみませんか？",
    "この本はとても面白かったです。おすすめです。",
    "週末は家族と過ごす予定です。楽しみにしています。",
    "技術の進歩により、私たちの生活は大きく変わりました。",
]


@dataclass
class BenchResult:
    strategy: str
    num_requests: int
    total_wall_sec: float
    total_audio_sec: float
    per_request_latencies: list[float]
    vram_peak_gb: float
    errors: int = 0

    @property
    def throughput_rps(self) -> float:
        return self.num_requests / self.total_wall_sec if self.total_wall_sec > 0 else 0

    @property
    def throughput_rtx(self) -> float:
        return self.total_audio_sec / self.total_wall_sec if self.total_wall_sec > 0 else 0

    @property
    def avg_latency(self) -> float:
        return sum(self.per_request_latencies) / len(self.per_request_latencies) if self.per_request_latencies else 0

    @property
    def p50_latency(self) -> float:
        s = sorted(self.per_request_latencies)
        return s[len(s) // 2] if s else 0

    @property
    def p99_latency(self) -> float:
        s = sorted(self.per_request_latencies)
        idx = min(len(s) - 1, int(len(s) * 0.99))
        return s[idx] if s else 0


def banner(title: str) -> None:
    print(f"\n{'='*70}\n  {title}\n{'='*70}", flush=True)


def print_bench(r: BenchResult) -> None:
    print(f"  Strategy:     {r.strategy}")
    print(f"  Requests:     {r.num_requests} ({r.errors} errors)")
    print(f"  Wall time:    {r.total_wall_sec:.2f}s")
    print(f"  Audio total:  {r.total_audio_sec:.1f}s")
    print(f"  Throughput:   {r.throughput_rps:.2f} req/s | {r.throughput_rtx:.1f}x realtime")
    print(f"  Latency:      avg={r.avg_latency*1000:.0f}ms p50={r.p50_latency*1000:.0f}ms p99={r.p99_latency*1000:.0f}ms")
    print(f"  VRAM peak:    {r.vram_peak_gb:.2f} GB")
    print(flush=True)


# ──────────────────────────────────────────────────────────────────────
# Helper: build runtime
# ──────────────────────────────────────────────────────────────────────
def _download_ckpt() -> str:
    from huggingface_hub import hf_hub_download
    return hf_hub_download(repo_id=HF_REPO, filename="model.safetensors")


def _build_runtime(ckpt: str, device: str = "cuda", precision: str = "bf16"):
    from irodori_tts.inference_runtime import InferenceRuntime, RuntimeKey
    return InferenceRuntime.from_key(RuntimeKey(
        checkpoint=ckpt,
        model_device=device,
        codec_repo="facebook/dacvae-watermarked",
        model_precision=precision,
        codec_device=device,
        codec_precision=precision,
    ))


# ══════════════════════════════════════════════════════════════════════
# STRATEGY A: Serial baseline
# ══════════════════════════════════════════════════════════════════════
def bench_serial(ckpt: str, texts: list[str], num_steps: int = 30) -> BenchResult:
    banner("STRATEGY A: Serial (single runtime, one-at-a-time)")
    from irodori_tts.inference_runtime import SamplingRequest

    rt = _build_runtime(ckpt)
    torch.cuda.reset_peak_memory_stats()

    latencies = []
    total_audio = 0.0
    errors = 0

    wall_start = time.time()
    for i, text in enumerate(texts):
        t0 = time.time()
        try:
            result = rt.synthesize(SamplingRequest(
                text=text, no_ref=True, num_steps=num_steps, seed=i,
            ))
            audio_sec = result.audio.shape[-1] / result.sample_rate
            total_audio += audio_sec
        except Exception as e:
            print(f"  ERROR [{i}]: {e}", flush=True)
            errors += 1
        latencies.append(time.time() - t0)
    wall_total = time.time() - wall_start
    peak = torch.cuda.max_memory_allocated() / 1e9

    del rt
    gc.collect()
    torch.cuda.empty_cache()

    r = BenchResult("A: Serial", len(texts), wall_total, total_audio, latencies, peak, errors)
    print_bench(r)
    return r


# ══════════════════════════════════════════════════════════════════════
# STRATEGY B: Request batching (pack N texts into 1 model call)
# ══════════════════════════════════════════════════════════════════════
def bench_batched(ckpt: str, texts: list[str], batch_sizes: list[int] = [1, 2, 4, 8, 16],
                  num_steps: int = 30) -> list[BenchResult]:
    banner("STRATEGY B: Request batching (N texts → 1 batched inference)")
    from irodori_tts.inference_runtime import (
        InferenceRuntime, RuntimeKey, SamplingRequest,
        find_flattening_point, save_wav,
    )
    from irodori_tts.codec import patchify_latent, unpatchify_latent
    from irodori_tts.rf import sample_euler_rf_cfg
    from irodori_tts.text_normalization import normalize_text

    rt = _build_runtime(ckpt)
    results = []

    for bs in batch_sizes:
        batch_texts = texts[:bs]
        torch.cuda.reset_peak_memory_stats()

        t0 = time.time()
        try:
            normalized = [normalize_text(t).strip() for t in batch_texts]
            text_ids, text_mask = rt.tokenizer.batch_encode(
                normalized, max_length=rt.default_text_max_len,
            )
            text_ids = text_ids.to(rt.model_device)
            text_mask = text_mask.to(rt.model_device)

            ref_len = max(1, rt.model_cfg.speaker_patch_size)
            ref_latent = torch.zeros(
                (bs, ref_len, rt.model_cfg.latent_dim * rt.model_cfg.latent_patch_size),
                device=rt.model_device, dtype=next(rt.model.parameters()).dtype,
            )
            ref_mask = torch.zeros((bs, ref_len), dtype=torch.bool, device=rt.model_device)

            target_samples = int(30.0 * rt.codec.sample_rate)
            latent_steps = math.ceil(target_samples / int(rt.codec.model.hop_length))
            patched_steps = math.ceil(latent_steps / rt.model_cfg.latent_patch_size)

            with torch.inference_mode():
                z_patched = sample_euler_rf_cfg(
                    model=rt.model,
                    text_input_ids=text_ids,
                    text_mask=text_mask,
                    ref_latent=ref_latent,
                    ref_mask=ref_mask,
                    sequence_length=patched_steps,
                    num_steps=num_steps,
                    cfg_scale_text=3.0,
                    cfg_scale_speaker=0.0,
                    seed=42,
                )
                z = unpatchify_latent(
                    z_patched, rt.model_cfg.latent_patch_size, rt.model_cfg.latent_dim,
                )[:, :latent_steps]

                audios = []
                for i in range(bs):
                    audio_i = rt.codec.decode_latent(z[i:i+1]).cpu()[0]
                    fp = find_flattening_point(z[i])
                    trim = min(target_samples, int(fp * int(rt.codec.model.hop_length)))
                    if trim > 0:
                        audio_i = audio_i[:, :trim]
                    audios.append(audio_i)

            wall = time.time() - t0
            total_audio = sum(a.shape[-1] / rt.codec.sample_rate for a in audios)
            latencies = [wall / bs] * bs
            peak = torch.cuda.max_memory_allocated() / 1e9

            r = BenchResult(f"B: Batch={bs}", bs, wall, total_audio, latencies, peak)
        except Exception as e:
            wall = time.time() - t0
            peak = torch.cuda.max_memory_allocated() / 1e9
            print(f"  ERROR batch_size={bs}: {e}", flush=True)
            traceback.print_exc()
            r = BenchResult(f"B: Batch={bs}", bs, wall, 0, [wall], peak, errors=bs)

        print_bench(r)
        results.append(r)

    del rt
    gc.collect()
    torch.cuda.empty_cache()
    return results


# ══════════════════════════════════════════════════════════════════════
# STRATEGY C: Multi-instance (N model copies, true parallelism)
# ══════════════════════════════════════════════════════════════════════
def _worker_process(worker_id: int, ckpt: str, task_queue: mp.Queue,
                    result_queue: mp.Queue, ready_event: mp.Event,
                    num_steps: int) -> None:
    """Each worker loads its own model copy and processes tasks from queue."""
    try:
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"
        import torch as _torch
        from irodori_tts.inference_runtime import InferenceRuntime, RuntimeKey, SamplingRequest

        rt = InferenceRuntime.from_key(RuntimeKey(
            checkpoint=ckpt,
            model_device="cuda",
            codec_repo="facebook/dacvae-watermarked",
            model_precision="bf16",
            codec_device="cuda",
            codec_precision="bf16",
        ))
        ready_event.set()

        while True:
            try:
                task = task_queue.get(timeout=5)
            except Exception:
                break
            if task is None:
                break

            task_id, text, seed = task
            t0 = time.time()
            try:
                result = rt.synthesize(SamplingRequest(
                    text=text, no_ref=True, num_steps=num_steps, seed=seed,
                ))
                audio_sec = result.audio.shape[-1] / result.sample_rate
                result_queue.put((task_id, True, audio_sec, time.time() - t0, None))
            except Exception as e:
                result_queue.put((task_id, False, 0, time.time() - t0, str(e)))

        vram = _torch.cuda.max_memory_allocated() / 1e9
        result_queue.put(("__vram__", worker_id, vram))
    except Exception as e:
        ready_event.set()
        result_queue.put(("__error__", worker_id, str(e)))


def bench_multi_instance(ckpt: str, texts: list[str],
                         num_instances_list: list[int] = [1, 2, 3, 4],
                         num_steps: int = 30) -> list[BenchResult]:
    banner("STRATEGY C: Multi-instance (N model copies, process-parallel)")
    results = []

    for n_inst in num_instances_list:
        print(f"\n  --- {n_inst} instance(s) ---", flush=True)
        ctx = mp.get_context("spawn")
        task_q: mp.Queue = ctx.Queue()
        result_q: mp.Queue = ctx.Queue()
        ready_events = [ctx.Event() for _ in range(n_inst)]

        workers = []
        for i in range(n_inst):
            p = ctx.Process(
                target=_worker_process,
                args=(i, ckpt, task_q, result_q, ready_events[i], num_steps),
            )
            p.start()
            workers.append(p)

        print(f"  Waiting for {n_inst} workers to load model...", flush=True)
        for ev in ready_events:
            ev.wait(timeout=120)
        print(f"  All {n_inst} workers ready.", flush=True)

        for i, text in enumerate(texts):
            task_q.put((i, text, i * 111))

        wall_start = time.time()
        latencies = []
        total_audio = 0.0
        errors = 0
        collected = 0

        while collected < len(texts):
            try:
                item = result_q.get(timeout=120)
                if item[0] == "__vram__" or item[0] == "__error__":
                    continue
                task_id, ok, audio_sec, lat, err = item
                latencies.append(lat)
                if ok:
                    total_audio += audio_sec
                else:
                    errors += 1
                    print(f"    task {task_id} error: {err}", flush=True)
                collected += 1
            except Exception:
                break

        wall_total = time.time() - wall_start

        for _ in range(n_inst):
            task_q.put(None)

        vram_reports = []
        deadline = time.time() + 30
        while time.time() < deadline:
            try:
                item = result_q.get(timeout=2)
                if item[0] == "__vram__":
                    vram_reports.append(item[2])
            except Exception:
                break
            if len(vram_reports) >= n_inst:
                break

        for p in workers:
            p.join(timeout=30)
            if p.is_alive():
                p.terminate()

        peak_vram = max(vram_reports) * n_inst if vram_reports else 0

        r = BenchResult(
            f"C: {n_inst} instances", len(texts), wall_total,
            total_audio, latencies, peak_vram, errors,
        )
        print_bench(r)
        results.append(r)

    return results


# ══════════════════════════════════════════════════════════════════════
# STRATEGY D: Multi-instance + batching
# ══════════════════════════════════════════════════════════════════════
def _worker_batched_process(worker_id: int, ckpt: str,
                            task_queue: mp.Queue, result_queue: mp.Queue,
                            ready_event: mp.Event, batch_size: int,
                            num_steps: int) -> None:
    try:
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"
        import torch as _torch
        import math as _math
        from irodori_tts.inference_runtime import (
            InferenceRuntime, RuntimeKey, find_flattening_point,
        )
        from irodori_tts.codec import unpatchify_latent
        from irodori_tts.rf import sample_euler_rf_cfg
        from irodori_tts.text_normalization import normalize_text

        rt = InferenceRuntime.from_key(RuntimeKey(
            checkpoint=ckpt,
            model_device="cuda",
            codec_repo="facebook/dacvae-watermarked",
            model_precision="bf16",
            codec_device="cuda",
            codec_precision="bf16",
        ))
        ready_event.set()

        while True:
            batch = []
            try:
                item = task_queue.get(timeout=5)
                if item is None:
                    break
                batch.append(item)
            except Exception:
                break

            while len(batch) < batch_size:
                try:
                    item = task_queue.get_nowait()
                    if item is None:
                        task_queue.put(None)
                        break
                    batch.append(item)
                except Exception:
                    break

            if not batch:
                continue

            bs = len(batch)
            t0 = time.time()
            try:
                normalized = [normalize_text(b[1]).strip() for b in batch]
                with _torch.inference_mode():
                    text_ids, text_mask = rt.tokenizer.batch_encode(
                        normalized, max_length=rt.default_text_max_len,
                    )
                    text_ids = text_ids.to(rt.model_device)
                    text_mask = text_mask.to(rt.model_device)

                    dtype = next(rt.model.parameters()).dtype
                    ref_len = max(1, rt.model_cfg.speaker_patch_size)
                    ref_latent = _torch.zeros(
                        (bs, ref_len, rt.model_cfg.latent_dim * rt.model_cfg.latent_patch_size),
                        device=rt.model_device, dtype=dtype,
                    )
                    ref_mask = _torch.zeros((bs, ref_len), dtype=_torch.bool, device=rt.model_device)

                    target_samples = int(30.0 * rt.codec.sample_rate)
                    latent_steps = _math.ceil(target_samples / int(rt.codec.model.hop_length))
                    patched_steps = _math.ceil(latent_steps / rt.model_cfg.latent_patch_size)

                    z_patched = sample_euler_rf_cfg(
                        model=rt.model,
                        text_input_ids=text_ids,
                        text_mask=text_mask,
                        ref_latent=ref_latent,
                        ref_mask=ref_mask,
                        sequence_length=patched_steps,
                        num_steps=num_steps,
                        cfg_scale_text=3.0,
                        cfg_scale_speaker=0.0,
                        seed=42,
                    )
                    z = unpatchify_latent(
                        z_patched, rt.model_cfg.latent_patch_size, rt.model_cfg.latent_dim,
                    )[:, :latent_steps]

                    for i in range(bs):
                        audio_i = rt.codec.decode_latent(z[i:i+1]).cpu()[0]
                        fp = find_flattening_point(z[i])
                        trim = min(target_samples, int(fp * int(rt.codec.model.hop_length)))
                        if trim > 0:
                            audio_i = audio_i[:, :trim]
                        audio_sec = audio_i.shape[-1] / rt.codec.sample_rate
                        lat = time.time() - t0
                        result_queue.put((batch[i][0], True, audio_sec, lat, None))

            except Exception as e:
                lat = time.time() - t0
                for b in batch:
                    result_queue.put((b[0], False, 0, lat, str(e)))

        vram = _torch.cuda.max_memory_allocated() / 1e9
        result_queue.put(("__vram__", worker_id, vram))
    except Exception as e:
        ready_event.set()
        result_queue.put(("__error__", worker_id, str(e)))


def bench_multi_instance_batched(ckpt: str, texts: list[str],
                                 configs: list[tuple[int, int]] = [(2, 4), (2, 8), (4, 4)],
                                 num_steps: int = 30) -> list[BenchResult]:
    banner("STRATEGY D: Multi-instance + batching")
    results = []

    for n_inst, bs in configs:
        print(f"\n  --- {n_inst} instances x batch_size={bs} ---", flush=True)
        ctx = mp.get_context("spawn")
        task_q: mp.Queue = ctx.Queue()
        result_q: mp.Queue = ctx.Queue()
        ready_events = [ctx.Event() for _ in range(n_inst)]

        workers = []
        for i in range(n_inst):
            p = ctx.Process(
                target=_worker_batched_process,
                args=(i, ckpt, task_q, result_q, ready_events[i], bs, num_steps),
            )
            p.start()
            workers.append(p)

        for ev in ready_events:
            ev.wait(timeout=120)
        print(f"  All {n_inst} workers ready.", flush=True)

        for i, text in enumerate(texts):
            task_q.put((i, text, i * 111))

        wall_start = time.time()
        latencies = []
        total_audio = 0.0
        errors = 0
        collected = 0

        while collected < len(texts):
            try:
                item = result_q.get(timeout=120)
                if item[0] in ("__vram__", "__error__"):
                    continue
                task_id, ok, audio_sec, lat, err = item
                latencies.append(lat)
                if ok:
                    total_audio += audio_sec
                else:
                    errors += 1
                collected += 1
            except Exception:
                break

        wall_total = time.time() - wall_start

        for _ in range(n_inst):
            task_q.put(None)

        vram_reports = []
        deadline = time.time() + 30
        while time.time() < deadline:
            try:
                item = result_q.get(timeout=2)
                if item[0] == "__vram__":
                    vram_reports.append(item[2])
            except Exception:
                break
            if len(vram_reports) >= n_inst:
                break

        for p in workers:
            p.join(timeout=30)
            if p.is_alive():
                p.terminate()

        peak_vram = max(vram_reports) * n_inst if vram_reports else 0

        r = BenchResult(
            f"D: {n_inst}x inst, bs={bs}", len(texts), wall_total,
            total_audio, latencies, peak_vram, errors,
        )
        print_bench(r)
        results.append(r)

    return results


# ══════════════════════════════════════════════════════════════════════
# MAIN
# ══════════════════════════════════════════════════════════════════════
def main() -> None:
    print(f"GPU: {torch.cuda.get_device_name(0)}", flush=True)
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB", flush=True)
    print(f"Texts: {len(TEST_TEXTS)} requests", flush=True)

    ckpt = _download_ckpt()
    all_results: list[BenchResult] = []

    r_serial = bench_serial(ckpt, TEST_TEXTS, num_steps=30)
    all_results.append(r_serial)

    r_batched = bench_batched(ckpt, TEST_TEXTS,
                              batch_sizes=[1, 2, 4, 8, 16],
                              num_steps=30)
    all_results.extend(r_batched)

    r_multi = bench_multi_instance(ckpt, TEST_TEXTS,
                                   num_instances_list=[1, 2, 4],
                                   num_steps=30)
    all_results.extend(r_multi)

    r_combo = bench_multi_instance_batched(ckpt, TEST_TEXTS,
                                           configs=[(2, 4), (2, 8), (4, 4)],
                                           num_steps=30)
    all_results.extend(r_combo)

    # ── Summary table ────────────────────────────────────────────────
    banner("SUMMARY: All strategies compared")
    print(f"  {'Strategy':<30} {'Reqs':>5} {'Wall(s)':>8} {'Audio(s)':>9} "
          f"{'RPS':>6} {'RTX':>6} {'AvgLat':>8} {'P99Lat':>8} {'VRAM(GB)':>9} {'Err':>4}")
    print("  " + "-" * 110)
    for r in all_results:
        print(f"  {r.strategy:<30} {r.num_requests:>5} {r.total_wall_sec:>8.2f} "
              f"{r.total_audio_sec:>9.1f} {r.throughput_rps:>6.2f} "
              f"{r.throughput_rtx:>6.1f}x {r.avg_latency*1000:>7.0f}ms "
              f"{r.p99_latency*1000:>7.0f}ms {r.vram_peak_gb:>9.2f} {r.errors:>4}")

    serial_rps = all_results[0].throughput_rps
    print(f"\n  Baseline serial: {serial_rps:.2f} req/s")
    for r in all_results[1:]:
        speedup = r.throughput_rps / serial_rps if serial_rps > 0 else 0
        print(f"  {r.strategy:<30} -> {speedup:.1f}x speedup vs serial")


if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)
    main()
