#!/usr/bin/env python3
"""Quick sanity test: load each codec, encode/decode synthetic audio, verify basics.

Usage:
    python scripts/smoke_test.py --codecs xcodec2 snac wavtokenizer bicodec
"""

from __future__ import annotations

import argparse
import logging
import sys
import traceback

import torch

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s: %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger("codecbench.smoke")


def _smoke_one(codec_name: str, device: str) -> tuple[bool, str]:
    """Run smoke test for a single codec. Returns (passed, message)."""
    from codecbench.codecs import get_codec
    from codecbench.audio.io import generate_synthetic

    try:
        codec = get_codec(codec_name)
    except KeyError as e:
        return False, f"Not registered: {e}"

    try:
        codec.load(device=device)
    except Exception as e:
        return False, f"Load failed: {e}"

    sr = codec.native_sr
    wav = generate_synthetic(6.0, sr, batch_size=2, device=device)

    # 1. Encode
    with torch.inference_mode():
        tb = codec.encode(wav, sr)
    assert tb.batch_size == 2, f"Expected batch_size=2, got {tb.batch_size}"
    assert tb.codec_name == codec_name
    assert tb.sample_rate == sr

    # 2. Token sanity
    vmin, vmax = tb.observed_vocab()
    assert vmin >= 0, f"Negative token value: {vmin}"
    assert vmax < 100_000, f"Suspiciously large token: {vmax}"

    # 3. Decode
    with torch.inference_mode():
        recon = codec.decode(tb)
    assert recon.ndim == 3, f"Expected 3-D output, got {recon.ndim}"
    assert recon.shape[0] == 2, f"Batch dim mismatch: {recon.shape[0]}"
    assert not torch.isnan(recon).any(), "NaN in reconstructed audio"
    assert not torch.isinf(recon).any(), "Inf in reconstructed audio"

    # 4. Output length sanity (should be within 2x of input)
    in_len = wav.shape[-1]
    out_len = recon.shape[-1]
    ratio = out_len / in_len
    assert 0.5 < ratio < 2.0, f"Output length ratio {ratio:.2f} is unreasonable"

    # 5. Determinism
    with torch.inference_mode():
        tb2 = codec.encode(wav, sr)

    def _tokens_equal(a, b) -> bool:
        if isinstance(a, torch.Tensor):
            return torch.equal(a, b)
        if isinstance(a, dict):
            return all(torch.equal(a[k], b[k]) for k in a)
        if isinstance(a, (list, tuple)):
            return all(torch.equal(x, y) for x, y in zip(a, b))
        return False

    det_ok = _tokens_equal(tb.tokens, tb2.tokens)
    if not det_ok:
        return False, "Non-deterministic encoding (tokens differ on repeated encode)"

    # 6. flatten_for_lm
    flat = codec.flatten_for_lm(tb)
    assert flat.ndim == 2, f"Expected 2-D flat tokens, got {flat.ndim}"
    assert flat.shape[0] == 2

    info = (
        f"OK: tokens={tb.shapes_summary()} vocab=[{vmin},{vmax}] "
        f"recon_shape={list(recon.shape)} flat_len={flat.shape[1]} det=True"
    )
    return True, info


def run_smoke(args: argparse.Namespace) -> None:
    device = args.device
    if device == "cuda" and not torch.cuda.is_available():
        logger.warning("CUDA not available, falling back to CPU")
        device = "cpu"

    passed = 0
    failed = 0
    skipped = 0

    for codec_name in args.codecs:
        logger.info("-" * 50)
        logger.info("Smoke test: %s", codec_name)
        try:
            ok, msg = _smoke_one(codec_name, device)
            if ok:
                logger.info("  PASS: %s", msg)
                passed += 1
            else:
                logger.error("  FAIL: %s", msg)
                failed += 1
        except Exception as e:
            logger.error("  ERROR: %s", e)
            logger.debug(traceback.format_exc())
            failed += 1

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    logger.info("=" * 50)
    logger.info("Smoke test summary: %d passed, %d failed", passed, failed)

    if failed > 0:
        sys.exit(1)


def main() -> None:
    parser = argparse.ArgumentParser(description="CodecBench smoke test")
    parser.add_argument(
        "--codecs", nargs="+",
        default=["xcodec2", "snac", "wavtokenizer", "bicodec"],
    )
    parser.add_argument("--device", default="cuda")
    args = parser.parse_args()
    run_smoke(args)


if __name__ == "__main__":
    main()
