#!/usr/bin/env python3
"""
Modi Voice Data Pipeline for VibeVoice Finetuning (v2 - No Demucs)
===================================================================
Smart trim intro/outro → Resample 24kHz mono → Silero VAD → Segmentation (5-30s)

Fully resumable: tracks progress per file in progress.json.
"""

import json
import os
import subprocess
import time
import csv
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

import numpy as np
import soundfile as sf
import torch
import torchaudio

# ─── CONFIG ───
SRC_DIR = Path("/home/ubuntu/modi")
OUT_DIR = Path("/home/ubuntu/modi_processed")
SEGMENTS_DIR = OUT_DIR / "segments"
TRIMMED_DIR = OUT_DIR / "trimmed_24k"
PROGRESS_FILE = OUT_DIR / "progress.json"
MANIFEST_FILE = OUT_DIR / "manifest.csv"

TARGET_SR = 24000
TRIM_HEAD_SEC = 120.0   # cut ~2 min from start
TRIM_TAIL_SEC = 120.0   # cut ~2 min from end
SILENCE_SEARCH_SEC = 10  # search window around cut point to find silence
MIN_SEGMENT_SEC = 5.0
MAX_SEGMENT_SEC = 30.0
MERGE_THRESHOLD_SEC = 2.0
SILENCE_PAD_SEC = 0.15

WORKERS = 4  # for resample (CPU, fast)

# ─── PROGRESS TRACKING ───

def load_progress():
    if PROGRESS_FILE.exists():
        return json.loads(PROGRESS_FILE.read_text())
    return {}

def save_progress(progress):
    PROGRESS_FILE.write_text(json.dumps(progress, indent=2))

def file_stage(progress, name):
    return progress.get(name, {}).get("stage", "pending")

def mark_stage(progress, name, stage, **extra):
    if name not in progress:
        progress[name] = {}
    progress[name]["stage"] = stage
    progress[name]["updated"] = time.strftime("%Y-%m-%d %H:%M:%S")
    progress[name].update(extra)
    save_progress(progress)

# ─── STEP 1: SMART TRIM + RESAMPLE 24kHz MONO ───

