#!/usr/bin/env python3
"""
SAM Audio multi-prompt speaker separation.

Strategy:
  1. Load existing NeMo diarization timestamps.
  2. For each speaker, use SAM Audio span/temporal prompting:
       - "+" anchors = time ranges where this speaker is active (from diarization)
       - "-" anchors = time ranges where OTHER speakers are active
  3. Process the full audio in CHUNK_SEC-second windows, passing relevant
     anchors per chunk, and stitch the separated tracks.
  4. Also run a secondary text-prompt pass ("person speaking") on each chunk
     for a clean voice extraction baseline.

Output directory: sam_audio_out/
  speaker_0_sam.wav
  speaker_1_sam.wav
  ... (one per speaker from diarization)
"""

import json
import os
import sys
import types
import numpy as np
import torch
import torchaudio
import soundfile as sf

# ── Compatibility patches ────────────────────────────────────────────────────
# torchvision removed functional_tensor in newer versions; stub it so pytorchvideo
# (needed by imagebind) doesn't crash on import.
import torchvision.transforms
if not hasattr(torchvision.transforms, "functional_tensor"):
    import torchvision.transforms.functional as _tf
    stub = types.ModuleType("torchvision.transforms.functional_tensor")
    stub.__dict__.update(
        {k: getattr(_tf, k) for k in dir(_tf) if not k.startswith("__")}
    )
    import sys as _sys
    _sys.modules["torchvision.transforms.functional_tensor"] = stub
    torchvision.transforms.functional_tensor = stub

# ─── CONFIG ──────────────────────────────────────────────────────────────────
AUDIO_48K   = "cooking_input_48k.wav"
DIAR_JSON   = "diarization_output/nemo_diarization.json"
OUTPUT_DIR  = "sam_audio_out"
MODEL_ID    = "facebook/sam-audio-large"
CHUNK_SEC   = 30          # process in 30-second windows
OVERLAP_SEC = 1           # 1-second crossfade overlap between chunks
MIN_ANCHOR_SEC = 1.0      # minimum anchor duration to bother including
MAX_ANCHORS_PER_TYPE = 4  # max +/- anchors per chunk (keep prompt focused)
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"
# ─────────────────────────────────────────────────────────────────────────────

os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Device: {DEVICE}")
print(f"Loading audio: {AUDIO_48K}")
wav_full, sr = torchaudio.load(AUDIO_48K)
assert sr == 48000, f"Expected 48kHz, got {sr}"
wav_full = wav_full.mean(0)   # (T,) mono
total_dur = wav_full.shape[0] / sr
print(f"Audio duration: {total_dur:.1f}s")

print(f"\nLoading diarization: {DIAR_JSON}")
with open(DIAR_JSON) as f:
    segments = json.load(f)

# Build per-speaker segment list
speakers: dict[str, list[tuple[float, float]]] = {}
for seg in segments:
    spk = seg["speaker"]
    speakers.setdefault(spk, []).append((seg["start"], seg["end"]))

print(f"Speakers found: {sorted(speakers.keys())}")
for spk, segs in sorted(speakers.items()):
    total = sum(e - s for s, e in segs)
    print(f"  {spk}: {total:.1f}s across {len(segs)} segments")

print(f"\nLoading SAM Audio model: {MODEL_ID} ...")
from sam_audio import SAMAudio, SAMAudioProcessor
# Disable visual/text rankers — not needed when reranking_candidates=1
model = SAMAudio.from_pretrained(
    MODEL_ID, visual_ranker=None, text_ranker=None
).to(DEVICE).eval()
processor = SAMAudioProcessor.from_pretrained(MODEL_ID)
print("Model loaded.\n")


