import torch
import torchaudio
import json
import numpy as np
import soundfile as sf
from sam_audio import SAMAudio, SAMAudioProcessor
from pathlib import Path
import time
import gc

OUT = Path("/home/ubuntu/sam_audio_test/sam_multiprompt_fixed")
OUT.mkdir(exist_ok=True)

CHUNK_SEC = 30
GAP_SEC = 0.15

_orig_init = SAMAudio.__init__
def _patched_init(self, cfg):
    saved_vr = cfg.visual_ranker
    cfg.visual_ranker = None
    _orig_init(self, cfg)
    cfg.visual_ranker = saved_vr
SAMAudio.__init__ = _patched_init

print("=== Loading SAM-Audio Large ===")
t0 = time.time()
model = SAMAudio.from_pretrained("facebook/sam-audio-large")
processor = SAMAudioProcessor.from_pretrained("facebook/sam-audio-large")
model = model.eval().cuda()
sr = processor.audio_sampling_rate
print(f"Loaded in {time.time()-t0:.1f}s, sr={sr}")

with open("/home/ubuntu/clip00_lalal_pipeline/diarization.json") as f:
    diar = json.load(f)

speakers = sorted(set(d["speaker"] for d in diar))

waveform, orig_sr = torchaudio.load("/home/ubuntu/clip_00.mp3")
if orig_sr != sr:
    waveform = torchaudio.functional.resample(waveform, orig_sr, sr)
if waveform.shape[0] > 1:
    waveform = waveform.mean(dim=0, keepdim=True)

total_samples = waveform.shape[1]
total_sec = total_samples / sr
chunk_samples = CHUNK_SEC * sr
n_chunks = (total_samples + chunk_samples - 1) // chunk_samples

for spk in speakers:
    spk_segs = [(d["start"], d["end"]) for d in diar if d["speaker"] == spk]
    total_speech = sum(e - s for s, e in spk_segs)
    print(f"\n=== {spk}: {len(spk_segs)} segs, {total_speech:.1f}s ===")

    target_chunks = []
    t_start = time.time()

    for i in range(n_chunks):
        chunk_start = i * CHUNK_SEC
        chunk_end = min((i + 1) * CHUNK_SEC, total_sec)
        s = i * chunk_samples
        e = min(s + chunk_samples, total_samples)
        chunk = waveform[:, s:e]

        chunk_anchors = []
        for seg_s, seg_e in spk_segs:
            if seg_e > chunk_start and seg_s < chunk_end:
                local_s = max(0, seg_s - chunk_start)
                local_e = min(CHUNK_SEC, seg_e - chunk_start)
                if local_e - local_s >= 0.1:
                    chunk_anchors.append(["+", round(local_s, 2), round(local_e, 2)])

        chunk_path = OUT / "_tmp_chunk.wav"
        torchaudio.save(str(chunk_path), chunk, sr)

        if chunk_anchors:
            # Has anchors for this speaker - use them
            batch = processor(
                audios=[str(chunk_path)],
                descriptions=["speaking"],
                anchors=[chunk_anchors[:8]],
            ).to("cuda")
            with torch.inference_mode():
                result = model.separate(batch, predict_spans=False, reranking_candidates=1)
            mode = f"{len(chunk_anchors)} anchors"
        else:
            # No anchors - use predict_spans to auto-detect
            batch = processor(
                audios=[str(chunk_path)],
                descriptions=["speaking"],
            ).to("cuda")
            with torch.inference_mode():
                result = model.separate(batch, predict_spans=True, reranking_candidates=1)
            mode = "predict_spans"

        target_chunks.append(result.target[0].unsqueeze(0).cpu())
        elapsed = time.time() - t_start
        print(f"  Chunk {i+1}/{n_chunks} ({chunk_start:.0f}-{chunk_end:.0f}s) [{mode}] [{elapsed:.1f}s]")

        del batch, result
        torch.cuda.empty_cache()
        gc.collect()

    target_full = torch.cat(target_chunks, dim=-1)
    full_path = OUT / f"{spk}_full.wav"
    torchaudio.save(str(full_path), target_full, sr)

    # Use diarization timestamps to cut instead of energy-based silence detection
    # This ensures we get ALL segments the speaker talks in
    audio_np = target_full.numpy().T
    gap_samples = int(GAP_SEC * sr)
    gap_shape = (gap_samples, audio_np.shape[1]) if audio_np.ndim > 1 else gap_samples
    gap_arr = np.zeros(gap_shape)

    parts = []
    for idx, (seg_s, seg_e) in enumerate(spk_segs):
        ss = int(seg_s * sr)
        se = min(int(seg_e * sr), len(audio_np))
        seg = audio_np[ss:se]
        if len(seg) > 0:
            parts.append(seg)
            if idx < len(spk_segs) - 1:
                parts.append(gap_arr)

    compact = np.concatenate(parts, axis=0) if parts else np.zeros((1, 1))
    compact_path = OUT / f"{spk}_compact.wav"
    sf.write(str(compact_path), compact, sr)
    compact_dur = compact.shape[0] / sr
    print(f"  Full: {total_sec:.1f}s | Compact: {compact_dur:.1f}s ({len(spk_segs)} segs) | Expected: {total_speech:.1f}s")

    (OUT / "_tmp_chunk.wav").unlink(missing_ok=True)
    del target_full, target_chunks
    gc.collect()

print("\n=== DONE ===")
