"""Approach (d): hard-mute ALL SPEAKER_01 turns (including overlap with SP00).
Runs on every overlap clip listed in /tmp/abtest_indices.json.
"""
import json
from pathlib import Path
import numpy as np
import soundfile as sf
import librosa

HERE = Path(__file__).parent
SR = 16000
FADE = int(0.02 * SR)  # 20 ms fades
INDICES = json.load(open("/tmp/abtest_indices.json"))

manifest = json.load(open(HERE / "manifest.json"))
overlap_meta = {o["index"]: o for o in manifest["overlaps"]}
diarization = json.load(open(HERE / "diarization.json"))["output"]["diarization"]

OUT = HERE / "ab_test" / "d_full_sp01"
OUT.mkdir(parents=True, exist_ok=True)

for idx in INDICES:
    meta = overlap_meta[idx]
    win_s, win_e = meta["start"], meta["end"]
    audio_path = HERE / f"extracted/overlap_{idx:04d}_extracted.wav"
    audio, _ = librosa.load(str(audio_path), sr=SR)
    audio = audio.astype(np.float32).copy()
    n = len(audio)

    # All SP01 turns intersecting the window — no subtraction
    sp1 = sorted((max(t["start"], win_s), min(t["end"], win_e))
                 for t in diarization
                 if t["speaker"] == "SPEAKER_01"
                 and t["end"] > win_s and t["start"] < win_e)
    # merge adjacent
    merged = []
    for s, e in sp1:
        if merged and s <= merged[-1][1]:
            merged[-1] = (merged[-1][0], max(merged[-1][1], e))
        else:
            merged.append((s, e))

    for s, e in merged:
        i0 = max(0, min(n, int(round((s - win_s) * SR))))
        i1 = max(0, min(n, int(round((e - win_s) * SR))))
        if i1 <= i0:
            continue
        fo = min(i0 + FADE, i1)
        if fo > i0:
            audio[i0:fo] *= np.linspace(1, 0, fo - i0).astype(np.float32)
        if i1 - FADE > fo:
            audio[fo:i1 - FADE] = 0
        fi = max(i1 - FADE, fo)
        if i1 > fi:
            audio[fi:i1] *= np.linspace(0, 1, i1 - fi).astype(np.float32)

    sf.write(str(OUT / f"overlap_{idx:04d}.wav"), audio, SR, subtype="PCM_16")

print(f"Done. Wrote {len(INDICES)} files to {OUT}")
