"""
Mute SPEAKER_01 in RAW_TALKS_clip_2m33s_to_10m10s.mp3 using diarization
intervals from dairize.json. Keeps SPEAKER_00 and all gaps as-is.
20ms linear fade in/out at every mute boundary to avoid clicks.
"""
import json
import subprocess
import sys
from pathlib import Path
import numpy as np
import soundfile as sf

HERE = Path(__file__).parent
JSON_PATH = HERE / "dairize.json"
INPUT_MP3 = HERE / "RAW_TALKS_clip_2m33s_to_10m10s.mp3"
TMP_WAV = HERE / "_tmp_input.wav"
OUT_WAV = HERE / "_tmp_output.wav"
OUT_MP3 = HERE / "RAW_TALKS_clip_speaker00_only.mp3"

MUTE_SPEAKER = "SPEAKER_01"
FADE_MS = 20  # linear fade in/out at each mute edge
MERGE_GAP = 0.05  # merge mute intervals closer than this

# 1. Load diarization, collect intervals to mute
data = json.load(open(JSON_PATH))
intervals = [
    (s["start"], s["end"]) for s in data["diarization"] if s["speaker"] == MUTE_SPEAKER
]
intervals.sort()
# Merge overlapping/adjacent
merged = []
for s, e in intervals:
    if merged and s <= merged[-1][1] + MERGE_GAP:
        merged[-1] = (merged[-1][0], max(merged[-1][1], e))
    else:
        merged.append((s, e))
print(f"Muting {len(merged)} merged intervals of {MUTE_SPEAKER} "
      f"(total {sum(e-s for s,e in merged):.1f}s)")

# 2. Decode mp3 -> wav (44.1k stereo preserved)
subprocess.run(
    ["ffmpeg", "-y", "-i", str(INPUT_MP3), "-c:a", "pcm_s16le", str(TMP_WAV)],
    check=True, capture_output=True,
)

# 3. Load, mute with fades, save
audio, sr = sf.read(str(TMP_WAV), always_2d=True)  # shape (N, ch)
n = audio.shape[0]
fade_len = max(1, int(round(FADE_MS / 1000 * sr)))
print(f"sr={sr}, samples={n}, channels={audio.shape[1]}, fade_samples={fade_len}")

for s, e in merged:
    i0 = int(round(s * sr))
    i1 = int(round(e * sr))
    i0 = max(0, min(n, i0))
    i1 = max(0, min(n, i1))
    if i1 <= i0:
        continue
    # Fade-out at the start of the mute region (segment goes 1 -> 0)
    fo_end = min(i0 + fade_len, i1)
    if fo_end > i0:
        ramp = np.linspace(1.0, 0.0, fo_end - i0, dtype=audio.dtype)[:, None]
        audio[i0:fo_end] *= ramp
    # Hard zero in the middle
    if i1 - fade_len > fo_end:
        audio[fo_end : i1 - fade_len] = 0
    # Fade-in at the end of the mute region (segment goes 0 -> 1)
    fi_start = max(i1 - fade_len, fo_end)
    if i1 > fi_start:
        ramp = np.linspace(0.0, 1.0, i1 - fi_start, dtype=audio.dtype)[:, None]
        audio[fi_start:i1] *= ramp

sf.write(str(OUT_WAV), audio, sr, subtype="PCM_16")

# 4. Re-encode to mp3 at a reasonable bitrate
subprocess.run(
    ["ffmpeg", "-y", "-i", str(OUT_WAV), "-c:a", "libmp3lame", "-b:a", "192k",
     str(OUT_MP3)],
    check=True, capture_output=True,
)

# Cleanup temps
TMP_WAV.unlink(missing_ok=True)
OUT_WAV.unlink(missing_ok=True)
print(f"Done -> {OUT_MP3}")
