#!/usr/bin/env python3
"""
Extract ultra-clean per-speaker audio from NeMo diarization.
- Filter out short segments (crosstalk, laughs, single words)
- Energy-based VAD to trim within segments
- No silence between segments
- Crossfade between cuts
"""

import json
import numpy as np
import soundfile as sf
import os

INPUT_FILE = "demucs_out/htdemucs/clip_00/vocals.wav"
DIAR_FILE = "diarization_output/nemo_diarization.json"
OUTPUT_DIR = "diarization_output"

MIN_SEGMENT_SEC = 0.8       # drop segments shorter than this
ENERGY_TRIM_DB = -35        # trim sub-segments below this energy (relative to segment peak)
ENERGY_FRAME_MS = 25        # frame size for energy calculation
ENERGY_HOP_MS = 10          # hop size for energy calculation
MIN_SPEECH_FRAMES = 5       # minimum consecutive speech frames to keep
CROSSFADE_MS = 10

print(f"Input: {INPUT_FILE}")
audio_data, sample_rate = sf.read(INPUT_FILE, dtype="float32")
if audio_data.ndim > 1:
    audio_data = audio_data.mean(axis=1)

with open(DIAR_FILE) as f:
    segments = json.load(f)

segments.sort(key=lambda x: x["start"])


def compute_frame_energy(signal, frame_size, hop_size):
    """Compute per-frame energy in dB relative to peak."""
    n_frames = max(1, (len(signal) - frame_size) // hop_size + 1)
    energies = np.zeros(n_frames)
    for i in range(n_frames):
        start = i * hop_size
        frame = signal[start:start + frame_size]
        rms = np.sqrt(np.mean(frame ** 2) + 1e-10)
        energies[i] = 20 * np.log10(rms + 1e-10)
    return energies


def energy_trim(signal, sr, threshold_db, frame_ms, hop_ms, min_speech_frames):
    """Remove low-energy regions from a signal. Returns list of clean chunks."""
    frame_size = int(frame_ms / 1000 * sr)
    hop_size = int(hop_ms / 1000 * sr)

    if len(signal) < frame_size:
        return [signal] if np.max(np.abs(signal)) > 0.01 else []

    energies = compute_frame_energy(signal, frame_size, hop_size)
    peak_energy = np.max(energies)
    threshold = peak_energy + threshold_db

    is_speech = energies > threshold

    # Smooth: require min_speech_frames consecutive speech frames
    smoothed = np.zeros_like(is_speech)
    count = 0
    for i in range(len(is_speech)):
        if is_speech[i]:
            count += 1
        else:
            if count >= min_speech_frames:
                smoothed[i - count:i] = True
            count = 0
    if count >= min_speech_frames:
        smoothed[len(is_speech) - count:len(is_speech)] = True

    # Convert frame indices back to sample indices
    chunks = []
    in_speech = False
    speech_start = 0

    for i in range(len(smoothed)):
        if smoothed[i] and not in_speech:
            speech_start = i * hop_size
            in_speech = True
        elif not smoothed[i] and in_speech:
            speech_end = min(i * hop_size + frame_size, len(signal))
            chunk = signal[speech_start:speech_end]
            if len(chunk) > 0:
                chunks.append(chunk)
            in_speech = False

    if in_speech:
        speech_end = min(len(smoothed) * hop_size + frame_size, len(signal))
        chunk = signal[speech_start:speech_end]
        if len(chunk) > 0:
            chunks.append(chunk)

    return chunks


speakers = {}
for seg in segments:
    spk = seg["speaker"]
    if spk not in speakers:
        speakers[spk] = []
    speakers[spk].append((seg["start"], seg["end"]))

crossfade_samples = int(CROSSFADE_MS / 1000 * sample_rate)

for spk in sorted(speakers):
    all_chunks = []
    skipped_short = 0
    skipped_quiet = 0
    kept = 0

    for seg_start, seg_end in speakers[spk]:
        seg_dur = seg_end - seg_start

        if seg_dur < MIN_SEGMENT_SEC:
            skipped_short += 1
            continue

        s = int(seg_start * sample_rate)
        e = min(int(seg_end * sample_rate), len(audio_data))
        raw_chunk = audio_data[s:e].copy()

        if len(raw_chunk) == 0:
            continue

        # Energy-based VAD within segment
        clean_chunks = energy_trim(
            raw_chunk, sample_rate,
            ENERGY_TRIM_DB, ENERGY_FRAME_MS, ENERGY_HOP_MS, MIN_SPEECH_FRAMES
        )

        if not clean_chunks:
            skipped_quiet += 1
            continue

        for chunk in clean_chunks:
            fade_len = min(crossfade_samples, len(chunk) // 4)
            if fade_len > 0:
                chunk[:fade_len] *= np.linspace(0, 1, fade_len)
                chunk[-fade_len:] *= np.linspace(1, 0, fade_len)
            all_chunks.append(chunk)

        kept += 1

    if not all_chunks:
        print(f"  {spk}: no speech found")
        continue

    clean_audio = np.concatenate(all_chunks)
    total_dur = len(clean_audio) / sample_rate

    out_path = os.path.join(OUTPUT_DIR, f"nemo_{spk}_ultraclean.wav")
    sf.write(out_path, clean_audio, sample_rate)

    orig_segs = len(speakers[spk])
    print(f"  {spk}:")
    print(f"    Kept {kept}/{orig_segs} segments (dropped {skipped_short} short, {skipped_quiet} quiet)")
    print(f"    {out_path} — {total_dur:.1f}s pure speech")

print("\nDone!")
