#!/usr/bin/env python3
"""Run Demucs (htdemucs) vocal separation on all downloaded YouTube audios."""

import os
import subprocess
import glob
import time
from concurrent.futures import ProcessPoolExecutor, as_completed

BASE = "/home/ubuntu/Speech_maker_pipeline/pawan_kalyan"
MAX_PARALLEL = 3  # Demucs is GPU-heavy; limit concurrency


def find_all_inputs():
    folders = sorted(glob.glob(os.path.join(BASE, "*")))
    inputs = []
    for folder in folders:
        vid = os.path.basename(folder)
        src = glob.glob(os.path.join(folder, "yt_downloaded_256kbps.*"))
        if not src:
            print(f"[SKIP] {vid}: no source audio found")
            continue
        src = src[0]
        ext = os.path.splitext(src)[1]
        dst = os.path.join(folder, f"yt_downloaded_256kbps_demucs.wav")
        if os.path.exists(dst) and os.path.getsize(dst) > 1000:
            print(f"[SKIP] {vid}: demucs output already exists ({os.path.getsize(dst)/(1024*1024):.1f}MB)")
            continue
        inputs.append((vid, src, folder, dst))
    return inputs


def run_demucs(args):
    vid, src, folder, dst = args
    tmp_dir = os.path.join(folder, "demucs_tmp")
    os.makedirs(tmp_dir, exist_ok=True)
    t0 = time.time()
    try:
        cmd = [
            "python3", "-m", "demucs",
            "--two-stems", "vocals",
            "-n", "htdemucs",
            "-o", tmp_dir,
            src,
        ]
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800)
        if result.returncode != 0:
            return vid, False, f"demucs failed: {result.stderr[-500:]}"

        vocals_candidates = glob.glob(os.path.join(tmp_dir, "htdemucs", "*", "vocals.wav"))
        if not vocals_candidates:
            vocals_candidates = glob.glob(os.path.join(tmp_dir, "**", "vocals.wav"), recursive=True)
        if not vocals_candidates:
            return vid, False, f"vocals.wav not found in {tmp_dir}"

        vocals_path = vocals_candidates[0]
        os.rename(vocals_path, dst)

        # cleanup tmp
        subprocess.run(["rm", "-rf", tmp_dir], capture_output=True)

        elapsed = time.time() - t0
        size_mb = os.path.getsize(dst) / (1024 * 1024)
        return vid, True, f"{size_mb:.1f}MB in {elapsed:.0f}s"
    except subprocess.TimeoutExpired:
        return vid, False, "timeout (30min)"
    except Exception as e:
        return vid, False, str(e)


def main():
    inputs = find_all_inputs()
    print(f"\nRunning Demucs on {len(inputs)} videos (max {MAX_PARALLEL} parallel)...\n")

    if not inputs:
        print("Nothing to do.")
        return

    results = {}
    with ProcessPoolExecutor(max_workers=MAX_PARALLEL) as pool:
        futures = {pool.submit(run_demucs, inp): inp[0] for inp in inputs}
        for future in as_completed(futures):
            vid = futures[future]
            vid_id, ok, msg = future.result()
            status = "OK" if ok else "FAIL"
            results[vid_id] = (ok, msg)
            print(f"  [{status}] {vid_id}: {msg}")

    print(f"\n{'='*60}")
    ok_count = sum(1 for v in results.values() if v[0])
    total = len(results)
    # include skipped (already done)
    all_done = glob.glob(os.path.join(BASE, "*", "yt_downloaded_256kbps_demucs.wav"))
    print(f"Done: {ok_count}/{total} new + {len(all_done) - ok_count} already existed = {len(all_done)} total demucs outputs")
    for vid_id in sorted(results):
        ok, msg = results[vid_id]
        print(f"  [{'OK' if ok else 'FAIL'}] {vid_id}: {msg}")


if __name__ == "__main__":
    main()
