"""For each extracted/overlap_NNN_extracted.wav, compute ECAPA cosine
similarity vs main_speaker_ref.wav. Save to similarities.json."""
import json, glob, time
from pathlib import Path
import torch
import torch.nn.functional as F
import librosa
from speechbrain.pretrained.interfaces import Pretrained


class Encoder(Pretrained):
    MODULES_NEEDED = ["compute_features", "mean_var_norm", "embedding_model"]

    def encode_batch(self, wavs, wav_lens=None):
        if len(wavs.shape) == 1:
            wavs = wavs.unsqueeze(0)
        if wav_lens is None:
            wav_lens = torch.ones(wavs.shape[0], device=self.device)
        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
        wavs = wavs.float()
        feats = self.mods.compute_features(wavs)
        feats = self.mods.mean_var_norm(feats, wav_lens)
        return self.mods.embedding_model(feats, wav_lens)


HERE = Path(__file__).parent
REF = HERE / "main_speaker_ref.wav"
EXTRACTED = sorted(glob.glob(str(HERE / "extracted" / "overlap_*.wav")))

print(f"Loading ECAPA-TDNN...")
ecapa = Encoder.from_hparams(source="yangwang825/ecapa-tdnn-vox2")

print(f"Encoding reference: {REF.name}")
ref_wav, _ = librosa.load(str(REF), sr=16000)
ref_emb = ecapa.encode_batch(torch.tensor(ref_wav)).squeeze()

manifest = json.load(open(HERE / "manifest.json"))
overlap_meta = {o["index"]: o for o in manifest["overlaps"]}

results = []
t0 = time.time()
for i, fp in enumerate(EXTRACTED):
    idx = int(Path(fp).stem.split("_")[1])
    wav, _ = librosa.load(fp, sr=16000)
    if len(wav) == 0:
        sim = 0.0
    else:
        emb = ecapa.encode_batch(torch.tensor(wav)).squeeze()
        sim = F.cosine_similarity(emb.unsqueeze(0), ref_emb.unsqueeze(0)).item()
    meta = overlap_meta.get(idx, {})
    results.append({
        "index": idx,
        "start": meta.get("start"),
        "end": meta.get("end"),
        "duration": round((meta.get("end", 0) - meta.get("start", 0)), 2),
        "speakers": meta.get("speakers", []),
        "similarity": round(sim, 4),
        "original": f"overlaps/overlap_{idx:04d}.wav",
        "extracted": f"extracted/overlap_{idx:04d}_extracted.wav",
    })
print(f"Computed {len(results)} similarities in {time.time()-t0:.1f}s")

(HERE / "similarities.json").write_text(json.dumps(results, indent=2))
sims = [r["similarity"] for r in results]
print(f"Sim stats: min={min(sims):.3f} median={sorted(sims)[len(sims)//2]:.3f} max={max(sims):.3f}")
print(f"  >0.7: {sum(1 for s in sims if s > 0.7)} clips")
print(f"  >0.5: {sum(1 for s in sims if s > 0.5)} clips")
print(f"  <0.3: {sum(1 for s in sims if s < 0.3)} clips")
