#!/usr/bin/env python3
"""Production readiness gate checks.

Verifies all hard prerequisites for a production training launch.
Exits non-zero on any failure with clear actionable messages.

Usage:
  python3 scripts/check_prod_readiness.py
  python3 scripts/check_prod_readiness.py --require-r2-roundtrip
  python3 scripts/check_prod_readiness.py --skip-train-runtime-checks  # CI mode
"""

import argparse
import hashlib
import os
import shutil
import sys
import tempfile
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src"))


def check_disk(min_free_tb: float = 2.0) -> tuple[bool, str]:
    total, used, free = shutil.disk_usage("/")
    free_tb = free / (1024**4)
    total_tb = total / (1024**4)
    used_pct = 100 * used / total
    msg = f"Disk: {free_tb:.1f} TB free / {total_tb:.1f} TB total ({used_pct:.0f}% used)"
    if free_tb < min_free_tb:
        return False, f"{msg} — need >= {min_free_tb} TB free"
    return True, msg


def check_import(module: str) -> tuple[bool, str]:
    try:
        mod = __import__(module)
        version = getattr(mod, "__version__", "unknown")
        return True, f"{module} {version}"
    except ImportError:
        return False, f"{module} NOT INSTALLED — run: pip install -e '.[train,dev]'"


def check_gpus(min_count: int = 8) -> tuple[bool, str]:
    try:
        import torch

        count = torch.cuda.device_count()
        if count < min_count:
            return False, f"GPUs: {count} found, need >= {min_count}"
        names = [torch.cuda.get_device_name(i) for i in range(count)]
        return True, f"GPUs: {count}x {names[0]}"
    except ImportError:
        return False, "torch not installed"


def check_config_resolved(config_path: Path) -> tuple[bool, str]:
    try:
        from omegaconf import OmegaConf

        cfg = OmegaConf.load(str(config_path))
    except Exception as e:
        return False, f"Cannot load config {config_path}: {e}"

    issues = []
    tok_dir = cfg.model.tokenizer.get("dir")
    if tok_dir is None or tok_dir == "null":
        issues.append("model.tokenizer.dir is null")
    elif not Path(tok_dir).exists():
        issues.append(f"model.tokenizer.dir={tok_dir} does not exist")

    train_mf = cfg.model.train_ds.get("manifest_filepath")
    train_cfg = cfg.model.train_ds.get("input_cfg")
    if (train_mf is None or train_mf == "null") and (train_cfg is None or train_cfg == "null"):
        issues.append("model.train_ds: neither manifest_filepath nor input_cfg set")
    elif train_mf and train_mf != "null" and not Path(train_mf).exists():
        issues.append(f"model.train_ds.manifest_filepath={train_mf} does not exist")

    val_mf = cfg.model.validation_ds.get("manifest_filepath")
    if val_mf is None or val_mf == "null":
        issues.append("model.validation_ds.manifest_filepath is null")
    elif not Path(val_mf).exists():
        issues.append(f"model.validation_ds.manifest_filepath={val_mf} does not exist")

    if issues:
        return False, "Unresolved config placeholders:\n    " + "\n    ".join(issues)
    return True, f"Config {config_path}: all paths resolved"


def _get_r2_client():
    """Build and return (s3_client, bucket) or raise on failure."""
    import boto3
    from dotenv import load_dotenv

    env_file = Path(__file__).resolve().parent.parent / ".env"
    if env_file.exists():
        load_dotenv(env_file)

    endpoint = os.environ.get("R2_ENDPOINT_URL")
    key_id = os.environ.get("R2_ACCESS_KEY_ID")
    secret = os.environ.get("R2_SECRET_ACCESS_KEY")
    bucket = os.environ.get("R2_BUCKET_CHECKPOINTS", "ptcheckpoints")

    if not all([endpoint, key_id, secret]):
        raise RuntimeError("R2 credentials not set in .env")

    s3 = boto3.client(
        "s3",
        endpoint_url=endpoint,
        aws_access_key_id=key_id,
        aws_secret_access_key=secret,
        region_name="auto",
    )
    return s3, bucket


def check_r2_connectivity() -> tuple[bool, str]:
    try:
        s3, bucket = _get_r2_client()
        s3.head_bucket(Bucket=bucket)
        return True, f"R2 bucket '{bucket}' accessible"
    except ImportError:
        return False, "boto3/python-dotenv not installed"
    except Exception as e:
        return False, f"R2 access failed: {e}"


def check_r2_roundtrip() -> tuple[bool, str]:
    """Full upload → download → SHA-256 verify → delete probe."""
    try:
        s3, bucket = _get_r2_client()
        test_key = "_maya_asr_pipeline_test/readiness_probe.bin"
        data = os.urandom(4096)
        local_hash = hashlib.sha256(data).hexdigest()

        with tempfile.NamedTemporaryFile(delete=False, suffix=".bin") as f:
            f.write(data)
            local_path = f.name

        try:
            s3.upload_file(local_path, bucket, test_key)
            resp = s3.head_object(Bucket=bucket, Key=test_key)
            if resp["ContentLength"] != len(data):
                return False, "R2 roundtrip: size mismatch after upload"

            with tempfile.NamedTemporaryFile(delete=False, suffix=".bin") as f:
                dl_path = f.name
            s3.download_file(bucket, test_key, dl_path)
            dl_hash = hashlib.sha256(Path(dl_path).read_bytes()).hexdigest()
            Path(dl_path).unlink(missing_ok=True)

            if local_hash != dl_hash:
                return False, "R2 roundtrip: SHA-256 mismatch after download"

            s3.delete_object(Bucket=bucket, Key=test_key)
            return True, "R2 roundtrip: upload+download+SHA256+delete OK"
        finally:
            Path(local_path).unlink(missing_ok=True)
            try:
                s3.delete_object(Bucket=bucket, Key=test_key)
            except Exception:
                pass
    except ImportError:
        return False, "boto3/python-dotenv not installed"
    except Exception as e:
        return False, f"R2 roundtrip failed: {e}"


