import torch
import torchaudio
import json
import numpy as np
import soundfile as sf
from sam_audio import SAMAudio, SAMAudioProcessor
from pathlib import Path
import time
import gc

OUT = Path("/home/ubuntu/sam_audio_test/sam_multiprompt_out")
OUT.mkdir(exist_ok=True)

CHUNK_SEC = 30
GAP_SEC = 0.15

# Skip visual ranker only - KEEP text ranker for reranking
_orig_init = SAMAudio.__init__
def _patched_init(self, cfg):
    saved_vr = cfg.visual_ranker
    cfg.visual_ranker = None
    _orig_init(self, cfg)
    cfg.visual_ranker = saved_vr
SAMAudio.__init__ = _patched_init

print("=== Loading SAM-Audio Large (with text ranker) ===")
t0 = time.time()
model = SAMAudio.from_pretrained("facebook/sam-audio-large")
processor = SAMAudioProcessor.from_pretrained("facebook/sam-audio-large")
model = model.eval().cuda()
sr = processor.audio_sampling_rate
print(f"Loaded in {time.time()-t0:.1f}s, sr={sr}")
mem = torch.cuda.mem_get_info()
print(f"GPU: {mem[1]/1e9:.1f}GB total, {mem[0]/1e9:.1f}GB free")

with open("/home/ubuntu/clip00_lalal_pipeline/diarization.json") as f:
    diar = json.load(f)

speakers = sorted(set(d["speaker"] for d in diar))

waveform, orig_sr = torchaudio.load("/home/ubuntu/clip_00.mp3")
if orig_sr != sr:
    waveform = torchaudio.functional.resample(waveform, orig_sr, sr)
if waveform.shape[0] > 1:
    waveform = waveform.mean(dim=0, keepdim=True)

total_samples = waveform.shape[1]
total_sec = total_samples / sr
chunk_samples = CHUNK_SEC * sr
n_chunks = (total_samples + chunk_samples - 1) // chunk_samples

for spk in speakers:
    spk_segs = [(d["start"], d["end"]) for d in diar if d["speaker"] == spk]
    total_speech = sum(e - s for s, e in spk_segs)
    print(f"\n=== {spk}: {len(spk_segs)} segs, {total_speech:.1f}s ===")

    target_chunks = []
    t_start = time.time()

    for i in range(n_chunks):
        chunk_start = i * CHUNK_SEC
        chunk_end = min((i + 1) * CHUNK_SEC, total_sec)
        s = i * chunk_samples
        e = min(s + chunk_samples, total_samples)
        chunk = waveform[:, s:e]

        # All anchors for this speaker in this chunk
        chunk_anchors = []
        for seg_s, seg_e in spk_segs:
            if seg_e > chunk_start and seg_s < chunk_end:
                local_s = max(0, seg_s - chunk_start)
                local_e = min(CHUNK_SEC, seg_e - chunk_start)
                if local_e - local_s >= 0.1:
                    chunk_anchors.append(["+", round(local_s, 2), round(local_e, 2)])

        chunk_path = OUT / "_tmp_chunk.wav"
        torchaudio.save(str(chunk_path), chunk, sr)

        if not chunk_anchors:
            target_chunks.append(torch.zeros(1, e - s))
            print(f"  Chunk {i+1}/{n_chunks} ({chunk_start:.0f}-{chunk_end:.0f}s) [skip]")
            continue

        # Multi-prompt: text "speaking" + span anchors + predict_spans=True
        batch = processor(
            audios=[str(chunk_path)],
            descriptions=["speaking"],
            anchors=[chunk_anchors[:8]],
        ).to("cuda")

        with torch.inference_mode():
            result = model.separate(
                batch,
                predict_spans=True,   # KEY: model finds additional spans
                reranking_candidates=1,
            )

        target_chunks.append(result.target[0].unsqueeze(0).cpu())
        elapsed = time.time() - t_start
        print(f"  Chunk {i+1}/{n_chunks} ({chunk_start:.0f}-{chunk_end:.0f}s) [{len(chunk_anchors)} anchors] [{elapsed:.1f}s]")

        del batch, result
        torch.cuda.empty_cache()
        gc.collect()

    target_full = torch.cat(target_chunks, dim=-1)
    full_path = OUT / f"{spk}_full.wav"
    torchaudio.save(str(full_path), target_full, sr)

    # Compact
    audio_np = target_full.numpy().T
    mono = audio_np.mean(axis=1) if audio_np.ndim > 1 else audio_np
    threshold = 10 ** (-38 / 20)
    frame_len = int(0.02 * sr)
    hop = frame_len // 2
    n_frames = (len(mono) - frame_len) // hop + 1
    rms = np.array([np.sqrt(np.mean(mono[j*hop:j*hop+frame_len]**2)) for j in range(n_frames)])
    is_speech = rms > threshold
    min_sil = int(0.3 / (hop / sr))
    min_sp = int(0.1 / (hop / sr))
    for j in range(len(is_speech)):
        if not is_speech[j]:
            start = j
            while j < len(is_speech) and not is_speech[j]: j += 1
            if (j - start) < min_sil: is_speech[start:j] = True
    j = 0
    while j < len(is_speech):
        if is_speech[j]:
            start = j
            while j < len(is_speech) and is_speech[j]: j += 1
            if (j - start) < min_sp: is_speech[start:j] = False
        else: j += 1
    segs = []
    j = 0
    while j < len(is_speech):
        if is_speech[j]:
            start = j
            while j < len(is_speech) and is_speech[j]: j += 1
            ss = max(0, start*hop - int(0.02*sr))
            se = min(len(audio_np), j*hop + int(0.02*sr))
            segs.append(audio_np[ss:se])
        else: j += 1
    gap_arr = np.zeros((int(GAP_SEC*sr), audio_np.shape[1]) if audio_np.ndim > 1 else int(GAP_SEC*sr))
    parts = []
    for idx, seg in enumerate(segs):
        parts.append(seg)
        if idx < len(segs)-1: parts.append(gap_arr)
    compact = np.concatenate(parts, axis=0) if parts else np.zeros((1, 1))
    compact_path = OUT / f"{spk}_compact.wav"
    sf.write(str(compact_path), compact, sr)
    compact_dur = len(compact) / sr if compact.ndim == 1 else compact.shape[0] / sr
    print(f"  Full: {total_sec:.1f}s | Compact: {compact_dur:.1f}s ({len(segs)} segs) | Expected: {total_speech:.1f}s")

    (OUT / "_tmp_chunk.wav").unlink(missing_ok=True)
    del target_full, target_chunks
    gc.collect()

print("\n=== DONE ===")
