#!/usr/bin/env python3
"""
Pure SAM Audio iterative speaker separation.

No NeMo. No gender assumptions. No assumed speaker count.

Strategy:
  1. Use SAM Audio text prompt "person speaking" with predict_spans=True
     to extract one speaker from the mix.
  2. Take the residual (everything left), feed it back, and extract the next
     speaker with the same prompt.
  3. Repeat until the residual has negligible speech energy.
  4. Save each extracted speaker as speaker_1.wav, speaker_2.wav, etc.
"""

import os
import types
import numpy as np
import torch
import torchaudio
import soundfile as sf

import torchvision.transforms
if not hasattr(torchvision.transforms, "functional_tensor"):
    import torchvision.transforms.functional as _tf
    stub = types.ModuleType("torchvision.transforms.functional_tensor")
    stub.__dict__.update(
        {k: getattr(_tf, k) for k in dir(_tf) if not k.startswith("__")}
    )
    import sys as _sys
    _sys.modules["torchvision.transforms.functional_tensor"] = stub
    torchvision.transforms.functional_tensor = stub

AUDIO_FILE = "cooking_input_48k.wav"
OUTPUT_DIR = "sam_audio_pure_out"
MODEL_ID   = "facebook/sam-audio-large"
CHUNK_SEC  = 30
MAX_SPEAKERS = 6
MIN_ENERGY_DB = -40  # stop extracting when residual RMS is below this
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"

os.makedirs(OUTPUT_DIR, exist_ok=True)
torch.set_float32_matmul_precision("high")

print(f"Device: {DEVICE}")
print(f"Loading: {AUDIO_FILE}")
wav_full, sr = torchaudio.load(AUDIO_FILE)
assert sr == 48000, f"Expected 48kHz, got {sr}"
wav_full = wav_full.mean(0)  # mono (T,)
total_dur = wav_full.shape[0] / sr
print(f"Duration: {total_dur:.1f}s ({total_dur/60:.1f} min)")

print(f"\nLoading SAM Audio: {MODEL_ID} ...")
from sam_audio import SAMAudio, SAMAudioProcessor

model = SAMAudio.from_pretrained(
    MODEL_ID, visual_ranker=None, text_ranker=None
).to(DEVICE).eval()
processor = SAMAudioProcessor.from_pretrained(MODEL_ID)
print("Model loaded.\n")

chunk_samples = int(CHUNK_SEC * sr)


def rms_db(signal: np.ndarray) -> float:
    rms = np.sqrt(np.mean(signal ** 2) + 1e-10)
    return 20 * np.log10(rms + 1e-10)


def separate_from_tensor(audio_tensor: torch.Tensor, description: str) -> tuple[np.ndarray, np.ndarray]:
    """
    Run SAM Audio on audio_tensor (1D, 48kHz) in chunks.
    Returns (target, residual) numpy arrays.
    """
    total_samples = audio_tensor.shape[0]
    n_chunks = int(np.ceil(total_samples / chunk_samples))

    target_out   = np.zeros(total_samples, dtype=np.float32)
    residual_out = np.zeros(total_samples, dtype=np.float32)

    for i in range(n_chunks):
        c_start = i * chunk_samples
        c_end   = min(c_start + chunk_samples, total_samples)
        chunk_len = c_end - c_start
        chunk_wav = audio_tensor[c_start:c_end].unsqueeze(0)  # (1, T)

        try:
            inputs = processor(
                audios=[chunk_wav],
                descriptions=[description],
            ).to(DEVICE)

            with torch.inference_mode():
                result = model.separate(inputs, predict_spans=True, reranking_candidates=1)

            t = result.target[0].cpu().numpy()
            r = result.residual[0].cpu().numpy()

            if len(t) > chunk_len: t = t[:chunk_len]
            elif len(t) < chunk_len: t = np.pad(t, (0, chunk_len - len(t)))
            if len(r) > chunk_len: r = r[:chunk_len]
            elif len(r) < chunk_len: r = np.pad(r, (0, chunk_len - len(r)))

            target_out[c_start:c_end] = t
            residual_out[c_start:c_end] = r

            print(f"    [{i+1}/{n_chunks}] {c_start/sr:.0f}-{c_end/sr:.0f}s done")
        except Exception as ex:
            print(f"    [{i+1}/{n_chunks}] {c_start/sr:.0f}-{c_end/sr:.0f}s ERROR: {ex}")

        torch.cuda.empty_cache()

    return target_out, residual_out


# ─── ITERATIVE SPEAKER EXTRACTION ────────────────────────────────────────────

current_audio = wav_full.clone()
speaker_num = 0

print("="*60)
print("Iterative speaker extraction — SAM Audio only")
print("Prompt: \"person speaking\" (let the model find whoever is loudest)")
print("="*60)

while speaker_num < MAX_SPEAKERS:
    energy = rms_db(current_audio.numpy())
    print(f"\nRemaining audio energy: {energy:.1f} dB")

    if energy < MIN_ENERGY_DB:
        print(f"Residual energy below {MIN_ENERGY_DB} dB — stopping extraction.")
        break

    speaker_num += 1
    print(f"\n--- Extracting speaker {speaker_num} ---")

    target, residual = separate_from_tensor(current_audio, "person speaking")

    target_energy  = rms_db(target)
    residual_energy = rms_db(residual)
    print(f"  Extracted energy: {target_energy:.1f} dB")
    print(f"  Residual energy:  {residual_energy:.1f} dB")

    if target_energy < MIN_ENERGY_DB:
        print(f"  Extracted speaker has negligible energy — discarding, stopping.")
        break

    target_path = os.path.join(OUTPUT_DIR, f"speaker_{speaker_num}.wav")
    sf.write(target_path, target, sr)
    print(f"  Saved: {target_path}")

    residual_path = os.path.join(OUTPUT_DIR, f"residual_after_speaker_{speaker_num}.wav")
    sf.write(residual_path, residual, sr)
    print(f"  Saved: {residual_path}")

    current_audio = torch.from_numpy(residual)

# ─── SUMMARY ─────────────────────────────────────────────────────────────────
print(f"\n{'='*60}")
print(f"DONE — Extracted {speaker_num} speaker(s)")
print(f"Output: {OUTPUT_DIR}/")
print(f"{'='*60}")
for f in sorted(os.listdir(OUTPUT_DIR)):
    if f.endswith(".wav"):
        path = os.path.join(OUTPUT_DIR, f)
        info = sf.info(path)
        print(f"  {f}  ({info.duration:.1f}s)")
