#!/usr/bin/env python3
"""Run reconstruction quality evaluation for neural audio codecs.

Computes: mel L1, multi-resolution STFT, SI-SDR, HF energy retention.

Usage:
    python scripts/run_eval.py --codecs xcodec2 snac wavtokenizer bicodec \
        --audio-dir data/eval_clips/ --output-dir results
"""

from __future__ import annotations

import argparse
import json
import logging
import sys
from pathlib import Path

import torch

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s: %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger("codecbench.eval")


def _sanity_check(tb, audio_out: torch.Tensor, codec_name: str) -> list[str]:
    """Run sanity checks on codec output, return list of issues."""
    issues = []

    if torch.isnan(audio_out).any():
        issues.append(f"{codec_name}: NaN in output audio")
    if torch.isinf(audio_out).any():
        issues.append(f"{codec_name}: Inf in output audio")
    if audio_out.ndim != 3:
        issues.append(f"{codec_name}: output shape {audio_out.shape} is not [B, 1, T']")

    # Token checks
    def _check_tensor(t: torch.Tensor, label: str) -> None:
        if not t.dtype in (torch.int32, torch.int64, torch.int16, torch.long):
            issues.append(f"{codec_name} {label}: token dtype is {t.dtype}, expected integer")
        if torch.isnan(t.float()).any():
            issues.append(f"{codec_name} {label}: NaN in tokens")

    if isinstance(tb.tokens, torch.Tensor):
        _check_tensor(tb.tokens, "tokens")
    elif isinstance(tb.tokens, dict):
        for k, v in tb.tokens.items():
            _check_tensor(v, f"tokens[{k}]")
    elif isinstance(tb.tokens, (list, tuple)):
        for i, t in enumerate(tb.tokens):
            _check_tensor(t, f"tokens[{i}]")

    return issues


def _determinism_check(codec, wav: torch.Tensor, sr: int) -> bool:
    """Encode same input twice, verify tokens are identical."""
    with torch.inference_mode():
        tb1 = codec.encode(wav, sr)
        tb2 = codec.encode(wav, sr)

    def _equal(a, b) -> bool:
        if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
            return torch.equal(a, b)
        if isinstance(a, dict) and isinstance(b, dict):
            return all(_equal(a[k], b[k]) for k in a)
        if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
            return all(_equal(x, y) for x, y in zip(a, b))
        return False

    return _equal(tb1.tokens, tb2.tokens)


def run_eval(args: argparse.Namespace) -> None:
    from codecbench.codecs import get_codec
    from codecbench.audio.io import load_audio, ensure_shape
    from codecbench.audio.batching import pad_or_crop
    from codecbench.metrics import log_mel_l1, multi_resolution_stft_loss, si_sdr, hf_energy_delta_db
    from codecbench.reporting import generate_summary_csv, generate_report_md, load_jsonl

    device = args.device
    if device == "cuda" and not torch.cuda.is_available():
        logger.warning("CUDA not available, falling back to CPU")
        device = "cpu"

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    audio_dir = Path(args.audio_dir)

    if not audio_dir.exists():
        logger.error("Audio directory %s does not exist", audio_dir)
        sys.exit(1)

    audio_files = sorted(audio_dir.glob("*.wav")) + sorted(audio_dir.glob("*.flac"))
    if not audio_files:
        logger.error("No .wav or .flac files found in %s", audio_dir)
        sys.exit(1)

    logger.info("Found %d audio files in %s", len(audio_files), audio_dir)
    target_seconds = args.seconds

    eval_results = []

    for codec_name in args.codecs:
        try:
            codec = get_codec(codec_name)
        except KeyError as e:
            logger.warning("Skipping %s: %s", codec_name, e)
            continue

        logger.info("=" * 60)
        logger.info("Loading %s for evaluation", codec_name)
        try:
            codec.load(device=device)
        except Exception as e:
            logger.error("Failed to load %s: %s", codec_name, e)
            continue

        sr = codec.native_sr
        target_len = int(target_seconds * sr)

        # Load and batch audio
        wav_list = []
        for af in audio_files:
            wav, _ = load_audio(af, target_sr=sr)
            wav = pad_or_crop(wav, target_len)
            wav_list.append(wav)

        wav_batch = torch.stack(wav_list, dim=0).to(device)  # [B, 1, T]

        # Sanity: determinism check on first sample
        det_ok = _determinism_check(codec, wav_batch[:1], sr)
        logger.info("%s determinism check: %s", codec_name, "PASS" if det_ok else "FAIL")

        # Encode + decode
        with torch.inference_mode():
            tb = codec.encode(wav_batch, sr)
            recon = codec.decode(tb)

        # Sanity checks
        issues = _sanity_check(tb, recon, codec_name)
        for issue in issues:
            logger.warning("SANITY: %s", issue)

        # Vocab stats
        try:
            vmin, vmax = tb.observed_vocab()
            logger.info("%s vocab range: [%d, %d]", codec_name, vmin, vmax)
        except Exception:
            vmin, vmax = -1, -1

        # Metrics
        mel = log_mel_l1(wav_batch, recon, sr)
        mrstft = multi_resolution_stft_loss(wav_batch, recon)
        sisdr_val = si_sdr(wav_batch, recon)
        hf_delta = hf_energy_delta_db(wav_batch, recon, sr)

        logger.info(
            "%s metrics: mel_l1=%.4f mrstft=%.4f sisdr=%.1f hf_delta=%.1f",
            codec_name, mel, mrstft, sisdr_val, hf_delta,
        )

        result = {
            "codec": codec_name,
            "sr": sr,
            "n_files": len(audio_files),
            "clip_seconds": target_seconds,
            "mel_l1": round(mel, 6),
            "mrstft": round(mrstft, 6),
            "sisdr": round(sisdr_val, 2),
            "hf_energy_delta_db": round(hf_delta, 2),
            "deterministic": det_ok,
            "vocab_min": vmin,
            "vocab_max": vmax,
            "tokens_shape": tb.shapes_summary(),
            "sanity_issues": issues,
        }
        eval_results.append(result)

        del codec, tb, recon
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    if not eval_results:
        logger.error("No evaluation results produced")
        sys.exit(1)

    # Save
    eval_jsonl_path = output_dir / "eval_results.jsonl"
    with open(eval_jsonl_path, "w") as f:
        for r in eval_results:
            f.write(json.dumps(r) + "\n")

    eval_csv_path = output_dir / "eval_summary.csv"
    generate_summary_csv(eval_results, eval_csv_path)

    # Merge with bench results if available
    bench_jsonl = output_dir / "bench_results.jsonl"
    bench_data = load_jsonl(bench_jsonl) if bench_jsonl.exists() else []

    report_path = output_dir / "report.md"
    generate_report_md(bench_data, report_path, eval_results=eval_results)

    logger.info("=" * 60)
    logger.info("Evaluation complete. Outputs:")
    logger.info("  JSONL:  %s", eval_jsonl_path)
    logger.info("  CSV:    %s", eval_csv_path)
    logger.info("  Report: %s", report_path)


def main() -> None:
    parser = argparse.ArgumentParser(description="CodecBench reconstruction quality evaluation")
    parser.add_argument("--codecs", nargs="+", default=["xcodec2", "snac", "wavtokenizer", "bicodec"])
    parser.add_argument("--audio-dir", required=True, help="Directory of .wav/.flac files")
    parser.add_argument("--seconds", type=float, default=6.0)
    parser.add_argument("--output-dir", default="results")
    parser.add_argument("--device", default="cuda")
    args = parser.parse_args()
    run_eval(args)


if __name__ == "__main__":
    main()
