#!/usr/bin/env python3
"""Pyannote community-1 cloud on lalal.ai cleaned bob5 audio."""

import torchaudio
if not hasattr(torchaudio, 'list_audio_backends'):
    torchaudio.list_audio_backends = lambda: ["soundfile"]

import os, json, numpy as np, soundfile as sf

INPUT_FILE = "/home/ubuntu/bob5_lalal_16k.wav"
OUTPUT_DIR = "/home/ubuntu/bob5_lalal_diarized"
os.makedirs(OUTPUT_DIR, exist_ok=True)

PYANNOTE_API_KEY = "sk_4477f5473f584d1190f2c3bdbf37445b"

print("Loading pyannote community-1-cloud...")
from pyannote.audio import Pipeline

pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-community-1-cloud",
    token=PYANNOTE_API_KEY,
)

print(f"Running cloud diarization on {INPUT_FILE}...")
result = pipeline(INPUT_FILE)

diarization = result.speaker_diarization

audio, sr = sf.read(INPUT_FILE, dtype='float32')
if audio.ndim > 1:
    audio = audio.mean(axis=1)
total_samples = len(audio)
duration = total_samples / sr
print(f"Audio: {duration:.1f}s at {sr}Hz\n")

speakers = {}
all_segments = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
    speakers.setdefault(speaker, []).append((turn.start, turn.end))
    all_segments.append((turn.start, turn.end, speaker))
all_segments.sort(key=lambda x: x[0])

# Build overlap mask
speaker_masks = {}
for spk, segs in speakers.items():
    mask = np.zeros(total_samples, dtype=np.bool_)
    for s, e in segs:
        si, ei = int(s * sr), min(int(e * sr), total_samples)
        mask[si:ei] = True
    speaker_masks[spk] = mask

active_count = np.zeros(total_samples, dtype=np.int8)
for mask in speaker_masks.values():
    active_count += mask.astype(np.int8)
overlap_mask = active_count >= 2
overlap_seconds = np.sum(overlap_mask) / sr
print(f"Total overlap: {overlap_seconds:.1f}s ({overlap_seconds/duration*100:.1f}%)\n")

print(f"Found {len(speakers)} speaker(s):\n")
for spk in sorted(speakers, key=lambda s: np.sum(speaker_masks[s]), reverse=True):
    clean_mask = speaker_masks[spk] & ~overlap_mask
    original_speech = np.sum(speaker_masks[spk]) / sr
    clean_speech = np.sum(clean_mask) / sr
    removed = original_speech - clean_speech

    diff = np.diff(clean_mask.astype(np.int8))
    starts = np.where(diff == 1)[0] + 1
    ends = np.where(diff == -1)[0] + 1
    if clean_mask[0]:
        starts = np.concatenate([[0], starts])
    if clean_mask[-1]:
        ends = np.concatenate([ends, [total_samples]])

    chunks = [audio[si:ei] for si, ei in zip(starts, ends)]
    if chunks:
        out = np.concatenate(chunks)
        path = os.path.join(OUTPUT_DIR, f"{spk}.wav")
        sf.write(path, out, sr)
        print(f"  {spk}: {clean_speech:.1f}s clean (removed {removed:.1f}s overlap) -> {path}")

raw_segs = [{"speaker": s[2], "start": round(s[0], 3), "end": round(s[1], 3)} for s in all_segments]
with open(os.path.join(OUTPUT_DIR, "diarization.json"), "w") as f:
    json.dump(raw_segs, f, indent=2)

print("\nDone!")
