"""Compare XCodec Indic vs NanoCodec on Telugu audio."""

import sys
sys.path.insert(0, "/home/ubuntu/bench-codecs")

import time
import torch
import torchaudio
import numpy as np

AUDIO_PATH = "/home/ubuntu/Telugu_ Movie Roast.wav"
DEVICE = "cuda"


def si_sdr(orig, recon):
    min_len = min(len(orig), len(recon))
    o = orig[:min_len].astype(np.float64)
    r = recon[:min_len].astype(np.float64)
    dot = np.dot(o, r)
    s_target = dot * o / (np.dot(o, o) + 1e-9)
    e_noise = r - s_target
    return 10 * np.log10(np.dot(s_target, s_target) / (np.dot(e_noise, e_noise) + 1e-9))


def load_xcodec():
    from xcodec2.modeling_xcodec2 import XCodec2Model
    from codecbench.codecs.xcodec2_fast import (
        GPUMelExtractor, _apply_layer_truncation, _apply_sdpa_patch, FastXCodec2Codec,
    )

    model = XCodec2Model.from_pretrained("HKUSTAudio/xcodec2")
    ckpt = torch.load("/home/ubuntu/xcodec_indic/indic_step_00198000.ckpt", map_location="cpu", weights_only=False)
    state_dict = ckpt.get("state_dict", ckpt)
    cleaned = {}
    for k, v in state_dict.items():
        k = k.replace("model.", "", 1) if k.startswith("model.") else k
        k = k.replace(".act.beta", ".act.bias")
        cleaned[k] = v
    model.load_state_dict(cleaned, strict=False)
    model.eval().to(DEVICE)
    _apply_layer_truncation(model.semantic_model)
    _apply_sdpa_patch(model.semantic_model)

    codec = FastXCodec2Codec.__new__(FastXCodec2Codec)
    codec._model_id = "HKUSTAudio/xcodec2"
    codec._model = model
    codec._device = DEVICE
    codec._dtype = torch.float32
    codec._use_compile = False
    codec._mel_extractor = GPUMelExtractor(model.feature_extractor, device=DEVICE)
    codec._mel_extractor.to(DEVICE)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    # warmup
    dummy = torch.randn(1, 1, 96000, device=DEVICE)
    for _ in range(3):
        from codecbench.codecs.base import TokenBatch
        tb = codec.encode(dummy, 16000)
        _ = codec.decode(tb)
    torch.cuda.synchronize()
    return codec


def load_nanocodec():
    from nemo.collections.tts.models import AudioCodecModel
    model = AudioCodecModel.from_pretrained("nvidia/nemo-nano-codec-22khz-1.89kbps-21.5fps")
    model.eval().to(DEVICE)

    # warmup
    dummy = torch.randn(1, 22050 * 6, device=DEVICE)
    dummy_len = torch.tensor([22050 * 6], device=DEVICE)
    for _ in range(3):
        with torch.no_grad():
            codes, cl = model.encode(audio=dummy, audio_len=dummy_len)
            model.decode(tokens=codes, tokens_len=cl)
    torch.cuda.synchronize()
    return model


def run_xcodec(codec, wav, sr):
    from codecbench.codecs.base import TokenBatch
    wav_16k = torchaudio.functional.resample(wav, sr, 16000)
    wav_gpu = wav_16k.unsqueeze(0).to(DEVICE)  # [1, 1, T]

    torch.cuda.synchronize()
    t0 = time.perf_counter()
    tb = codec.encode(wav_gpu, 16000)
    torch.cuda.synchronize()
    enc_ms = (time.perf_counter() - t0) * 1000

    torch.cuda.synchronize()
    t0 = time.perf_counter()
    recon = codec.decode(tb)
    torch.cuda.synchronize()
    dec_ms = (time.perf_counter() - t0) * 1000

    recon_wav = recon.squeeze().cpu().numpy()
    orig_wav = wav_16k.squeeze().numpy()
    quality = si_sdr(orig_wav, recon_wav)

    out_path = "/home/ubuntu/telugu_recon_xcodec.wav"
    torchaudio.save(out_path, torch.from_numpy(recon_wav).unsqueeze(0), 16000)

    n_tokens = tb.tokens.shape[-1]
    dur = wav_16k.shape[1] / 16000
    return enc_ms, dec_ms, quality, n_tokens, dur, out_path


