#!/usr/bin/env python3
"""
Comprehensive Irodori-TTS-500M test suite.
Tests: basic inference, emoji style control, voice cloning,
       CFG modes, sampling params, multi-candidate, concurrency, torch.compile.
"""
from __future__ import annotations

import gc
import os
import secrets
import sys
import threading
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path

import torch

from irodori_tts.inference_runtime import (
    InferenceRuntime,
    RuntimeKey,
    SamplingRequest,
    SamplingResult,
    save_wav,
)

OUT = Path("test_outputs")
OUT.mkdir(exist_ok=True)

HF_REPO = "Aratako/Irodori-TTS-500M"
DEVICE = "cuda"
PRECISION = "bf16"


def banner(title: str) -> None:
    sep = "=" * 70
    print(f"\n{sep}\n  {title}\n{sep}", flush=True)


def print_result(tag: str, result: SamplingResult, outpath: str) -> None:
    audio_sec = result.audio.shape[-1] / result.sample_rate
    timings = {k: v for k, v in result.stage_timings}
    rf_ms = timings.get("sample_rf", 0) * 1000
    dec_ms = timings.get("decode_latent", 0) * 1000
    total_ms = result.total_to_decode * 1000
    rtf = result.total_to_decode / audio_sec if audio_sec > 0 else float("inf")
    print(
        f"  [{tag}] {audio_sec:.2f}s audio | "
        f"rf={rf_ms:.0f}ms dec={dec_ms:.0f}ms total={total_ms:.0f}ms | "
        f"RTF={rtf:.3f} | seed={result.used_seed} | {outpath}",
        flush=True,
    )


def synth(runtime: InferenceRuntime, tag: str, filename: str, **kwargs) -> SamplingResult:
    req = SamplingRequest(**kwargs)
    result = runtime.synthesize(req)
    outpath = str(OUT / filename)
    save_wav(outpath, result.audio, result.sample_rate)
    print_result(tag, result, outpath)
    return result


# ── Download checkpoint once ─────────────────────────────────────────
def download_checkpoint() -> str:
    from huggingface_hub import hf_hub_download

    print("Downloading model.safetensors from HF...", flush=True)
    t0 = time.time()
    path = hf_hub_download(repo_id=HF_REPO, filename="model.safetensors")
    print(f"  Downloaded in {time.time()-t0:.1f}s -> {path}", flush=True)
    return path


# ── Build runtime ────────────────────────────────────────────────────
def build_runtime(ckpt: str, *, compile_model: bool = False) -> InferenceRuntime:
    print(f"Building runtime (precision={PRECISION}, compile={compile_model})...", flush=True)
    t0 = time.time()
    key = RuntimeKey(
        checkpoint=ckpt,
        model_device=DEVICE,
        codec_repo="facebook/dacvae-watermarked",
        model_precision=PRECISION,
        codec_device=DEVICE,
        codec_precision=PRECISION,
        compile_model=compile_model,
    )
    rt = InferenceRuntime.from_key(key)
    print(f"  Runtime ready in {time.time()-t0:.1f}s", flush=True)
    return rt


# ══════════════════════════════════════════════════════════════════════
#  TEST 1 — Basic TTS (no reference audio)
# ══════════════════════════════════════════════════════════════════════
def test_basic_tts(rt: InferenceRuntime) -> None:
    banner("TEST 1: Basic TTS (no reference)")
    texts = [
        ("short", "今日はいい天気ですね。"),
        ("medium", "お電話ありがとうございます。ただいま電話が大変混み合っております。恐れ入りますが、発信音のあとに、ご用件をお話しください。"),
        ("long", "その森には、古い言い伝えがありました。月が最も高く昇る夜、静かに耳を澄ませば、風の歌声が聞こえるというのです。私は半信半疑でしたが、その夜、確かに誰かが私を呼ぶ声を聞いたのです。"),
    ]
    for name, text in texts:
        synth(
            rt, f"basic_{name}", f"t1_basic_{name}.wav",
            text=text, no_ref=True, num_steps=40, seed=42,
        )


