#!/usr/bin/env python3
"""
NeMo diarization + ultraclean speaker extraction pipeline.
"""

import torchaudio
if not hasattr(torchaudio, 'list_audio_backends'):
    torchaudio.list_audio_backends = lambda: ["soundfile"]

import os
import json
import glob
import numpy as np
import soundfile as sf
import torch
import librosa
import wget
from omegaconf import OmegaConf
from nemo.collections.asr.models import ClusteringDiarizer

import sys
INPUT_FILE = sys.argv[1] if len(sys.argv) > 1 else "demucs_out/htdemucs/_Cooking_With_Fun_Vlog_256KBPS/vocals.wav"
OUTPUT_DIR = sys.argv[2] if len(sys.argv) > 2 else "cooking_diarized"
MIN_SEGMENT_SEC = 0.8
ENERGY_TRIM_DB = -35
ENERGY_FRAME_MS = 25
ENERGY_HOP_MS = 10
MIN_SPEECH_FRAMES = 5
CROSSFADE_MS = 10

os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Input: {INPUT_FILE}")
audio_data, sample_rate = sf.read(INPUT_FILE, dtype="float32")
if audio_data.ndim > 1:
    audio_data = audio_data.mean(axis=1)
duration = len(audio_data) / sample_rate
print(f"Audio: {duration:.1f}s, {sample_rate}Hz")

# ── NEMO DIARIZATION ─────────────────────────────────
print("\n[Step 1] Running NeMo diarization...")
nemo_dir = os.path.join(OUTPUT_DIR, "nemo_work")
os.makedirs(nemo_dir, exist_ok=True)

# Resample to 16kHz for NeMo
if sample_rate != 16000:
    print(f"  Resampling {sample_rate}Hz -> 16kHz...")
    audio_16k = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
    resampled_path = os.path.join(nemo_dir, "input_16k.wav")
    sf.write(resampled_path, audio_16k, 16000)
    abs_input = os.path.abspath(resampled_path)
else:
    abs_input = os.path.abspath(INPUT_FILE)

manifest_path = os.path.join(nemo_dir, "input_manifest.json")
with open(manifest_path, "w") as f:
    entry = {"audio_filepath": abs_input, "offset": 0, "duration": None,
             "label": "infer", "text": "-", "num_speakers": None,
             "rttm_filepath": None, "uem_filepath": None}
    f.write(json.dumps(entry) + "\n")

MODEL_CONFIG = os.path.join(nemo_dir, "diar_infer_meeting.yaml")
if not os.path.exists(MODEL_CONFIG):
    config_url = "https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/diar_infer_meeting.yaml"
    wget.download(config_url, MODEL_CONFIG, bar=None)

cfg = OmegaConf.load(MODEL_CONFIG)
cfg.diarizer.manifest_filepath = manifest_path
cfg.diarizer.out_dir = nemo_dir
cfg.diarizer.speaker_embeddings.model_path = "titanet_large"
cfg.diarizer.speaker_embeddings.parameters.window_length_in_sec = [3.0, 2.5, 2.0, 1.5, 1.0, 0.5]
cfg.diarizer.speaker_embeddings.parameters.shift_length_in_sec = [1.5, 1.25, 1.0, 0.75, 0.5, 0.25]
cfg.diarizer.speaker_embeddings.parameters.multiscale_weights = [1, 1, 1, 1, 1, 1]
cfg.diarizer.oracle_vad = False
cfg.diarizer.vad.model_path = "vad_multilingual_marblenet"
cfg.diarizer.vad.parameters.onset = 0.9
cfg.diarizer.vad.parameters.offset = 0.5
cfg.diarizer.vad.parameters.pad_onset = 0
cfg.diarizer.vad.parameters.pad_offset = 0
cfg.diarizer.vad.parameters.min_duration_on = 0
cfg.diarizer.vad.parameters.min_duration_off = 0.6
cfg.diarizer.vad.parameters.filter_speech_first = True
cfg.diarizer.clustering.parameters.oracle_num_speakers = False
cfg.diarizer.clustering.parameters.max_num_speakers = 8
cfg.diarizer.clustering.parameters.enhanced_count_thres = 80
cfg.diarizer.clustering.parameters.max_rp_threshold = 0.25
cfg.diarizer.clustering.parameters.sparse_search_volume = 30

nemo_diarizer = ClusteringDiarizer(cfg=cfg)
nemo_diarizer.diarize()

