"""Benchmark XCodec Indic: encode → decode round-trip on Modi audio segments."""

import sys
sys.path.insert(0, "/home/ubuntu/bench-codecs")

import time
import torch
import torchaudio
import numpy as np
from pathlib import Path

CKPT = "/home/ubuntu/xcodec_indic/indic_step_00198000.ckpt"
SEGMENTS_DIR = Path("/home/ubuntu/modi_processed/segments/001_ERW9i1lwnBw")
DEVICE = "cuda"

def load_model():
    from xcodec2.modeling_xcodec2 import XCodec2Model
    from codecbench.codecs.xcodec2_fast import (
        GPUMelExtractor, _apply_layer_truncation, _apply_sdpa_patch, FastXCodec2Codec,
    )

    print("Loading XCodec2 base model...")
    model = XCodec2Model.from_pretrained("HKUSTAudio/xcodec2")

    print(f"Loading custom Indic checkpoint: {CKPT}")
    ckpt = torch.load(CKPT, map_location="cpu", weights_only=False)
    state_dict = ckpt.get("state_dict", ckpt.get("model", ckpt))
    cleaned = {}
    for k, v in state_dict.items():
        k = k.replace("model.", "", 1) if k.startswith("model.") else k
        # Custom checkpoint uses act.beta, HF model expects act.bias
        k = k.replace(".act.beta", ".act.bias")
        cleaned[k] = v
    missing, unexpected = model.load_state_dict(cleaned, strict=False)
    if missing:
        print(f"  Missing {len(missing)} keys (expected - discriminator/criteria): {missing[:3]}...")
    if unexpected:
        print(f"  Unexpected {len(unexpected)} keys: {unexpected[:3]}...")

    model.eval().to(DEVICE)
    _apply_layer_truncation(model.semantic_model)
    _apply_sdpa_patch(model.semantic_model)

    codec = FastXCodec2Codec.__new__(FastXCodec2Codec)
    codec._model_id = "HKUSTAudio/xcodec2"
    codec._model = model
    codec._device = DEVICE
    codec._dtype = torch.float32
    codec._use_compile = False
    codec._mel_extractor = GPUMelExtractor(model.feature_extractor, device=DEVICE)
    codec._mel_extractor.to(DEVICE)

    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    print("Warming up...")
    dummy = torch.randn(1, 1, 96000, device=DEVICE)
    for _ in range(3):
        from codecbench.codecs.base import TokenBatch
        tb = codec.encode(dummy, 16000)
        _ = codec.decode(tb)
    torch.cuda.synchronize()
    print("Model ready.\n")
    return codec


def benchmark_segment(codec, wav_path: Path):
    from codecbench.codecs.base import TokenBatch

    wav, sr = torchaudio.load(str(wav_path))
    if sr != 16000:
        wav = torchaudio.functional.resample(wav, sr, 16000)
        sr = 16000

    duration_s = wav.shape[1] / sr
    wav_gpu = wav.unsqueeze(0).to(DEVICE)  # [1, 1, T]

    # Encode
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    tb = codec.encode(wav_gpu, sr)
    torch.cuda.synchronize()
    encode_ms = (time.perf_counter() - t0) * 1000

    n_tokens = tb.tokens.shape[-1]

    # Decode
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    recon = codec.decode(tb)
    torch.cuda.synchronize()
    decode_ms = (time.perf_counter() - t0) * 1000

    recon_wav = recon.squeeze().cpu()
    orig_wav = wav.squeeze().cpu()

    # Trim to same length
    min_len = min(orig_wav.shape[0], recon_wav.shape[0])
    orig_wav = orig_wav[:min_len]
    recon_wav = recon_wav[:min_len]

    # SI-SDR (Scale-Invariant Signal-to-Distortion Ratio)
    orig_np = orig_wav.numpy().astype(np.float64)
    recon_np = recon_wav.numpy().astype(np.float64)
    dot = np.dot(orig_np, recon_np)
    s_target = dot * orig_np / (np.dot(orig_np, orig_np) + 1e-9)
    e_noise = recon_np - s_target
    si_sdr = 10 * np.log10(np.dot(s_target, s_target) / (np.dot(e_noise, e_noise) + 1e-9))

    # Save reconstructed audio
    out_path = wav_path.parent.parent.parent / f"recon_{wav_path.stem}.wav"
    torchaudio.save(str(out_path), recon_wav.unsqueeze(0), 16000)

    return {
        "file": wav_path.name,
        "duration_s": duration_s,
        "n_tokens": n_tokens,
        "encode_ms": encode_ms,
        "decode_ms": decode_ms,
        "encode_rtf": encode_ms / (duration_s * 1000),
        "decode_rtf": decode_ms / (duration_s * 1000),
        "si_sdr_db": si_sdr,
        "recon_path": str(out_path),
    }


def main():
    codec = load_model()

    segments = sorted(SEGMENTS_DIR.glob("seg_*.wav"))[:5]
    print(f"Benchmarking {len(segments)} segments from {SEGMENTS_DIR.name}\n")
    print(f"{'File':<16} {'Dur(s)':>7} {'Tokens':>7} {'Enc(ms)':>9} {'Dec(ms)':>9} {'Enc RTF':>8} {'Dec RTF':>8} {'SI-SDR':>8}")
    print("-" * 90)

    results = []
    for seg_path in segments:
        r = benchmark_segment(codec, seg_path)
        results.append(r)
        print(f"{r['file']:<16} {r['duration_s']:>7.2f} {r['n_tokens']:>7} {r['encode_ms']:>9.1f} {r['decode_ms']:>9.1f} {r['encode_rtf']:>8.4f} {r['decode_rtf']:>8.4f} {r['si_sdr_db']:>8.2f}")

    print("-" * 90)
    avg_enc = np.mean([r["encode_ms"] for r in results])
    avg_dec = np.mean([r["decode_ms"] for r in results])
    avg_dur = np.mean([r["duration_s"] for r in results])
    avg_si = np.mean([r["si_sdr_db"] for r in results])
    avg_enc_rtf = np.mean([r["encode_rtf"] for r in results])
    avg_dec_rtf = np.mean([r["decode_rtf"] for r in results])
    print(f"{'AVERAGE':<16} {avg_dur:>7.2f} {'':>7} {avg_enc:>9.1f} {avg_dec:>9.1f} {avg_enc_rtf:>8.4f} {avg_dec_rtf:>8.4f} {avg_si:>8.2f}")

    print(f"\nReconstructed audio saved to: {results[0]['recon_path'].rsplit('/', 1)[0]}/")
    print(f"\nRTF < 1.0 means faster than real-time.")


if __name__ == "__main__":
    main()
