"""Benchmark runner: orchestrates warmup, timing, metrics, and aggregation."""

from __future__ import annotations

import json
import logging
import platform
import subprocess
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

import torch

from codecbench.codecs.base import NeuralCodec, TokenBatch
from codecbench.bench.timer import BenchStats, CUDATimer, measure_peak_vram, reset_vram_stats

logger = logging.getLogger(__name__)


@dataclass
class BenchConfig:
    batch_sizes: list[int] = field(default_factory=lambda: [1, 8, 32])
    clip_seconds: float = 6.0
    warmup_iters: int = 10
    measure_iters: int = 50
    dtypes: list[str] = field(default_factory=lambda: ["fp32", "bf16"])


@dataclass
class RunResult:
    """Single benchmark run result (one codec, one config)."""
    codec: str
    sr: int
    batch_size: int
    clip_seconds: float
    dtype: str
    encode_stats: dict
    decode_stats: dict
    e2e_stats: dict
    tokens_per_sec: float
    tokens_shape: str
    peak_vram_mb: float
    gpu_name: str
    torch_version: str
    extra: dict = field(default_factory=dict)

    def to_jsonl(self) -> str:
        d = {
            "codec": self.codec,
            "sr": self.sr,
            "batch_size": self.batch_size,
            "clip_seconds": self.clip_seconds,
            "dtype": self.dtype,
            "encode_ms": self.encode_stats.get("mean_ms", 0),
            "encode_p50_ms": self.encode_stats.get("p50_ms", 0),
            "encode_p95_ms": self.encode_stats.get("p95_ms", 0),
            "decode_ms": self.decode_stats.get("mean_ms", 0),
            "decode_p50_ms": self.decode_stats.get("p50_ms", 0),
            "decode_p95_ms": self.decode_stats.get("p95_ms", 0),
            "e2e_ms": self.e2e_stats.get("mean_ms", 0),
            "tokens_per_sec": self.tokens_per_sec,
            "tokens_shape": self.tokens_shape,
            "peak_vram_mb": self.peak_vram_mb,
            "gpu_name": self.gpu_name,
            "torch": self.torch_version,
            **self.extra,
        }
        return json.dumps(d)


def _get_gpu_name() -> str:
    if torch.cuda.is_available():
        return torch.cuda.get_device_name(0)
    return "no-gpu"


def _get_git_commit() -> str:
    try:
        return subprocess.check_output(
            ["git", "rev-parse", "--short", "HEAD"],
            stderr=subprocess.DEVNULL,
        ).decode().strip()
    except Exception:
        return "unknown"


def _dtype_from_str(s: str) -> torch.dtype:
    return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[s]


def _make_autocast(dtype_str: str, device: str = "cuda"):
    """Return autocast context manager for the given dtype string."""
    if dtype_str == "fp32":
        return torch.inference_mode()
    dt = _dtype_from_str(dtype_str)
    return torch.autocast(device_type=device.split(":")[0], dtype=dt)


def run_single_benchmark(
    codec: NeuralCodec,
    wav_batch: torch.Tensor,
    config: BenchConfig,
    dtype_str: str = "fp32",
    device: str = "cuda",
) -> RunResult:
    """Run benchmark for a single codec + config combination.

    Args:
        codec: loaded codec (already on device)
        wav_batch: [B, 1, T] at codec's native SR
        config: benchmark parameters
        dtype_str: "fp32", "fp16", or "bf16"
        device: CUDA device
    """
    sr = codec.native_sr
    bs = wav_batch.shape[0]
    timer = CUDATimer(device)

    # --- Warmup ---
    logger.info("Warmup: %d iters for %s @ bs=%d dtype=%s", config.warmup_iters, codec.name, bs, dtype_str)
    autocast_ctx = _make_autocast(dtype_str, device)
    for _ in range(config.warmup_iters):
        with autocast_ctx:
            with torch.inference_mode():
                tb = codec.encode(wav_batch, sr)
                _ = codec.decode(tb)
    torch.cuda.synchronize()

    # --- Measure ---
    encode_stats = BenchStats(label=f"{codec.name}_encode")
    decode_stats = BenchStats(label=f"{codec.name}_decode")
    e2e_stats = BenchStats(label=f"{codec.name}_e2e")

    reset_vram_stats(device)
    last_tb = None

    for i in range(config.measure_iters):
        with autocast_ctx:
            with torch.inference_mode():
                # Encode
                timer.record_start()
                tb = codec.encode(wav_batch, sr)
                enc_ms = timer.record_end()
                encode_stats.times_ms.append(enc_ms)
                last_tb = tb

                # Decode
                timer.record_start()
                _ = codec.decode(tb)
                dec_ms = timer.record_end()
                decode_stats.times_ms.append(dec_ms)

                e2e_stats.times_ms.append(enc_ms + dec_ms)

    peak_vram = measure_peak_vram(device)

    # Token rate: tokens per second of audio
    total_audio_seconds = bs * config.clip_seconds
    if last_tb is not None:
        total_tokens = last_tb.token_count
        tokens_per_sec = total_tokens / total_audio_seconds
        tokens_shape = last_tb.shapes_summary()
    else:
        tokens_per_sec = 0.0
        tokens_shape = "N/A"

    return RunResult(
        codec=codec.name,
        sr=sr,
        batch_size=bs,
        clip_seconds=config.clip_seconds,
        dtype=dtype_str,
        encode_stats=encode_stats.as_dict(),
        decode_stats=decode_stats.as_dict(),
        e2e_stats=e2e_stats.as_dict(),
        tokens_per_sec=round(tokens_per_sec, 2),
        tokens_shape=tokens_shape,
        peak_vram_mb=round(peak_vram / (1024 ** 2), 1),
        gpu_name=_get_gpu_name(),
        torch_version=torch.__version__,
        extra={"commit": _get_git_commit()},
    )


def run_codec_suite(
    codec: NeuralCodec,
    config: BenchConfig,
    device: str = "cuda",
) -> list[RunResult]:
    """Run full benchmark suite for one codec across all batch sizes and dtypes."""
    results = []
    for dtype_str in config.dtypes:
        for bs in config.batch_sizes:
            n_samples = int(config.clip_seconds * codec.native_sr)
            wav_batch = torch.randn(bs, 1, n_samples, device=device)
            wav_batch = wav_batch / wav_batch.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) * 0.9

            try:
                result = run_single_benchmark(codec, wav_batch, config, dtype_str, device)
                results.append(result)
                logger.info(
                    "%s bs=%d %s: encode=%.1fms decode=%.1fms vram=%.0fMB tps=%.1f",
                    codec.name, bs, dtype_str,
                    result.encode_stats["mean_ms"],
                    result.decode_stats["mean_ms"],
                    result.peak_vram_mb,
                    result.tokens_per_sec,
                )
            except Exception as e:
                logger.error("Failed %s bs=%d %s: %s", codec.name, bs, dtype_str, e)
                continue

    return results


def save_results_jsonl(results: list[RunResult], path: str | Path) -> None:
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w") as f:
        for r in results:
            f.write(r.to_jsonl() + "\n")
    logger.info("Results saved to %s", path)