# Parse RTTM
rttm_files = glob.glob(os.path.join(nemo_dir, "pred_rttms", "*.rttm"))
segments = []
for rttm_file in rttm_files:
    with open(rttm_file) as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 8 and parts[0] == "SPEAKER":
                segments.append({
                    "speaker": parts[7],
                    "start": round(float(parts[3]), 3),
                    "end": round(float(parts[3]) + float(parts[4]), 3)
                })
segments.sort(key=lambda x: x["start"])

speakers = {}
for seg in segments:
    speakers.setdefault(seg["speaker"], []).append((seg["start"], seg["end"]))

print(f"\nNeMo found {len(speakers)} speaker(s):")
for spk in sorted(speakers):
    total = sum(e - s for s, e in speakers[spk])
    print(f"  {spk}: {total:.1f}s across {len(speakers[spk])} segments")

with open(os.path.join(OUTPUT_DIR, "diarization.json"), "w") as f:
    json.dump(segments, f, indent=2)

del nemo_diarizer
torch.cuda.empty_cache()

# ── ULTRACLEAN EXTRACTION ─────────────────────────────
print("\n[Step 2] Extracting ultraclean speaker tracks...")

def compute_frame_energy(signal, frame_size, hop_size):
    n_frames = max(1, (len(signal) - frame_size) // hop_size + 1)
    energies = np.zeros(n_frames)
    for i in range(n_frames):
        start = i * hop_size
        frame = signal[start:start + frame_size]
        rms = np.sqrt(np.mean(frame ** 2) + 1e-10)
        energies[i] = 20 * np.log10(rms + 1e-10)
    return energies

def energy_trim(signal, sr):
    frame_size = int(ENERGY_FRAME_MS / 1000 * sr)
    hop_size = int(ENERGY_HOP_MS / 1000 * sr)
    if len(signal) < frame_size:
        return [signal] if np.max(np.abs(signal)) > 0.01 else []
    energies = compute_frame_energy(signal, frame_size, hop_size)
    peak_energy = np.max(energies)
    threshold = peak_energy + ENERGY_TRIM_DB
    is_speech = energies > threshold
    smoothed = np.zeros_like(is_speech)
    count = 0
    for i in range(len(is_speech)):
        if is_speech[i]:
            count += 1
        else:
            if count >= MIN_SPEECH_FRAMES:
                smoothed[i - count:i] = True
            count = 0
    if count >= MIN_SPEECH_FRAMES:
        smoothed[len(is_speech) - count:len(is_speech)] = True
    chunks = []
    in_speech = False
    speech_start = 0
    for i in range(len(smoothed)):
        if smoothed[i] and not in_speech:
            speech_start = i * hop_size
            in_speech = True
        elif not smoothed[i] and in_speech:
            speech_end = min(i * hop_size + frame_size, len(signal))
            if speech_end > speech_start:
                chunks.append(signal[speech_start:speech_end])
            in_speech = False
    if in_speech:
        speech_end = min(len(smoothed) * hop_size + frame_size, len(signal))
        if speech_end > speech_start:
            chunks.append(signal[speech_start:speech_end])
    return chunks

crossfade_samples = int(CROSSFADE_MS / 1000 * sample_rate)

for spk in sorted(speakers):
    all_chunks = []
    kept = 0
    dropped = 0

    for seg_start, seg_end in speakers[spk]:
        if seg_end - seg_start < MIN_SEGMENT_SEC:
            dropped += 1
            continue

        s = int(seg_start * sample_rate)
        e = min(int(seg_end * sample_rate), len(audio_data))
        raw_chunk = audio_data[s:e].copy()
        if len(raw_chunk) == 0:
            continue

        clean_chunks = energy_trim(raw_chunk, sample_rate)
        if not clean_chunks:
            dropped += 1
            continue

        for chunk in clean_chunks:
            fade_len = min(crossfade_samples, len(chunk) // 4)
            if fade_len > 0:
                chunk[:fade_len] *= np.linspace(0, 1, fade_len)
                chunk[-fade_len:] *= np.linspace(1, 0, fade_len)
            all_chunks.append(chunk)
        kept += 1

    if not all_chunks:
        print(f"  {spk}: no speech")
        continue

    clean_audio = np.concatenate(all_chunks)
    total_dur = len(clean_audio) / sample_rate
    out_path = os.path.join(OUTPUT_DIR, f"{spk}_clean.wav")
    sf.write(out_path, clean_audio, sample_rate)
    print(f"  {spk}: kept {kept}, dropped {dropped} → {out_path} ({total_dur:.1f}s)")

print("\nDone!")
