#!/usr/bin/env python3
"""
NeMo diarization + ultraclean speaker extraction pipeline.
"""

import torchaudio
if not hasattr(torchaudio, 'list_audio_backends'):
    torchaudio.list_audio_backends = lambda: ["soundfile"]

import os
import json
import glob
import numpy as np
import soundfile as sf
import torch
import librosa
import wget
from omegaconf import OmegaConf
# NeuralDiarizer (MSDD) imported later after config is loaded

import sys
INPUT_FILE = sys.argv[1] if len(sys.argv) > 1 else "demucs_out/htdemucs/_Cooking_With_Fun_Vlog_256KBPS/vocals.wav"
OUTPUT_DIR = sys.argv[2] if len(sys.argv) > 2 else "cooking_diarized"
MIN_SEGMENT_SEC = 0.8
ENERGY_TRIM_DB = -35
ENERGY_FRAME_MS = 25
ENERGY_HOP_MS = 10
MIN_SPEECH_FRAMES = 5
CROSSFADE_MS = 10

os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Input: {INPUT_FILE}")
audio_data, sample_rate = sf.read(INPUT_FILE, dtype="float32")
if audio_data.ndim > 1:
    audio_data = audio_data.mean(axis=1)
duration = len(audio_data) / sample_rate
print(f"Audio: {duration:.1f}s, {sample_rate}Hz")

# ── NEMO DIARIZATION ─────────────────────────────────
print("\n[Step 1] Running NeMo diarization...")
nemo_dir = os.path.join(OUTPUT_DIR, "nemo_work")
os.makedirs(nemo_dir, exist_ok=True)

# Resample to 16kHz for NeMo
if sample_rate != 16000:
    print(f"  Resampling {sample_rate}Hz -> 16kHz...")
    audio_16k = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
    resampled_path = os.path.join(nemo_dir, "input_16k.wav")
    sf.write(resampled_path, audio_16k, 16000)
    abs_input = os.path.abspath(resampled_path)
else:
    abs_input = os.path.abspath(INPUT_FILE)

manifest_path = os.path.join(nemo_dir, "input_manifest.json")
with open(manifest_path, "w") as f:
    entry = {"audio_filepath": abs_input, "offset": 0, "duration": None,
             "label": "infer", "text": "-", "num_speakers": None,
             "rttm_filepath": None, "uem_filepath": None}
    f.write(json.dumps(entry) + "\n")

MODEL_CONFIG = os.path.join(nemo_dir, "diar_infer_telephonic.yaml")
if not os.path.exists(MODEL_CONFIG):
    config_url = "https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/diar_infer_telephonic.yaml"
    wget.download(config_url, MODEL_CONFIG, bar=None)

cfg = OmegaConf.load(MODEL_CONFIG)
cfg.diarizer.manifest_filepath = manifest_path
cfg.diarizer.out_dir = nemo_dir
cfg.diarizer.speaker_embeddings.model_path = "titanet_large"
cfg.diarizer.oracle_vad = False
cfg.diarizer.vad.model_path = "vad_multilingual_marblenet"
cfg.diarizer.clustering.parameters.oracle_num_speakers = False
cfg.diarizer.clustering.parameters.max_num_speakers = 8
cfg.diarizer.msdd_model.model_path = "diar_msdd_telephonic"
cfg.diarizer.msdd_model.parameters.sigmoid_threshold = [0.7]

from nemo.collections.asr.models.msdd_models import NeuralDiarizer
nemo_diarizer = NeuralDiarizer(cfg=cfg).to("cuda")
nemo_diarizer.diarize()

# Parse RTTM
rttm_files = glob.glob(os.path.join(nemo_dir, "pred_rttms", "*.rttm"))
segments = []
for rttm_file in rttm_files:
    with open(rttm_file) as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 8 and parts[0] == "SPEAKER":
                segments.append({
                    "speaker": parts[7],
                    "start": round(float(parts[3]), 3),
                    "end": round(float(parts[3]) + float(parts[4]), 3)
                })
segments.sort(key=lambda x: x["start"])

