#!/usr/bin/env python3
"""Vast.ai benchmark orchestrator: deploy benchmarks on different GPU types.

Workflow:
  1. Search for cheapest offers matching target GPU types
  2. Spin up SSH instances with our pipeline image
  3. Upload code + .env, run gpu_benchmark.py on each
  4. Collect results, compare cross-GPU
  5. Tear down instances

Usage:
  python scripts/vastai_benchmark.py --api-key <key> --gpus RTX_4090,A100_SXM4
  python scripts/vastai_benchmark.py --api-key <key> --list-gpus  # show available
"""

from __future__ import annotations

import argparse
import json
import logging
import os
import re
import subprocess
import sys
import time
from dataclasses import dataclass
from pathlib import Path

import requests

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("vastai_bench")

BASE_URL = "https://console.vast.ai/api/v0"
DOCKER_IMAGE = "bharathkumar192/codecbench-sft:latest"


def _parse_sft_benchmark_log(log_path: Path) -> dict:
    text = log_path.read_text()

    def grab(pattern: str, cast=float, default=None):
        m = re.search(pattern, text)
        if not m:
            return default
        return cast(m.group(1))

    return {
        "gpu_name": grab(r"GPU:\s+(.+)", str, ""),
        "vram_peak_mb": grab(r"VRAM:\s+([0-9.]+)\s+/\s+[0-9.]+\s+MB", float, 0.0),
        "vram_total_mb": grab(r"VRAM:\s+[0-9.]+\s+/\s+([0-9.]+)\s+MB", float, 0.0),
        "batch_size": grab(r"Batch size:\s+([0-9]+)", int, 0),
        "download_workers": grab(r"Download workers:\s+([0-9]+)", int, 0),
        "download_s": grab(r"Download:\s+([0-9.]+)s", float, 0.0),
        "decode_s": grab(r"Decode:\s+([0-9.]+)s", float, 0.0),
        "encode_s": grab(r"Encode:\s+([0-9.]+)s", float, 0.0),
        "upload_s": grab(r"Upload:\s+([0-9.]+)s", float, 0.0),
        "audio_s": grab(r"Audio:\s+([0-9.]+)s", float, 0.0),
        "segments": grab(r"Segments:\s+([0-9]+)", int, 0),
        "encode_rtf": grab(r"Encode:\s+[0-9.]+s\s+\(RTF=([0-9.]+)x\)", float, 0.0),
        "effective_per_shard_s": grab(r"Effective per shard:\s+([0-9.]+)s", float, 0.0),
        "next_download_ready": "Next download ready by upload finish: True" in text,
        "eta_100_h": grab(r"ETA\s+100 GPUs:\s+([0-9.]+)h", float, 0.0),
        "eta_200_h": grab(r"ETA\s+200 GPUs:\s+([0-9.]+)h", float, 0.0),
    }


@dataclass
class GPUTarget:
    name: str            # e.g. "RTX_4090"
    min_vram_gb: float   # minimum VRAM
    max_price: float     # max $/hr


# GPU targets to benchmark (higher TFLOPS than RTX 3060)
DEFAULT_TARGETS = [
    GPUTarget("RTX 4090",       24, 0.60),
    GPUTarget("RTX 3090",       24, 0.40),
    GPUTarget("RTX 3090 Ti",    24, 0.45),
]


