#!/usr/bin/env python3
"""
Run both pyannote and NeMo diarization on the same audio, compare results,
and output per-speaker audio tracks from each.
"""

import os
import sys
import json
import torch
import numpy as np
import soundfile as sf
import torchaudio
if not hasattr(torchaudio, 'list_audio_backends'):
    torchaudio.list_audio_backends = lambda: ["soundfile"]

INPUT_FILE = sys.argv[1] if len(sys.argv) > 1 else "demucs_out/htdemucs/clip_00/vocals.wav"
OUTPUT_DIR = "diarization_output"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Input: {INPUT_FILE}")
print(f"Output: {OUTPUT_DIR}/")
print()

audio_data, sample_rate = sf.read(INPUT_FILE, dtype="float32")
if audio_data.ndim > 1:
    audio_data = audio_data.mean(axis=1)
total_samples = len(audio_data)
duration = total_samples / sample_rate
print(f"Audio: {duration:.1f}s, {sample_rate}Hz, {total_samples} samples")
print()


def save_speaker_tracks(speakers, audio_data, sample_rate, prefix, out_dir):
    for spk in sorted(speakers):
        track = np.zeros(len(audio_data), dtype=np.float32)
        for seg_start, seg_end in speakers[spk]:
            s = int(seg_start * sample_rate)
            e = min(int(seg_end * sample_rate), len(audio_data))
            track[s:e] = audio_data[s:e]
        out_path = os.path.join(out_dir, f"{prefix}_{spk}.wav")
        sf.write(out_path, track, sample_rate)
        total_time = sum(end - start for start, end in speakers[spk])
        print(f"  {out_path} ({total_time:.1f}s, {len(speakers[spk])} segments)")


# ═══════════════════════════════════════════════════════
# 1. PYANNOTE DIARIZATION
# ═══════════════════════════════════════════════════════
print("=" * 60)
print("  PYANNOTE DIARIZATION")
print("=" * 60)
print()

from pyannote.audio import Pipeline as PyAnnotePipeline

print("Loading pyannote pipeline...")
pyannote_pipeline = PyAnnotePipeline.from_pretrained("pyannote/speaker-diarization-3.1")
pyannote_pipeline.to(torch.device("cuda"))
print("Running pyannote diarization...")

result = pyannote_pipeline(INPUT_FILE)
diarization = result.speaker_diarization

pyannote_speakers = {}
pyannote_segments = []

for turn, _, speaker in diarization.itertracks(yield_label=True):
    if speaker not in pyannote_speakers:
        pyannote_speakers[speaker] = []
    pyannote_speakers[speaker].append((turn.start, turn.end))
    pyannote_segments.append({"speaker": speaker, "start": round(turn.start, 3), "end": round(turn.end, 3)})

print(f"\nPyannote found {len(pyannote_speakers)} speaker(s):")
for spk in sorted(pyannote_speakers):
    total = sum(end - start for start, end in pyannote_speakers[spk])
    n = len(pyannote_speakers[spk])
    print(f"  {spk}: {total:.1f}s across {n} segments")

print("\nSaving pyannote speaker tracks...")
save_speaker_tracks(pyannote_speakers, audio_data, sample_rate, "pyannote", OUTPUT_DIR)

with open(os.path.join(OUTPUT_DIR, "pyannote_diarization.json"), "w") as f:
    json.dump(pyannote_segments, f, indent=2)

del pyannote_pipeline, result, diarization
torch.cuda.empty_cache()
print()

# ═══════════════════════════════════════════════════════
# 2. NEMO DIARIZATION
# ═══════════════════════════════════════════════════════
print("=" * 60)
print("  NEMO DIARIZATION")
print("=" * 60)
print()

from nemo.collections.asr.models import ClusteringDiarizer
import tempfile
import wget

print("Setting up NeMo diarization...")

nemo_dir = os.path.join(OUTPUT_DIR, "nemo_work")
os.makedirs(nemo_dir, exist_ok=True)

manifest_path = os.path.join(nemo_dir, "input_manifest.json")
abs_input = os.path.abspath(INPUT_FILE)

if sample_rate != 16000:
    print(f"Resampling {sample_rate}Hz -> 16000Hz for NeMo...")
    import librosa
    audio_16k = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
    resampled_path = os.path.join(nemo_dir, "input_16k.wav")
    sf.write(resampled_path, audio_16k, 16000)
    abs_input = os.path.abspath(resampled_path)
    nemo_sr = 16000