def run_nanocodec(model, wav, sr):
    target_sr = model.sample_rate
    wav_rs = torchaudio.functional.resample(wav, sr, target_sr)
    wav_gpu = wav_rs.squeeze(0).to(DEVICE)
    wav_len = torch.tensor([wav_gpu.shape[0]], device=DEVICE)

    torch.cuda.synchronize()
    t0 = time.perf_counter()
    with torch.no_grad():
        codes, codes_len = model.encode(audio=wav_gpu.unsqueeze(0), audio_len=wav_len)
    torch.cuda.synchronize()
    enc_ms = (time.perf_counter() - t0) * 1000

    torch.cuda.synchronize()
    t0 = time.perf_counter()
    with torch.no_grad():
        recon_audio, recon_len = model.decode(tokens=codes, tokens_len=codes_len)
    torch.cuda.synchronize()
    dec_ms = (time.perf_counter() - t0) * 1000

    recon_wav = recon_audio.squeeze().cpu().numpy()
    orig_wav = wav_rs.squeeze().numpy()
    quality = si_sdr(orig_wav, recon_wav)

    out_path = "/home/ubuntu/telugu_recon_nanocodec.wav"
    torchaudio.save(out_path, torch.from_numpy(recon_wav).unsqueeze(0), target_sr)

    n_tokens = codes.shape[-1]
    n_cb = codes.shape[1]
    dur = wav_rs.shape[1] / target_sr
    return enc_ms, dec_ms, quality, n_tokens, n_cb, dur, out_path


def main():
    wav, sr = torchaudio.load(AUDIO_PATH)
    dur = wav.shape[1] / sr
    print(f"Input: {AUDIO_PATH}")
    print(f"Duration: {dur:.2f}s, SR: {sr}Hz\n")

    print("Loading XCodec Indic...")
    xcodec = load_xcodec()
    print("Loading NanoCodec...")
    nanocodec = load_nanocodec()
    print()

    # XCodec
    x_enc, x_dec, x_si, x_tok, x_dur, x_path = run_xcodec(xcodec, wav, sr)

    # Free XCodec VRAM for NanoCodec
    del xcodec
    torch.cuda.empty_cache()

    # NanoCodec
    n_enc, n_dec, n_si, n_tok, n_cb, n_dur, n_path = run_nanocodec(nanocodec, wav, sr)

    print(f"{'':20} {'XCodec Indic':>15} {'NanoCodec':>15}")
    print(f"{'-'*52}")
    print(f"{'Output SR':20} {'16 kHz':>15} {'22 kHz':>15}")
    print(f"{'Codebooks':20} {'1':>15} {n_cb:>15}")
    print(f"{'Tokens (frames)':20} {x_tok:>15} {n_tok:>15}")
    print(f"{'FPS':20} {x_tok/x_dur:>15.1f} {n_tok/n_dur:>15.1f}")
    print(f"{'Encode (ms)':20} {x_enc:>15.1f} {n_enc:>15.1f}")
    print(f"{'Decode (ms)':20} {x_dec:>15.1f} {n_dec:>15.1f}")
    print(f"{'Encode RTF':20} {x_enc/(x_dur*1000):>15.4f} {n_enc/(n_dur*1000):>15.4f}")
    print(f"{'Decode RTF':20} {x_dec/(x_dur*1000):>15.4f} {n_dec/(n_dur*1000):>15.4f}")
    print(f"{'SI-SDR (dB)':20} {x_si:>15.2f} {n_si:>15.2f}")
    print(f"\nReconstructed files:")
    print(f"  XCodec:    {x_path}")
    print(f"  NanoCodec: {n_path}")


if __name__ == "__main__":
    main()
