#!/usr/bin/env python3
"""
Spawn extra worker processes on existing Vast.ai instances to saturate GPU/CPU.

Each additional worker loads its own model copy (~3GB VRAM) and claims shards independently.
With ~2.8GB VRAM per worker and 24GB GPUs, we can safely run 4 workers per GPU.
"""

import json
import os
import subprocess
import sys
import threading
import time

ENV_VARS = {
    "DATABASE_URL": "postgresql://postgres.exlkkfpymkpqlxulurel:Chibhakaku%402001@aws-0-us-west-2.pooler.supabase.com:6543/postgres",
    "R2_ENDPOINT_URL": "https://cb908ed13329eb7b186e06ab51bda190.r2.cloudflarestorage.com",
    "R2_ACCESS_KEY_ID": "c3c9190ae7ff98b10271ea8db6940210",
    "R2_SECRET_ACCESS_KEY": "eab9394d02b48a865634105b92c74751ec9a311c56884f7aead5d76476c6b576",
    "R2_BUCKET_SFT_DATA": "finalsftdata",
}

# How many TOTAL workers per GPU type
WORKERS_PER_GPU = {
    "RTX 4090": 4,   # 24GB, fast GPU, CPU usually bottleneck
    "RTX 3090": 3,   # 24GB, slower GPU
    "L40S": 4,       # 48GB
    "L40": 4,        # 48GB
    "RTX A6000": 5,  # 48GB
}

def get_instances():
    result = subprocess.run(
        ["vastai", "show", "instances", "--raw"],
        capture_output=True, text=True, timeout=30
    )
    data = json.loads(result.stdout)
    return [i for i in data if isinstance(i, dict)
            and "neucodec" in (i.get("label", "") or "")
            and i.get("actual_status") == "running"]


def get_worker_id_from_label(label):
    """Extract base worker ID like 'nc-048' from label like 'neucodec-nc-048'."""
    if "nc-" in label:
        idx = label.index("nc-")
        return label[idx:idx+6]  # nc-048
    return None


def ssh_cmd(host, port, cmd, timeout=30):
    """Run a command over SSH."""
    result = subprocess.run(
        ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
         "-p", str(port), f"root@{host}", cmd],
        capture_output=True, text=True, timeout=timeout
    )
    return result


def count_running_workers(host, port):
    """Count how many worker.py processes are running on an instance."""
    r = ssh_cmd(host, port, "pgrep -f 'worker.py' | wc -l")
    try:
        return int(r.stdout.strip())
    except (ValueError, AttributeError):
        return 0


def spawn_worker(host, port, worker_id):
    """Start an additional worker process on the instance."""
    env_export = "; ".join(f'export {k}="{v}"' for k, v in ENV_VARS.items())
    cmd = f'{env_export}; cd /app; nohup python3 worker.py --worker-id {worker_id} > /app/worker_{worker_id}.log 2>&1 &'
    r = ssh_cmd(host, port, f"bash -c '{cmd}'", timeout=15)
    return r.returncode == 0


def process_instance(inst, results):
    """Check and spawn extra workers on one instance."""
    iid = inst["id"]
    label = inst.get("label", "")
    gpu = inst.get("gpu_name", "unknown")
    host = inst.get("ssh_host", "")
    port = inst.get("ssh_port", 22)

    if not host:
        results.append((label, gpu, "no SSH", 0))
        return

    base_id = get_worker_id_from_label(label)
    if not base_id:
        results.append((label, gpu, "no worker ID", 0))
        return

    target = WORKERS_PER_GPU.get(gpu, 3)

    try:
        current = count_running_workers(host, port)
        need = target - current

        if need <= 0:
            results.append((base_id, gpu, f"already {current}/{target}", 0))
            return

        spawned = 0
        for i in range(need):
            suffix = chr(ord('b') + current - 1 + i)  # b, c, d, ...
            new_id = f"{base_id}{suffix}"
            if spawn_worker(host, port, new_id):
                spawned += 1
                print(f"  [{new_id}] spawned on {host}:{port} ({gpu})")
            else:
                print(f"  [{new_id}] FAILED on {host}:{port}")

        results.append((base_id, gpu, f"{current}→{current+spawned}/{target}", spawned))

    except Exception as e:
        results.append((base_id, gpu, f"error: {e}", 0))


def main():
    print("=== Fetching instances ===")
    instances = get_instances()
    print(f"Found {len(instances)} running instances")

    # Process in parallel with threads (max 20 concurrent SSH)
    results = []
    threads = []
    semaphore = threading.Semaphore(20)

    def worker(inst):
        with semaphore:
            process_instance(inst, results)

    for inst in instances:
        t = threading.Thread(target=worker, args=(inst,))
        t.start()
        threads.append(t)

    for t in threads:
        t.join(timeout=60)

    # Summary
    total_spawned = sum(r[3] for r in results)
    print(f"\n=== DONE: Spawned {total_spawned} extra workers across {len(instances)} instances ===")

    # Count by GPU
    by_gpu = {}
    for _, gpu, _, spawned in results:
        by_gpu[gpu] = by_gpu.get(gpu, 0) + spawned
    for gpu, count in sorted(by_gpu.items(), key=lambda x: -x[1]):
        print(f"  {gpu}: +{count} workers")


if __name__ == "__main__":
    main()