else:
    nemo_sr = sample_rate

with open(manifest_path, "w") as f:
    entry = {
        "audio_filepath": abs_input,
        "offset": 0,
        "duration": None,
        "label": "infer",
        "text": "-",
        "num_speakers": None,
        "rttm_filepath": None,
        "uem_filepath": None,
    }
    f.write(json.dumps(entry) + "\n")

from omegaconf import OmegaConf

MODEL_CONFIG = os.path.join(nemo_dir, "diar_infer_meeting.yaml")
if not os.path.exists(MODEL_CONFIG):
    config_url = "https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/diar_infer_meeting.yaml"
    wget.download(config_url, MODEL_CONFIG, bar=None)

cfg = OmegaConf.load(MODEL_CONFIG)

cfg.diarizer.manifest_filepath = manifest_path
cfg.diarizer.out_dir = nemo_dir

cfg.diarizer.speaker_embeddings.model_path = "titanet_large"
cfg.diarizer.speaker_embeddings.parameters.window_length_in_sec = [1.5, 1.25, 1.0, 0.75, 0.5]
cfg.diarizer.speaker_embeddings.parameters.shift_length_in_sec = [0.75, 0.625, 0.5, 0.375, 0.25]
cfg.diarizer.speaker_embeddings.parameters.multiscale_weights = [1, 1, 1, 1, 1]

cfg.diarizer.oracle_vad = False
cfg.diarizer.vad.model_path = "vad_multilingual_marblenet"
cfg.diarizer.vad.parameters.onset = 0.8
cfg.diarizer.vad.parameters.offset = 0.6
cfg.diarizer.vad.parameters.pad_offset = -0.05

cfg.diarizer.clustering.parameters.oracle_num_speakers = False
cfg.diarizer.clustering.parameters.max_num_speakers = 8

print("Running NeMo diarization (this may take a few minutes)...")
nemo_diarizer = ClusteringDiarizer(cfg=cfg)
nemo_diarizer.diarize()

import glob
rttm_files = glob.glob(os.path.join(nemo_dir, "pred_rttms", "*.rttm"))

nemo_speakers = {}
nemo_segments = []

for rttm_file in rttm_files:
    with open(rttm_file, "r") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 8 and parts[0] == "SPEAKER":
                start = float(parts[3])
                dur = float(parts[4])
                speaker = parts[7]
                end = start + dur
                if speaker not in nemo_speakers:
                    nemo_speakers[speaker] = []
                nemo_speakers[speaker].append((start, end))
                nemo_segments.append({"speaker": speaker, "start": round(start, 3), "end": round(end, 3)})

nemo_segments.sort(key=lambda x: x["start"])

print(f"\nNeMo found {len(nemo_speakers)} speaker(s):")
for spk in sorted(nemo_speakers):
    total = sum(end - start for start, end in nemo_speakers[spk])
    n = len(nemo_speakers[spk])
    print(f"  {spk}: {total:.1f}s across {n} segments")

print("\nSaving NeMo speaker tracks...")
save_speaker_tracks(nemo_speakers, audio_data, sample_rate, "nemo", OUTPUT_DIR)

with open(os.path.join(OUTPUT_DIR, "nemo_diarization.json"), "w") as f:
    json.dump(nemo_segments, f, indent=2)

# ═══════════════════════════════════════════════════════
# 3. COMPARISON
# ═══════════════════════════════════════════════════════
print()
print("=" * 60)
print("  COMPARISON")
print("=" * 60)
print()
print(f"{'':>20s}  {'Pyannote':>10s}  {'NeMo':>10s}")
print("-" * 45)
print(f"{'Speakers found':>20s}  {len(pyannote_speakers):>10d}  {len(nemo_speakers):>10d}")

py_total = sum(sum(e - s for s, e in segs) for segs in pyannote_speakers.values())
nm_total = sum(sum(e - s for s, e in segs) for segs in nemo_speakers.values())
print(f"{'Total speech (s)':>20s}  {py_total:>10.1f}  {nm_total:>10.1f}")

py_segs = sum(len(segs) for segs in pyannote_speakers.values())
nm_segs = sum(len(segs) for segs in nemo_speakers.values())
print(f"{'Total segments':>20s}  {py_segs:>10d}  {nm_segs:>10d}")

print()
print("All outputs saved to:", OUTPUT_DIR)
print("Done!")
