#!/usr/bin/env python3
"""
Extract speaker WAVs with overlapping parts CUT OUT.

1. Load regular diarization (with overlaps)
2. For each time point, count how many speakers are active
3. If 2+ speakers overlap at any moment, REMOVE that time from ALL of them
4. Only keep audio where exactly ONE speaker is talking
5. Preserve original timeline (silence elsewhere)
"""
import torchaudio
if not hasattr(torchaudio, 'list_audio_backends'):
    torchaudio.list_audio_backends = lambda: ["soundfile"]

import json, os, numpy as np, soundfile as sf

INPUT_FILE = "/home/ubuntu/bob5_vocals_16k.wav"
OUTPUT_DIR = "/home/ubuntu/bob5_clean"
os.makedirs(OUTPUT_DIR, exist_ok=True)

PYANNOTE_API_KEY = "sk_4477f5473f584d1190f2c3bdbf37445b"

print("Loading pyannote community-1-cloud pipeline...")
from pyannote.audio import Pipeline

pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-community-1-cloud",
    token=PYANNOTE_API_KEY,
)

print(f"Running diarization on {INPUT_FILE}...")
result = pipeline(INPUT_FILE)

diarization = result.speaker_diarization

audio, sr = sf.read(INPUT_FILE, dtype='float32')
if audio.ndim > 1:
    audio = audio.mean(axis=1)
total_samples = len(audio)
duration = total_samples / sr
print(f"Audio: {duration:.1f}s at {sr}Hz\n")

# Collect all segments per speaker from the regular (overlapping) diarization
speakers = {}
all_segments = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
    speakers.setdefault(speaker, []).append((turn.start, turn.end))
    all_segments.append((turn.start, turn.end, speaker))

all_segments.sort(key=lambda x: x[0])

# Build a sample-level overlap mask: for each sample, count active speakers
print("Building overlap map...")
speaker_masks = {}
for spk, segs in speakers.items():
    mask = np.zeros(total_samples, dtype=np.bool_)
    for s, e in segs:
        si, ei = int(s * sr), min(int(e * sr), total_samples)
        mask[si:ei] = True
    speaker_masks[spk] = mask

# Count how many speakers are active at each sample
active_count = np.zeros(total_samples, dtype=np.int8)
for mask in speaker_masks.values():
    active_count += mask.astype(np.int8)

# Overlap = where 2+ speakers active
overlap_mask = active_count >= 2
overlap_seconds = np.sum(overlap_mask) / sr
print(f"Total overlap found: {overlap_seconds:.1f}s ({overlap_seconds/duration*100:.1f}% of audio)")

# For each speaker, only keep samples where THEY are active AND no overlap
print(f"\nExtracting {len(speakers)} clean speaker tracks (overlaps removed):\n")

stats = []
for spk in sorted(speakers, key=lambda s: np.sum(speaker_masks[s]), reverse=True):
    original_speech = np.sum(speaker_masks[spk]) / sr
    
    # Clean mask: speaker is active AND no one else overlapping
    clean_mask = speaker_masks[spk] & ~overlap_mask
    clean_speech = np.sum(clean_mask) / sr
    removed = original_speech - clean_speech
    
    # Build timeline track
    track = np.zeros(total_samples, dtype=np.float32)
    track[clean_mask] = audio[clean_mask]
    
    path = os.path.join(OUTPUT_DIR, f"{spk}.wav")
    sf.write(path, track, sr)
    
    print(f"  {spk}: {clean_speech:.1f}s clean (was {original_speech:.1f}s, removed {removed:.1f}s overlap)")
    stats.append({"speaker": spk, "clean_seconds": round(clean_speech, 1), 
                   "original_seconds": round(original_speech, 1),
                   "overlap_removed": round(removed, 1),
                   "segments": len(speakers[spk])})

# Save stats
with open(os.path.join(OUTPUT_DIR, "stats.json"), "w") as f:
    json.dump({"total_overlap_seconds": round(overlap_seconds, 1), "speakers": stats}, f, indent=2)

# Save the raw diarization too
raw_segs = [{"speaker": s[2], "start": round(s[0], 3), "end": round(s[1], 3)} for s in all_segments]
with open(os.path.join(OUTPUT_DIR, "diarization_raw.json"), "w") as f:
    json.dump(raw_segs, f, indent=2)

print(f"\nDone! Clean tracks in {OUTPUT_DIR}/")