def check_artifact(path: Path, label: str) -> tuple[bool, str]:
    if path.exists():
        return True, f"{label}: {path}"
    return False, f"{label}: MISSING at {path}"


def main():
    parser = argparse.ArgumentParser(description="Production readiness gate")
    parser.add_argument(
        "--config",
        type=Path,
        default=Path("configs/train/stage1_prod_8xh200.yaml"),
    )
    parser.add_argument("--min-free-tb", type=float, default=2.0)
    parser.add_argument("--min-gpus", type=int, default=8)
    parser.add_argument(
        "--skip-train-runtime-checks",
        action="store_true",
        default=False,
        help="Skip torch/nemo/GPU checks (for CI or preflight-only)",
    )
    parser.add_argument(
        "--require-r2-roundtrip",
        action="store_true",
        default=False,
        help="Require full R2 upload/download/SHA-256/delete probe",
    )
    parser.add_argument(
        "--train-parquet",
        type=Path,
        default=Path("artifacts/phase3/production_train_final.parquet"),
        help="Production training parquet (used by train_prod.py)",
    )
    args = parser.parse_args()

    print("============================================")
    print("  Production Readiness Gate")
    print("============================================")
    print()

    checks: list[tuple[str, bool, str]] = []

    # Disk
    ok, msg = check_disk(args.min_free_tb)
    checks.append(("Disk headroom", ok, msg))

    # Training runtime (optional — skip in CI)
    if not args.skip_train_runtime_checks:
        ok, msg = check_import("torch")
        checks.append(("torch", ok, msg))
        ok, msg = check_import("nemo")
        checks.append(("nemo", ok, msg))
        ok, msg = check_gpus(args.min_gpus)
        checks.append(("GPUs", ok, msg))
    else:
        checks.append(("Train runtime", True, "skipped (--skip-train-runtime-checks)"))

    # Config
    ok, msg = check_config_resolved(args.config)
    checks.append(("Config resolved", ok, msg))

    # R2
    if args.require_r2_roundtrip:
        ok, msg = check_r2_roundtrip()
        checks.append(("R2 roundtrip", ok, msg))
    else:
        ok, msg = check_r2_connectivity()
        checks.append(("R2 checkpoint storage", ok, msg))

    # Artifacts — only check what the config actually references
    try:
        from omegaconf import OmegaConf

        cfg = OmegaConf.load(str(args.config))

        # Tokenizer
        tok_dir = cfg.model.tokenizer.get("dir", "")
        if tok_dir and tok_dir != "null":
            ok, msg = check_artifact(Path(tok_dir) / "tokenizer.model", "Tokenizer model")
            checks.append(("Tokenizer model", ok, msg))
            ok, msg = check_artifact(Path(tok_dir) / "vocab.txt", "Tokenizer vocab")
            checks.append(("Tokenizer vocab", ok, msg))

        # Train data — check whichever the config actually uses
        train_mf = cfg.model.train_ds.get("manifest_filepath")
        train_cfg = cfg.model.train_ds.get("input_cfg")
        if train_mf and train_mf != "null":
            ok, msg = check_artifact(Path(train_mf), "Train manifest")
            checks.append(("Train manifest", ok, msg))
        if train_cfg and train_cfg != "null":
            ok, msg = check_artifact(Path(train_cfg), "Train input_cfg")
            checks.append(("Train input_cfg", ok, msg))

        # Val data
        val_mf = cfg.model.validation_ds.get("manifest_filepath")
        if val_mf and val_mf != "null":
            ok, msg = check_artifact(Path(val_mf), "Val manifest")
            checks.append(("Val manifest", ok, msg))
    except Exception as e:
        checks.append(("Artifact check", False, f"Cannot parse config for artifact check: {e}"))

    # Production parquet (the ACTUAL training input for train_prod.py)
    ok, msg = check_artifact(args.train_parquet, "Production train parquet")
    checks.append(("Prod train parquet", ok, msg))
    if ok:
        # Quick sanity: check row count and no blank transcripts
        try:
            import pyarrow.parquet as pq_lib
            import pyarrow.compute as pc_lib

            meta = pq_lib.read_metadata(str(args.train_parquet))
            n_rows = meta.num_rows
            if n_rows < 1_000_000:
                checks.append(("Parquet row count", False,
                               f"Only {n_rows:,} rows — expected millions"))
            else:
                checks.append(("Parquet row count", True, f"{n_rows:,} rows"))
        except Exception as e:
            checks.append(("Parquet validation", False, f"Cannot read parquet: {e}"))

    # Print results
    passed = 0
    failed = 0
    for name, ok, msg in checks:
        status = "PASS" if ok else "FAIL"
        print(f"  [{status}] {name}")
        if not ok:
            print(f"         {msg}")
            failed += 1
        else:
            passed += 1

    print()
    print(f"Results: {passed} passed, {failed} failed")
    print()

    if failed > 0:
        print("STATUS: NO-GO")
        print()
        print("Fix the FAIL items above before launching production training.")
        sys.exit(1)
    else:
        print("STATUS: GO")
        print()
        print("All checks passed. Ready to launch production training.")
        sys.exit(0)


if __name__ == "__main__":
    main()
