"""A/B test 3 leakage-removal approaches on 5 selected overlap clips:
  (a) SoloSpeech 2nd pass (iterate)
  (b) Hard-mute SPEAKER_01-exclusive sub-frames using pyannote diarization
  (c) Per-window ECAPA cosine filter
Outputs to ab_test/  with subfolders for each approach.
"""
import os, json, glob, time, yaml
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
import librosa
import soundfile as sf
from huggingface_hub import snapshot_download
from diffusers import DDIMScheduler

from solospeech.model.solospeech.conditioners import SoloSpeech_TSE
from solospeech.scripts.solospeech.utils import save_audio
from solospeech.vae_modules.autoencoder_wrapper import Autoencoder
from solospeech.corrector.fastgeco.model import ScoreModel
from solospeech.corrector.geco.util.other import pad_spec
from speechbrain.pretrained.interfaces import Pretrained


class Encoder(Pretrained):
    MODULES_NEEDED = ["compute_features", "mean_var_norm", "embedding_model"]
    def encode_batch(self, wavs, wav_lens=None):
        if len(wavs.shape) == 1:
            wavs = wavs.unsqueeze(0)
        if wav_lens is None:
            wav_lens = torch.ones(wavs.shape[0], device=self.device)
        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
        wavs = wavs.float()
        feats = self.mods.compute_features(wavs)
        feats = self.mods.mean_var_norm(feats, wav_lens)
        return self.mods.embedding_model(feats, wav_lens)


@torch.no_grad()
def sample_diffusion(tse_model, autoencoder, std, scheduler, device,
                     mixture, reference, lengths, reference_lengths,
                     ddim_steps=200, eta=0, seed=42):
    g = torch.Generator(device=device).manual_seed(seed)
    scheduler.set_timesteps(ddim_steps)
    pred = torch.randn(mixture.shape, generator=g, device=device)
    for t in scheduler.timesteps:
        pred = scheduler.scale_model_input(pred, t)
        out, _ = tse_model(x=pred, timesteps=t, mixture=mixture,
                           reference=reference, x_len=lengths, ref_len=reference_lengths)
        pred = scheduler.step(model_output=out, timestep=t, sample=pred,
                              eta=eta, generator=g).prev_sample
    return autoencoder(embedding=pred.transpose(2, 1), std=std).squeeze(1)


# ---------- setup ----------
HERE = Path(__file__).parent
SR = 16000
NUM_CANDIDATES = 4
NUM_INFER_STEPS = 200
SEED = 42
INDICES = json.load(open("/tmp/abtest_indices.json"))
print("Indices:", INDICES)

OUT_BASE = HERE / "ab_test"
(OUT_BASE / "a_iterate").mkdir(parents=True, exist_ok=True)
(OUT_BASE / "b_hardmute").mkdir(parents=True, exist_ok=True)
(OUT_BASE / "c_ecapa_window").mkdir(parents=True, exist_ok=True)

manifest = json.load(open(HERE / "manifest.json"))
overlap_meta = {o["index"]: o for o in manifest["overlaps"]}
diarization = json.load(open(HERE / "diarization.json"))["output"]["diarization"]

# ---------- load SoloSpeech models ONCE ----------
print("\nLoading SoloSpeech models...")
local_dir = snapshot_download(repo_id="OpenSound/SoloSpeech-models")
device = "cuda:0"
with open(os.path.join(local_dir, "config_extractor.yaml")) as fp:
    tse_cfg = yaml.safe_load(fp)
autoencoder = Autoencoder(os.path.join(local_dir, "compressor.ckpt"),
                          os.path.join(local_dir, "config_compressor.json"),
                          "stft_vae", quantization_first=True).eval().to(device)
tse_model = SoloSpeech_TSE(tse_cfg["diffwrap"]["UDiT"],
                           tse_cfg["diffwrap"]["ViT"]).to(device)
tse_model.load_state_dict(torch.load(os.path.join(local_dir, "extractor.pt"))["model"])
tse_model.eval()
geco_model = ScoreModel.load_from_checkpoint(
    os.path.join(local_dir, "corrector.ckpt"),
    batch_size=1, num_workers=0, kwargs=dict(gpu=False)
)
geco_model.eval(no_ema=False)
geco_model.cuda()
ecapa = Encoder.from_hparams(source="yangwang825/ecapa-tdnn-vox2")
noise_scheduler = DDIMScheduler(**tse_cfg["ddim"]["diffusers"])
_lat = torch.randn((1, 128, 128), device=device)
_n = torch.randn(_lat.shape).to(device)
_t = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=device).long()
_ = noise_scheduler.add_noise(_lat, _n, _t)