def find_silence_near(audio, sr, target_sec, search_sec=10, direction="after"):
    """Find the best silence point near target_sec.
    direction='after': search from target_sec forward (for head trim)
    direction='before': search from target_sec backward (for tail trim)
    """
    target_sample = int(target_sec * sr)
    search_samples = int(search_sec * sr)
    window = int(0.05 * sr)  # 50ms energy window

    if direction == "after":
        start = max(0, target_sample - search_samples // 2)
        end = min(len(audio), target_sample + search_samples // 2)
    else:
        start = max(0, target_sample - search_samples // 2)
        end = min(len(audio), target_sample + search_samples // 2)

    if start >= end or end > len(audio):
        return target_sample

    best_pos = target_sample
    min_energy = float('inf')

    for pos in range(start, end - window, window // 2):
        chunk = audio[pos:pos + window]
        energy = float(np.mean(chunk ** 2))
        if energy < min_energy:
            min_energy = energy
            best_pos = pos

    return best_pos


def trim_and_resample(src_wav: Path, out_wav: Path):
    """Trim intro/outro at silence boundaries, resample to 24kHz mono."""
    if out_wav.exists() and out_wav.stat().st_size > 1000:
        return True, "already exists"

    try:
        info = sf.info(str(src_wav))
        orig_sr = info.samplerate
        duration = info.duration

        if duration < TRIM_HEAD_SEC + TRIM_TAIL_SEC + 60:
            # File too short to trim 2+2 min, just trim 30s each side
            head_sec = 30.0
            tail_sec = 30.0
        else:
            head_sec = TRIM_HEAD_SEC
            tail_sec = TRIM_TAIL_SEC

        # Load full audio (mono)
        audio, sr = sf.read(str(src_wav), dtype='float32')
        if audio.ndim == 2:
            audio = audio.mean(axis=1)

        # Find silence near head cut point
        head_cut = find_silence_near(audio, sr, head_sec, SILENCE_SEARCH_SEC, "after")

        # Find silence near tail cut point
        tail_target = len(audio) / sr - tail_sec
        tail_cut = find_silence_near(audio, sr, tail_target, SILENCE_SEARCH_SEC, "before")

        if tail_cut <= head_cut:
            tail_cut = len(audio)

        trimmed = audio[head_cut:tail_cut]
        trimmed_dur = len(trimmed) / sr

        # Resample to 24kHz
        trimmed_tensor = torch.from_numpy(trimmed).unsqueeze(0).float()
        if sr != TARGET_SR:
            trimmed_tensor = torchaudio.functional.resample(trimmed_tensor, sr, TARGET_SR)

        trimmed_np = trimmed_tensor.squeeze(0).numpy()

        out_wav.parent.mkdir(parents=True, exist_ok=True)
        sf.write(str(out_wav), trimmed_np, TARGET_SR)

        return True, f"{duration:.0f}s -> {trimmed_dur:.0f}s (cut {head_cut/sr:.1f}s head, {(len(audio)-tail_cut)/sr:.1f}s tail)"

    except Exception as e:
        return False, str(e)


# ─── STEP 2: SILERO VAD + SEGMENTATION ───

_vad_model = None
_vad_utils = None

def get_vad():
    global _vad_model, _vad_utils
    if _vad_model is None:
        model, utils = torch.hub.load(
            repo_or_dir='snakers4/silero-vad',
            model='silero_vad',
            trust_repo=True
        )
        _vad_model = model
        _vad_utils = utils
    return _vad_model, _vad_utils


def vad_segment(wav_path: Path, out_dir: Path):
    """Run Silero VAD, merge nearby speech chunks, split into 5-30s segments."""
    out_dir.mkdir(parents=True, exist_ok=True)

    model, utils = get_vad()
    get_speech_timestamps = utils[0]

    audio_24k, sr = sf.read(str(wav_path), dtype='float32')
    total_samples = len(audio_24k)

    # VAD needs 16kHz
    audio_tensor = torch.from_numpy(audio_24k).float()
    audio_16k = torchaudio.functional.resample(audio_tensor, sr, 16000)

    speech_timestamps = get_speech_timestamps(
        audio_16k, model,
        threshold=0.35,
        min_speech_duration_ms=300,
        min_silence_duration_ms=200,
        speech_pad_ms=100,
        return_seconds=False,
        sampling_rate=16000
    )

    if not speech_timestamps:
        return []

    # Convert 16kHz indices to 24kHz
    ratio = TARGET_SR / 16000.0
    chunks = [(int(ts['start'] * ratio), int(ts['end'] * ratio)) for ts in speech_timestamps]

    # Merge close chunks
    merged = [chunks[0]]
    merge_gap = int(MERGE_THRESHOLD_SEC * TARGET_SR)
    for start, end in chunks[1:]:
        prev_start, prev_end = merged[-1]
        if start - prev_end < merge_gap:
            merged[-1] = (prev_start, end)
        else:
            merged.append((start, end))

    # Split into 5-30s segments
    pad = int(SILENCE_PAD_SEC * TARGET_SR)
    segments = []

    for chunk_start, chunk_end in merged:
        chunk_dur = (chunk_end - chunk_start) / TARGET_SR

        if chunk_dur < MIN_SEGMENT_SEC:
            if segments and (chunk_start - segments[-1][1]) / TARGET_SR < MERGE_THRESHOLD_SEC:
                prev = segments[-1]
                if (chunk_end - prev[0]) / TARGET_SR <= MAX_SEGMENT_SEC:
                    segments[-1] = (prev[0], chunk_end)
                    continue
            if chunk_dur >= 1.0:
                segments.append((chunk_start, chunk_end))
            continue

        if chunk_dur <= MAX_SEGMENT_SEC:
            segments.append((chunk_start, chunk_end))
        else:
            pos = chunk_start
            while pos < chunk_end:
                remaining = (chunk_end - pos) / TARGET_SR
                if remaining <= MAX_SEGMENT_SEC:
                    segments.append((pos, chunk_end))
                    break

                target_end = pos + int(MAX_SEGMENT_SEC * TARGET_SR)
                search_start = pos + int((MAX_SEGMENT_SEC - 3.0) * TARGET_SR)
                search_end = min(target_end + int(1.0 * TARGET_SR), chunk_end)

                best_split = target_end
                if search_start < search_end and search_end <= total_samples:
                    window = int(0.05 * TARGET_SR)
                    min_energy = float('inf')
                    for sp in range(search_start, search_end, window):
                        end_sp = min(sp + window, total_samples)
                        energy = np.mean(audio_24k[sp:end_sp] ** 2)
                        if energy < min_energy:
                            min_energy = energy
                            best_split = sp

                segments.append((pos, best_split))
                pos = best_split

    # Post-merge short segments
    final_segments = []
    for seg in segments:
        if final_segments:
            prev = final_segments[-1]
            prev_dur = (prev[1] - prev[0]) / TARGET_SR
            this_dur = (seg[1] - seg[0]) / TARGET_SR
            gap = (seg[0] - prev[1]) / TARGET_SR
            combined_dur = (seg[1] - prev[0]) / TARGET_SR

            if this_dur < MIN_SEGMENT_SEC and combined_dur <= MAX_SEGMENT_SEC and gap < 3.0:
                final_segments[-1] = (prev[0], seg[1])
                continue
            if prev_dur < MIN_SEGMENT_SEC and combined_dur <= MAX_SEGMENT_SEC and gap < 3.0:
                final_segments[-1] = (prev[0], seg[1])
                continue

        final_segments.append(seg)

    final_segments = [s for s in final_segments if (s[1] - s[0]) / TARGET_SR >= 2.0]

    # Write segments
    written = []
    for i, (start, end) in enumerate(final_segments):
        s = max(0, start - pad)
        e = min(total_samples, end + pad)
        segment_audio = audio_24k[s:e]

        peak = np.max(np.abs(segment_audio))
        if peak > 0.01:
            segment_audio = segment_audio * (0.9 / peak)

        seg_name = f"seg_{i:04d}.wav"
        seg_path = out_dir / seg_name
        sf.write(str(seg_path), segment_audio, TARGET_SR)

        duration = len(segment_audio) / TARGET_SR
        written.append({
            "filename": seg_name,
            "duration": round(duration, 2),
        })

    return written


# ─── MANIFEST ───

def build_manifest(progress):
    rows = []
    for name, info in sorted(progress.items()):
        if info.get("stage") != "segmented":
            continue
        seg_dir = SEGMENTS_DIR / name
        if not seg_dir.exists():
            continue
        for wav in sorted(seg_dir.glob("seg_*.wav")):
            try:
                audio, sr = sf.read(str(wav), dtype='float32')
                dur = len(audio) / sr
            except Exception:
                dur = 0
            rows.append({
                "source": name,
                "filename": str(wav),
                "duration": round(dur, 2),
            })

    with open(MANIFEST_FILE, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=["source", "filename", "duration"])
        writer.writeheader()
        writer.writerows(rows)

    total_dur = sum(r["duration"] for r in rows)
    print(f"\nManifest: {len(rows)} segments, {total_dur:.0f}s ({total_dur/3600:.1f}hrs) -> {MANIFEST_FILE}")
    return rows


# ─── MAIN ───

def get_source_files():
    files = []
    for f in sorted(SRC_DIR.glob("*.wav")):
        if "_demucs" in f.name or "_16k" in f.name:
            continue
        files.append(f)
    return files


def main():
    for d in [SEGMENTS_DIR, TRIMMED_DIR]:
        d.mkdir(parents=True, exist_ok=True)

    progress = load_progress()
    source_files = get_source_files()

    already_done = sum(1 for f in source_files if file_stage(progress, f.stem) == "segmented")

    print(f"{'=' * 60}")
    print(f"Modi Voice Pipeline v2 (no Demucs)")
    print(f"{'=' * 60}")
    print(f"Source: {SRC_DIR} ({len(source_files)} files)")
    print(f"Output: {OUT_DIR}")
    print(f"Progress: {already_done}/{len(source_files)} done")
    print(f"{'=' * 60}", flush=True)

    t_start = time.time()

    # Phase 1: Trim + Resample (parallel, CPU-bound)
    need_trim = [f for f in source_files if file_stage(progress, f.stem) in ("pending",)]
    if need_trim:
        print(f"\n--- Trim + Resample: {len(need_trim)} files, {WORKERS} parallel ---", flush=True)

        def _trim_one(src):
            name = src.stem
            out = TRIMMED_DIR / f"{name}.wav"
            t0 = time.time()
            ok, msg = trim_and_resample(src, out)
            elapsed = time.time() - t0
            return name, ok, msg, elapsed

        with ThreadPoolExecutor(max_workers=WORKERS) as pool:
            futures = {pool.submit(_trim_one, f): f.stem for f in need_trim}
            done = 0
            for future in as_completed(futures):
                done += 1
                name, ok, msg, elapsed = future.result()
                if ok:
                    mark_stage(progress, name, "trimmed", trim_info=msg, trim_time=round(elapsed))
                    print(f"  [{done}/{len(need_trim)}] OK {name}: {msg} ({elapsed:.0f}s)", flush=True)
                else:
                    mark_stage(progress, name, "trim_failed", error=msg)
                    print(f"  [{done}/{len(need_trim)}] FAIL {name}: {msg}", flush=True)

    # Phase 2: VAD + Segmentation (sequential, uses GPU for VAD)
    need_vad = [f for f in source_files if file_stage(progress, f.stem) == "trimmed"]
    if need_vad:
        print(f"\n--- VAD + Segmentation: {len(need_vad)} files ---", flush=True)

        for i, src in enumerate(need_vad, 1):
            name = src.stem
            trimmed_wav = TRIMMED_DIR / f"{name}.wav"
            seg_dir = SEGMENTS_DIR / name

            if not trimmed_wav.exists():
                mark_stage(progress, name, "trim_failed", error="trimmed file missing")
                continue

            try:
                t0 = time.time()
                seg_info = vad_segment(trimmed_wav, seg_dir)
                elapsed = time.time() - t0
                total_dur = sum(s["duration"] for s in seg_info)
                mark_stage(progress, name, "segmented",
                           num_segments=len(seg_info),
                           total_duration=round(total_dur, 1),
                           vad_time=round(elapsed))
                print(f"  [{i}/{len(need_vad)}] {name}: {len(seg_info)} segs, {total_dur:.0f}s ({elapsed:.0f}s)", flush=True)
            except Exception as e:
                mark_stage(progress, name, "error", error=str(e))
                print(f"  [{i}/{len(need_vad)}] ERR {name}: {e}", flush=True)

    # Build manifest
    print(f"\n{'=' * 60}")
    print("Building manifest...", flush=True)
    rows = build_manifest(progress)

    elapsed = time.time() - t_start
    done_count = sum(1 for v in progress.values() if v.get("stage") == "segmented")
    failed_count = sum(1 for v in progress.values() if "fail" in v.get("stage", "") or v.get("stage") == "error")
    total_hrs = sum(v.get("total_duration", 0) for v in progress.values() if v.get("stage") == "segmented") / 3600
    print(f"\nDone: {done_count}/{len(source_files)} files, {failed_count} failed")
    print(f"Total speech: {total_hrs:.1f} hours")
    print(f"Time: {elapsed/60:.1f} min")


if __name__ == "__main__":
    main()
