#!/usr/bin/env python3
"""Pyannote community-1-cloud diarization on bob5 vocals."""

import torchaudio
if not hasattr(torchaudio, 'list_audio_backends'):
    torchaudio.list_audio_backends = lambda: ["soundfile"]

import os, json, numpy as np, soundfile as sf

INPUT_FILE = "/home/ubuntu/bob5_vocals_16k.wav"
OUTPUT_DIR = "/home/ubuntu/bob5_pyannote"
os.makedirs(OUTPUT_DIR, exist_ok=True)

PYANNOTE_API_KEY = "sk_4477f5473f584d1190f2c3bdbf37445b"

print("Loading pyannote community-1-cloud pipeline...")
from pyannote.audio import Pipeline

pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-community-1-cloud",
    token=PYANNOTE_API_KEY,
)

print(f"Running cloud diarization on {INPUT_FILE}...")
result = pipeline(INPUT_FILE)

# Extract from DiarizeOutput
if hasattr(result, 'speaker_diarization'):
    diarization = result.speaker_diarization
else:
    diarization = result

segments = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
    segments.append({
        "speaker": speaker,
        "start": round(turn.start, 3),
        "end": round(turn.end, 3),
    })
segments.sort(key=lambda x: x["start"])

speakers = {}
for seg in segments:
    speakers.setdefault(seg["speaker"], []).append((seg["start"], seg["end"]))

print(f"\nFound {len(speakers)} speaker(s):")
for spk in sorted(speakers, key=lambda s: sum(e-st for st,e in speakers[s]), reverse=True):
    total = sum(e - s for s, e in speakers[spk])
    print(f"  {spk}: {total:.1f}s ({total/60:.1f} min), {len(speakers[spk])} segments")

with open(os.path.join(OUTPUT_DIR, "diarization.json"), "w") as f:
    json.dump(segments, f, indent=2)

# Also save exclusive diarization if available
if hasattr(result, 'exclusive_speaker_diarization'):
    exc_segments = []
    for turn, speaker in result.exclusive_speaker_diarization:
        exc_segments.append({
            "speaker": speaker,
            "start": round(turn.start, 3),
            "end": round(turn.end, 3),
        })
    exc_segments.sort(key=lambda x: x["start"])
    with open(os.path.join(OUTPUT_DIR, "diarization_exclusive.json"), "w") as f:
        json.dump(exc_segments, f, indent=2)
    print(f"Exclusive diarization: {len(exc_segments)} segments saved")

# Extract speaker tracks
audio, sr = sf.read(INPUT_FILE, dtype='float32')
if audio.ndim > 1:
    audio = audio.mean(axis=1)

for spk in sorted(speakers, key=lambda s: sum(e-st for st,e in speakers[s]), reverse=True):
    chunks = []
    for s, e in speakers[spk]:
        si, ei = int(s * sr), min(int(e * sr), len(audio))
        chunks.append(audio[si:ei])
    if chunks:
        out = np.concatenate(chunks)
        path = os.path.join(OUTPUT_DIR, f"{spk}.wav")
        sf.write(path, out, sr)
        print(f"  Saved {spk} -> {path} ({len(out)/sr:.1f}s)")

print("\nDone!")