class VastAIBenchOrchestrator:
    def __init__(self, api_key: str, docker_image: str = DOCKER_IMAGE, sync_local: bool = False):
        self.api_key = api_key
        self.docker_image = docker_image
        self.sync_local = sync_local
        self.headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json",
        }
        self.instances: dict[str, int] = {}  # gpu_name -> instance_id
        self.project_root = Path(__file__).resolve().parent.parent
        self.identity_file = Path.home() / ".ssh" / "id_ed25519"

    def search_offers(self, target: GPUTarget) -> list[dict]:
        """Find cheapest offer for a GPU type."""
        resp = requests.post(
            f"{BASE_URL}/bundles/",
            headers=self.headers,
            json={
                "limit": 20,
                "type": "on-demand",
                "verified": {"eq": True},
                "rentable": {"eq": True},
                "rented": {"eq": False},
                "num_gpus": {"eq": 1},
                "gpu_name": {"eq": target.name},
                "gpu_ram": {"gte": target.min_vram_gb * 1024},
                "dph_total": {"lte": target.max_price},
                "reliability": {"gt": 0.95},
                "inet_down": {"gte": 200},
                "disk_space": {"gte": 50},
                "cuda_vers": {"gte": 12.0},
                "order": [["dph_total", "asc"]],
            },
        )
        resp.raise_for_status()
        offers = resp.json().get("offers", [])
        return offers

    def list_available_gpus(self) -> None:
        """Show what GPU types are available and cheapest price."""
        gpu_names = [
            "RTX 4090", "RTX 3090", "RTX 3090 Ti", "L40S",
            "A100 SXM4", "A100 PCIe",
        ]
        print(f"\n{'GPU':<20} {'Offers':>7} {'Cheapest':>10} {'VRAM':>8}")
        print("-" * 50)
        for gn in gpu_names:
            target = GPUTarget(gn, 0, 999)
            offers = self.search_offers(target)
            if offers:
                cheapest = min(o["dph_total"] for o in offers)
                vram = offers[0].get("gpu_ram", 0) / 1024
                print(f"  {gn:<18} {len(offers):>5} ${cheapest:>8.3f}/hr {vram:>6.0f}GB")
            else:
                print(f"  {gn:<18}     0    ---       ---")
        print()

    def ensure_ssh_key(self) -> None:
        """Ensure our SSH key is registered with Vast.ai."""
        pub_key_path = Path.home() / ".ssh" / "id_ed25519.pub"
        if not pub_key_path.exists():
            pub_key_path = Path.home() / ".ssh" / "id_rsa.pub"
        if not pub_key_path.exists():
            logger.info("Generating SSH key for Vast.ai...")
            subprocess.run(
                ["ssh-keygen", "-t", "ed25519", "-f", str(Path.home() / ".ssh" / "id_ed25519"), "-N", ""],
                check=True, capture_output=True,
            )
            pub_key_path = Path.home() / ".ssh" / "id_ed25519.pub"

        pub_key = pub_key_path.read_text().strip()

        # Check if already registered
        resp = requests.get(f"{BASE_URL}/ssh/", headers=self.headers)
        existing = resp.json() if resp.ok else []
        for k in existing:
            if pub_key[:50] in k.get("public_key", ""):
                logger.info("SSH key already registered with Vast.ai")
                return

        resp = requests.post(
            f"{BASE_URL}/ssh/",
            headers=self.headers,
            json={"ssh_key": pub_key},
        )
        if resp.ok:
            logger.info("SSH key registered with Vast.ai")
        else:
            logger.warning("Failed to register SSH key: %s", resp.text)

    def create_instance(self, offer_id: int, gpu_name: str) -> int | None:
        """Create a Vast.ai instance from an offer using pre-built Docker image."""
        env_file = self.project_root / ".env"
        env_vars = {}
        if env_file.exists():
            for line in env_file.read_text().splitlines():
                line = line.strip()
                if line and not line.startswith("#") and "=" in line:
                    k, v = line.split("=", 1)
                    env_vars[k] = v

        resp = requests.put(
            f"{BASE_URL}/asks/{offer_id}/",
            headers=self.headers,
            json={
                "image": self.docker_image,
                "disk": 50,
                "runtype": "ssh_direct",
                "label": f"bench-{gpu_name}",
                "env": env_vars,
            },
        )

        if not resp.ok:
            logger.error("Failed to create instance for %s: %s", gpu_name, resp.text)
            return None

        data = resp.json()
        instance_id = data.get("new_contract")
        if instance_id:
            logger.info("Created instance %s for %s (image: %s)",
                        instance_id, gpu_name, self.docker_image)
            self.instances[gpu_name] = instance_id
        return instance_id

    def wait_for_instance(self, instance_id: int, timeout: int = 600) -> dict | None:
        """Wait until instance is running and SSH-accessible."""
        t0 = time.time()
        while time.time() - t0 < timeout:
            resp = requests.get(
                f"{BASE_URL}/instances/{instance_id}/",
                headers=self.headers,
            )
            if not resp.ok:
                time.sleep(10)
                continue

            inst = resp.json().get("instances", resp.json())
            status = inst.get("actual_status", "")

            if status == "running" and inst.get("ssh_host") and inst.get("ssh_port"):
                logger.info("Instance %s is running: ssh -p %s root@%s",
                            instance_id, inst["ssh_port"], inst["ssh_host"])
                return inst
            elif status in ("exited", "offline"):
                logger.error("Instance %s failed: %s", instance_id, status)
                return None

            logger.info("Instance %s status: %s, waiting...", instance_id, status)
            time.sleep(15)

        logger.error("Instance %s timed out", instance_id)
        return None

    def wait_for_ssh(self, ssh_host: str, ssh_port: int, timeout: int = 180) -> bool:
        """Wait until SSH is responsive (image has everything pre-installed)."""
        t0 = time.time()
        while time.time() - t0 < timeout:
            try:
                result = subprocess.run(
                    ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
                     "-i", str(self.identity_file), "-p", str(ssh_port), f"root@{ssh_host}", "python3 --version"],
                    capture_output=True, text=True, timeout=20,
                )
                if result.returncode == 0:
                    logger.info("SSH ready on %s:%s (%s)",
                                ssh_host, ssh_port, result.stdout.strip())
                    return True
            except Exception:
                pass
            time.sleep(10)
        return False

    def sync_code(self, ssh_host: str, ssh_port: int) -> bool:
        """Optionally sync local code changes over the baked-in image code."""
        src = str(self.project_root) + "/"
        excludes = [
            "--exclude=venv/", "--exclude=repos/", "--exclude=data/",
            "--exclude=results/", "--exclude=metafiles/",
            "--exclude=__pycache__/", "--exclude=.git/",
            "--exclude=*.pyc", "--exclude=.cursor/", "--exclude=models/",
        ]
        cmd = [
            "rsync", "-avz", "--timeout=60",
            "-e", f"ssh -i {self.identity_file} -o StrictHostKeyChecking=no -p {ssh_port}",
            *excludes,
            src, f"root@{ssh_host}:/app/",
        ]
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
        if result.returncode != 0:
            logger.error("rsync failed: %s", result.stderr[-500:])
            return False
        logger.info("Code synced to %s", ssh_host)
        return True

    def run_benchmark_remote(
        self,
        ssh_host: str,
        ssh_port: int,
        gpu_name: str,
        batch_size: int = 2,
    ) -> dict | None:
        """Run the SFT benchmark on the remote instance and fetch the log."""
        env_path = self.project_root / ".env"
        if env_path.exists():
            put_env = subprocess.run(
                ["scp", "-i", str(self.identity_file), "-o", "StrictHostKeyChecking=no",
                 "-P", str(ssh_port), str(env_path), f"root@{ssh_host}:/app/.env"],
                capture_output=True, text=True, timeout=60,
            )
            if put_env.returncode != 0:
                logger.error("Failed to copy .env to %s: %s", gpu_name, put_env.stderr[-500:])
                return None

        remote_cmd = (
            f"cd /app && "
            f"mkdir -p /tmp/pipeline && "
            f"python3 -m codecbench.pipeline.cli sft-run "
            f"--benchmark --batch-size={batch_size} "
            f"--offer-id vast_bench_{gpu_name.replace(' ', '_')} "
            f"2>&1 | tee /tmp/sft_benchmark.log"
        )

        logger.info("Running SFT benchmark on %s (BS=%d)...", gpu_name, batch_size)
        try:
            result = subprocess.run(
                ["ssh", "-i", str(self.identity_file), "-o", "StrictHostKeyChecking=no", "-o", "ServerAliveInterval=30",
                 "-p", str(ssh_port), f"root@{ssh_host}", remote_cmd],
                capture_output=True, text=True, timeout=7200,
            )
            print(result.stdout[-3000:])  # Print the report
            if result.returncode != 0:
                logger.error("Benchmark failed: %s", result.stderr[-500:])
                return None
        except subprocess.TimeoutExpired:
            logger.error("Benchmark timed out on %s", gpu_name)
            return None

        # Fetch benchmark log
        try:
            fetch = subprocess.run(
                ["scp", "-i", str(self.identity_file), "-o", "StrictHostKeyChecking=no",
                 "-P", str(ssh_port),
                 f"root@{ssh_host}:/tmp/sft_benchmark.log",
                 f"/tmp/bench_{gpu_name.replace(' ', '_')}.log"],
                capture_output=True, text=True, timeout=30,
            )
            if fetch.returncode == 0:
                log_path = Path(f"/tmp/bench_{gpu_name.replace(' ', '_')}.log")
                return {
                    "gpu_name": gpu_name,
                    "batch_size": batch_size,
                    "log_path": str(log_path),
                    "summary": _parse_sft_benchmark_log(log_path),
                }
        except Exception as e:
            logger.error("Failed to fetch results from %s: %s", gpu_name, e)

        return None

    def destroy_instance(self, instance_id: int) -> None:
        resp = requests.delete(
            f"{BASE_URL}/instances/{instance_id}/",
            headers=self.headers,
        )
        if resp.ok:
            logger.info("Destroyed instance %s", instance_id)
        else:
            logger.warning("Failed to destroy %s: %s", instance_id, resp.text)

    def destroy_all(self) -> None:
        for gpu, iid in self.instances.items():
            self.destroy_instance(iid)
        self.instances.clear()

    def run_full_benchmark(
        self,
        targets: list[GPUTarget],
        batch_sizes: list[int] | None = None,
    ) -> dict[str, list[dict]]:
        """Full multi-GPU benchmark: search → create → setup → benchmark → collect → destroy."""
        if batch_sizes is None:
            batch_sizes = [1, 2, 4]

        self.ensure_ssh_key()

        all_results: dict[str, list[dict]] = {}

        for target in targets:
            logger.info("\n{'='*60}")
            logger.info("Searching for %s (max $%.2f/hr)...", target.name, target.max_price)
            offers = self.search_offers(target)
            if not offers:
                logger.warning("No offers found for %s, skipping", target.name)
                continue

            offer = offers[0]
            logger.info("Best offer: ID=%s, $%.3f/hr, %s, %.0fGB VRAM",
                        offer["id"], offer["dph_total"], offer["gpu_name"],
                        offer.get("gpu_ram", 0) / 1024)

            # Create instance
            instance_id = self.create_instance(offer["id"], target.name)
            if not instance_id:
                continue

            try:
                # Wait for running
                inst = self.wait_for_instance(instance_id)
                if not inst:
                    continue

                ssh_host = inst["ssh_host"]
                ssh_port = inst["ssh_port"]

                # Image has everything baked in -- just wait for SSH
                if not self.wait_for_ssh(ssh_host, ssh_port):
                    logger.error("SSH timed out on %s, skipping", target.name)
                    continue

                # Optionally sync local code changes over the image
                if self.sync_local:
                    self.sync_code(ssh_host, ssh_port)

                # Run benchmarks at different batch sizes
                gpu_results = []
                for bs in batch_sizes:
                    result = self.run_benchmark_remote(
                        ssh_host, ssh_port, target.name,
                        batch_size=bs,
                    )
                    if result:
                        gpu_results.append(result)

                all_results[target.name] = gpu_results

            finally:
                self.destroy_instance(instance_id)
                if target.name in self.instances:
                    del self.instances[target.name]

        return all_results