# ══════════════════════════════════════════════════════════════════════
#  TEST 2 — Emoji style control
# ══════════════════════════════════════════════════════════════════════
def test_emoji_style(rt: InferenceRuntime) -> None:
    banner("TEST 2: Emoji-based style control")
    texts = [
        ("whisper", "なーに、どうしたの？…え？もっと近づいてほしい？…👂😮‍💨👂😮‍💨こういうのが好きなんだ？"),
        ("crying", "うぅ…😭そんなに酷いこと、言わないで…😭"),
        ("sneezing", "🤧🤧ごめんね、風邪引いちゃってて🤧…大丈夫、ただの風邪だからすぐ治るよ🥺"),
        ("neutral", "今日はいい天気ですね。散歩に行きましょう。"),
    ]
    for name, text in texts:
        synth(
            rt, f"emoji_{name}", f"t2_emoji_{name}.wav",
            text=text, no_ref=True, num_steps=40, seed=42,
        )


# ══════════════════════════════════════════════════════════════════════
#  TEST 3 — Voice cloning (use t1 output as reference)
# ══════════════════════════════════════════════════════════════════════
def test_voice_cloning(rt: InferenceRuntime) -> None:
    banner("TEST 3: Voice cloning (zero-shot)")
    ref_wav = str(OUT / "t1_basic_short.wav")
    if not Path(ref_wav).exists():
        print("  SKIP: reference wav not found (run test 1 first)", flush=True)
        return
    texts = [
        ("clone_a", "こんにちは、元気ですか？今日は何をしますか？"),
        ("clone_b", "明日の天気予報では、関東地方は晴れのち曇りでしょう。"),
    ]
    for name, text in texts:
        synth(
            rt, name, f"t3_{name}.wav",
            text=text, ref_wav=ref_wav, num_steps=40, seed=42,
        )


# ══════════════════════════════════════════════════════════════════════
#  TEST 4 — CFG guidance modes
# ══════════════════════════════════════════════════════════════════════
def test_cfg_modes(rt: InferenceRuntime) -> None:
    banner("TEST 4: CFG guidance modes")
    text = "今日は素晴らしい一日になりそうです。"
    for mode in ("independent", "joint", "alternating"):
        cfg_text = 3.0
        cfg_spk = 3.0 if mode == "joint" else 5.0
        synth(
            rt, f"cfg_{mode}", f"t4_cfg_{mode}.wav",
            text=text, no_ref=True, num_steps=30, seed=42,
            cfg_scale_text=cfg_text, cfg_scale_speaker=cfg_spk,
            cfg_guidance_mode=mode,
        )


# ══════════════════════════════════════════════════════════════════════
#  TEST 5 — Sampling parameter sweep
# ══════════════════════════════════════════════════════════════════════
def test_sampling_params(rt: InferenceRuntime) -> None:
    banner("TEST 5: Sampling parameter sweep")
    text = "東京タワーは日本で最も有名な観光地の一つです。"

    for steps in (10, 20, 40):
        synth(
            rt, f"steps_{steps}", f"t5_steps_{steps}.wav",
            text=text, no_ref=True, num_steps=steps, seed=42,
        )

    for seed in (0, 123, 999):
        synth(
            rt, f"seed_{seed}", f"t5_seed_{seed}.wav",
            text=text, no_ref=True, num_steps=30, seed=seed,
        )

    for cfg_t in (1.0, 3.0, 7.0):
        synth(
            rt, f"cfgtxt_{cfg_t}", f"t5_cfgtxt_{cfg_t}.wav",
            text=text, no_ref=True, num_steps=30, seed=42,
            cfg_scale_text=cfg_t,
        )


