#!/usr/bin/env python3
"""Re-extract speaker WAVs: concatenate segments with no silence gaps, overlaps already removed."""
import json, os, numpy as np, soundfile as sf

INPUT_FILE = "/home/ubuntu/bob5_vocals_16k.wav"
STATS_FILE = "/home/ubuntu/bob5_clean/diarization_raw.json"
OUTPUT_DIR = "/home/ubuntu/bob5_final"
os.makedirs(OUTPUT_DIR, exist_ok=True)

with open(STATS_FILE) as f:
    all_segments = json.load(f)

audio, sr = sf.read(INPUT_FILE, dtype='float32')
if audio.ndim > 1:
    audio = audio.mean(axis=1)
total_samples = len(audio)

# Build speaker segments
speakers = {}
for seg in all_segments:
    speakers.setdefault(seg["speaker"], []).append((seg["start"], seg["end"]))

# Build overlap mask
speaker_masks = {}
for spk, segs in speakers.items():
    mask = np.zeros(total_samples, dtype=np.bool_)
    for s, e in segs:
        si, ei = int(s * sr), min(int(e * sr), total_samples)
        mask[si:ei] = True
    speaker_masks[spk] = mask

active_count = np.zeros(total_samples, dtype=np.int8)
for mask in speaker_masks.values():
    active_count += mask.astype(np.int8)
overlap_mask = active_count >= 2

# For each speaker: collect clean (non-overlap) chunks, concatenate them tight
for spk in sorted(speakers, key=lambda s: np.sum(speaker_masks[s]), reverse=True):
    clean_mask = speaker_masks[spk] & ~overlap_mask
    
    # Find contiguous runs of True in clean_mask
    diff = np.diff(clean_mask.astype(np.int8))
    starts = np.where(diff == 1)[0] + 1
    ends = np.where(diff == -1)[0] + 1
    if clean_mask[0]:
        starts = np.concatenate([[0], starts])
    if clean_mask[-1]:
        ends = np.concatenate([ends, [total_samples]])
    
    chunks = []
    for si, ei in zip(starts, ends):
        chunks.append(audio[si:ei])
    
    if chunks:
        out = np.concatenate(chunks)
        path = os.path.join(OUTPUT_DIR, f"{spk}.wav")
        sf.write(path, out, sr)
        print(f"  {spk}: {len(out)/sr:.1f}s -> {path}")

print("\nDone!")