# ---------- enroll reference ONCE ----------
ref_wav, _ = librosa.load(str(HERE / "main_speaker_ref.wav"), sr=SR)
ref_t = torch.tensor(ref_wav).unsqueeze(0).to(device)
ref_latent, _ = autoencoder(audio=ref_t.unsqueeze(1))
ref_latent_b = ref_latent.repeat(NUM_CANDIDATES, 1, 1)
ref_lengths = torch.LongTensor([ref_latent.shape[-1]] * NUM_CANDIDATES).to(device)
ref_emb = ecapa.encode_batch(torch.tensor(ref_wav)).squeeze()


# ---------- approach (a) iterate SoloSpeech ----------
def solospeech_pass(mixture_np):
    with torch.no_grad():
        m = torch.tensor(mixture_np).unsqueeze(0).to(device)
        mw = m
        m_lat, std = autoencoder(audio=m.unsqueeze(1))
        L = torch.LongTensor([m_lat.shape[-1]] * NUM_CANDIDATES).to(device)
        m_lat = m_lat.repeat(NUM_CANDIDATES, 1, 1)
        tse_pred = sample_diffusion(tse_model, autoencoder, std, noise_scheduler, device,
                                    m_lat.transpose(2, 1), ref_latent_b.transpose(2, 1),
                                    L, ref_lengths, ddim_steps=NUM_INFER_STEPS, seed=SEED)
        emb_pred = ecapa.encode_batch(tse_pred).squeeze()
        sims = F.cosine_similarity(emb_pred, ref_emb.unsqueeze(0), dim=1)
        _, idx = torch.max(sims, dim=0)
        pred = tse_pred[idx].unsqueeze(0)
        # GeCo corrector
        min_leng = min(pred.shape[-1], mw.shape[-1])
        x = pred[..., :min_leng]
        mw2 = mw[..., :min_leng]
        norm = mw2.abs().max()
        x = x / norm
        mw2 = mw2 / norm
        X = torch.unsqueeze(geco_model._forward_transform(geco_model._stft(x.cuda())), 0)
        X = pad_spec(X)
        M = torch.unsqueeze(geco_model._forward_transform(geco_model._stft(mw2.cuda())), 0)
        M = pad_spec(M)
        ts = torch.linspace(0.5, 0.03, 1, device=M.device)
        std_ge = geco_model.sde._std(0.5 * torch.ones((M.shape[0],), device=M.device))
        z = torch.randn_like(M)
        X_t = M + z * std_ge[:, None, None, None]
        for jdx in range(len(ts)):
            t = ts[jdx]
            dt = t - ts[jdx + 1] if jdx != len(ts) - 1 else ts[-1]
            f, g = geco_model.sde.sde(X_t, t, M)
            vec_t = torch.ones(M.shape[0], device=M.device) * t
            mean = X_t - (f - g**2 * geco_model.forward(X_t, vec_t, M, X, vec_t[:, None, None, None])) * dt
            if jdx == len(ts) - 1:
                X_t = mean
                break
            zz = torch.randn_like(X)
            X_t = mean + zz * g * torch.sqrt(dt)
        sample = X_t.squeeze()
        x_hat = geco_model.to_audio(sample.squeeze(), min_leng)
        x_hat = x_hat * norm / x_hat.abs().max()
        return x_hat.detach().cpu().numpy().squeeze()


def approach_a(idx):
    """SoloSpeech 2nd pass: feed extracted output as new mixture."""
    extracted_path = HERE / f"extracted/overlap_{idx:04d}_extracted.wav"
    audio, _ = librosa.load(str(extracted_path), sr=SR)
    out = solospeech_pass(audio)
    save_audio(str(OUT_BASE / "a_iterate" / f"overlap_{idx:04d}.wav"), SR, torch.tensor(out))