# ══════════════════════════════════════════════════════════════════════
#  TEST 6 — Multi-candidate batch generation
# ══════════════════════════════════════════════════════════════════════
def test_multi_candidate(rt: InferenceRuntime) -> None:
    banner("TEST 6: Multi-candidate generation")
    text = "桜の花が満開です。公園で花見をしましょう。"
    for n in (1, 2, 4):
        t0 = time.time()
        req = SamplingRequest(
            text=text, no_ref=True, num_steps=30, seed=42,
            num_candidates=n, decode_mode="batch",
        )
        result = rt.synthesize(req)
        dur = time.time() - t0
        for i, audio in enumerate(result.audios):
            p = str(OUT / f"t6_cand{n}_{i}.wav")
            save_wav(p, audio, result.sample_rate)
        audio_sec = result.audios[0].shape[-1] / result.sample_rate
        print(
            f"  [multi_n={n}] {len(result.audios)} candidates, "
            f"{audio_sec:.2f}s each, total={dur*1000:.0f}ms, "
            f"per_cand={dur/n*1000:.0f}ms",
            flush=True,
        )


# ══════════════════════════════════════════════════════════════════════
#  TEST 7 — CONCURRENCY (the big one)
# ══════════════════════════════════════════════════════════════════════
@dataclass
class ConcurrencyResult:
    tag: str
    success: bool
    audio_sec: float
    wall_sec: float
    error: str | None = None


def _concurrent_worker(
    rt: InferenceRuntime, tag: str, text: str, seed: int
) -> ConcurrencyResult:
    try:
        t0 = time.time()
        req = SamplingRequest(
            text=text, no_ref=True, num_steps=30, seed=seed,
        )
        result = rt.synthesize(req)
        wall = time.time() - t0
        audio_sec = result.audio.shape[-1] / result.sample_rate
        p = str(OUT / f"t7_{tag}.wav")
        save_wav(p, result.audio, result.sample_rate)
        return ConcurrencyResult(tag=tag, success=True, audio_sec=audio_sec, wall_sec=wall)
    except Exception as e:
        return ConcurrencyResult(
            tag=tag, success=False, audio_sec=0, wall_sec=time.time() - t0,
            error=f"{type(e).__name__}: {e}",
        )


def test_concurrency(rt: InferenceRuntime) -> None:
    banner("TEST 7: Concurrency / Thread-safety")

    texts = [
        "今日の夕食は何にしようかな。",
        "新しいプロジェクトが始まりました。頑張りましょう。",
        "電車が遅れているようです。別のルートを探しましょう。",
        "この映画はとても感動的でした。また見たいです。",
        "来週の会議の準備はできていますか？",
        "日本の四季はとても美しいです。",
        "人工知能の発展は目覚ましいものがあります。",
        "音楽を聴くとリラックスできます。",
    ]

    for num_threads in (1, 2, 4, 8):
        actual_tasks = min(num_threads, len(texts))
        print(f"\n  --- {num_threads} concurrent threads ({actual_tasks} tasks) ---", flush=True)
        results: list[ConcurrencyResult] = []

        wall_start = time.time()
        with ThreadPoolExecutor(max_workers=num_threads) as pool:
            futures = {}
            for i in range(actual_tasks):
                tag = f"conc{num_threads}_{i}"
                fut = pool.submit(_concurrent_worker, rt, tag, texts[i], seed=i * 111)
                futures[fut] = tag

            for fut in as_completed(futures):
                results.append(fut.result())

        total_wall = time.time() - wall_start
        ok = sum(1 for r in results if r.success)
        fail = sum(1 for r in results if not r.success)
        avg_wall = sum(r.wall_sec for r in results) / len(results) if results else 0
        total_audio = sum(r.audio_sec for r in results)

        print(
            f"  threads={num_threads} | ok={ok} fail={fail} | "
            f"total_wall={total_wall:.2f}s | avg_per_req={avg_wall:.2f}s | "
            f"total_audio={total_audio:.1f}s | "
            f"throughput={total_audio/total_wall:.2f}x realtime",
            flush=True,
        )
        for r in results:
            status = "OK" if r.success else f"FAIL: {r.error}"
            print(f"    {r.tag}: wall={r.wall_sec:.2f}s audio={r.audio_sec:.2f}s {status}", flush=True)

    # Also test: same seed from multiple threads produces same output
    banner("TEST 7b: Determinism under concurrency")
    determinism_results: list[SamplingResult] = []

    def _det_worker(rt, text, seed):
        req = SamplingRequest(text=text, no_ref=True, num_steps=20, seed=seed)
        return rt.synthesize(req)

    text_det = "決定性のテストです。"
    with ThreadPoolExecutor(max_workers=4) as pool:
        futs = [pool.submit(_det_worker, rt, text_det, 777) for _ in range(4)]
        determinism_results = [f.result() for f in futs]

    shapes = [r.audio.shape for r in determinism_results]
    all_same_shape = all(s == shapes[0] for s in shapes)
    if all_same_shape:
        diffs = []
        for i in range(1, len(determinism_results)):
            diff = (determinism_results[0].audio - determinism_results[i].audio).abs().max().item()
            diffs.append(diff)
        print(f"  All 4 outputs same shape: {shapes[0]}", flush=True)
        print(f"  Max abs diffs vs first: {diffs}", flush=True)
        print(f"  Deterministic: {all(d == 0.0 for d in diffs)}", flush=True)
    else:
        print(f"  WARNING: shapes differ: {shapes}", flush=True)


