#!/usr/bin/env python3
"""
Noise Cancellation Worker (Demucs)
audio_fetch_done → noise_cancelling → noise_done

Downloads raw audio from R2, runs Demucs vocal isolation on GPU,
uploads cleaned vocals to R2.
"""

import os
import sys
import time
import uuid
import logging
import tempfile
import shutil
import subprocess

from common import (
    get_db, get_s3, claim_jobs, get_job, verify_claim,
    update_job, upload_to_r2, download_from_url, POLL_INTERVAL, BATCH_SIZE,
)

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
log = logging.getLogger("denoiser")

WORKER_ID = os.environ.get("DENOISER_WORKER_ID", f"denoiser-{uuid.uuid4().hex[:8]}")

FROM_STATUS = "audio_fetch_done"
CLAIM_PREFIX = "DENOISE_CLAIMED"
PROCESSING_STATUS = "noise_cancelling"
DONE_STATUS = "noise_done"
FAIL_STATUS = "failed"


def run_demucs(input_path: str, output_dir: str) -> str:
    """Run Demucs vocal isolation. Returns path to vocals.wav."""
    result = subprocess.run(
        ["python3", "-m", "demucs", "--two-stems", "vocals", "-n", "htdemucs",
         "--device", "cuda", "-o", output_dir, input_path],
        capture_output=True, text=True, timeout=600,
    )
    if result.returncode != 0:
        raise RuntimeError(f"Demucs failed: {result.stderr[:300]}")

    basename = os.path.splitext(os.path.basename(input_path))[0]
    vocals_path = os.path.join(output_dir, "htdemucs", basename, "vocals.wav")
    if not os.path.exists(vocals_path):
        raise RuntimeError(f"Demucs output not found at {vocals_path}")
    return vocals_path


def process_job(job_id: str, conn, s3):
    job = get_job(conn, job_id)
    if not job:
        return
    raw_url = job.get("raw_audio_url")
    log.info(f"[{WORKER_ID}] Processing {job_id[:12]}...")

    if not verify_claim(conn, job_id, CLAIM_PREFIX, WORKER_ID):
        log.warning(f"[{WORKER_ID}] Job {job_id[:12]} no longer claimed, skipping")
        return

    if not raw_url:
        update_job(conn, job_id, status=FAIL_STATUS, error_message="No raw_audio_url")
        return

    work_dir = tempfile.mkdtemp(prefix=f"denoise_{job_id[:8]}_")
    try:
        update_job(conn, job_id, status=PROCESSING_STATUS)

        ext = raw_url.rsplit(".", 1)[-1].split("?")[0] if "." in raw_url else "m4a"
        local_audio = os.path.join(work_dir, f"input.{ext}")
        log.info(f"  Downloading raw audio...")
        download_from_url(raw_url, local_audio)
        log.info(f"  Downloaded {os.path.getsize(local_audio)/1024/1024:.1f} MB")

        log.info("  Running Demucs...")
        t0 = time.time()
        vocals_path = run_demucs(local_audio, work_dir)
        elapsed = time.time() - t0
        log.info(f"  Demucs done in {elapsed:.1f}s ({os.path.getsize(vocals_path)/1024/1024:.1f} MB)")

        r2_key = f"{job_id}/noise_cancelled.wav"
        log.info(f"  Uploading to R2: {r2_key}")
        nc_url = upload_to_r2(s3, vocals_path, r2_key)

        update_job(conn, job_id,
                   status=DONE_STATUS,
                   noise_cancelled_audio_url=nc_url)
        log.info(f"  Job {job_id[:12]} → {DONE_STATUS}")

    except Exception as e:
        log.error(f"  Job {job_id[:12]} failed: {e}", exc_info=True)
        update_job(conn, job_id, status=FAIL_STATUS, error_message=str(e)[:500])
    finally:
        shutil.rmtree(work_dir, ignore_errors=True)


def main():
    log.info(f"Denoiser starting | worker={WORKER_ID} batch={BATCH_SIZE} poll={POLL_INTERVAL}s")
    conn = None
    s3 = get_s3()

    while True:
        try:
            if conn is None:
                conn = get_db()
            jobs = claim_jobs(conn, FROM_STATUS, CLAIM_PREFIX, WORKER_ID, BATCH_SIZE)
            if not jobs:
                time.sleep(POLL_INTERVAL)
                continue
            log.info(f"[{WORKER_ID}] Claimed {len(jobs)} jobs")
            for job_id in jobs:
                process_job(job_id, conn, s3)
        except KeyboardInterrupt:
            break
        except Exception as e:
            log.error(f"Loop error: {e}", exc_info=True)
            conn = None
            time.sleep(POLL_INTERVAL)


if __name__ == "__main__":
    main()