def approach_b(idx):
    """Hard-mute SPEAKER_01-exclusive sub-frames inside the loose-cut window."""
    meta = overlap_meta[idx]
    win_s, win_e = meta["start"], meta["end"]
    extracted_path = HERE / f"extracted/overlap_{idx:04d}_extracted.wav"
    audio, _ = librosa.load(str(extracted_path), sr=SR)
    n = len(audio)

    # Find SPEAKER_01-only intervals inside this window:
    # = SPEAKER_01 turns minus any SPEAKER_00 turn overlap
    sp1 = [(t["start"], t["end"]) for t in diarization
           if t["speaker"] == "SPEAKER_01"
           and t["end"] > win_s and t["start"] < win_e]
    sp0 = [(t["start"], t["end"]) for t in diarization
           if t["speaker"] == "SPEAKER_00"
           and t["end"] > win_s and t["start"] < win_e]
    # subtract sp0 from sp1
    def subtract(intervals, holes):
        out = []
        for s, e in intervals:
            cur = [(s, e)]
            for hs, he in holes:
                new = []
                for cs, ce in cur:
                    if he <= cs or hs >= ce:
                        new.append((cs, ce)); continue
                    if hs > cs: new.append((cs, hs))
                    if he < ce: new.append((he, ce))
                cur = new
            out.extend(cur)
        return out
    sp1_only = subtract(sp1, sp0)
    # clip to window and translate to clip-relative samples
    fade = int(round(0.02 * SR))  # 20ms fades
    audio = audio.copy()
    for s, e in sp1_only:
        s = max(s, win_s); e = min(e, win_e)
        if e <= s: continue
        i0 = int(round((s - win_s) * SR))
        i1 = int(round((e - win_s) * SR))
        i0 = max(0, min(n, i0)); i1 = max(0, min(n, i1))
        if i1 <= i0: continue
        fo = min(i0 + fade, i1)
        if fo > i0:
            audio[i0:fo] *= np.linspace(1, 0, fo - i0).astype(audio.dtype)
        if i1 - fade > fo:
            audio[fo:i1 - fade] = 0
        fi = max(i1 - fade, fo)
        if i1 > fi:
            audio[fi:i1] *= np.linspace(0, 1, i1 - fi).astype(audio.dtype)
    sf.write(str(OUT_BASE / "b_hardmute" / f"overlap_{idx:04d}.wav"), audio, SR, subtype="PCM_16")


def approach_c(idx, win_ms=400, hop_ms=200, sim_thresh=0.4):
    """Per-window ECAPA: walk audio in sliding windows, attenuate windows
    whose ECAPA cosine sim to reference falls below threshold."""
    extracted_path = HERE / f"extracted/overlap_{idx:04d}_extracted.wav"
    audio, _ = librosa.load(str(extracted_path), sr=SR)
    n = len(audio)
    win = int(SR * win_ms / 1000)
    hop = int(SR * hop_ms / 1000)
    if n < win:
        sf.write(str(OUT_BASE / "c_ecapa_window" / f"overlap_{idx:04d}.wav"), audio, SR)
        return
    gain = np.ones(n, dtype=np.float32)
    fade = int(0.02 * SR)
    for start in range(0, n - win + 1, hop):
        chunk = audio[start:start + win]
        emb = ecapa.encode_batch(torch.tensor(chunk)).squeeze()
        sim = F.cosine_similarity(emb.unsqueeze(0), ref_emb.unsqueeze(0)).item()
        if sim < sim_thresh:
            # zero out the hop-region centered in this window
            i0 = start + (win - hop) // 2
            i1 = i0 + hop
            i0 = max(0, i0); i1 = min(n, i1)
            gain[i0:i1] = 0
    # smooth: linear interp transitions
    smoothed = gain.copy()
    # Apply 20ms fades on every 0->1 / 1->0 transition
    diff = np.diff(np.concatenate([[1.0], gain, [1.0]]))
    transitions = np.where(diff != 0)[0]
    for ti in transitions:
        if ti == 0 or ti >= n: continue
        prev = smoothed[ti - 1]
        nxt = smoothed[ti] if ti < n else 1.0
        ramp_end = min(n, ti + fade)
        smoothed[ti:ramp_end] = np.linspace(prev, nxt, ramp_end - ti)
    out = audio * smoothed
    sf.write(str(OUT_BASE / "c_ecapa_window" / f"overlap_{idx:04d}.wav"), out, SR, subtype="PCM_16")


# ---------- run all 3 approaches on each clip ----------
for idx in INDICES:
    print(f"\n--- clip #{idx} ---")
    t0 = time.time()
    approach_a(idx); print(f"  (a) iterate    {time.time()-t0:.1f}s")
    t0 = time.time()
    approach_b(idx); print(f"  (b) hardmute   {time.time()-t0:.1f}s")
    t0 = time.time()
    approach_c(idx); print(f"  (c) ecapa win  {time.time()-t0:.1f}s")

print(f"\nDone. Output: {OUT_BASE}/")