def get_anchors_for_speaker_in_window(
    target_speaker: str,
    all_speakers: dict[str, list[tuple[float, float]]],
    win_start: float,
    win_end: float,
) -> list[list]:
    """
    Build anchor list for SAM Audio span prompting.
    "+" = target speaker is speaking in this range (positive examples)
    "-" = another speaker is speaking, target is NOT (negative examples)
    Anchors are clipped to [win_start, win_end] and shifted to window-relative time.
    """
    anchors = []

    def clip_and_shift(s, e):
        cs = max(s, win_start) - win_start
        ce = min(e, win_end) - win_start
        return cs, ce

    # Positive anchors: target speaker's segments in this window
    pos = []
    for s, e in all_speakers.get(target_speaker, []):
        if e <= win_start or s >= win_end:
            continue
        cs, ce = clip_and_shift(s, e)
        if ce - cs >= MIN_ANCHOR_SEC:
            pos.append(["+", round(cs, 3), round(ce, 3)])
    # Sort by duration descending, keep top N
    pos.sort(key=lambda x: x[2] - x[1], reverse=True)
    anchors.extend(pos[:MAX_ANCHORS_PER_TYPE])

    # Negative anchors: OTHER speakers' segments in this window
    neg = []
    for spk, segs in all_speakers.items():
        if spk == target_speaker:
            continue
        for s, e in segs:
            if e <= win_start or s >= win_end:
                continue
            cs, ce = clip_and_shift(s, e)
            if ce - cs >= MIN_ANCHOR_SEC:
                neg.append(["-", round(cs, 3), round(ce, 3)])
    neg.sort(key=lambda x: x[2] - x[1], reverse=True)
    anchors.extend(neg[:MAX_ANCHORS_PER_TYPE])

    return anchors


def process_speaker(target_speaker: str) -> np.ndarray:
    """
    Separate target_speaker from the full recording via chunked span prompting.
    Returns a numpy array of the separated waveform at 48kHz.
    """
    chunk_samples  = int(CHUNK_SEC * sr)
    overlap_samples = int(OVERLAP_SEC * sr)
    total_samples  = wav_full.shape[0]

    separated_chunks = []
    chunk_positions  = []

    n_chunks = int(np.ceil(total_samples / chunk_samples))
    print(f"  Processing {n_chunks} chunks of {CHUNK_SEC}s each ...")

    for i in range(n_chunks):
        c_start = i * chunk_samples
        c_end   = min(c_start + chunk_samples, total_samples)
        win_start_t = c_start / sr
        win_end_t   = c_end / sr

        chunk_wav = wav_full[c_start:c_end].unsqueeze(0)  # (1, T)

        anchors = get_anchors_for_speaker_in_window(
            target_speaker, speakers, win_start_t, win_end_t
        )

        if not anchors:
            # No anchor info for this window — emit silence
            sep = torch.zeros(c_end - c_start)
            separated_chunks.append(sep.numpy())
            chunk_positions.append((c_start, c_end))
            print(f"    chunk {i+1}/{n_chunks} [{win_start_t:.1f}-{win_end_t:.1f}s] — no anchors, silence")
            continue

        # Use both span anchors + text description for stronger signal
        description = "person speaking"

        try:
            inputs = processor(
                audios=[chunk_wav],
                descriptions=[description],
                anchors=[anchors],
            ).to(DEVICE)

            with torch.inference_mode():
                result = model.separate(inputs, predict_spans=False, reranking_candidates=1)

            target_wav = result.target[0].cpu()  # (T,)
            separated_chunks.append(target_wav.numpy())
            chunk_positions.append((c_start, c_end))

            n_pos = sum(1 for a in anchors if a[0] == "+")
            n_neg = sum(1 for a in anchors if a[0] == "-")
            print(f"    chunk {i+1}/{n_chunks} [{win_start_t:.1f}-{win_end_t:.1f}s] — {n_pos}+ {n_neg}- anchors ✓")

        except Exception as ex:
            print(f"    chunk {i+1}/{n_chunks} ERROR: {ex} — using silence")
            sep = torch.zeros(c_end - c_start)
            separated_chunks.append(sep.numpy())
            chunk_positions.append((c_start, c_end))

        torch.cuda.empty_cache()

    # Stitch chunks with simple crossfade
    output = np.zeros(total_samples, dtype=np.float32)
    fade = np.linspace(0, 1, overlap_samples) if overlap_samples > 0 else None

    for idx, (chunk, (c_start, c_end)) in enumerate(zip(separated_chunks, chunk_positions)):
        chunk_len = c_end - c_start
        # Trim/pad chunk to match expected length
        if len(chunk) > chunk_len:
            chunk = chunk[:chunk_len]
        elif len(chunk) < chunk_len:
            chunk = np.pad(chunk, (0, chunk_len - len(chunk)))

        if idx == 0 or overlap_samples == 0:
            output[c_start:c_end] = chunk
        else:
            # Crossfade the overlap region
            ov_start = c_start
            ov_end   = min(c_start + overlap_samples, c_end)
            ov_len   = ov_end - ov_start
            fade_in  = np.linspace(0, 1, ov_len)
            output[ov_start:ov_end] = (
                output[ov_start:ov_end] * (1 - fade_in)
                + chunk[:ov_len] * fade_in
            )
            output[ov_end:c_end] = chunk[ov_len:]

    return output


