"""Run SoloSpeech (default 200x4) on every overlap clip with main_speaker_ref.wav.
Loads all models ONCE for the whole batch — much faster than 156 separate runs."""
import os, sys, glob, time, yaml
from pathlib import Path
import torch
import torch.nn.functional as F
import librosa
import soundfile as sf
from huggingface_hub import snapshot_download
from diffusers import DDIMScheduler

# SoloSpeech imports
from solospeech.model.solospeech.conditioners import SoloSpeech_TSE
from solospeech.scripts.solospeech.utils import save_audio
from solospeech.vae_modules.autoencoder_wrapper import Autoencoder
from solospeech.corrector.fastgeco.model import ScoreModel
from solospeech.corrector.geco.util.other import pad_spec
from speechbrain.pretrained.interfaces import Pretrained


class Encoder(Pretrained):
    MODULES_NEEDED = ["compute_features", "mean_var_norm", "embedding_model"]

    def encode_batch(self, wavs, wav_lens=None, normalize=False):
        if len(wavs.shape) == 1:
            wavs = wavs.unsqueeze(0)
        if wav_lens is None:
            wav_lens = torch.ones(wavs.shape[0], device=self.device)
        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
        wavs = wavs.float()
        feats = self.mods.compute_features(wavs)
        feats = self.mods.mean_var_norm(feats, wav_lens)
        return self.mods.embedding_model(feats, wav_lens)


@torch.no_grad()
def sample_diffusion(tse_model, autoencoder, std, scheduler, device,
                     mixture, reference, lengths, reference_lengths,
                     ddim_steps=200, eta=0, seed=42):
    g = torch.Generator(device=device).manual_seed(seed)
    scheduler.set_timesteps(ddim_steps)
    pred = torch.randn(mixture.shape, generator=g, device=device)
    for t in scheduler.timesteps:
        pred = scheduler.scale_model_input(pred, t)
        out, _ = tse_model(x=pred, timesteps=t, mixture=mixture,
                           reference=reference, x_len=lengths, ref_len=reference_lengths)
        pred = scheduler.step(model_output=out, timestep=t, sample=pred,
                              eta=eta, generator=g).prev_sample
    pred = autoencoder(embedding=pred.transpose(2, 1), std=std).squeeze(1)
    return pred


# ---------------- config ----------------
HERE = Path(__file__).parent
REF = HERE / "main_speaker_ref.wav"
OVDIR = HERE / "overlaps"
OUTDIR = HERE / "extracted"
OUTDIR.mkdir(exist_ok=True)
NUM_INFER_STEPS = 200
NUM_CANDIDATES = 4
SEED = 42

# ---------------- load models once ----------------
print("[1/3] downloading checkpoints...")
local_dir = snapshot_download(repo_id="OpenSound/SoloSpeech-models")
tse_cfg_path = os.path.join(local_dir, "config_extractor.yaml")
vae_cfg_path = os.path.join(local_dir, "config_compressor.json")
ae_ckpt = os.path.join(local_dir, "compressor.ckpt")
tse_ckpt = os.path.join(local_dir, "extractor.pt")
geco_ckpt = os.path.join(local_dir, "corrector.ckpt")

device = "cuda:0"
print("[2/3] loading models...")
with open(tse_cfg_path) as fp:
    tse_cfg = yaml.safe_load(fp)
autoencoder = Autoencoder(ae_ckpt, vae_cfg_path, "stft_vae", quantization_first=True)
autoencoder.eval().to(device)
tse_model = SoloSpeech_TSE(tse_cfg["diffwrap"]["UDiT"], tse_cfg["diffwrap"]["ViT"]).to(device)
tse_model.load_state_dict(torch.load(tse_ckpt)["model"])
tse_model.eval()
geco_model = ScoreModel.load_from_checkpoint(
    geco_ckpt, batch_size=1, num_workers=0, kwargs=dict(gpu=False)
)
geco_model.eval(no_ema=False)
geco_model.cuda()
ecapatdnn = Encoder.from_hparams(source="yangwang825/ecapa-tdnn-vox2")
noise_scheduler = DDIMScheduler(**tse_cfg["ddim"]["diffusers"])
# warmup noise scheduler dtypes
_lat = torch.randn((1, 128, 128), device=device)
_n = torch.randn(_lat.shape).to(device)
_t = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=device).long()
_ = noise_scheduler.add_noise(_lat, _n, _t)

