#!/usr/bin/env python3
"""Watch experiment checkpoint directories and upload new files to R2.

Uses file-level state tracking: each .ckpt/.nemo file is tracked by
{name, size, mtime}. New or changed files are uploaded individually with
their own step prefix extracted from the filename.

Usage:
  python3 scripts/auto_upload_checkpoints.py \
    --exp-dir experiments/stage1_prod/maya_asr_stage1 \
    --model-name maya-asr-stage1 \
    --poll-interval 60 \
    --delete-after-upload
"""

import argparse
import hashlib
import json
import re
import subprocess
import sys
import time
from pathlib import Path

from dotenv import load_dotenv

ENV_FILE = Path(__file__).resolve().parent.parent / ".env"
if ENV_FILE.exists():
    load_dotenv(ENV_FILE)

STATE_FILENAME = ".watcher_state.json"


def extract_step_from_file(filepath: Path) -> int:
    """Extract step number from a single checkpoint filename."""
    name = filepath.name
    m = re.search(r"step[=_](\d+)", name)
    if m:
        return int(m.group(1))
    m = re.search(r"epoch[=_](\d+)", name)
    if m:
        return int(m.group(1))
    # Fallback: hash the filename for a stable pseudo-step
    return int(hashlib.md5(name.encode()).hexdigest()[:8], 16) % 100000000


def file_signature(f: Path) -> dict:
    """Return a dict capturing file identity (for change detection)."""
    st = f.stat()
    return {"size": st.st_size, "mtime": st.st_mtime}


def load_state(ckpt_dir: Path) -> dict:
    """Load watcher state for a checkpoint directory."""
    state_file = ckpt_dir / STATE_FILENAME
    if state_file.exists():
        try:
            return json.loads(state_file.read_text())
        except (json.JSONDecodeError, OSError):
            return {}
    return {}


def save_state(ckpt_dir: Path, state: dict):
    """Save watcher state."""
    state_file = ckpt_dir / STATE_FILENAME
    state_file.write_text(json.dumps(state, indent=2))


def find_checkpoint_dirs(exp_dir: Path) -> list[Path]:
    """Find all checkpoint directories under the experiment dir.

    Handles two NeMo layouts:
    1. exp_dir/checkpoints/ (direct — NeMo default with resume_if_exists)
    2. exp_dir/<run_dir>/checkpoints/ (versioned runs)
    """
    ckpt_dirs = []
    if not exp_dir.exists():
        return ckpt_dirs
    # Check direct checkpoints dir first (NeMo default layout)
    direct = exp_dir / "checkpoints"
    if direct.is_dir():
        ckpt_dirs.append(direct)
    # Also check subdirectories for versioned runs
    for run_dir in sorted(exp_dir.iterdir()):
        if not run_dir.is_dir() or run_dir.name == "checkpoints":
            continue
        ckpts = run_dir / "checkpoints"
        if ckpts.is_dir():
            ckpt_dirs.append(ckpts)
    return ckpt_dirs


def main():
    parser = argparse.ArgumentParser(description="Auto-upload checkpoints to R2")
    parser.add_argument("--exp-dir", type=Path, required=True)
    parser.add_argument("--model-name", type=str, required=True)
    parser.add_argument("--poll-interval", type=int, default=60)
    parser.add_argument("--delete-after-upload", action="store_true", default=False)
    parser.add_argument("--dry-run", action="store_true", default=False)
    args = parser.parse_args()

    project_root = Path(__file__).resolve().parent.parent
    upload_script = project_root / "scripts" / "upload_checkpoint.py"

    print(f"Watching {args.exp_dir} for new checkpoints...")
    print(f"  Model: {args.model_name}")
    print(f"  Poll interval: {args.poll_interval}s")
    print(f"  Delete after upload: {args.delete_after_upload}")
    print()

    # Track pending files: {ckpt_dir: {filename: signature}}
    pending: dict[str, dict[str, dict]] = {}

    while True:
        ckpt_dirs = find_checkpoint_dirs(args.exp_dir)
        for ckpt_dir in ckpt_dirs:
            state = load_state(ckpt_dir)
            ckpt_files = [
                f for f in ckpt_dir.iterdir() if f.is_file() and f.suffix in (".ckpt", ".nemo")
            ]

            for ckpt_file in ckpt_files:
                fname = ckpt_file.name
                sig = file_signature(ckpt_file)

                # Already uploaded with same signature?
                uploaded = state.get("uploaded", {})
                if fname in uploaded and uploaded[fname] == sig:
                    continue

                # Stability gate: file must have same sig across 2 polls
                dir_key = str(ckpt_dir)
                if dir_key not in pending:
                    pending[dir_key] = {}
                prev_sig = pending[dir_key].get(fname)

                if prev_sig != sig:
                    # First observation or size/mtime changed — record and wait
                    pending[dir_key][fname] = sig
                    continue

                # Stable across 2 polls — queue for grouped upload
                pending[dir_key].pop(fname, None)
                if "ready_to_upload" not in pending:
                    pending["ready_to_upload"] = {}
                pending["ready_to_upload"].setdefault(dir_key, []).append(fname)

        # Upload all stable files from each dir together with a shared step prefix
        ready = pending.pop("ready_to_upload", {})
        for dir_key, fnames in ready.items():
            ckpt_dir = Path(dir_key)
            state = load_state(ckpt_dir)

            # Determine step from the .ckpt file (authoritative), fall back to any file
            step = None
            for fname in fnames:
                if fname.endswith(".ckpt"):
                    step = extract_step_from_file(ckpt_dir / fname)
                    break
            if step is None:
                step = extract_step_from_file(ckpt_dir / fnames[0])

            print(f"Uploading group: {fnames} (step={step}) from {ckpt_dir}")

            cmd = [
                sys.executable,
                str(upload_script),
                "--checkpoint-dir",
                str(ckpt_dir),
                "--model-name",
                args.model_name,
                "--step",
                str(step),
                "--files",
            ] + fnames
            if args.dry_run:
                cmd.append("--dry-run")

            result = subprocess.run(cmd, cwd=str(project_root))

            if result.returncode == 0 and not args.dry_run:
                if "uploaded" not in state:
                    state["uploaded"] = {}
                for fname in fnames:
                    state["uploaded"][fname] = file_signature(ckpt_dir / fname)
                save_state(ckpt_dir, state)

                if args.delete_after_upload:
                    for fname in fnames:
                        print(f"  Deleting: {fname}")
                        (ckpt_dir / fname).unlink(missing_ok=True)

        time.sleep(args.poll_interval)


if __name__ == "__main__":
    main()
