import torch
import torchaudio
import numpy as np
import soundfile as sf
from sam_audio import SAMAudio, SAMAudioProcessor
from pathlib import Path
import time
import gc

OUT = Path("/home/ubuntu/sam_audio_test/sam_v3_out")
OUT.mkdir(exist_ok=True)

CHUNK_SEC = 10
GAP_SEC = 0.15

# Skip visual ranker only
_orig_init = SAMAudio.__init__
def _patched_init(self, cfg):
    saved_vr = cfg.visual_ranker
    cfg.visual_ranker = None
    _orig_init(self, cfg)
    cfg.visual_ranker = saved_vr
SAMAudio.__init__ = _patched_init

print("=== Loading SAM-Audio Large ===")
t0 = time.time()
model = SAMAudio.from_pretrained("facebook/sam-audio-large")
processor = SAMAudioProcessor.from_pretrained("facebook/sam-audio-large")
model = model.eval().cuda()
sr = processor.audio_sampling_rate
print(f"Loaded in {time.time()-t0:.1f}s, sr={sr}")

waveform, orig_sr = torchaudio.load("/home/ubuntu/clip_00.mp3")
if orig_sr != sr:
    waveform = torchaudio.functional.resample(waveform, orig_sr, sr)
if waveform.shape[0] > 1:
    waveform = waveform.mean(dim=0, keepdim=True)

total_samples = waveform.shape[1]
total_sec = total_samples / sr
chunk_samples = CHUNK_SEC * sr
n_chunks = (total_samples + chunk_samples - 1) // chunk_samples

# Pure SAM Audio: no anchors, predict_spans=True, reranking=4
# Try different speaker descriptions
prompts = [
    ("first_speaker", "first speaker"),
    ("second_speaker", "second speaker"),
]

for label, description in prompts:
    print(f"\n=== '{description}' (predict_spans=True, reranking=4) ===")
    target_chunks = []
    t_start = time.time()

    for i in range(n_chunks):
        chunk_start = i * CHUNK_SEC
        chunk_end = min((i + 1) * CHUNK_SEC, total_sec)
        s = i * chunk_samples
        e = min(s + chunk_samples, total_samples)
        chunk = waveform[:, s:e]

        chunk_path = OUT / "_tmp_chunk.wav"
        torchaudio.save(str(chunk_path), chunk, sr)

        # NO anchors - let predict_spans + text do the work
        batch = processor(
            audios=[str(chunk_path)],
            descriptions=[description],
        ).to("cuda")

        with torch.inference_mode():
            result = model.separate(
                batch,
                predict_spans=False,
                reranking_candidates=1,
            )

        target_chunks.append(result.target[0].unsqueeze(0).cpu())
        elapsed = time.time() - t_start
        print(f"  Chunk {i+1}/{n_chunks} ({chunk_start:.0f}-{chunk_end:.0f}s) [{elapsed:.1f}s]")

        del batch, result
        torch.cuda.empty_cache()
        gc.collect()

    target_full = torch.cat(target_chunks, dim=-1)
    full_path = OUT / f"{label}_full.wav"
    torchaudio.save(str(full_path), target_full, sr)

    # Compact
    audio_np = target_full.numpy().T
    mono = audio_np.mean(axis=1) if audio_np.ndim > 1 else audio_np
    threshold = 10 ** (-38 / 20)
    frame_len = int(0.02 * sr)
    hop = frame_len // 2
    n_frames = (len(mono) - frame_len) // hop + 1
    rms = np.array([np.sqrt(np.mean(mono[j*hop:j*hop+frame_len]**2)) for j in range(n_frames)])
    is_speech = rms > threshold
    min_sil = int(0.3 / (hop / sr))
    min_sp = int(0.1 / (hop / sr))
    for j in range(len(is_speech)):
        if not is_speech[j]:
            start = j
            while j < len(is_speech) and not is_speech[j]: j += 1
            if (j - start) < min_sil: is_speech[start:j] = True
    j = 0
    while j < len(is_speech):
        if is_speech[j]:
            start = j
            while j < len(is_speech) and is_speech[j]: j += 1
            if (j - start) < min_sp: is_speech[start:j] = False
        else: j += 1
    segs = []
    j = 0
    while j < len(is_speech):
        if is_speech[j]:
            start = j
            while j < len(is_speech) and is_speech[j]: j += 1
            ss = max(0, start*hop - int(0.02*sr))
            se = min(len(audio_np), j*hop + int(0.02*sr))
            segs.append(audio_np[ss:se])
        else: j += 1
    gap_shape = (int(GAP_SEC*sr), audio_np.shape[1]) if audio_np.ndim > 1 else int(GAP_SEC*sr)
    gap_arr = np.zeros(gap_shape)
    parts = []
    for idx, seg in enumerate(segs):
        parts.append(seg)
        if idx < len(segs)-1: parts.append(gap_arr)
    compact = np.concatenate(parts, axis=0) if parts else np.zeros((1, 1))
    compact_path = OUT / f"{label}_compact.wav"
    sf.write(str(compact_path), compact, sr)
    compact_dur = compact.shape[0] / sr
    print(f"  Full: {total_sec:.1f}s | Compact: {compact_dur:.1f}s ({len(segs)} segs)")

    (OUT / "_tmp_chunk.wav").unlink(missing_ok=True)
    del target_full, target_chunks
    gc.collect()

print("\n=== DONE ===")
