"""Unit tests for check_prod_readiness.py — uses mocked/temp state."""

import subprocess
from pathlib import Path

import yaml

PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent


def _make_minimal_config(tmp_path, tok_dir=None, train_mf=None, val_mf=None):
    """Create a minimal prod config with given paths (None = null placeholder)."""
    cfg = {
        "name": "test",
        "trainer": {"devices": 1, "max_steps": 10},
        "model": {
            "model_defaults": {"enc_hidden": 256, "pred_hidden": 256},
            "preprocessor": {},
            "encoder": {"n_layers": 4, "d_model": 256},
            "decoder": {},
            "joint": {},
            "tokenizer": {"dir": tok_dir, "type": "bpe"},
            "train_ds": {
                "manifest_filepath": train_mf,
                "batch_size": 4,
                "shuffle": True,
            },
            "validation_ds": {
                "manifest_filepath": val_mf,
                "batch_size": 4,
                "shuffle": False,
            },
            "optim": {"name": "adamw", "lr": 0.001},
        },
        "exp_manager": {},
    }
    config_path = tmp_path / "test_config.yaml"
    with open(config_path, "w") as f:
        yaml.dump(cfg, f)
    return config_path


def _readiness_cmd(config, extra_args=None):
    """Build readiness command with skip-train-runtime-checks for CI."""
    cmd = [
        "python3",
        "scripts/check_prod_readiness.py",
        "--config",
        str(config),
        "--skip-train-runtime-checks",
        "--min-free-tb",
        "0",
    ]
    if extra_args:
        cmd.extend(extra_args)
    return cmd


def test_readiness_fails_on_unresolved_placeholders(tmp_path):
    """Config with null tokenizer/manifest should fail."""
    config = _make_minimal_config(tmp_path)
    result = subprocess.run(
        _readiness_cmd(config),
        capture_output=True,
        text=True,
        cwd=str(PROJECT_ROOT),
    )
    assert result.returncode != 0
    assert "NO-GO" in result.stdout
    assert "FAIL" in result.stdout


def test_readiness_fails_on_missing_artifact(tmp_path):
    """Config pointing to nonexistent tokenizer dir should fail."""
    config = _make_minimal_config(
        tmp_path,
        tok_dir=str(tmp_path / "nonexistent_tok"),
        train_mf=str(tmp_path / "nonexistent_train.jsonl"),
        val_mf=str(tmp_path / "nonexistent_val.jsonl"),
    )
    result = subprocess.run(
        _readiness_cmd(config),
        capture_output=True,
        text=True,
        cwd=str(PROJECT_ROOT),
    )
    assert result.returncode != 0
    assert "does not exist" in result.stdout


def test_readiness_fails_on_disk_threshold(tmp_path):
    """Unreachable disk threshold should fail."""
    config = _make_minimal_config(tmp_path)
    result = subprocess.run(
        [
            "python3",
            "scripts/check_prod_readiness.py",
            "--config",
            str(config),
            "--skip-train-runtime-checks",
            "--min-free-tb",
            "999",
        ],
        capture_output=True,
        text=True,
        cwd=str(PROJECT_ROOT),
    )
    assert result.returncode != 0
    assert "Disk" in result.stdout
    assert "FAIL" in result.stdout


def test_readiness_passes_with_resolved_config(tmp_path):
    """Config with all paths resolved and relaxed thresholds should pass."""
    tok_dir = tmp_path / "tok"
    tok_dir.mkdir()
    (tok_dir / "tokenizer.model").write_bytes(b"dummy")
    (tok_dir / "vocab.txt").write_text("dummy")

    train_mf = tmp_path / "train.jsonl"
    val_mf = tmp_path / "val.jsonl"
    train_mf.write_text('{"audio_filepath":"a","text":"b","duration":1}\n')
    val_mf.write_text('{"audio_filepath":"a","text":"b","duration":1}\n')

    config = _make_minimal_config(
        tmp_path,
        tok_dir=str(tok_dir),
        train_mf=str(train_mf),
        val_mf=str(val_mf),
    )

    # Ensure prod artifacts exist (real or stub)
    artifacts = [
        "data/manifests/stage1_prod_full.jsonl",
        "data/manifests/stage1_prod_train.jsonl",
        "data/manifests/stage1_prod_val.jsonl",
        "tokenizers/stage1_prod_bpe/tokenizer.model",
        "tokenizers/stage1_prod_bpe/vocab.txt",
        "configs/data/stage1_prod_input_cfg.yaml",
    ]
    created_stubs = []
    for a in artifacts:
        p = PROJECT_ROOT / a
        p.parent.mkdir(parents=True, exist_ok=True)
        if not p.exists():
            p.write_bytes(b"stub")
            created_stubs.append(p)

    result = subprocess.run(
        _readiness_cmd(config),
        capture_output=True,
        text=True,
        cwd=str(PROJECT_ROOT),
    )

    for p in created_stubs:
        p.unlink(missing_ok=True)

    assert result.returncode == 0, (
        f"Expected pass:\nstdout: {result.stdout}\nstderr: {result.stderr}"
    )
    assert "GO" in result.stdout


def test_readiness_r2_roundtrip_flag():
    """--require-r2-roundtrip should perform full upload/download probe."""
    result = subprocess.run(
        [
            "python3",
            "scripts/check_prod_readiness.py",
            "--skip-train-runtime-checks",
            "--min-free-tb",
            "0",
            "--require-r2-roundtrip",
        ],
        capture_output=True,
        text=True,
        cwd=str(PROJECT_ROOT),
    )
    # Should contain "R2 roundtrip" check (pass or fail depending on env)
    assert "R2 roundtrip" in result.stdout
