#!/usr/bin/env python3
"""
NeMo diarization on lalal.ai-cleaned Jalsa voice track.
"""

import torchaudio
if not hasattr(torchaudio, 'list_audio_backends'):
    torchaudio.list_audio_backends = lambda: ["soundfile"]

import os, json, glob
import torch

INPUT_FILE = "jalsa_lalal/voice_16k.wav"
OUTPUT_DIR = "jalsa_diarized_v2"
os.makedirs(OUTPUT_DIR, exist_ok=True)

nemo_dir = os.path.join(OUTPUT_DIR, "nemo_work")
os.makedirs(nemo_dir, exist_ok=True)

abs_input = os.path.abspath(INPUT_FILE)
manifest_path = os.path.join(nemo_dir, "input_manifest.json")
with open(manifest_path, "w") as f:
    entry = {
        "audio_filepath": abs_input, "offset": 0, "duration": None,
        "label": "infer", "text": "-", "num_speakers": None,
        "rttm_filepath": None, "uem_filepath": None
    }
    f.write(json.dumps(entry) + "\n")

import wget
from omegaconf import OmegaConf

MODEL_CONFIG = os.path.join(nemo_dir, "diar_infer_telephonic.yaml")
if not os.path.exists(MODEL_CONFIG):
    config_url = "https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/diar_infer_telephonic.yaml"
    wget.download(config_url, MODEL_CONFIG, bar=None)

cfg = OmegaConf.load(MODEL_CONFIG)
cfg.diarizer.manifest_filepath = manifest_path
cfg.diarizer.out_dir = nemo_dir
cfg.diarizer.speaker_embeddings.model_path = "titanet_large"
cfg.diarizer.oracle_vad = False
cfg.diarizer.vad.model_path = "vad_multilingual_marblenet"
cfg.diarizer.clustering.parameters.oracle_num_speakers = False
cfg.diarizer.clustering.parameters.max_num_speakers = 20
cfg.diarizer.clustering.parameters.enhanced_count_thres = 40
cfg.diarizer.clustering.parameters.max_rp_threshold = 0.15
cfg.diarizer.clustering.parameters.sparse_search_volume = 50
cfg.diarizer.msdd_model.model_path = "diar_msdd_telephonic"
cfg.diarizer.msdd_model.parameters.sigmoid_threshold = [0.5]

print("Starting NeMo diarization...")
from nemo.collections.asr.models.msdd_models import NeuralDiarizer
nemo_diarizer = NeuralDiarizer(cfg=cfg).to("cuda")
nemo_diarizer.diarize()

rttm_files = glob.glob(os.path.join(nemo_dir, "pred_rttms", "*.rttm"))
segments = []
for rttm_file in rttm_files:
    with open(rttm_file) as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 8 and parts[0] == "SPEAKER":
                segments.append({
                    "speaker": parts[7],
                    "start": round(float(parts[3]), 3),
                    "end": round(float(parts[3]) + float(parts[4]), 3)
                })
segments.sort(key=lambda x: x["start"])

speakers = {}
for seg in segments:
    speakers.setdefault(seg["speaker"], []).append((seg["start"], seg["end"]))

print(f"\nFound {len(speakers)} speaker(s):")
for spk in sorted(speakers, key=lambda s: sum(e-st for st,e in speakers[s]), reverse=True):
    total = sum(e - s for s, e in speakers[spk])
    print(f"  {spk}: {total:.1f}s ({total/60:.1f} min) across {len(speakers[spk])} segments")

with open(os.path.join(OUTPUT_DIR, "diarization.json"), "w") as f:
    json.dump(segments, f, indent=2)
print(f"\nSaved: {OUTPUT_DIR}/diarization.json")

del nemo_diarizer
torch.cuda.empty_cache()
print("Done!")