# ══════════════════════════════════════════════════════════════════════
#  TEST 8 — torch.compile speedup
# ══════════════════════════════════════════════════════════════════════
def test_compile(ckpt: str) -> None:
    banner("TEST 8: torch.compile acceleration")
    text = "コンパイルテストです。速度を比較します。"

    rt_nocompile = build_runtime(ckpt, compile_model=False)
    times_no = []
    for i in range(3):
        t0 = time.time()
        req = SamplingRequest(text=text, no_ref=True, num_steps=30, seed=42)
        r = rt_nocompile.synthesize(req)
        times_no.append(time.time() - t0)
    audio_sec = r.audio.shape[-1] / r.sample_rate
    print(f"  No compile: runs={[f'{t:.2f}s' for t in times_no]} audio={audio_sec:.2f}s", flush=True)
    del rt_nocompile
    gc.collect()
    torch.cuda.empty_cache()

    rt_compile = build_runtime(ckpt, compile_model=True)
    times_yes = []
    for i in range(3):
        t0 = time.time()
        req = SamplingRequest(text=text, no_ref=True, num_steps=30, seed=42)
        r = rt_compile.synthesize(req)
        times_yes.append(time.time() - t0)
    audio_sec = r.audio.shape[-1] / r.sample_rate
    print(f"  Compiled:   runs={[f'{t:.2f}s' for t in times_yes]} audio={audio_sec:.2f}s", flush=True)
    print(
        f"  Speedup (last run): {times_no[-1]/times_yes[-1]:.2f}x "
        f"({times_no[-1]*1000:.0f}ms -> {times_yes[-1]*1000:.0f}ms)",
        flush=True,
    )
    del rt_compile
    gc.collect()
    torch.cuda.empty_cache()


# ══════════════════════════════════════════════════════════════════════
#  MAIN
# ══════════════════════════════════════════════════════════════════════
def main() -> None:
    print(f"PyTorch: {torch.__version__}", flush=True)
    print(f"CUDA: {torch.cuda.is_available()}, device: {torch.cuda.get_device_name(0)}", flush=True)
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB", flush=True)

    ckpt = download_checkpoint()
    rt = build_runtime(ckpt)

    # Show VRAM after loading
    alloc = torch.cuda.memory_allocated() / 1e9
    reserved = torch.cuda.memory_reserved() / 1e9
    print(f"VRAM after load: allocated={alloc:.2f}GB reserved={reserved:.2f}GB", flush=True)

    test_basic_tts(rt)
    test_emoji_style(rt)
    test_voice_cloning(rt)
    test_cfg_modes(rt)
    test_sampling_params(rt)
    test_multi_candidate(rt)
    test_concurrency(rt)

    del rt
    gc.collect()
    torch.cuda.empty_cache()

    test_compile(ckpt)

    banner("ALL TESTS COMPLETE")
    print(f"Output files in: {OUT.resolve()}", flush=True)
    for f in sorted(OUT.iterdir()):
        print(f"  {f.name}  ({f.stat().st_size / 1024:.1f} KB)", flush=True)


if __name__ == "__main__":
    main()
