import torch
import torchaudio
import json
import numpy as np
import soundfile as sf
from sam_audio import SAMAudio, SAMAudioProcessor
from pathlib import Path
import time
import gc

OUT = Path("/home/ubuntu/sam_audio_test")
MODEL_NAME = "facebook/sam-audio-large"
CHUNK_SEC = 30
GAP_SEC = 0.15

_orig_init = SAMAudio.__init__
def _patched_init(self, cfg):
    saved = (cfg.visual_ranker, cfg.text_ranker)
    cfg.visual_ranker = None
    cfg.text_ranker = None
    _orig_init(self, cfg)
    cfg.visual_ranker, cfg.text_ranker = saved
SAMAudio.__init__ = _patched_init

print("=== Loading SAM-Audio ===")
t0 = time.time()
model = SAMAudio.from_pretrained(MODEL_NAME)
processor = SAMAudioProcessor.from_pretrained(MODEL_NAME)
model = model.eval().cuda()
sr = processor.audio_sampling_rate
print(f"Loaded in {time.time()-t0:.1f}s, sr={sr}")

with open("/home/ubuntu/clip00_lalal_pipeline/diarization.json") as f:
    diar = json.load(f)

speakers = sorted(set(d["speaker"] for d in diar))
print(f"Speakers from diarization: {speakers}")

for spk in speakers:
    segs = [(d["start"], d["end"]) for d in diar if d["speaker"] == spk]
    total = sum(e - s for s, e in segs)
    print(f"  {spk}: {len(segs)} segments, {total:.1f}s total")

waveform, orig_sr = torchaudio.load("/home/ubuntu/clip_00.mp3")
if orig_sr != sr:
    waveform = torchaudio.functional.resample(waveform, orig_sr, sr)
if waveform.shape[0] > 1:
    waveform = waveform.mean(dim=0, keepdim=True)

total_samples = waveform.shape[1]
total_sec = total_samples / sr
chunk_samples = CHUNK_SEC * sr
n_chunks = (total_samples + chunk_samples - 1) // chunk_samples

for spk in speakers:
    spk_segs = [(d["start"], d["end"]) for d in diar if d["speaker"] == spk]
    # Pick strong anchors: segments > 3s where this speaker clearly dominates
    strong_anchors = [(s, e) for s, e in spk_segs if (e - s) >= 3.0]

    print(f"\n=== Separating {spk} ({len(strong_anchors)} anchor segments) ===")
    t_start = time.time()
    target_chunks = []

    for i in range(n_chunks):
        chunk_start = i * CHUNK_SEC
        chunk_end = min((i + 1) * CHUNK_SEC, total_sec)
        s = i * chunk_samples
        e = min(s + chunk_samples, total_samples)
        chunk = waveform[:, s:e]

        chunk_path = OUT / "_tmp_chunk.wav"
        torchaudio.save(str(chunk_path), chunk, sr)

        # Find anchors that overlap with this chunk (convert to chunk-local time)
        chunk_anchors = []
        for seg_s, seg_e in strong_anchors:
            if seg_e > chunk_start and seg_s < chunk_end:
                local_s = max(0, seg_s - chunk_start)
                local_e = min(CHUNK_SEC, seg_e - chunk_start)
                if local_e - local_s >= 1.0:
                    chunk_anchors.append(["+", round(local_s, 2), round(local_e, 2)])

        if chunk_anchors:
            batch = processor(
                audios=[str(chunk_path)],
                descriptions=["speaking"],
                anchors=[chunk_anchors[:5]],  # max 5 anchors per chunk
            ).to("cuda")
        else:
            # No anchors for this chunk - use predict_spans with description
            batch = processor(
                audios=[str(chunk_path)],
                descriptions=["speaking"],
            ).to("cuda")

        with torch.inference_mode():
            if chunk_anchors:
                result = model.separate(batch, predict_spans=False, reranking_candidates=1)
            else:
                # No anchors - just pass through silence for this speaker
                target_chunks.append(torch.zeros(1, e - s))
                del batch
                torch.cuda.empty_cache()
                elapsed = time.time() - t_start
                print(f"  Chunk {i+1}/{n_chunks} ({chunk_start:.0f}-{chunk_end:.0f}s) [skip - no anchors] [{elapsed:.1f}s]")
                continue

            result = model.separate(batch, predict_spans=False, reranking_candidates=1)

        target_chunks.append(result.target[0].unsqueeze(0).cpu())
        elapsed = time.time() - t_start
        print(f"  Chunk {i+1}/{n_chunks} ({chunk_start:.0f}-{chunk_end:.0f}s) [{len(chunk_anchors)} anchors] [{elapsed:.1f}s]")

        del batch, result
        torch.cuda.empty_cache()
        gc.collect()

    target_full = torch.cat(target_chunks, dim=-1)
    full_path = OUT / f"{spk}_sam_full.wav"
    torchaudio.save(str(full_path), target_full, sr)
    print(f"  Saved full track: {full_path} ({target_full.shape[-1]/sr:.1f}s)")

    # Strip silence for compact version
    audio_np = target_full.numpy().T
    mono = audio_np.mean(axis=1)
    threshold = 10 ** (-40 / 20)
    frame_len = int(0.02 * sr)
    hop = frame_len // 2
    n_frames = (len(mono) - frame_len) // hop + 1
    rms = np.array([np.sqrt(np.mean(mono[j*hop:j*hop+frame_len]**2)) for j in range(n_frames)])
    is_speech = rms > threshold
    min_sil = int(0.3 / (hop / sr))
    min_sp = int(0.1 / (hop / sr))
    for j in range(len(is_speech)):
        if not is_speech[j]:
            start = j
            while j < len(is_speech) and not is_speech[j]: j += 1
            if (j - start) < min_sil: is_speech[start:j] = True
    j = 0
    while j < len(is_speech):
        if is_speech[j]:
            start = j
            while j < len(is_speech) and is_speech[j]: j += 1
            if (j - start) < min_sp: is_speech[start:j] = False
        else: j += 1
    segs = []
    j = 0
    while j < len(is_speech):
        if is_speech[j]:
            start = j
            while j < len(is_speech) and is_speech[j]: j += 1
            ss = max(0, start*hop - int(0.02*sr))
            se = min(len(audio_np), j*hop + int(0.02*sr))
            segs.append(audio_np[ss:se])
        else: j += 1
    gap = np.zeros((int(GAP_SEC*sr), audio_np.shape[1]))
    parts = []
    for idx, seg in enumerate(segs):
        parts.append(seg)
        if idx < len(segs)-1: parts.append(gap)
    compact = np.concatenate(parts, axis=0)
    compact_path = OUT / f"{spk}_sam_compact.wav"
    sf.write(str(compact_path), compact, sr)
    print(f"  Compact: {len(audio_np)/sr:.1f}s -> {len(compact)/sr:.1f}s ({len(segs)} segments)")

    (OUT / "_tmp_chunk.wav").unlink(missing_ok=True)
    del target_full, target_chunks
    gc.collect()

print("\n=== DONE ===")
