#!/usr/bin/env python3
"""
Check all instances for dead workers and restart them.
Only starts new workers where no worker.py process is running.
Also ensures the latest worker.py is deployed.
"""

import json
import subprocess
import threading
import time

ENV_VARS = {
    "DATABASE_URL": "postgresql://postgres.exlkkfpymkpqlxulurel:Chibhakaku%402001@aws-0-us-west-2.pooler.supabase.com:6543/postgres",
    "R2_ENDPOINT_URL": "https://cb908ed13329eb7b186e06ab51bda190.r2.cloudflarestorage.com",
    "R2_ACCESS_KEY_ID": "c3c9190ae7ff98b10271ea8db6940210",
    "R2_SECRET_ACCESS_KEY": "eab9394d02b48a865634105b92c74751ec9a311c56884f7aead5d76476c6b576",
    "R2_BUCKET_SFT_DATA": "finalsftdata",
}

WORKER_FILE = "/home/ubuntu/neucodec/worker.py"


def get_instances():
    result = subprocess.run(
        ["vastai", "show", "instances", "--raw"],
        capture_output=True, text=True, timeout=30
    )
    data = json.loads(result.stdout)
    return [i for i in data if isinstance(i, dict)
            and "neucodec" in (i.get("label", "") or "")
            and i.get("actual_status") == "running"]


def get_worker_id(label):
    if "nc-" in label:
        idx = label.index("nc-")
        return label[idx:idx+6]
    return None


def check_and_restart(inst, results, semaphore):
    with semaphore:
        label = inst.get("label", "")
        host = inst.get("ssh_host", "")
        port = inst.get("ssh_port", 22)
        base_id = get_worker_id(label)
        if not host or not base_id:
            return

        try:
            # Check if worker is running
            r = subprocess.run(
                ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
                 "-p", str(port), f"root@{host}",
                 "ps aux | grep 'python3 worker.py' | grep -v grep | wc -l"],
                capture_output=True, text=True, timeout=15
            )
            count = int(r.stdout.strip()) if r.returncode == 0 else 0

            if count > 0:
                results.append((base_id, "running", count))
                return

            # Worker not running — SCP latest code and start
            subprocess.run(
                ["scp", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
                 "-P", str(port), WORKER_FILE, f"root@{host}:/app/worker.py"],
                capture_output=True, text=True, timeout=30
            )

            env_export = "; ".join(f'export {k}="{v}"' for k, v in ENV_VARS.items())
            cmd = f'{env_export}; cd /app; nohup python3 worker.py --worker-id {base_id} > /app/worker.log 2>&1 &'
            r = subprocess.run(
                ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
                 "-p", str(port), f"root@{host}", f"bash -c '{cmd}'"],
                capture_output=True, text=True, timeout=15
            )
            if r.returncode == 0:
                results.append((base_id, "restarted", 1))
                print(f"  [{base_id}] restarted")
            else:
                results.append((base_id, "failed", 0))
        except Exception as e:
            results.append((base_id, "error", str(e)[:50]))


def main():
    instances = get_instances()
    print(f"Checking {len(instances)} instances...")

    results = []
    threads = []
    sem = threading.Semaphore(20)

    for inst in instances:
        t = threading.Thread(target=check_and_restart, args=(inst, results, sem))
        t.start()
        threads.append(t)

    for t in threads:
        t.join(timeout=60)

    running = sum(1 for _, s, _ in results if s == "running")
    restarted = sum(1 for _, s, _ in results if s == "restarted")
    failed = sum(1 for _, s, _ in results if s in ("failed", "error"))
    print(f"\nDone: {running} already running, {restarted} restarted, {failed} failed")


if __name__ == "__main__":
    main()
