#!/usr/bin/env python3
"""
Ultra-clean speaker extraction using:
1. NeMo diarization (coarse speaker assignment)
2. Pyannote overlap detection (find where 2+ people talk)
3. Speaker embedding verification (drop segments that don't match speaker)
"""

import json
import numpy as np
import soundfile as sf
import torch
import torchaudio
if not hasattr(torchaudio, 'list_audio_backends'):
    torchaudio.list_audio_backends = lambda: ["soundfile"]
import os

INPUT_FILE = "demucs_out/htdemucs/clip_00/vocals.wav"
DIAR_FILE = "diarization_output/nemo_diarization.json"
OUTPUT_DIR = "diarization_output"

MIN_SEGMENT_SEC = 0.8
CROSSFADE_MS = 10
EMBED_SIMILARITY_THRESHOLD = 0.25  # cosine sim threshold to keep segment

print(f"Input: {INPUT_FILE}")
audio_data, sample_rate = sf.read(INPUT_FILE, dtype="float32")
if audio_data.ndim > 1:
    audio_data = audio_data.mean(axis=1)

with open(DIAR_FILE) as f:
    segments = json.load(f)
segments.sort(key=lambda x: x["start"])

# ── 1. PYANNOTE OVERLAP DETECTION ────────────────────
print("\n[Step 1] Running pyannote overlap detection...")
from pyannote.audio import Pipeline as PyAnnotePipeline

pipeline = PyAnnotePipeline.from_pretrained("pyannote/speaker-diarization-3.1")
pipeline.to(torch.device("cuda"))

result = pipeline(INPUT_FILE)
diarization = result.speaker_diarization

# Find overlapping regions from pyannote
# Overlap = any time where pyannote says 2+ speakers are active
timeline_points = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
    timeline_points.append((turn.start, 'start', speaker))
    timeline_points.append((turn.end, 'end', speaker))
timeline_points.sort(key=lambda x: x[0])

overlap_regions = []
active_speakers = set()
overlap_start = None

for t, event, spk in timeline_points:
    if event == 'start':
        active_speakers.add((spk, t))
        if len(active_speakers) >= 2 and overlap_start is None:
            overlap_start = t
    else:
        active_speakers = {(s, st) for s, st in active_speakers if s != spk}
        if len(active_speakers) < 2 and overlap_start is not None:
            overlap_regions.append((overlap_start, t))
            overlap_start = None

total_overlap = sum(e - s for s, e in overlap_regions)
print(f"  Found {len(overlap_regions)} overlap regions ({total_overlap:.1f}s total)")

del pipeline, result, diarization
torch.cuda.empty_cache()

# ── 2. SPEAKER EMBEDDING MODEL ───────────────────────
print("\n[Step 2] Loading speaker embedding model...")
from speechbrain.inference.speaker import EncoderClassifier

embed_model = EncoderClassifier.from_hparams(
    source="speechbrain/spkrec-ecapa-voxceleb",
    run_opts={"device": "cuda"}
)


def get_embedding(audio_segment, sr):
    """Get speaker embedding for an audio segment."""
    if len(audio_segment) < sr * 0.3:
        return None
    waveform = torch.tensor(audio_segment, dtype=torch.float32).unsqueeze(0)
    with torch.no_grad():
        emb = embed_model.encode_batch(waveform.to("cuda"))
    return emb.squeeze().cpu().numpy()


def cosine_sim(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-8)


# ── 3. BUILD SPEAKER CENTROIDS ───────────────────────
print("\n[Step 3] Computing speaker centroids from long segments...")
speakers = {}
for seg in segments:
    spk = seg["speaker"]
    if spk not in speakers:
        speakers[spk] = []
    speakers[spk].append((seg["start"], seg["end"]))

speaker_centroids = {}
for spk in sorted(speakers):
    embeddings = []
    # Use only segments > 3s for centroid (most reliable)
    long_segs = [(s, e) for s, e in speakers[spk] if e - s > 3.0]
    for seg_start, seg_end in long_segs[:10]:  # use up to 10 long segments
        s = int(seg_start * sample_rate)
        e = min(int(seg_end * sample_rate), len(audio_data))
        emb = get_embedding(audio_data[s:e], sample_rate)
        if emb is not None:
            embeddings.append(emb)

    if embeddings:
        speaker_centroids[spk] = np.mean(embeddings, axis=0)
        print(f"  {spk}: centroid from {len(embeddings)} long segments")
    else:
        print(f"  {spk}: WARNING - no long segments for centroid")

# ── 4. EXTRACT & VERIFY ──────────────────────────────
print("\n[Step 4] Extracting and verifying segments...")
crossfade_samples = int(CROSSFADE_MS / 1000 * sample_rate)


def overlaps_with(seg_start, seg_end, regions):
    """Check if segment overlaps with any region. Return non-overlapping parts."""
    clean_parts = [(seg_start, seg_end)]
    for ov_start, ov_end in regions:
        new_parts = []
        for ps, pe in clean_parts:
            if pe <= ov_start or ps >= ov_end:
                new_parts.append((ps, pe))
            else:
                if ps < ov_start:
                    new_parts.append((ps, ov_start))
                if pe > ov_end:
                    new_parts.append((ov_end, pe))
        clean_parts = new_parts
    return clean_parts


for spk in sorted(speakers):
    all_chunks = []
    stats = {"total": 0, "short": 0, "overlap_trimmed": 0, "embedding_rejected": 0, "kept": 0}

    centroid = speaker_centroids.get(spk)

    for seg_start, seg_end in speakers[spk]:
        stats["total"] += 1

        if seg_end - seg_start < MIN_SEGMENT_SEC:
            stats["short"] += 1
            continue

        # Remove overlap regions from this segment
        clean_parts = overlaps_with(seg_start, seg_end, overlap_regions)

        for part_start, part_end in clean_parts:
            if part_end - part_start < MIN_SEGMENT_SEC:
                stats["overlap_trimmed"] += 1
                continue

            s = int(part_start * sample_rate)
            e = min(int(part_end * sample_rate), len(audio_data))
            chunk = audio_data[s:e].copy()

            if len(chunk) == 0:
                continue

            # Verify speaker embedding matches centroid
            if centroid is not None and part_end - part_start < 5.0:
                emb = get_embedding(chunk, sample_rate)
                if emb is not None:
                    sim = cosine_sim(emb, centroid)
                    if sim < EMBED_SIMILARITY_THRESHOLD:
                        stats["embedding_rejected"] += 1
                        continue

            # Crossfade
            fade_len = min(crossfade_samples, len(chunk) // 4)
            if fade_len > 0:
                chunk[:fade_len] *= np.linspace(0, 1, fade_len)
                chunk[-fade_len:] *= np.linspace(1, 0, fade_len)

            all_chunks.append(chunk)
            stats["kept"] += 1

    if not all_chunks:
        print(f"\n  {spk}: no speech found")
        continue

    clean_audio = np.concatenate(all_chunks)
    total_dur = len(clean_audio) / sample_rate

    out_path = os.path.join(OUTPUT_DIR, f"nemo_{spk}_verified.wav")
    sf.write(out_path, clean_audio, sample_rate)

    print(f"\n  {spk}:")
    print(f"    Total segments:     {stats['total']}")
    print(f"    Dropped (short):    {stats['short']}")
    print(f"    Dropped (overlap):  {stats['overlap_trimmed']}")
    print(f"    Dropped (wrong spk):{stats['embedding_rejected']}")
    print(f"    Kept:               {stats['kept']}")
    print(f"    → {out_path} — {total_dur:.1f}s")

print("\nDone!")
