#!/usr/bin/env python3
"""Chunked Demucs: convert to WAV, split into 10-min overlapping chunks, run Demucs, stitch."""

import os
import sys
import subprocess
import glob
import time
import shutil
import numpy as np
import soundfile as sf
from concurrent.futures import ThreadPoolExecutor, as_completed

CHUNK_DURATION = 600    # 10 minutes in seconds
OVERLAP = 5             # 5 seconds overlap on each side
PARALLEL_CHUNKS = 6

def flush_print(*args, **kwargs):
    print(*args, **kwargs, flush=True)


def get_duration(path):
    result = subprocess.run(
        ["ffprobe", "-v", "quiet", "-show_entries", "format=duration",
         "-of", "default=noprint_wrappers=1:nokey=1", path],
        capture_output=True, text=True
    )
    return float(result.stdout.strip())


def convert_to_wav(src, wav_path):
    """One-pass conversion to WAV 44100Hz stereo."""
    subprocess.run([
        "ffmpeg", "-y", "-i", src, "-ar", "44100", "-ac", "2", wav_path
    ], capture_output=True, timeout=300, check=True)


def split_from_wav(wav_path, tmp_dir, duration):
    """Split WAV into overlapping chunks — fast since WAV supports random access."""
    audio, sr = sf.read(wav_path, dtype="float32")
    total_samples = len(audio)
    chunk_samples = CHUNK_DURATION * sr
    overlap_samples = OVERLAP * sr

    chunks = []
    t = 0
    idx = 0
    while t < total_samples:
        start = max(0, t - overlap_samples)
        end = min(total_samples, t + chunk_samples + overlap_samples)
        chunk_path = os.path.join(tmp_dir, f"chunk_{idx:03d}.wav")
        sf.write(chunk_path, audio[start:end], sr)
        chunks.append({
            "idx": idx,
            "path": chunk_path,
            "start_sample": start,
            "end_sample": end,
            "orig_start_sample": t,
            "orig_end_sample": min(total_samples, t + chunk_samples),
        })
        t += chunk_samples
        idx += 1
    return chunks, sr, audio.shape


def run_demucs_on_chunk(chunk, demucs_out_dir):
    """Run Demucs on a single chunk."""
    t0 = time.time()
    cmd = [
        "python3", "-m", "demucs",
        "--two-stems", "vocals",
        "-n", "htdemucs",
        "-o", demucs_out_dir,
        chunk["path"],
    ]
    result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
    elapsed = time.time() - t0
    if result.returncode != 0:
        return chunk["idx"], False, f"failed in {elapsed:.0f}s: {result.stderr[-200:]}"

    vocal_name = os.path.splitext(os.path.basename(chunk["path"]))[0]
    vocal_path = os.path.join(demucs_out_dir, "htdemucs", vocal_name, "vocals.wav")
    if not os.path.exists(vocal_path):
        return chunk["idx"], False, "vocals.wav not found"

    chunk["vocals_path"] = vocal_path
    return chunk["idx"], True, f"ok in {elapsed:.0f}s"


def stitch_chunks(chunks, sr, shape, dst):
    """Stitch vocal chunks back, using the non-overlapping center of each chunk."""
    total_samples = shape[0]
    channels = shape[1] if len(shape) > 1 else 1
    if channels > 1:
        output = np.zeros((total_samples, channels), dtype=np.float32)
    else:
        output = np.zeros(total_samples, dtype=np.float32)

    for chunk in sorted(chunks, key=lambda c: c["idx"]):
        if "vocals_path" not in chunk:
            continue

        audio, _ = sf.read(chunk["vocals_path"], dtype="float32")

        offset_in_chunk = chunk["orig_start_sample"] - chunk["start_sample"]
        n_samples = chunk["orig_end_sample"] - chunk["orig_start_sample"]
        n_samples = min(n_samples, len(audio) - offset_in_chunk)
        dst_start = chunk["orig_start_sample"]

        if n_samples > 0:
            output[dst_start:dst_start + n_samples] = audio[offset_in_chunk:offset_in_chunk + n_samples]

    sf.write(dst, output, sr)


def process_one_movie(src, dst):
    vid = os.path.basename(os.path.dirname(src))
    folder = os.path.dirname(src)
    tmp_dir = os.path.join(folder, "demucs_chunks_tmp")

    if os.path.exists(tmp_dir):
        shutil.rmtree(tmp_dir)
    os.makedirs(tmp_dir)
    demucs_out = os.path.join(tmp_dir, "demucs_out")
    os.makedirs(demucs_out)

    t0 = time.time()

    # Step 1: Convert to WAV (one pass)
    flush_print(f"[{vid}] Converting to WAV...")
    wav_path = os.path.join(tmp_dir, "full.wav")
    t1 = time.time()
    convert_to_wav(src, wav_path)
    flush_print(f"[{vid}] Converted in {time.time()-t1:.1f}s")

    # Step 2: Split from WAV (fast, in-memory)
    flush_print(f"[{vid}] Splitting into {CHUNK_DURATION}s chunks with {OVERLAP}s overlap...")
    t2 = time.time()
    chunks, sr, shape = split_from_wav(wav_path, tmp_dir, 0)
    flush_print(f"[{vid}] Split into {len(chunks)} chunks in {time.time()-t2:.1f}s")

    # Remove full WAV to free disk
    os.remove(wav_path)

    # Step 3: Demucs on each chunk in parallel
    flush_print(f"[{vid}] Running Demucs on {len(chunks)} chunks ({PARALLEL_CHUNKS} parallel)...")
    t3 = time.time()
    with ThreadPoolExecutor(max_workers=PARALLEL_CHUNKS) as pool:
        futures = {pool.submit(run_demucs_on_chunk, c, demucs_out): c for c in chunks}
        done_count = 0
        for future in as_completed(futures):
            idx, ok, msg = future.result()
            done_count += 1
            status = "OK" if ok else "FAIL"
            flush_print(f"[{vid}]   chunk {idx:03d} [{status}] {msg}  ({done_count}/{len(chunks)})")
    flush_print(f"[{vid}] All Demucs done in {time.time()-t3:.1f}s")

    # Step 4: Stitch
    flush_print(f"[{vid}] Stitching vocals...")
    t4 = time.time()
    stitch_chunks(chunks, sr, shape, dst)
    flush_print(f"[{vid}] Stitched in {time.time()-t4:.1f}s")

    shutil.rmtree(tmp_dir)

    elapsed = time.time() - t0
    size_mb = os.path.getsize(dst) / (1024 * 1024)
    flush_print(f"[{vid}] DONE: {size_mb:.1f}MB in {elapsed:.0f}s ({elapsed/60:.1f}min)")
    return True


if __name__ == "__main__":
    if len(sys.argv) > 1:
        video_id = sys.argv[1]
        folder = f"/home/ubuntu/Speech_maker_pipeline/pawan_kalyan/{video_id}"
    else:
        folders = sorted(glob.glob("/home/ubuntu/Speech_maker_pipeline/pawan_kalyan/*"))
        folder = folders[0]

    src = glob.glob(os.path.join(folder, "yt_downloaded_256kbps.*"))[0]
    dst = os.path.join(folder, "yt_downloaded_256kbps_demucs.wav")
    flush_print(f"Processing: {src}")
    flush_print(f"Output:     {dst}")
    process_one_movie(src, dst)
