"""Mute SPEAKER_01 in clip_4m12s_to_13m10s.wav using pyannote precision-2
exclusiveDiarization. Keeps SPEAKER_00 and all gaps as-is. 20ms fades."""
import json
from pathlib import Path
import numpy as np
import soundfile as sf

HERE = Path(__file__).parent
JSON_PATH = HERE / "dairize.json"
INPUT = HERE / "clip_4m12s_to_13m10s.wav"
OUTPUT = HERE / "clip_speaker00_only.wav"

MUTE_SPEAKER = "SPEAKER_01"
FADE_MS = 20
MERGE_GAP = 0.05

job = json.load(open(JSON_PATH))
output = job.get("output") or job  # tolerate either layout
turns = output.get("exclusiveDiarization") or output.get("diarization") or []
intervals = sorted((t["start"], t["end"]) for t in turns if t["speaker"] == MUTE_SPEAKER)

merged = []
for s, e in intervals:
    if merged and s <= merged[-1][1] + MERGE_GAP:
        merged[-1] = (merged[-1][0], max(merged[-1][1], e))
    else:
        merged.append((s, e))
print(f"Muting {len(merged)} merged intervals of {MUTE_SPEAKER} "
      f"(total {sum(e-s for s,e in merged):.1f}s)")

audio, sr = sf.read(str(INPUT), always_2d=True)
n = audio.shape[0]
fade_len = max(1, int(round(FADE_MS / 1000 * sr)))
print(f"sr={sr}, samples={n}, channels={audio.shape[1]}")

audio = audio.astype(np.float32)
for s, e in merged:
    i0 = max(0, min(n, int(round(s * sr))))
    i1 = max(0, min(n, int(round(e * sr))))
    if i1 <= i0:
        continue
    fo_end = min(i0 + fade_len, i1)
    if fo_end > i0:
        ramp = np.linspace(1.0, 0.0, fo_end - i0, dtype=np.float32)[:, None]
        audio[i0:fo_end] *= ramp
    if i1 - fade_len > fo_end:
        audio[fo_end : i1 - fade_len] = 0
    fi_start = max(i1 - fade_len, fo_end)
    if i1 > fi_start:
        ramp = np.linspace(0.0, 1.0, i1 - fi_start, dtype=np.float32)[:, None]
        audio[fi_start:i1] *= ramp

sf.write(str(OUTPUT), audio, sr, subtype="PCM_16")
print(f"Done -> {OUTPUT}")