def print_comparison(all_results: dict[str, list[dict]]) -> None:
    """Print cross-GPU comparison table."""
    print(f"\n{'='*100}")
    print(f"  SFT CROSS-GPU BENCHMARK COMPARISON")
    print(f"{'='*100}")
    print(f"  {'GPU':<20} {'BS':>3} {'VRAM_pk':>8} {'Enc_RTF':>8} {'dl_s':>6} {'dec_s':>6} {'enc_s':>7} {'eff_s':>7} {'ETA100':>8}")
    print(f"  {'-'*95}")

    for gpu_name, results in all_results.items():
        for r in results:
            s = r["summary"]
            print(f"  {gpu_name:<20} {s['batch_size']:>3} "
                  f"{s['vram_peak_mb']:>7.0f} {s['encode_rtf']:>8.1f} "
                  f"{s['download_s']:>6.1f} {s['decode_s']:>6.1f} {s['encode_s']:>7.1f} "
                  f"{s['effective_per_shard_s']:>7.1f} {s['eta_100_h']:>7.1f}h")
    print(f"{'='*100}\n")


def main():
    parser = argparse.ArgumentParser(description="Vast.ai multi-GPU benchmark orchestrator")
    parser.add_argument("--api-key", type=str, required=True, help="Vast.ai API key")
    parser.add_argument("--gpus", type=str, default=None,
                        help="Comma-separated GPU names (e.g. RTX_4090,A100_SXM4)")
    parser.add_argument("--list-gpus", action="store_true", help="List available GPUs and prices")
    parser.add_argument("--num-videos", type=int, default=10)
    parser.add_argument("--batch-sizes", type=str, default="1,2,4",
                        help="Comma-separated batch sizes to test")
    parser.add_argument("--max-price", type=float, default=1.0, help="Max $/hr per GPU")
    parser.add_argument("--image", type=str, default=DOCKER_IMAGE,
                        help=f"Docker image to use (default: {DOCKER_IMAGE})")
    parser.add_argument("--sync-local", action="store_true",
                        help="Rsync local code changes over the baked-in image code")
    parser.add_argument("--output", type=str, default="results/vastai_benchmark.json")
    args = parser.parse_args()

    orch = VastAIBenchOrchestrator(args.api_key, docker_image=args.image,
                                   sync_local=args.sync_local)

    if args.list_gpus:
        orch.list_available_gpus()
        return

    # Parse targets
    if args.gpus:
        gpu_names = [g.strip() for g in args.gpus.split(",")]
        targets = [GPUTarget(g, 0, args.max_price) for g in gpu_names]
    else:
        targets = DEFAULT_TARGETS

    batch_sizes = [int(x) for x in args.batch_sizes.split(",")]

    try:
        all_results = orch.run_full_benchmark(targets, batch_sizes)

        if all_results:
            print_comparison(all_results)

            out = Path(args.output)
            out.parent.mkdir(parents=True, exist_ok=True)
            with open(out, "w") as f:
                json.dump(all_results, f, indent=2)
            logger.info("All results saved to %s", out)
    except KeyboardInterrupt:
        logger.info("Interrupted! Destroying all instances...")
        orch.destroy_all()
    except Exception as e:
        logger.error("Fatal error: %s", e, exc_info=True)
        orch.destroy_all()
        raise


if __name__ == "__main__":
    main()
