#!/usr/bin/env python3
"""NeMo diarization on ORIGINAL audio (not lalal.ai processed)."""

import torchaudio
if not hasattr(torchaudio, 'list_audio_backends'):
    torchaudio.list_audio_backends = lambda: ["soundfile"]

import os, json, glob, torch, numpy as np, soundfile as sf, wget
from omegaconf import OmegaConf

INPUT_FILE = "jalsa_original_16k.wav"
OUTPUT_DIR = "jalsa_nemo_original"
os.makedirs(OUTPUT_DIR, exist_ok=True)

nemo_dir = os.path.join(OUTPUT_DIR, "nemo_work")
os.makedirs(nemo_dir, exist_ok=True)

abs_input = os.path.abspath(INPUT_FILE)
manifest_path = os.path.join(nemo_dir, "input_manifest.json")
with open(manifest_path, "w") as f:
    f.write(json.dumps({
        "audio_filepath": abs_input, "offset": 0, "duration": None,
        "label": "infer", "text": "-", "num_speakers": None,
        "rttm_filepath": None, "uem_filepath": None
    }) + "\n")

MODEL_CONFIG = os.path.join(nemo_dir, "diar_infer_telephonic.yaml")
if not os.path.exists(MODEL_CONFIG):
    wget.download(
        "https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/diar_infer_telephonic.yaml",
        MODEL_CONFIG, bar=None
    )

cfg = OmegaConf.load(MODEL_CONFIG)
cfg.diarizer.manifest_filepath = manifest_path
cfg.diarizer.out_dir = nemo_dir
cfg.diarizer.speaker_embeddings.model_path = "titanet_large"
cfg.diarizer.oracle_vad = False
cfg.diarizer.vad.model_path = "vad_multilingual_marblenet"
cfg.diarizer.clustering.parameters.oracle_num_speakers = False
cfg.diarizer.clustering.parameters.max_num_speakers = 20
cfg.diarizer.clustering.parameters.max_rp_threshold = 0.15
cfg.diarizer.clustering.parameters.sparse_search_volume = 50
cfg.diarizer.msdd_model.model_path = "diar_msdd_telephonic"
cfg.diarizer.msdd_model.parameters.sigmoid_threshold = [0.5]

print("Starting NeMo diarization on ORIGINAL audio...")
from nemo.collections.asr.models.msdd_models import NeuralDiarizer
diarizer = NeuralDiarizer(cfg=cfg).to("cuda")
diarizer.diarize()

rttm_files = glob.glob(os.path.join(nemo_dir, "pred_rttms", "*.rttm"))
segments = []
for rttm_file in rttm_files:
    with open(rttm_file) as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 8 and parts[0] == "SPEAKER":
                segments.append({
                    "speaker": parts[7],
                    "start": round(float(parts[3]), 3),
                    "end": round(float(parts[3]) + float(parts[4]), 3)
                })
segments.sort(key=lambda x: x["start"])

speakers = {}
for seg in segments:
    speakers.setdefault(seg["speaker"], []).append((seg["start"], seg["end"]))

print(f"\nFound {len(speakers)} speaker(s):")
for spk in sorted(speakers, key=lambda s: sum(e-st for st,e in speakers[s]), reverse=True):
    total = sum(e - s for s, e in speakers[spk])
    print(f"  {spk}: {total:.1f}s ({total/60:.1f} min), {len(speakers[spk])} segments")

with open(os.path.join(OUTPUT_DIR, "diarization.json"), "w") as f:
    json.dump(segments, f, indent=2)

# Extract speaker tracks
audio, sr = sf.read(INPUT_FILE, dtype='float32')
if audio.ndim > 1:
    audio = audio.mean(axis=1)

for spk in sorted(speakers):
    chunks = []
    for s, e in speakers[spk]:
        si, ei = int(s * sr), min(int(e * sr), len(audio))
        chunks.append(audio[si:ei])
    if chunks:
        out = np.concatenate(chunks)
        path = os.path.join(OUTPUT_DIR, f"{spk}.wav")
        sf.write(path, out, sr)
        print(f"  {spk} -> {path} ({len(out)/sr:.1f}s)")

del diarizer
torch.cuda.empty_cache()
print("\nDone!")
