#!/usr/bin/env python3
"""Fleet deployer: provision N GPU instances on Vast.ai, bootstrap, start production workers.

Usage:
    python scripts/deploy_fleet.py --num-gpus 5 --max-price 0.55
    python scripts/deploy_fleet.py --num-gpus 5 --dry-run          # show what would be rented
    python scripts/deploy_fleet.py --status                        # check running fleet
    python scripts/deploy_fleet.py --destroy-all                   # tear down everything
"""
from __future__ import annotations

import argparse
import json
import logging
import os
import subprocess
import sys
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field, asdict
from pathlib import Path

import requests
from dotenv import load_dotenv

load_dotenv()

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("fleet")

BASE_URL = "https://console.vast.ai/api/v0"
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DOCKER_IMAGE = "bharathkumar192/codecbench-sft:latest"
FLEET_STATE_FILE = PROJECT_ROOT / "results" / "fleet_state.json"

GPU_PREFERENCES = [
    "RTX 4090", "RTX 3090", "RTX 3090 Ti", "L40S", "L40",
    "RTX 4080", "RTX 4080S", "RTX A6000", "RTX A5000",
]
BAD_MACHINES: set[int] = {54269, 24352, 42359, 30895, 15501, 42700}

# BS=8 for datacenter 40GB+ (L40S/L40/A6000/A100), BS=4 for 4090 (24GB), BS=1 for everything else
def pick_batch_size(vram_gb: float, gpu_name: str = "") -> int:
    if vram_gb >= 40:
        return 8
    high_end = any(k in gpu_name for k in ["4090", "3090", "A100"])
    if high_end and vram_gb >= 22:
        return 4
    return 1


@dataclass
class InstanceInfo:
    instance_id: int
    gpu_name: str
    machine_id: int
    price_per_hour: float
    ssh_host: str = ""
    ssh_port: int = 0
    status: str = "created"
    batch_size: int = 2
    vram_gb: float = 0
    worker_pid: int = 0


