#!/usr/bin/env python3
"""Pyannote diarization on ORIGINAL audio."""

import torchaudio
if not hasattr(torchaudio, 'list_audio_backends'):
    torchaudio.list_audio_backends = lambda: ["soundfile"]

import os, json, torch, numpy as np, soundfile as sf

INPUT_FILE = "jalsa_original_16k.wav"
OUTPUT_DIR = "jalsa_pyannote"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("Loading pyannote pipeline...")
from pyannote.audio import Pipeline

from huggingface_hub import HfFolder
hf_token = HfFolder.get_token()
pipeline = Pipeline.from_pretrained(
    "pyannote/speaker-diarization-3.1",
    token=hf_token,
)
pipeline.to(torch.device("cuda"))

print(f"Running pyannote diarization on {INPUT_FILE}...")
result = pipeline(INPUT_FILE)

# pyannote v4 returns DiarizeOutput; extract the Annotation
if hasattr(result, 'speaker_diarization'):
    diarization = result.speaker_diarization
else:
    diarization = result

segments = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
    segments.append({
        "speaker": speaker,
        "start": round(turn.start, 3),
        "end": round(turn.end, 3),
    })
segments.sort(key=lambda x: x["start"])

speakers = {}
for seg in segments:
    speakers.setdefault(seg["speaker"], []).append((seg["start"], seg["end"]))

print(f"\nFound {len(speakers)} speaker(s):")
for spk in sorted(speakers, key=lambda s: sum(e-st for st,e in speakers[s]), reverse=True):
    total = sum(e - s for s, e in speakers[spk])
    print(f"  {spk}: {total:.1f}s ({total/60:.1f} min), {len(speakers[spk])} segments")

with open(os.path.join(OUTPUT_DIR, "diarization.json"), "w") as f:
    json.dump(segments, f, indent=2)

# Extract speaker tracks
audio, sr = sf.read(INPUT_FILE, dtype='float32')
if audio.ndim > 1:
    audio = audio.mean(axis=1)

for spk in sorted(speakers):
    chunks = []
    for s, e in speakers[spk]:
        si, ei = int(s * sr), min(int(e * sr), len(audio))
        chunks.append(audio[si:ei])
    if chunks:
        out = np.concatenate(chunks)
        path = os.path.join(OUTPUT_DIR, f"{spk}.wav")
        sf.write(path, out, sr)
        print(f"  {spk} -> {path} ({len(out)/sr:.1f}s)")

print("\nDone!")
