#!/usr/bin/env python3
"""Pyannote community-1 LOCAL (HuggingFace) diarization on bob5 vocals."""

import torchaudio
if not hasattr(torchaudio, 'list_audio_backends'):
    torchaudio.list_audio_backends = lambda: ["soundfile"]

import os, json, numpy as np, soundfile as sf, torch

INPUT_FILE = "/home/ubuntu/bob5_vocals_16k.wav"
OUTPUT_DIR = "/home/ubuntu/bob5_local"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("Loading pyannote community-1 locally via HuggingFace...")
from pyannote.audio import Pipeline
from huggingface_hub import HfFolder

hf_token = HfFolder.get_token()
pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-community-1",
    token=hf_token,
)
pipeline.to(torch.device("cuda"))

print(f"Running local diarization on {INPUT_FILE}...")
result = pipeline(INPUT_FILE)

diarization = result.speaker_diarization

audio, sr = sf.read(INPUT_FILE, dtype='float32')
if audio.ndim > 1:
    audio = audio.mean(axis=1)
total_samples = len(audio)
duration = total_samples / sr
print(f"Audio: {duration:.1f}s at {sr}Hz\n")

# Collect all segments per speaker
speakers = {}
all_segments = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
    speakers.setdefault(speaker, []).append((turn.start, turn.end))
    all_segments.append((turn.start, turn.end, speaker))
all_segments.sort(key=lambda x: x[0])

# Build overlap mask
print("Building overlap map...")
speaker_masks = {}
for spk, segs in speakers.items():
    mask = np.zeros(total_samples, dtype=np.bool_)
    for s, e in segs:
        si, ei = int(s * sr), min(int(e * sr), total_samples)
        mask[si:ei] = True
    speaker_masks[spk] = mask

active_count = np.zeros(total_samples, dtype=np.int8)
for mask in speaker_masks.values():
    active_count += mask.astype(np.int8)
overlap_mask = active_count >= 2
overlap_seconds = np.sum(overlap_mask) / sr
print(f"Total overlap: {overlap_seconds:.1f}s ({overlap_seconds/duration*100:.1f}%)\n")

# Extract clean concatenated tracks (no silence gaps, no overlaps)
for spk in sorted(speakers, key=lambda s: np.sum(speaker_masks[s]), reverse=True):
    clean_mask = speaker_masks[spk] & ~overlap_mask
    original_speech = np.sum(speaker_masks[spk]) / sr
    clean_speech = np.sum(clean_mask) / sr
    removed = original_speech - clean_speech

    diff = np.diff(clean_mask.astype(np.int8))
    starts = np.where(diff == 1)[0] + 1
    ends = np.where(diff == -1)[0] + 1
    if clean_mask[0]:
        starts = np.concatenate([[0], starts])
    if clean_mask[-1]:
        ends = np.concatenate([ends, [total_samples]])

    chunks = [audio[si:ei] for si, ei in zip(starts, ends)]
    if chunks:
        out = np.concatenate(chunks)
        path = os.path.join(OUTPUT_DIR, f"{spk}.wav")
        sf.write(path, out, sr)
        print(f"  {spk}: {clean_speech:.1f}s clean (removed {removed:.1f}s overlap) -> {path}")

# Save raw diarization
raw_segs = [{"speaker": s[2], "start": round(s[0], 3), "end": round(s[1], 3)} for s in all_segments]
with open(os.path.join(OUTPUT_DIR, "diarization.json"), "w") as f:
    json.dump(raw_segs, f, indent=2)

print("\nDone!")