class FleetDeployer:
    def __init__(self, api_key: str, max_price: float = 0.80):
        self.api_key = api_key
        self.max_price = max_price
        self.headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
        self.instances: list[InstanceInfo] = []
        self.identity_file = Path.home() / ".ssh" / "id_ed25519"

    def search_offers(self, gpu_name: str, limit: int = 20) -> list[dict]:
        body = {
            "limit": limit, "type": "on-demand",
            "rentable": {"eq": True}, "rented": {"eq": False},
            "num_gpus": {"eq": 1}, "gpu_name": {"eq": gpu_name},
            "dph_total": {"lte": self.max_price},
            "order": [["dph_total", "asc"]],
        }
        for attempt in range(3):
            r = requests.post(f"{BASE_URL}/bundles/", headers=self.headers, json=body, timeout=30)
            if r.status_code == 429:
                time.sleep(2 * (attempt + 1))
                continue
            r.raise_for_status()
            offers = r.json().get("offers", [])
            filtered = [
                o for o in offers
                if o.get("machine_id") not in BAD_MACHINES
                and float(o.get("reliability", 0)) >= 0.93
                and float(o.get("inet_down", 0)) >= 50
                and float(o.get("disk_space", 0)) >= 40
            ]
            return sorted(
                filtered,
                key=lambda o: (o["dph_total"], -(o.get("inet_down") or 0), -(o.get("reliability") or 0)),
            )
        r.raise_for_status()
        return []

    def pick_best_offers(self, num_gpus: int) -> list[dict]:
        """Pick N offers prioritizing high-throughput GPUs (L40S, 4090) then filling with value."""
        per_type: dict[str, list[dict]] = {}
        for gpu in GPU_PREFERENCES:
            offers = self.search_offers(gpu, limit=50)
            for o in offers:
                o["_gpu_type"] = gpu
            per_type[gpu] = offers
            logger.info("  %s: %d offers available", gpu, len(offers))

        picked: list[dict] = []
        used_machines: set[int] = set()

        # Pass 1: grab ALL available L40S and 4090 first (highest throughput)
        for gpu in ["L40S", "RTX 4090", "L40"]:
            for o in per_type.get(gpu, []):
                if len(picked) >= num_gpus:
                    break
                mid = o.get("machine_id")
                if mid in used_machines:
                    continue
                picked.append(o)
                used_machines.add(mid)

        # Pass 2: fill remaining with all other types, cheapest first
        remaining = []
        for gpu in GPU_PREFERENCES:
            if gpu in ("L40S", "RTX 4090", "L40"):
                continue
            remaining.extend(per_type.get(gpu, []))
        remaining.sort(key=lambda o: o["dph_total"])

        for o in remaining:
            if len(picked) >= num_gpus:
                break
            mid = o.get("machine_id")
            if mid in used_machines:
                continue
            picked.append(o)
            used_machines.add(mid)

        return picked[:num_gpus]

    def create_instance(self, offer: dict) -> InstanceInfo | None:
        env_file = PROJECT_ROOT / ".env"
        env_vars = {}
        docker_user = ""
        docker_pat = ""
        if env_file.exists():
            for line in env_file.read_text().splitlines():
                line = line.strip()
                if line and not line.startswith("#") and "=" in line:
                    k, v = line.split("=", 1)
                    env_vars[k] = v
                    if k == "DOCKER_USERNAME":
                        docker_user = v
                    elif k == "DOCKER_PAT":
                        docker_pat = v

        vram_gb = offer.get("gpu_ram", 0) / 1024
        gpu_name = offer.get("gpu_name", "")
        bs = pick_batch_size(vram_gb, gpu_name)
        onstart = (
            "mkdir -p /tmp/pipeline; "
            "cd /app && "
            "python3 - <<'PY'\n"
            "import os\n"
            "keys = [\n"
            "    'HF_TOKEN','R2_ENDPOINT_URL','R2_BUCKET_DESTINATION','R2_ACCESS_KEY_ID',\n"
            "    'R2_SECRET_ACCESS_KEY','ACCOUNT_ID','S3_API','DATABASE_URL','VAST_API'\n"
            "]\n"
            "with open('/app/.env', 'w') as f:\n"
            "    for k in keys:\n"
            "        f.write(f\"{k}={os.environ.get(k, '')}\\n\")\n"
            "PY\n"
            f"python3 -m codecbench.pipeline.cli sft-run "
            f"--batch-size {bs} "
            f"--offer-id vast_$CONTAINER_ID "
            f">/tmp/worker.log 2>&1 &"
        )

        create_body: dict = {
            "image": DOCKER_IMAGE,
            "disk": 60,
            "runtype": "ssh_direct",
            "label": f"fleet-{offer.get('gpu_name', 'gpu')}",
            "env": env_vars,
            "onstart": onstart,
        }
        if docker_user and docker_pat:
            create_body["docker_login_repo"] = "https://index.docker.io/v1/"
            create_body["docker_login_user"] = docker_user
            create_body["docker_login_pass"] = docker_pat

        r = None
        for attempt in range(3):
            r = requests.put(
                f"{BASE_URL}/asks/{offer['id']}/",
                headers=self.headers,
                json=create_body,
            )
            if r.status_code == 429:
                time.sleep(3 * (attempt + 1))
                continue
            break
        if not r or not r.ok:
            logger.error("Create failed for offer %s: %s", offer["id"], r.text[:200] if r else "no response")
            return None

        iid = r.json().get("new_contract")
        if not iid:
            return None

        vram_gb = offer.get("gpu_ram", 0) / 1024
        info = InstanceInfo(
            instance_id=iid,
            gpu_name=offer.get("gpu_name", "unknown"),
            machine_id=offer.get("machine_id", 0),
            price_per_hour=offer.get("dph_total", 0),
            vram_gb=vram_gb,
            batch_size=pick_batch_size(vram_gb, offer.get("gpu_name", "")),
        )
        logger.info("Created instance %d: %s @ $%.3f/hr (BS=%d)",
                     iid, info.gpu_name, info.price_per_hour, info.batch_size)
        return info

    def wait_instance_ready(self, info: InstanceInfo, timeout: int = 600) -> bool:
        t0 = time.time()
        while time.time() - t0 < timeout:
            r = requests.get(f"{BASE_URL}/instances/{info.instance_id}/",
                             headers=self.headers, timeout=30)
            if not r.ok:
                time.sleep(15)
                continue
            inst = r.json().get("instances", r.json())
            if isinstance(inst, list):
                inst = inst[0] if inst else {}
            actual = inst.get("actual_status") or ""
            msg = inst.get("status_msg") or ""
            if actual == "running" and inst.get("ssh_host") and inst.get("ssh_port"):
                info.ssh_host = inst["ssh_host"]
                info.ssh_port = inst["ssh_port"]
                info.status = "running"
                return True
            if "Error" in msg or actual in {"exited", "offline", "dead"}:
                logger.error("Instance %d failed: %s / %s", info.instance_id, actual, msg[:150])
                info.status = "failed"
                return False
            elapsed = int(time.time() - t0)
            logger.info("  Instance %d: %s [%ds] %s",
                        info.instance_id, actual or "pending", elapsed, msg[:80])
            time.sleep(15)
        info.status = "timeout"
        return False

    def wait_ssh(self, info: InstanceInfo, timeout: int = 180) -> bool:
        t0 = time.time()
        while time.time() - t0 < timeout:
            try:
                r = subprocess.run(
                    ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
                     "-i", str(self.identity_file), "-p", str(info.ssh_port), f"root@{info.ssh_host}", "echo OK"],
                    capture_output=True, text=True, timeout=20,
                )
                if r.returncode == 0:
                    return True
            except Exception:
                pass
            time.sleep(10)
        return False

    def run_ssh(self, info: InstanceInfo, cmd: str, timeout: int = 3600) -> bool:
        full = ["ssh", "-i", str(self.identity_file), "-o", "StrictHostKeyChecking=no", "-o", "ServerAliveInterval=30",
                "-p", str(info.ssh_port), f"root@{info.ssh_host}", cmd]
        try:
            r = subprocess.run(full, text=True, capture_output=True, timeout=timeout)
            if r.returncode != 0:
                logger.error("SSH cmd failed on %d (%s): %s",
                             info.instance_id, info.gpu_name, r.stderr.strip()[-500:])
                return False
            return True
        except subprocess.TimeoutExpired:
            logger.error("SSH cmd timed out on %d", info.instance_id)
            return False

    def sync_code(self, info: InstanceInfo) -> bool:
        src = str(PROJECT_ROOT) + "/"
        excludes = [
            "--exclude=venv/", "--exclude=repos/", "--exclude=data/",
            "--exclude=results/", "--exclude=metafiles/",
            "--exclude=__pycache__/", "--exclude=.git/",
            "--exclude=*.pyc", "--exclude=.cursor/", "--exclude=models/",
        ]
        cmd = [
            "rsync", "-avz", "--timeout=60",
            "-e", f"ssh -i {self.identity_file} -o StrictHostKeyChecking=no -p {info.ssh_port}",
            *excludes, src, f"root@{info.ssh_host}:/app/",
        ]
        r = subprocess.run(cmd, capture_output=True, text=True, timeout=120)
        return r.returncode == 0

    def run_ssh_detached(self, info: InstanceInfo, cmd: str, log_file: str,
                         timeout: int = 3600, poll_interval: int = 15) -> bool:
        """Run a command via nohup, poll for completion. Survives SSH proxy drops."""
        wrapper = (
            f"nohup bash -c '{cmd}; echo __DONE_EXIT_$?' > {log_file} 2>&1 & echo $!"
        )
        full = ["ssh", "-i", str(self.identity_file), "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
                "-p", str(info.ssh_port), f"root@{info.ssh_host}", wrapper]
        try:
            r = subprocess.run(full, text=True, capture_output=True, timeout=30)
            if r.returncode != 0:
                logger.error("Failed to launch detached cmd on %d: %s", info.instance_id, r.stderr[:200])
                return False
        except Exception as e:
            logger.error("SSH launch error on %d: %s", info.instance_id, e)
            return False

        t0 = time.time()
        while time.time() - t0 < timeout:
            time.sleep(poll_interval)
            try:
                check = subprocess.run(
                    ["ssh", "-i", str(self.identity_file), "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
                     "-p", str(info.ssh_port), f"root@{info.ssh_host}",
                     f"tail -5 {log_file} 2>/dev/null"],
                    text=True, capture_output=True, timeout=20,
                )
                if check.returncode == 0 and "__DONE_EXIT_" in check.stdout:
                    exit_line = [l for l in check.stdout.strip().split("\n") if "__DONE_EXIT_" in l]
                    if exit_line:
                        code = exit_line[-1].split("__DONE_EXIT_")[-1].strip()
                        if code == "0":
                            return True
                        else:
                            logger.error("Detached cmd failed (exit=%s) on %d. Tail:\n%s",
                                         code, info.instance_id, check.stdout[-500:])
                            return False
            except Exception:
                pass
        logger.error("Detached cmd timed out on %d after %ds", info.instance_id, timeout)
        return False

    def bootstrap_instance(self, info: InstanceInfo) -> bool:
        """Lightweight bootstrap: pip install -e . (code from rsync) + verify GPU.
        Docker image already has all deps and Spark-TTS repo. Models download on first worker.setup().
        """
        logger.info("  [%d %s] pip-editable ...", info.instance_id, info.gpu_name)
        if not self.run_ssh(info, "cd /app && python3 -m pip install --no-cache-dir -e . 2>&1 | tail -3", timeout=120):
            logger.error("pip install -e . failed for %d", info.instance_id)
            return False

        logger.info("  [%d %s] verify-gpu ...", info.instance_id, info.gpu_name)
        if not self.run_ssh(info, 'python3 -c "import torch; assert torch.cuda.is_available(); print(\'GPU:\', torch.cuda.get_device_name(0))"', timeout=60):
            logger.error("GPU verification failed for %d", info.instance_id)
            return False

        info.status = "bootstrapped"
        return True

    def start_worker(self, info: InstanceInfo) -> bool:
        """Start the production pipeline worker as a background process.
        Uses bash -c with explicit fd closing so SSH doesn't hang waiting for nohup.
        """
        worker_cmd = (
            f"cd /app && python3 -m codecbench.pipeline.cli sft-run "
            f"--batch-size {info.batch_size} "
            f"--offer-id vast_{info.instance_id} "
            f"</dev/null >/tmp/worker.log 2>&1 & echo $!"
        )
        full = ["ssh", "-i", str(self.identity_file), "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=30",
                "-p", str(info.ssh_port), f"root@{info.ssh_host}", worker_cmd]
        for attempt in range(3):
            try:
                r = subprocess.run(full, text=True, capture_output=True, timeout=60)
                last_line = r.stdout.strip().split('\n')[-1].strip() if r.stdout.strip() else ""
                if r.returncode == 0 and last_line.isdigit():
                    info.worker_pid = int(last_line)
                    info.status = "worker_running"
                    logger.info("Worker started on %d (%s) PID=%d BS=%d",
                                info.instance_id, info.gpu_name, info.worker_pid, info.batch_size)
                    return True
                logger.warning("Worker start attempt %d on %d: rc=%d out=[%s] err=[%s]",
                               attempt + 1, info.instance_id, r.returncode,
                               r.stdout.strip()[-200:], r.stderr.strip()[-200:])
            except subprocess.TimeoutExpired:
                logger.warning("Worker start attempt %d timed out on %d", attempt + 1, info.instance_id)
            except Exception as e:
                logger.warning("Worker start attempt %d error on %d: %s", attempt + 1, info.instance_id, e)
            time.sleep(5)
        logger.error("Failed to start worker on %d after 3 attempts", info.instance_id)
        return False

    def provision_one(self, offer: dict) -> InstanceInfo | None:
        """Create instance and wait for it to start. Worker auto-starts via onstart script."""
        info = self.create_instance(offer)
        if not info:
            return None
        try:
            if not self.wait_instance_ready(info, timeout=600):
                self.destroy_instance(info.instance_id)
                return None
            logger.info("Instance %d running at %s:%d — onstart worker auto-launching",
                        info.instance_id, info.ssh_host, info.ssh_port)
            info.status = "worker_running"
            return info

        except Exception as e:
            logger.error("Provision failed for %d: %s\n%s",
                         info.instance_id, e, traceback.format_exc())
            self.destroy_instance(info.instance_id)
            return None

    def destroy_instance(self, instance_id: int) -> None:
        r = requests.delete(f"{BASE_URL}/instances/{instance_id}/", headers=self.headers, timeout=30)
        if r.ok:
            logger.info("Destroyed instance %d", instance_id)

    def destroy_all(self) -> None:
        state = self.load_state()
        for inst in state:
            self.destroy_instance(inst["instance_id"])
        FLEET_STATE_FILE.unlink(missing_ok=True)
        logger.info("All fleet instances destroyed")

    def save_state(self) -> None:
        FLEET_STATE_FILE.parent.mkdir(parents=True, exist_ok=True)
        data = [asdict(i) for i in self.instances]
        FLEET_STATE_FILE.write_text(json.dumps(data, indent=2))

    def load_state(self) -> list[dict]:
        if FLEET_STATE_FILE.exists():
            return json.loads(FLEET_STATE_FILE.read_text())
        return []

    def get_offer_details(self, offer_id: int) -> dict | None:
        """Fetch a specific offer by ID. Searches multiple price tiers to avoid limit cutoffs."""
        for max_price in [0.30, 0.80, 2.00, 5.00]:
            body = {
                "limit": 1000, "type": "on-demand",
                "rentable": {"eq": True}, "rented": {"eq": False},
                "num_gpus": {"eq": 1},
                "dph_total": {"lte": max_price},
                "order": [["dph_total", "asc"]],
            }
            r = requests.post(f"{BASE_URL}/bundles/", headers=self.headers, json=body, timeout=30)
            if not r.ok:
                continue
            for o in r.json().get("offers", []):
                if o.get("id") == offer_id:
                    return o
        return None

    def deploy_by_offer_ids(self, offer_ids: list[int], dry_run: bool = False) -> None:
        """Deploy specific hand-picked offers by ID."""
        offers = []
        for oid in offer_ids:
            o = self.get_offer_details(oid)
            if o:
                offers.append(o)
            else:
                logger.warning("Offer %d not found or no longer available", oid)

        if not offers:
            logger.error("No valid offers found!")
            return

        print(f"\n{'='*70}")
        print(f"  MANUAL FLEET DEPLOYMENT ({len(offers)} GPUs)")
        print(f"{'='*70}")
        total_cost = 0
        for i, o in enumerate(offers):
            vram = o.get("gpu_ram", 0) / 1024
            bs = pick_batch_size(vram)
            price = o["dph_total"]
            total_cost += price
            dl = o.get("inet_down", 0)
            ul = o.get("inet_up", 0)
            rel = o.get("reliability", 0)
            geo = str(o.get("geolocation", "?"))[:18]
            print(f"  {i+1}. {o['gpu_name']:<20} {vram:.0f}GB  BS={bs}  "
                  f"${price:.3f}/hr  DL={dl:.0f} UL={ul:.0f}  rel={rel:.3f}  {geo}")
        print(f"\n  Total fleet cost: ${total_cost:.3f}/hr (${total_cost*24:.2f}/day)")
        print(f"{'='*70}\n")

        if dry_run:
            print("  DRY RUN — no instances created")
            return

        logger.info("Provisioning %d instances (8 at a time)...", len(offers))
        with ThreadPoolExecutor(max_workers=8, thread_name_prefix="prov") as pool:
            futures = {pool.submit(self.provision_one, o): o for o in offers}
            for fut in as_completed(futures):
                offer = futures[fut]
                try:
                    info = fut.result()
                    if info:
                        self.instances.append(info)
                        self.save_state()
                except Exception as e:
                    logger.error("Provision error for %s: %s", offer.get("gpu_name"), e)

        running = [i for i in self.instances if i.status == "worker_running"]
        failed_count = len(offers) - len(running)
        print(f"\n{'='*70}")
        print(f"  DEPLOYMENT RESULT: {len(running)}/{len(offers)} workers running")
        print(f"{'='*70}")
        for info in running:
            print(f"  Instance {info.instance_id}: {info.gpu_name:<20} "
                  f"BS={info.batch_size}  ${info.price_per_hour:.3f}/hr  "
                  f"PID={info.worker_pid}  {info.ssh_host}:{info.ssh_port}")
        if failed_count:
            print(f"\n  {failed_count} instance(s) failed to provision")
        print(f"{'='*70}\n")
        self.save_state()

    def deploy_fleet(self, num_gpus: int, dry_run: bool = False) -> None:
        """Deploy N GPU instances across diverse GPU types."""
        offers = self.pick_best_offers(num_gpus)
        if not offers:
            logger.error("No suitable offers found!")
            return

        print(f"\n{'='*70}")
        print(f"  FLEET DEPLOYMENT PLAN ({len(offers)} GPUs)")
        print(f"{'='*70}")
        total_cost = 0
        for i, o in enumerate(offers):
            vram = o.get("gpu_ram", 0) / 1024
            bs = pick_batch_size(vram)
            price = o["dph_total"]
            total_cost += price
            print(f"  {i+1}. {o['gpu_name']:<20} {vram:.0f}GB VRAM  BS={bs}  "
                  f"${price:.3f}/hr  machine={o.get('machine_id')}  "
                  f"rel={o.get('reliability',0):.3f}")
        print(f"\n  Total fleet cost: ${total_cost:.3f}/hr (${total_cost*24:.2f}/day)")
        print(f"{'='*70}\n")

        if dry_run:
            print("  DRY RUN — no instances created")
            return

        logger.info("Provisioning %d instances (8 at a time)...", len(offers))
        with ThreadPoolExecutor(max_workers=8, thread_name_prefix="prov") as pool:
            futures = {pool.submit(self.provision_one, o): o for o in offers}
            for fut in as_completed(futures):
                offer = futures[fut]
                try:
                    info = fut.result()
                    if info:
                        self.instances.append(info)
                        self.save_state()
                except Exception as e:
                    logger.error("Provision error for %s: %s", offer.get("gpu_name"), e)

        # Summary
        running = [i for i in self.instances if i.status == "worker_running"]
        failed = len(offers) - len(running)
        print(f"\n{'='*70}")
        print(f"  DEPLOYMENT RESULT: {len(running)}/{len(offers)} workers running")
        print(f"{'='*70}")
        for info in running:
            print(f"  Instance {info.instance_id}: {info.gpu_name:<20} "
                  f"BS={info.batch_size}  ${info.price_per_hour:.3f}/hr  "
                  f"PID={info.worker_pid}  {info.ssh_host}:{info.ssh_port}")
        if failed:
            print(f"\n  {failed} instance(s) failed to provision")
        print(f"{'='*70}\n")
        self.save_state()

    def check_status(self) -> None:
        """Check health of all fleet instances."""
        state = self.load_state()
        if not state:
            print("No fleet state found. Deploy first with --deploy")
            return

        print(f"\n{'='*70}")
        print(f"  FLEET STATUS ({len(state)} instances)")
        print(f"{'='*70}")

        for inst in state:
            iid = inst["instance_id"]
            gpu = inst.get("gpu_name", "?")
            ssh_h = inst.get("ssh_host", "")
            ssh_p = inst.get("ssh_port", 0)

            # Check Vast status
            try:
                r = requests.get(f"{BASE_URL}/instances/{iid}/", headers=self.headers, timeout=15)
                if r.ok:
                    d = r.json().get("instances", r.json())
                    if isinstance(d, list):
                        d = d[0] if d else {}
                    vast_status = d.get("actual_status") or d.get("cur_state") or "unknown"
                else:
                    vast_status = f"api_error_{r.status_code}"
            except Exception:
                vast_status = "api_unreachable"

            # Check worker process
            worker_alive = False
            recent_log = ""
            if ssh_h and ssh_p and vast_status == "running":
                try:
                    r = subprocess.run(
                    ["ssh", "-i", str(self.identity_file), "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
                         "-p", str(ssh_p), f"root@{ssh_h}",
                         "tail -3 /tmp/worker.log 2>/dev/null && "
                         "ps aux | grep 'codecbench.pipeline.cli' | grep -v grep | wc -l"],
                        capture_output=True, text=True, timeout=15,
                    )
                    if r.returncode == 0:
                        lines = r.stdout.strip().split("\n")
                        proc_count = lines[-1].strip() if lines else "0"
                        worker_alive = proc_count != "0"
                        recent_log = "\n".join(lines[:-1])[-200:]
                except Exception:
                    pass

            status_emoji = "OK" if worker_alive else ("VAST:" + vast_status if vast_status != "running" else "WORKER_DOWN")
            print(f"\n  [{status_emoji}] Instance {iid} ({gpu}) BS={inst.get('batch_size', '?')}")
            print(f"       Vast: {vast_status}  SSH: {ssh_h}:{ssh_p}")
            if recent_log:
                for line in recent_log.split("\n")[-2:]:
                    print(f"       Log: {line[:120]}")

        print(f"\n{'='*70}\n")