# ---------------- precompute reference embedding ----------------
print(f"[3/3] enrolling reference {REF.name}")
ref_wav, _ = librosa.load(str(REF), sr=16000)
ref_t = torch.tensor(ref_wav).unsqueeze(0).to(device)
ref_latent, _ = autoencoder(audio=ref_t.unsqueeze(1))
ref_latent_b = ref_latent.repeat(NUM_CANDIDATES, 1, 1)
ref_lengths = torch.LongTensor([ref_latent.shape[-1]] * NUM_CANDIDATES).to(device)
ecapa_ref = ecapatdnn.encode_batch(torch.tensor(ref_wav)).squeeze()


# ---------------- batch loop ----------------
def extract_one(test_path, out_path):
    mixture, _ = librosa.load(test_path, sr=16000)
    with torch.no_grad():
        m = torch.tensor(mixture).unsqueeze(0).to(device)
        mixture_wav = m
        m_lat, std = autoencoder(audio=m.unsqueeze(1))
        lengths = torch.LongTensor([m_lat.shape[-1]] * NUM_CANDIDATES).to(device)
        m_lat = m_lat.repeat(NUM_CANDIDATES, 1, 1)

        tse_pred = sample_diffusion(
            tse_model, autoencoder, std, noise_scheduler, device,
            m_lat.transpose(2, 1), ref_latent_b.transpose(2, 1),
            lengths, ref_lengths, ddim_steps=NUM_INFER_STEPS, seed=SEED,
        )
        # rerank
        emb_pred = ecapatdnn.encode_batch(tse_pred).squeeze()
        sims = F.cosine_similarity(emb_pred, ecapa_ref.unsqueeze(0), dim=1)
        _, idx = torch.max(sims, dim=0)
        pred = tse_pred[idx].unsqueeze(0)

        # corrector (geco)
        min_leng = min(pred.shape[-1], mixture_wav.shape[-1])
        x = pred[..., :min_leng]
        mw = mixture_wav[..., :min_leng]
        norm = mw.abs().max()
        x = x / norm
        mw = mw / norm
        X = torch.unsqueeze(geco_model._forward_transform(geco_model._stft(x.cuda())), 0)
        X = pad_spec(X)
        M = torch.unsqueeze(geco_model._forward_transform(geco_model._stft(mw.cuda())), 0)
        M = pad_spec(M)
        timesteps = torch.linspace(0.5, 0.03, 1, device=M.device)
        std_ge = geco_model.sde._std(0.5 * torch.ones((M.shape[0],), device=M.device))
        z = torch.randn_like(M)
        X_t = M + z * std_ge[:, None, None, None]
        for jdx in range(len(timesteps)):
            t = timesteps[jdx]
            dt = t - timesteps[jdx + 1] if jdx != len(timesteps) - 1 else timesteps[-1]
            f, g = geco_model.sde.sde(X_t, t, M)
            vec_t = torch.ones(M.shape[0], device=M.device) * t
            mean_x_tm1 = X_t - (f - g**2 * geco_model.forward(
                X_t, vec_t, M, X, vec_t[:, None, None, None])) * dt
            if jdx == len(timesteps) - 1:
                X_t = mean_x_tm1
                break
            zz = torch.randn_like(X)
            X_t = mean_x_tm1 + zz * g * torch.sqrt(dt)
        sample = X_t.squeeze()
        x_hat = geco_model.to_audio(sample.squeeze(), min_leng)
        x_hat = x_hat * norm / x_hat.abs().max()
        x_hat = x_hat.detach().cpu()
        save_audio(out_path, 16000, x_hat)


files = sorted(glob.glob(str(OVDIR / "overlap_*.wav")))
print(f"\nProcessing {len(files)} overlap clips ...")
t0 = time.time()
for i, fp in enumerate(files):
    out = OUTDIR / (Path(fp).stem + "_extracted.wav")
    if out.exists():
        continue
    try:
        extract_one(fp, str(out))
    except torch.cuda.OutOfMemoryError as e:
        print(f"  [{i+1}/{len(files)}] OOM on {Path(fp).name} -- skipping ({e})")
        torch.cuda.empty_cache()
        continue
    elapsed = time.time() - t0
    eta = elapsed / (i + 1) * (len(files) - i - 1)
    print(f"  [{i+1}/{len(files)}] {Path(fp).name} -> {out.name} "
          f"(elapsed {elapsed:.0f}s, eta {eta:.0f}s)")
print(f"\nDone in {time.time()-t0:.0f}s. Output: {OUTDIR}")
