"""Comprehensive validation: original vs fast XCodec2 over all available samples.

Encodes all chunks_16k WAV files with both original encode_code and FastXCodec2,
compares VQ codes exactly, measures timing, and verifies decode round-trip.

Phase 1: Original encode (sequential B=1, saves codes to disk)
Phase 2: Fast encode (sequential B=1, saves codes to disk)
Phase 3: Code-level comparison (per-segment + aggregate stats)
Phase 4: Decode round-trip sanity (encode→decode→SNR check on subset)
"""

import gc
import json
import os
import sys
import time
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
import torchaudio

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

CHUNKS_DIR = Path("data/chunks_16k")
RESULTS_DIR = Path("results/validation")
TARGET_SR = 16_000
CHUNK_SEC = 6.0


def discover_chunks(max_n: int = 0) -> list[Path]:
    files = sorted(CHUNKS_DIR.glob("*.wav"))
    if max_n > 0:
        files = files[:max_n]
    return files


def load_wav(path: Path) -> torch.Tensor:
    wav, sr = torchaudio.load(str(path))
    if sr != TARGET_SR:
        wav = torchaudio.functional.resample(wav, sr, TARGET_SR)
    if wav.shape[0] > 1:
        wav = wav.mean(dim=0, keepdim=True)
    target_len = int(CHUNK_SEC * TARGET_SR)
    if wav.shape[1] > target_len:
        wav = wav[:, :target_len]
    elif wav.shape[1] < target_len:
        wav = F.pad(wav, (0, target_len - wav.shape[1]))
    return wav  # [1, 96000]


def phase1_original(files: list[Path], out_dir: Path) -> tuple[float, int]:
    """Encode all files with original XCodec2, save codes, return total time."""
    from xcodec2.modeling_xcodec2 import XCodec2Model

    print(f"\n{'='*70}")
    print(f"PHASE 1: Original XCodec2 encode_code — {len(files)} files")
    print(f"{'='*70}")

    model = XCodec2Model.from_pretrained("HKUSTAudio/xcodec2")
    model.eval().cuda()

    # Warmup
    dummy = torch.randn(1, 96000, device="cuda")
    for _ in range(3):
        with torch.inference_mode():
            model.encode_code(dummy)
    torch.cuda.synchronize()

    out_dir.mkdir(parents=True, exist_ok=True)
    total_codes = 0
    t0 = time.perf_counter()

    with torch.inference_mode():
        for i, f in enumerate(files):
            wav = load_wav(f).cuda()  # [1, 96000]
            codes = model.encode_code(wav)
            codes_flat = codes.view(-1).cpu()
            torch.save(codes_flat, out_dir / f"{f.stem}.pt")
            total_codes += len(codes_flat)

            if (i + 1) % 100 == 0:
                elapsed = time.perf_counter() - t0
                print(f"  [{i+1}/{len(files)}] {elapsed:.1f}s elapsed, "
                      f"{elapsed/(i+1)*1000:.1f}ms/seg")

    torch.cuda.synchronize()
    total_time = time.perf_counter() - t0
    ms_per = total_time / len(files) * 1000
    rtf = CHUNK_SEC * 1000 / ms_per

    print(f"\n  Original done: {len(files)} files, {total_codes} total codes")
    print(f"  {total_time:.1f}s total, {ms_per:.1f}ms/seg, RTF={rtf:.1f}x")

    del model
    torch.cuda.empty_cache()
    gc.collect()
    return ms_per, total_codes


def phase2_fast(files: list[Path], out_dir: Path) -> tuple[float, int]:
    """Encode all files with FastXCodec2, save codes, return total time."""
    from codecbench.codecs.xcodec2_fast import FastXCodec2Codec

    print(f"\n{'='*70}")
    print(f"PHASE 2: FastXCodec2 _fast_encode — {len(files)} files")
    print(f"{'='*70}")

    fast = FastXCodec2Codec()
    fast.load(device="cuda")
    fast.warmup()

    out_dir.mkdir(parents=True, exist_ok=True)
    total_codes = 0
    t0 = time.perf_counter()

    with torch.inference_mode():
        for i, f in enumerate(files):
            wav = load_wav(f).cuda()  # [1, 96000]
            wav_3d = wav.unsqueeze(0)  # [1, 1, 96000]
            tb = fast.encode(wav_3d, TARGET_SR)
            codes_flat = tb.tokens.view(-1).cpu()
            torch.save(codes_flat, out_dir / f"{f.stem}.pt")
            total_codes += len(codes_flat)

            if (i + 1) % 100 == 0:
                elapsed = time.perf_counter() - t0
                print(f"  [{i+1}/{len(files)}] {elapsed:.1f}s elapsed, "
                      f"{elapsed/(i+1)*1000:.1f}ms/seg")

    torch.cuda.synchronize()
    total_time = time.perf_counter() - t0
    ms_per = total_time / len(files) * 1000
    rtf = CHUNK_SEC * 1000 / ms_per

    print(f"\n  Fast done: {len(files)} files, {total_codes} total codes")
    print(f"  {total_time:.1f}s total, {ms_per:.1f}ms/seg, RTF={rtf:.1f}x")

    del fast
    torch.cuda.empty_cache()
    gc.collect()
    return ms_per, total_codes


