import torch
import torchaudio
from sam_audio import SAMAudio, SAMAudioProcessor
from sam_audio.ranking import create_ranker
from pathlib import Path
import time
import gc

OUT = Path("/home/ubuntu/sam_audio_test")
OUT.mkdir(exist_ok=True)

AUDIO_FILE = "/home/ubuntu/clip_00.mp3"
CHUNK_SEC = 30
MODEL_NAME = "facebook/sam-audio-large"

# Monkey-patch to skip loading the visual ranker (ImageBind) - not needed for text prompting
_orig_init = SAMAudio.__init__
def _patched_init(self, cfg):
    saved = (cfg.visual_ranker, cfg.text_ranker)
    cfg.visual_ranker = None
    cfg.text_ranker = None
    _orig_init(self, cfg)
    cfg.visual_ranker, cfg.text_ranker = saved
SAMAudio.__init__ = _patched_init

print(f"=== Loading SAM-Audio ({MODEL_NAME}) - no visual ranker ===")
t0 = time.time()
model = SAMAudio.from_pretrained(MODEL_NAME)
processor = SAMAudioProcessor.from_pretrained(MODEL_NAME)
model = model.eval().cuda()
sr = processor.audio_sampling_rate
print(f"Model loaded in {time.time()-t0:.1f}s, target sr={sr}")

nvidia_output = torch.cuda.mem_get_info()
print(f"GPU memory: {nvidia_output[1]/1e9:.1f}GB total, {nvidia_output[0]/1e9:.1f}GB free")

waveform, orig_sr = torchaudio.load(AUDIO_FILE)
if orig_sr != sr:
    waveform = torchaudio.functional.resample(waveform, orig_sr, sr)
if waveform.shape[0] > 1:
    waveform = waveform.mean(dim=0, keepdim=True)

total_samples = waveform.shape[1]
total_sec = total_samples / sr
chunk_samples = CHUNK_SEC * sr

n_chunks = (total_samples + chunk_samples - 1) // chunk_samples
print(f"Audio: {total_sec:.1f}s at {sr}Hz -> {n_chunks} chunks of {CHUNK_SEC}s")

prompts = [
    ("speech", "speech"),
    ("man_speaking", "man speaking"),
    ("woman_speaking", "woman speaking"),
]

for label, description in prompts:
    print(f"\n=== Separating: '{description}' ===")
    t_start = time.time()

    target_chunks = []
    residual_chunks = []

    for i in range(n_chunks):
        s = i * chunk_samples
        e = min(s + chunk_samples, total_samples)
        chunk = waveform[:, s:e]

        chunk_path = OUT / "_tmp_chunk.wav"
        torchaudio.save(str(chunk_path), chunk, sr)

        batch = processor(
            audios=[str(chunk_path)],
            descriptions=[description],
        ).to("cuda")

        with torch.inference_mode():
            result = model.separate(batch, predict_spans=False, reranking_candidates=1)

        target_chunks.append(result.target[0].cpu())
        residual_chunks.append(result.residual[0].cpu())

        elapsed = time.time() - t_start
        print(f"  Chunk {i+1}/{n_chunks} ({s/sr:.0f}-{e/sr:.0f}s) [{elapsed:.1f}s]")

        del batch, result
        torch.cuda.empty_cache()
        gc.collect()

    target_full = torch.cat(target_chunks, dim=-1)
    residual_full = torch.cat(residual_chunks, dim=-1)

    if target_full.dim() == 1:
        target_full = target_full.unsqueeze(0)
    if residual_full.dim() == 1:
        residual_full = residual_full.unsqueeze(0)

    torchaudio.save(str(OUT / f"{label}_target.wav"), target_full, sr)
    torchaudio.save(str(OUT / f"{label}_residual.wav"), residual_full, sr)

    print(f"  Saved {label}: target={target_full.shape[-1]/sr:.1f}s, residual={residual_full.shape[-1]/sr:.1f}s [{time.time()-t_start:.1f}s]")

    (OUT / "_tmp_chunk.wav").unlink(missing_ok=True)
    del target_full, residual_full, target_chunks, residual_chunks
    gc.collect()

print(f"\n=== All outputs in {OUT} ===")
for f in sorted(OUT.glob("*.wav")):
    info = torchaudio.info(str(f))
    dur = info.num_frames / info.sample_rate
    size_mb = f.stat().st_size / 1024 / 1024
    print(f"  {f.name}: {dur:.1f}s, {size_mb:.1f}MB")

print("\n=== COMPLETE ===")
