#!/usr/bin/env python3
"""
Hot-swap worker.py on all running Vast.ai instances.
1. SCP new worker.py to each instance
2. Kill all existing worker processes
3. Restart with new code

This avoids Docker rebuild + re-pull.
"""

import json
import subprocess
import sys
import threading
import time

ENV_VARS = {
    "DATABASE_URL": "postgresql://postgres.exlkkfpymkpqlxulurel:Chibhakaku%402001@aws-0-us-west-2.pooler.supabase.com:6543/postgres",
    "R2_ENDPOINT_URL": "https://cb908ed13329eb7b186e06ab51bda190.r2.cloudflarestorage.com",
    "R2_ACCESS_KEY_ID": "c3c9190ae7ff98b10271ea8db6940210",
    "R2_SECRET_ACCESS_KEY": "eab9394d02b48a865634105b92c74751ec9a311c56884f7aead5d76476c6b576",
    "R2_BUCKET_SFT_DATA": "finalsftdata",
}

WORKER_FILE = "/home/ubuntu/neucodec/worker.py"

# How many workers per GPU — GPU fbank saturates GPU with 1 worker
WORKERS_PER_GPU = {
    "RTX 4090": 1,
    "RTX 3090": 1,
    "L40S": 1,
    "L40": 1,
    "RTX A6000": 1,
}


def get_instances():
    result = subprocess.run(
        ["vastai", "show", "instances", "--raw"],
        capture_output=True, text=True, timeout=30
    )
    data = json.loads(result.stdout)
    return [i for i in data if isinstance(i, dict)
            and "neucodec" in (i.get("label", "") or "")
            and i.get("actual_status") == "running"]


def get_worker_id(label):
    if "nc-" in label:
        idx = label.index("nc-")
        return label[idx:idx+6]
    return None


def hotswap_instance(inst, results, semaphore):
    """SCP worker.py, kill old workers, start new ones."""
    with semaphore:
        iid = inst["id"]
        label = inst.get("label", "")
        gpu = inst.get("gpu_name", "unknown")
        host = inst.get("ssh_host", "")
        port = inst.get("ssh_port", 22)
        base_id = get_worker_id(label)

        if not host or not base_id:
            results.append((base_id or label, "skip", "no SSH/ID"))
            return

        try:
            # 1. SCP new worker.py
            r = subprocess.run(
                ["scp", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
                 "-P", str(port), WORKER_FILE, f"root@{host}:/app/worker.py"],
                capture_output=True, text=True, timeout=30
            )
            if r.returncode != 0:
                results.append((base_id, "scp_fail", r.stderr[:100]))
                return

            # 2. Kill all worker processes (avoid pkill -f which can kill SSH)
            subprocess.run(
                ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
                 "-p", str(port), f"root@{host}",
                 "kill $(ps aux | grep 'python3 worker.py' | grep -v grep | awk '{print $2}') 2>/dev/null; sleep 2; kill -9 $(ps aux | grep 'python3 worker.py' | grep -v grep | awk '{print $2}') 2>/dev/null; echo done"],
                capture_output=True, text=True, timeout=20
            )

            # Small delay for GPU memory to free
            time.sleep(3)

            # 3. Start new workers
            env_export = "; ".join(f'export {k}="{v}"' for k, v in ENV_VARS.items())
            target = WORKERS_PER_GPU.get(gpu, 2)

            spawned = 0
            for i in range(target):
                suffix = "" if i == 0 else chr(ord('b') + i - 1)
                wid = f"{base_id}{suffix}"
                cmd = f'{env_export}; cd /app; nohup python3 worker.py --worker-id {wid} > /app/worker_{wid}.log 2>&1 &'
                r = subprocess.run(
                    ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
                     "-p", str(port), f"root@{host}", f"bash -c '{cmd}'"],
                    capture_output=True, text=True, timeout=15
                )
                if r.returncode == 0:
                    spawned += 1

            results.append((base_id, "ok", f"{spawned}/{target} workers"))
            print(f"  [{base_id}] {gpu}: swapped, {spawned}/{target} workers started")

        except Exception as e:
            results.append((base_id, "error", str(e)[:100]))
            print(f"  [{base_id}] ERROR: {e}")


def main():
    print("=== Fetching instances ===")
    instances = get_instances()
    print(f"Found {len(instances)} running instances\n")

    if "--dry-run" in sys.argv:
        for inst in instances:
            label = inst.get("label", "")
            gpu = inst.get("gpu_name", "")
            base_id = get_worker_id(label) or label
            target = WORKERS_PER_GPU.get(gpu, 2)
            print(f"  {base_id}: {gpu} → {target} workers")
        total_workers = sum(WORKERS_PER_GPU.get(i.get("gpu_name", ""), 2) for i in instances)
        print(f"\nTotal planned workers: {total_workers}")
        return

    results = []
    threads = []
    semaphore = threading.Semaphore(20)  # max 20 concurrent SSH

    for inst in instances:
        t = threading.Thread(target=hotswap_instance, args=(inst, results, semaphore))
        t.start()
        threads.append(t)

    for t in threads:
        t.join(timeout=120)

    # Summary
    ok = sum(1 for _, status, _ in results if status == "ok")
    fail = sum(1 for _, status, _ in results if status != "ok")
    print(f"\n=== DONE: {ok} swapped, {fail} failed ===")

    if fail > 0:
        print("Failures:")
        for name, status, msg in results:
            if status != "ok":
                print(f"  {name}: {status} - {msg}")


if __name__ == "__main__":
    main()
