import subprocess
import numpy as np
import soundfile as sf
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
import time
import os

MODI = Path("/home/ubuntu/modi")
GAP_SEC = 0.2
SILENCE_THRESHOLD_DB = -40
MAX_SILENCE_SEC = 0.5

def process_one(wav_path):
    name = wav_path.stem
    out_name = f"{name}_demucs.wav"
    out_path = MODI / out_name

    if out_path.exists():
        return f"SKIP {name}"

    # Run demucs - extract vocals
    try:
        result = subprocess.run(
            ["demucs", "--two-stems", "vocals", "-n", "htdemucs",
             "--out", "/tmp/demucs_tmp", str(wav_path)],
            capture_output=True, text=True, timeout=600
        )
    except Exception as e:
        return f"FAIL demucs {name}: {e}"

    vocals_path = Path(f"/tmp/demucs_tmp/htdemucs/{name}/vocals.wav")
    if not vocals_path.exists():
        return f"FAIL no vocals {name}"

    # Load vocals
    audio, sr = sf.read(str(vocals_path))
    if audio.ndim == 2:
        mono = audio.mean(axis=1)
    else:
        mono = audio

    # Compress long silences
    threshold = 10 ** (SILENCE_THRESHOLD_DB / 20)
    frame_len = int(0.02 * sr)
    hop = frame_len // 2
    n_frames = (len(mono) - frame_len) // hop + 1

    rms = np.array([
        np.sqrt(np.mean(mono[j * hop:j * hop + frame_len] ** 2))
        for j in range(n_frames)
    ])
    is_speech = rms > threshold

    # Build segments: speech regions and silence regions
    max_silence_samples = int(MAX_SILENCE_SEC * sr)
    gap_samples = int(GAP_SEC * sr)

    parts = []
    i = 0
    while i < len(is_speech):
        if is_speech[i]:
            # Speech region - keep as is
            start = i
            while i < len(is_speech) and is_speech[i]:
                i += 1
            s_sample = max(0, start * hop - int(0.01 * sr))
            e_sample = min(len(audio), i * hop + int(0.01 * sr))
            parts.append(audio[s_sample:e_sample])
        else:
            # Silence region - compress if too long
            start = i
            while i < len(is_speech) and not is_speech[i]:
                i += 1
            silence_samples = (i - start) * hop

            if silence_samples > max_silence_samples:
                # Replace long silence with short gap
                if audio.ndim == 2:
                    parts.append(np.zeros((gap_samples, audio.shape[1])))
                else:
                    parts.append(np.zeros(gap_samples))
            else:
                # Keep short silence as is
                s_sample = start * hop
                e_sample = min(len(audio), i * hop)
                parts.append(audio[s_sample:e_sample])

    if not parts:
        return f"FAIL empty {name}"

    compressed = np.concatenate(parts, axis=0)

    # Normalize
    peak = np.max(np.abs(compressed))
    if peak > 0:
        compressed = compressed * (0.9 / peak)

    sf.write(str(out_path), compressed, sr)

    # Cleanup demucs temp
    import shutil
    shutil.rmtree(f"/tmp/demucs_tmp/htdemucs/{name}", ignore_errors=True)

    orig_dur = len(audio) / sr
    new_dur = len(compressed) / sr
    return f"OK {name}: {orig_dur:.0f}s -> {new_dur:.0f}s ({(1-new_dur/orig_dur)*100:.0f}% reduced)"


if __name__ == "__main__":
    wavs = sorted([f for f in MODI.glob("*.wav") if "_demucs" not in f.name])
    print(f"=== Processing {len(wavs)} files with demucs + silence compression ===\n")

    # Process 2 at a time (demucs is GPU heavy)
    WORKERS = 2
    t0 = time.time()
    done = 0

    with ThreadPoolExecutor(max_workers=WORKERS) as pool:
        futures = {pool.submit(process_one, w): w.name for w in wavs}
        for future in as_completed(futures):
            done += 1
            result = future.result()
            elapsed = time.time() - t0
            print(f"  [{done}/{len(wavs)}] {result} [{elapsed:.0f}s]")

    elapsed = time.time() - t0
    print(f"\n=== DONE in {elapsed:.0f}s ({elapsed/60:.1f}min) ===")