def main():
    parser = argparse.ArgumentParser(description="Fleet deployer for Vast.ai GPU instances")
    parser.add_argument("--deploy", action="store_true", help="Deploy new fleet")
    parser.add_argument("--offer-ids", type=str, default=None,
                        help="Comma-separated Vast.ai offer IDs to deploy (manual pick)")
    parser.add_argument("--num-gpus", type=int, default=5)
    parser.add_argument("--max-price", type=float, default=0.55)
    parser.add_argument("--dry-run", action="store_true")
    parser.add_argument("--status", action="store_true", help="Check fleet health")
    parser.add_argument("--destroy-all", action="store_true", help="Destroy all fleet instances")
    args = parser.parse_args()

    api_key = os.environ.get("VAST_KEY", "")
    if not api_key:
        print("ERROR: VAST_KEY not set")
        sys.exit(1)

    deployer = FleetDeployer(api_key, max_price=args.max_price)

    if args.destroy_all:
        deployer.destroy_all()
    elif args.status:
        deployer.check_status()
    elif args.deploy or args.dry_run:
        if args.offer_ids:
            deployer.deploy_by_offer_ids(
                [int(x.strip()) for x in args.offer_ids.split(",")],
                dry_run=args.dry_run,
            )
        else:
            deployer.deploy_fleet(args.num_gpus, dry_run=args.dry_run)
    else:
        parser.print_help()


if __name__ == "__main__":
    main()
