"""Unit tests for checkpoint upload and watcher scripts."""

import importlib.util
import subprocess
from pathlib import Path

PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent


def _import_watcher():
    """Import auto_upload_checkpoints as a module."""
    spec = importlib.util.spec_from_file_location(
        "auto_upload_checkpoints",
        PROJECT_ROOT / "scripts" / "auto_upload_checkpoints.py",
    )
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    return mod


def test_upload_checkpoint_fails_on_missing_dir(tmp_path):
    """Should fail with clear error when checkpoint dir doesn't exist."""
    result = subprocess.run(
        [
            "python3",
            "scripts/upload_checkpoint.py",
            "--checkpoint-dir",
            str(tmp_path / "nonexistent"),
            "--model-name",
            "test",
            "--step",
            "1",
            "--dry-run",
        ],
        capture_output=True,
        text=True,
        cwd=str(PROJECT_ROOT),
    )
    assert result.returncode != 0
    assert "not found" in result.stderr


def test_upload_checkpoint_fails_on_empty_dir(tmp_path):
    """Should fail when no .ckpt or .nemo files found."""
    empty_dir = tmp_path / "empty_ckpt"
    empty_dir.mkdir()
    result = subprocess.run(
        [
            "python3",
            "scripts/upload_checkpoint.py",
            "--checkpoint-dir",
            str(empty_dir),
            "--model-name",
            "test",
            "--step",
            "1",
            "--dry-run",
        ],
        capture_output=True,
        text=True,
        cwd=str(PROJECT_ROOT),
    )
    assert result.returncode != 0
    assert "No .ckpt or .nemo" in result.stderr


def test_upload_checkpoint_dry_run_8digit_step(tmp_path):
    """Dry run should use 8-digit step in key."""
    ckpt_dir = tmp_path / "checkpoints"
    ckpt_dir.mkdir()
    (ckpt_dir / "model.ckpt").write_bytes(b"fake" * 100)
    (ckpt_dir / "model.nemo").write_bytes(b"fake" * 50)

    result = subprocess.run(
        [
            "python3",
            "scripts/upload_checkpoint.py",
            "--checkpoint-dir",
            str(ckpt_dir),
            "--model-name",
            "test-model",
            "--step",
            "42",
            "--dry-run",
        ],
        capture_output=True,
        text=True,
        cwd=str(PROJECT_ROOT),
    )
    assert result.returncode == 0
    assert "DRY RUN" in result.stdout
    assert "step_00000042" in result.stdout


def test_upload_checkpoint_large_step_number(tmp_path):
    """Step 200000 should format as step_00200000."""
    ckpt_dir = tmp_path / "checkpoints"
    ckpt_dir.mkdir()
    (ckpt_dir / "model.nemo").write_bytes(b"data")

    result = subprocess.run(
        [
            "python3",
            "scripts/upload_checkpoint.py",
            "--checkpoint-dir",
            str(ckpt_dir),
            "--model-name",
            "prod",
            "--step",
            "200000",
            "--dry-run",
        ],
        capture_output=True,
        text=True,
        cwd=str(PROJECT_ROOT),
    )
    assert result.returncode == 0
    assert "step_00200000" in result.stdout


def test_watcher_file_level_tracking(tmp_path):
    """Watcher state tracks individual files, not whole directories."""
    watcher = _import_watcher()
    file_signature = watcher.file_signature
    load_state = watcher.load_state
    save_state = watcher.save_state

    ckpt_dir = tmp_path / "run1" / "checkpoints"
    ckpt_dir.mkdir(parents=True)
    f1 = ckpt_dir / "step10.ckpt"
    f1.write_bytes(b"data1" * 100)

    # Initially no state
    state = load_state(ckpt_dir)
    assert state == {}

    # Save state for f1
    sig1 = file_signature(f1)
    state["uploaded"] = {f1.name: sig1}
    save_state(ckpt_dir, state)

    # Add new file f2 — should NOT be in uploaded state
    f2 = ckpt_dir / "step20.ckpt"
    f2.write_bytes(b"data2" * 100)

    state = load_state(ckpt_dir)
    assert f1.name in state["uploaded"]
    assert f2.name not in state["uploaded"]


def test_watcher_stability_gate(tmp_path):
    """File must have stable size across 2 observations."""
    watcher = _import_watcher()
    file_signature = watcher.file_signature

    ckpt_dir = tmp_path / "run1" / "checkpoints"
    ckpt_dir.mkdir(parents=True)
    f = ckpt_dir / "model.ckpt"
    f.write_bytes(b"data" * 100)
    sig1 = file_signature(f)

    # Same file, same content → same signature
    sig2 = file_signature(f)
    assert sig1 == sig2

    # File grows → different signature
    f.write_bytes(b"data" * 200)
    sig3 = file_signature(f)
    assert sig1 != sig3


def test_per_file_step_extraction():
    """Each file should get its own step from its filename."""
    watcher = _import_watcher()
    extract_step_from_file = watcher.extract_step_from_file

    assert extract_step_from_file(Path("step_5000.ckpt")) == 5000
    assert extract_step_from_file(Path("model-step=10000.ckpt")) == 10000
    assert extract_step_from_file(Path("model-epoch=3.ckpt")) == 3
    # Two different filenames should get different steps
    s1 = extract_step_from_file(Path("model-epoch=1.ckpt"))
    s2 = extract_step_from_file(Path("model-epoch=2.ckpt"))
    assert s1 != s2


def test_validate_r2_pipeline_dry_run():
    """R2 pipeline validation in dry-run should check credentials only."""
    result = subprocess.run(
        ["python3", "scripts/validate_r2_pipeline.py", "--dry-run"],
        capture_output=True,
        text=True,
        cwd=str(PROJECT_ROOT),
    )
    assert "R2 Checkpoint Pipeline Validation" in result.stdout