def phase3_compare(files: list[Path], orig_dir: Path, fast_dir: Path) -> dict:
    """Compare original vs fast codes, return detailed statistics."""
    print(f"\n{'='*70}")
    print(f"PHASE 3: Code-level comparison — {len(files)} files")
    print(f"{'='*70}")

    total_codes = 0
    total_match = 0
    per_segment_pct = []
    per_segment_diff_count = []
    segments_100pct = 0
    worst_segment = None
    worst_pct = 100.0

    for f in files:
        orig = torch.load(orig_dir / f"{f.stem}.pt", weights_only=True)
        fast = torch.load(fast_dir / f"{f.stem}.pt", weights_only=True)

        n = min(len(orig), len(fast))
        match = (orig[:n] == fast[:n]).sum().item()
        diff = n - match
        pct = 100.0 * match / n if n > 0 else 0.0

        total_codes += n
        total_match += match
        per_segment_pct.append(pct)
        per_segment_diff_count.append(diff)

        if pct == 100.0:
            segments_100pct += 1
        if pct < worst_pct:
            worst_pct = pct
            worst_segment = f.stem

    overall_pct = 100.0 * total_match / total_codes if total_codes > 0 else 0
    pct_arr = np.array(per_segment_pct)
    diff_arr = np.array(per_segment_diff_count)

    stats = {
        "num_files": len(files),
        "total_codes": total_codes,
        "total_match": total_match,
        "total_diff": total_codes - total_match,
        "overall_match_pct": round(overall_pct, 4),
        "per_segment_mean_pct": round(float(pct_arr.mean()), 4),
        "per_segment_median_pct": round(float(np.median(pct_arr)), 4),
        "per_segment_min_pct": round(float(pct_arr.min()), 2),
        "per_segment_max_pct": round(float(pct_arr.max()), 2),
        "per_segment_std_pct": round(float(pct_arr.std()), 4),
        "segments_100pct": segments_100pct,
        "segments_sub99pct": int((pct_arr < 99.0).sum()),
        "segments_sub98pct": int((pct_arr < 98.0).sum()),
        "segments_sub95pct": int((pct_arr < 95.0).sum()),
        "worst_segment": worst_segment,
        "worst_pct": worst_pct,
        "max_diff_per_segment": int(diff_arr.max()),
        "mean_diff_per_segment": round(float(diff_arr.mean()), 2),
        "diff_percentiles": {
            "p50": int(np.percentile(diff_arr, 50)),
            "p90": int(np.percentile(diff_arr, 90)),
            "p95": int(np.percentile(diff_arr, 95)),
            "p99": int(np.percentile(diff_arr, 99)),
            "p100": int(diff_arr.max()),
        }
    }

    print(f"\n  Overall match: {total_match}/{total_codes} ({overall_pct:.4f}%)")
    print(f"  Segments 100% match: {segments_100pct}/{len(files)}")
    print(f"  Segments <99% match: {stats['segments_sub99pct']}")
    print(f"  Segments <98% match: {stats['segments_sub98pct']}")
    print(f"  Segments <95% match: {stats['segments_sub95pct']}")
    print(f"  Worst segment: {worst_segment} ({worst_pct:.2f}%)")
    print(f"  Per-segment match: mean={pct_arr.mean():.4f}% "
          f"median={np.median(pct_arr):.4f}% std={pct_arr.std():.4f}%")
    print(f"  Per-segment diffs: mean={diff_arr.mean():.1f} "
          f"max={diff_arr.max()} p90={np.percentile(diff_arr, 90):.0f} "
          f"p99={np.percentile(diff_arr, 99):.0f}")

    # Distribution of match percentages
    buckets = [100.0, 99.5, 99.0, 98.5, 98.0, 97.0, 96.0, 95.0, 0.0]
    print(f"\n  Match distribution:")
    for i in range(len(buckets) - 1):
        hi, lo = buckets[i], buckets[i + 1]
        count = int(((pct_arr <= hi) & (pct_arr > lo)).sum())
        if hi == 100.0:
            count = int((pct_arr == 100.0).sum())
            label = f"  =100.0%"
        else:
            label = f"  ({lo:.1f}%, {hi:.1f}%]"
        if count > 0:
            print(f"    {label}: {count} segments")

    return stats