speakers = {}
for seg in segments:
    speakers.setdefault(seg["speaker"], []).append((seg["start"], seg["end"]))

print(f"\nNeMo found {len(speakers)} speaker(s):")
for spk in sorted(speakers):
    total = sum(e - s for s, e in speakers[spk])
    print(f"  {spk}: {total:.1f}s across {len(speakers[spk])} segments")

with open(os.path.join(OUTPUT_DIR, "diarization.json"), "w") as f:
    json.dump(segments, f, indent=2)

del nemo_diarizer
torch.cuda.empty_cache()

# ── ULTRACLEAN EXTRACTION ─────────────────────────────
print("\n[Step 2] Extracting ultraclean speaker tracks...")

def compute_frame_energy(signal, frame_size, hop_size):
    n_frames = max(1, (len(signal) - frame_size) // hop_size + 1)
    energies = np.zeros(n_frames)
    for i in range(n_frames):
        start = i * hop_size
        frame = signal[start:start + frame_size]
        rms = np.sqrt(np.mean(frame ** 2) + 1e-10)
        energies[i] = 20 * np.log10(rms + 1e-10)
    return energies

def energy_trim(signal, sr):
    frame_size = int(ENERGY_FRAME_MS / 1000 * sr)
    hop_size = int(ENERGY_HOP_MS / 1000 * sr)
    if len(signal) < frame_size:
        return [signal] if np.max(np.abs(signal)) > 0.01 else []
    energies = compute_frame_energy(signal, frame_size, hop_size)
    peak_energy = np.max(energies)
    threshold = peak_energy + ENERGY_TRIM_DB
    is_speech = energies > threshold
    smoothed = np.zeros_like(is_speech)
    count = 0
    for i in range(len(is_speech)):
        if is_speech[i]:
            count += 1
        else:
            if count >= MIN_SPEECH_FRAMES:
                smoothed[i - count:i] = True
            count = 0
    if count >= MIN_SPEECH_FRAMES:
        smoothed[len(is_speech) - count:len(is_speech)] = True
    chunks = []
    in_speech = False
    speech_start = 0
    for i in range(len(smoothed)):
        if smoothed[i] and not in_speech:
            speech_start = i * hop_size
            in_speech = True
        elif not smoothed[i] and in_speech:
            speech_end = min(i * hop_size + frame_size, len(signal))
            if speech_end > speech_start:
                chunks.append(signal[speech_start:speech_end])
            in_speech = False
    if in_speech:
        speech_end = min(len(smoothed) * hop_size + frame_size, len(signal))
        if speech_end > speech_start:
            chunks.append(signal[speech_start:speech_end])
    return chunks

crossfade_samples = int(CROSSFADE_MS / 1000 * sample_rate)

for spk in sorted(speakers):
    all_chunks = []
    kept = 0
    dropped = 0

    for seg_start, seg_end in speakers[spk]:
        if seg_end - seg_start < MIN_SEGMENT_SEC:
            dropped += 1
            continue

        s = int(seg_start * sample_rate)
        e = min(int(seg_end * sample_rate), len(audio_data))
        raw_chunk = audio_data[s:e].copy()
        if len(raw_chunk) == 0:
            continue

        clean_chunks = energy_trim(raw_chunk, sample_rate)
        if not clean_chunks:
            dropped += 1
            continue

        for chunk in clean_chunks:
            fade_len = min(crossfade_samples, len(chunk) // 4)
            if fade_len > 0:
                chunk[:fade_len] *= np.linspace(0, 1, fade_len)
                chunk[-fade_len:] *= np.linspace(1, 0, fade_len)
            all_chunks.append(chunk)
        kept += 1

    if not all_chunks:
        print(f"  {spk}: no speech")
        continue

    clean_audio = np.concatenate(all_chunks)
    total_dur = len(clean_audio) / sample_rate
    out_path = os.path.join(OUTPUT_DIR, f"{spk}_clean.wav")
    sf.write(out_path, clean_audio, sample_rate)
    print(f"  {spk}: kept {kept}, dropped {dropped} → {out_path} ({total_dur:.1f}s)")

print("\nDone!")
