#!/usr/bin/env python3
"""Extract speaker WAVs from exclusive (no-overlap) diarization."""
import json, os, numpy as np, soundfile as sf

INPUT_FILE = "/home/ubuntu/bob5_vocals_16k.wav"
DIAR_JSON = "/home/ubuntu/bob5_pyannote/diarization_exclusive.json"
OUTPUT_DIR = "/home/ubuntu/bob5_pyannote_exclusive"
os.makedirs(OUTPUT_DIR, exist_ok=True)

with open(DIAR_JSON) as f:
    segments = json.load(f)

speakers = {}
for seg in segments:
    speakers.setdefault(seg["speaker"], []).append((seg["start"], seg["end"]))

audio, sr = sf.read(INPUT_FILE, dtype='float32')
if audio.ndim > 1:
    audio = audio.mean(axis=1)

print(f"Audio: {len(audio)/sr:.1f}s at {sr}Hz")
print(f"Found {len(speakers)} speakers (exclusive, no overlaps)\n")

for spk in sorted(speakers, key=lambda s: sum(e-st for st,e in speakers[s]), reverse=True):
    total = sum(e - s for s, e in speakers[spk])
    chunks = []
    for s, e in speakers[spk]:
        si, ei = int(s * sr), min(int(e * sr), len(audio))
        chunks.append(audio[si:ei])
    out = np.concatenate(chunks)
    path = os.path.join(OUTPUT_DIR, f"{spk}.wav")
    sf.write(path, out, sr)
    print(f"  {spk}: {total:.1f}s ({total/60:.1f} min), {len(speakers[spk])} segments -> {path}")

# Also save the exclusive JSON there
with open(os.path.join(OUTPUT_DIR, "diarization_exclusive.json"), "w") as f:
    json.dump(segments, f, indent=2)

print("\nDone! All overlaps removed.")
