#!/usr/bin/env python3
"""
Extract clean per-speaker audio from NeMo diarization.
- No silence between segments (concatenated speech only)
- Skip overlapping regions
- Adds small crossfade between segments to avoid clicks
"""

import json
import numpy as np
import soundfile as sf
import os

INPUT_FILE = "demucs_out/htdemucs/clip_00/vocals.wav"
DIAR_FILE = "diarization_output/nemo_diarization.json"
OUTPUT_DIR = "diarization_output"
CROSSFADE_MS = 10

print(f"Input: {INPUT_FILE}")
audio_data, sample_rate = sf.read(INPUT_FILE, dtype="float32")
if audio_data.ndim > 1:
    audio_data = audio_data.mean(axis=1)

with open(DIAR_FILE) as f:
    segments = json.load(f)

segments.sort(key=lambda x: x["start"])

# Find overlapping regions to exclude
overlaps = []
for i in range(len(segments) - 1):
    curr_end = segments[i]["end"]
    next_start = segments[i + 1]["start"]
    if curr_end > next_start:
        overlap_start = next_start
        overlap_end = curr_end
        overlaps.append((overlap_start, overlap_end))

print(f"Found {len(overlaps)} overlapping regions")

def is_in_overlap(t, overlaps):
    for os_, oe in overlaps:
        if os_ <= t <= oe:
            return True
    return False

speakers = {}
for seg in segments:
    spk = seg["speaker"]
    if spk not in speakers:
        speakers[spk] = []
    speakers[spk].append((seg["start"], seg["end"]))

crossfade_samples = int(CROSSFADE_MS / 1000 * sample_rate)

for spk in sorted(speakers):
    chunks = []
    total_dur = 0

    for seg_start, seg_end in speakers[spk]:
        # Trim overlap from both ends
        for os_, oe in overlaps:
            if seg_start < oe and seg_end > os_:
                if seg_start >= os_ and seg_end <= oe:
                    seg_start = seg_end
                    break
                elif seg_start < os_:
                    seg_end = min(seg_end, os_)
                elif seg_end > oe:
                    seg_start = max(seg_start, oe)

        if seg_end <= seg_start + 0.05:
            continue

        s = int(seg_start * sample_rate)
        e = min(int(seg_end * sample_rate), len(audio_data))
        chunk = audio_data[s:e].copy()

        if len(chunk) == 0:
            continue

        # Apply short fade in/out to avoid clicks
        fade_len = min(crossfade_samples, len(chunk) // 4)
        if fade_len > 0:
            fade_in = np.linspace(0, 1, fade_len)
            fade_out = np.linspace(1, 0, fade_len)
            chunk[:fade_len] *= fade_in
            chunk[-fade_len:] *= fade_out

        chunks.append(chunk)
        total_dur += len(chunk) / sample_rate

    if not chunks:
        print(f"  {spk}: no speech found")
        continue

    clean_audio = np.concatenate(chunks)

    out_path = os.path.join(OUTPUT_DIR, f"nemo_{spk}_clean.wav")
    sf.write(out_path, clean_audio, sample_rate)
    n_segs = len(chunks)
    print(f"  {out_path} — {total_dur:.1f}s of speech, {n_segs} segments, no silence")

print("\nDone!")