# ─── MULTI-PROMPT TEXT PASS ───────────────────────────────────────────────────
def run_text_prompt_pass(description: str, out_name: str):
    """
    Alternative: run a single text-prompt pass across all chunks.
    Good for getting a general 'voice' track regardless of diarization.
    """
    chunk_samples = int(CHUNK_SEC * sr)
    total_samples = wav_full.shape[0]
    n_chunks = int(np.ceil(total_samples / chunk_samples))
    output   = np.zeros(total_samples, dtype=np.float32)

    print(f"\n[Text prompt] '{description}' → {out_name}")
    for i in range(n_chunks):
        c_start = i * chunk_samples
        c_end   = min(c_start + chunk_samples, total_samples)
        chunk_wav = wav_full[c_start:c_end].unsqueeze(0)

        try:
            inputs = processor(
                audios=[chunk_wav],
                descriptions=[description],
            ).to(DEVICE)
            with torch.inference_mode():
                result = model.separate(inputs, predict_spans=True, reranking_candidates=1)
            target_wav = result.target[0].cpu().numpy()
            chunk_len  = c_end - c_start
            if len(target_wav) > chunk_len:
                target_wav = target_wav[:chunk_len]
            elif len(target_wav) < chunk_len:
                target_wav = np.pad(target_wav, (0, chunk_len - len(target_wav)))
            output[c_start:c_end] = target_wav
            print(f"  chunk {i+1}/{n_chunks} [{c_start/sr:.1f}-{c_end/sr:.1f}s] ✓")
        except Exception as ex:
            print(f"  chunk {i+1}/{n_chunks} ERROR: {ex}")
        torch.cuda.empty_cache()

    out_path = os.path.join(OUTPUT_DIR, out_name)
    sf.write(out_path, output, sr)
    print(f"  Saved: {out_path} ({len(output)/sr:.1f}s)")


# ─── MAIN ─────────────────────────────────────────────────────────────────────

# === PHASE 1: Span-prompted speaker separation (one track per diarized speaker) ===
print("\n" + "="*60)
print("PHASE 1: Span-prompted multi-speaker separation")
print("="*60)

speaker_outputs = {}
for spk in sorted(speakers.keys()):
    print(f"\n[Speaker] {spk}")
    sep_wav = process_speaker(spk)
    out_path = os.path.join(OUTPUT_DIR, f"{spk}_sam_span.wav")
    sf.write(out_path, sep_wav, sr)
    speaker_outputs[spk] = out_path
    dur = len(sep_wav) / sr
    print(f"  → Saved: {out_path}  ({dur:.1f}s)")

# === PHASE 2: Text multi-prompting for voice types ===
print("\n" + "="*60)
print("PHASE 2: Text multi-prompting (voice type separation)")
print("="*60)

text_prompts = [
    ("person speaking",         "all_voices_sam.wav"),
    ("man speaking",            "man_voice_sam.wav"),
    ("woman speaking",          "woman_voice_sam.wav"),
]

for desc, fname in text_prompts:
    run_text_prompt_pass(desc, fname)

# ─── SUMMARY ─────────────────────────────────────────────────────────────────
print("\n" + "="*60)
print("DONE — SAM Audio speaker separation complete")
print(f"Output directory: {OUTPUT_DIR}/")
print("Files:")
for f in sorted(os.listdir(OUTPUT_DIR)):
    if f.endswith(".wav"):
        path = os.path.join(OUTPUT_DIR, f)
        data, rate = sf.read(path)
        print(f"  {f}  ({len(data)/rate:.1f}s)")