def phase4_roundtrip(files: list[Path], n_test: int = 20) -> dict:
    """Encode→decode round-trip: check audio isn't garbage."""
    from codecbench.codecs.xcodec2_fast import FastXCodec2Codec

    print(f"\n{'='*70}")
    print(f"PHASE 4: Encode→Decode round-trip sanity — {n_test} files")
    print(f"{'='*70}")

    fast = FastXCodec2Codec()
    fast.load(device="cuda")
    fast.warmup()

    snr_list = []
    max_err_list = []

    with torch.inference_mode():
        for f in files[:n_test]:
            wav = load_wav(f).cuda()
            wav_3d = wav.unsqueeze(0)

            tb = fast.encode(wav_3d, TARGET_SR)
            recon = fast.decode(tb)

            orig_np = wav_3d.squeeze().cpu().numpy()
            recon_np = recon.squeeze().cpu().numpy()

            # Align lengths
            minlen = min(len(orig_np), len(recon_np))
            orig_np = orig_np[:minlen]
            recon_np = recon_np[:minlen]

            noise = orig_np - recon_np
            sig_power = np.mean(orig_np ** 2)
            noise_power = np.mean(noise ** 2)
            snr = 10 * np.log10(sig_power / (noise_power + 1e-10))
            max_err = float(np.max(np.abs(noise)))

            snr_list.append(snr)
            max_err_list.append(max_err)

    snr_arr = np.array(snr_list)
    stats = {
        "n_tested": n_test,
        "snr_mean_db": round(float(snr_arr.mean()), 2),
        "snr_min_db": round(float(snr_arr.min()), 2),
        "snr_max_db": round(float(snr_arr.max()), 2),
        "max_sample_error": round(float(max(max_err_list)), 6),
    }

    print(f"\n  Reconstruction SNR: mean={snr_arr.mean():.2f}dB "
          f"min={snr_arr.min():.2f}dB max={snr_arr.max():.2f}dB")
    print(f"  Max sample error: {max(max_err_list):.6f}")
    print(f"  (Note: lossy codec — SNR >5dB with recognizable speech = normal)")

    del fast
    torch.cuda.empty_cache()
    gc.collect()
    return stats


def main():
    files = discover_chunks()
    print(f"Discovered {len(files)} chunks in {CHUNKS_DIR}")

    if len(files) == 0:
        print("ERROR: No WAV files found in chunks_16k/")
        sys.exit(1)

    RESULTS_DIR.mkdir(parents=True, exist_ok=True)
    orig_dir = RESULTS_DIR / "original_codes"
    fast_dir = RESULTS_DIR / "fast_codes"

    # Phase 1: Original
    orig_ms, orig_total = phase1_original(files, orig_dir)

    # Phase 2: Fast
    fast_ms, fast_total = phase2_fast(files, fast_dir)

    # Phase 3: Compare
    compare_stats = phase3_compare(files, orig_dir, fast_dir)

    # Phase 4: Roundtrip
    rt_stats = phase4_roundtrip(files, n_test=30)

    # Final summary
    speedup = orig_ms / fast_ms if fast_ms > 0 else 0
    print(f"\n{'='*70}")
    print(f"FINAL VALIDATION SUMMARY")
    print(f"{'='*70}")
    print(f"  Files tested:        {len(files)}")
    print(f"  Total VQ codes:      {compare_stats['total_codes']}")
    print(f"  Overall code match:  {compare_stats['overall_match_pct']:.4f}%")
    print(f"  Codes that differ:   {compare_stats['total_diff']}")
    print(f"  Worst segment:       {compare_stats['worst_pct']:.2f}% "
          f"({compare_stats['worst_segment']})")
    print(f"  Segments 100%:       {compare_stats['segments_100pct']}/{len(files)}")
    print(f"  Segments <95%:       {compare_stats['segments_sub95pct']}")
    print(f"  Original speed:      {orig_ms:.1f}ms/seg (RTF={CHUNK_SEC*1000/orig_ms:.1f}x)")
    print(f"  Fast speed:          {fast_ms:.1f}ms/seg (RTF={CHUNK_SEC*1000/fast_ms:.1f}x)")
    print(f"  Speedup:             {speedup:.2f}x")
    print(f"  Roundtrip SNR:       {rt_stats['snr_mean_db']:.1f}dB "
          f"(min={rt_stats['snr_min_db']:.1f}dB)")
    print(f"{'='*70}")

    report = {
        "num_files": len(files),
        "original_ms_per_seg": round(orig_ms, 2),
        "fast_ms_per_seg": round(fast_ms, 2),
        "speedup": round(speedup, 2),
        "code_comparison": compare_stats,
        "roundtrip": rt_stats,
    }
    report_path = RESULTS_DIR / "validation_report.json"
    with open(report_path, "w") as f:
        json.dump(report, f, indent=2)
    print(f"\nFull report saved to {report_path}")


if __name__ == "__main__":
    main()
