import torch
import torchaudio
import numpy as np
import soundfile as sf
from asteroid.models import ConvTasNet
from pathlib import Path
import time
import gc

OUT = Path("/home/ubuntu/sam_audio_test/convtasnet_16k_out")
OUT.mkdir(exist_ok=True)

AUDIO_FILE = "/home/ubuntu/sam_audio_test/speech_target.wav"  # SAM-Audio denoised
CHUNK_SEC = 10
OVERLAP_SEC = 2
GAP_SEC = 0.15

print("=== Loading Conv-TasNet 16kHz ===")
t0 = time.time()
import asteroid.models.base_models as _bm
_orig_load = torch.load
def _patched_load(*a, **kw):
    kw["weights_only"] = False
    return _orig_load(*a, **kw)
_bm.torch.load = _patched_load
model = ConvTasNet.from_pretrained("JorisCos/ConvTasNet_Libri2Mix_sepclean_16k")
_bm.torch.load = _orig_load
model = model.eval().cuda()
model_sr = 16000
print(f"Loaded in {time.time()-t0:.1f}s, expects {model_sr}Hz")

audio, orig_sr = sf.read(AUDIO_FILE)
if audio.ndim == 2:
    audio = audio.mean(axis=1)

audio_t = torch.tensor(audio, dtype=torch.float32)
audio_16k = torchaudio.functional.resample(audio_t, orig_sr, model_sr).numpy()
total_samples = len(audio_16k)
total_sec = total_samples / model_sr
print(f"Audio: {total_sec:.1f}s at {model_sr}Hz")

chunk_samples = CHUNK_SEC * model_sr
overlap_samples = OVERLAP_SEC * model_sr
step_samples = chunk_samples - overlap_samples

n_sources = 2
source_streams = [np.zeros(total_samples) for _ in range(n_sources)]
weight_map = np.zeros(total_samples)

pos = 0
chunk_idx = 0
t_start = time.time()

while pos < total_samples:
    end = min(pos + chunk_samples, total_samples)
    chunk = audio_16k[pos:end]
    if len(chunk) < model_sr:
        break

    chunk_tensor = torch.tensor(chunk, dtype=torch.float32).unsqueeze(0).cuda()

    with torch.inference_mode():
        est = model(chunk_tensor)

    fade = min(overlap_samples, len(chunk))
    window = np.ones(len(chunk))
    if pos > 0:
        window[:fade] = np.linspace(0, 1, fade)
    if end < total_samples:
        window[-fade:] = np.linspace(1, 0, fade)

    for src_idx in range(n_sources):
        src = est[0, src_idx].cpu().numpy()[:len(chunk)]
        source_streams[src_idx][pos:pos+len(chunk)] += src * window

    weight_map[pos:pos+len(chunk)] += window
    chunk_idx += 1

    if chunk_idx % 10 == 0:
        elapsed = time.time() - t_start
        pct = end / total_samples * 100
        print(f"  Chunk {chunk_idx} ({pos/model_sr:.0f}-{end/model_sr:.0f}s) [{pct:.0f}%] [{elapsed:.1f}s]")

    del chunk_tensor, est
    torch.cuda.empty_cache()
    pos += step_samples

elapsed = time.time() - t_start
print(f"  Separation done: {chunk_idx} chunks in {elapsed:.1f}s")

for src_idx in range(n_sources):
    mask = weight_map > 0
    source_streams[src_idx][mask] /= weight_map[mask]

energies = [np.sum(s**2) for s in source_streams]
order = sorted(range(n_sources), key=lambda i: energies[i], reverse=True)

for rank, src_idx in enumerate(order):
    src = source_streams[src_idx]
    peak = np.max(np.abs(src))
    if peak > 0:
        src = src / peak * 0.9

    # Save at native 16kHz
    sf.write(str(OUT / f"speaker_{rank}_full_16k.wav"), src, model_sr)

    # Compact version
    threshold = 10 ** (-35 / 20)
    frame_len = int(0.02 * model_sr)
    hop = frame_len // 2
    n_frames = (len(src) - frame_len) // hop + 1
    rms = np.array([np.sqrt(np.mean(src[j*hop:j*hop+frame_len]**2)) for j in range(n_frames)])
    is_speech = rms > threshold
    min_sil = int(0.3 / (hop / model_sr))
    min_sp = int(0.15 / (hop / model_sr))
    for j in range(len(is_speech)):
        if not is_speech[j]:
            start = j
            while j < len(is_speech) and not is_speech[j]: j += 1
            if (j - start) < min_sil: is_speech[start:j] = True
    j = 0
    while j < len(is_speech):
        if is_speech[j]:
            start = j
            while j < len(is_speech) and is_speech[j]: j += 1
            if (j - start) < min_sp: is_speech[start:j] = False
        else: j += 1
    segs = []
    j = 0
    while j < len(is_speech):
        if is_speech[j]:
            start = j
            while j < len(is_speech) and is_speech[j]: j += 1
            ss = max(0, start*hop - int(0.02*model_sr))
            se = min(len(src), j*hop + int(0.02*model_sr))
            segs.append(src[ss:se])
        else: j += 1
    gap = np.zeros(int(GAP_SEC * model_sr))
    parts = []
    for idx, seg in enumerate(segs):
        parts.append(seg)
        if idx < len(segs) - 1: parts.append(gap)
    compact = np.concatenate(parts) if parts else np.zeros(1)

    sf.write(str(OUT / f"speaker_{rank}_compact_16k.wav"), compact, model_sr)
    print(f"  Speaker {rank}: full={len(src)/model_sr:.1f}s | compact={len(compact)/model_sr:.1f}s ({len(segs)} segs)")

print(f"\nOutput: {OUT}")
for f in sorted(OUT.glob("*.wav")):
    sz = f.stat().st_size / 1024 / 1024
    print(f"  {f.name}: {sz:.1f}MB")
print("=== DONE ===")
