#!/usr/bin/env python3
"""
Rolling restart of all workers in controlled batches.
Kills workers in batches of BATCH_SIZE, waits for DB connections to free,
then starts new workers with updated code.
"""

import json
import subprocess
import threading
import time

ENV_VARS = {
    "DATABASE_URL": "postgresql://postgres.exlkkfpymkpqlxulurel:Chibhakaku%402001@aws-0-us-west-2.pooler.supabase.com:6543/postgres",
    "R2_ENDPOINT_URL": "https://cb908ed13329eb7b186e06ab51bda190.r2.cloudflarestorage.com",
    "R2_ACCESS_KEY_ID": "c3c9190ae7ff98b10271ea8db6940210",
    "R2_SECRET_ACCESS_KEY": "eab9394d02b48a865634105b92c74751ec9a311c56884f7aead5d76476c6b576",
    "R2_BUCKET_SFT_DATA": "finalsftdata",
}

WORKER_FILE = "/home/ubuntu/neucodec/worker.py"
BATCH_SIZE = 30
WAIT_BETWEEN_BATCHES = 15  # seconds


def get_instances():
    result = subprocess.run(
        ["vastai", "show", "instances", "--raw"],
        capture_output=True, text=True, timeout=30
    )
    data = json.loads(result.stdout)
    return [i for i in data if isinstance(i, dict)
            and "neucodec" in (i.get("label", "") or "")
            and i.get("actual_status") == "running"]


def get_worker_id(label):
    if "nc-" in label:
        idx = label.index("nc-")
        return label[idx:idx+6]
    return None


def restart_instance(inst):
    """Kill worker, SCP new code, start new worker."""
    label = inst.get("label", "")
    host = inst.get("ssh_host", "")
    port = inst.get("ssh_port", 22)
    wid = get_worker_id(label)
    if not host or not wid:
        return (wid or label, "skip")

    try:
        # Kill existing workers
        subprocess.run(
            ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
             "-p", str(port), f"root@{host}",
             "kill $(ps aux | grep 'python3 worker.py' | grep -v grep | awk '{print $2}') 2>/dev/null; echo done"],
            capture_output=True, text=True, timeout=15
        )

        # SCP new code
        r = subprocess.run(
            ["scp", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
             "-P", str(port), WORKER_FILE, f"root@{host}:/app/worker.py"],
            capture_output=True, text=True, timeout=30
        )
        if r.returncode != 0:
            return (wid, "scp_fail")

        # Start new worker
        env_str = "; ".join(f'export {k}="{v}"' for k, v in ENV_VARS.items())
        cmd = f'{env_str}; cd /app; nohup python3 worker.py --worker-id {wid} > /app/worker.log 2>&1 &'
        subprocess.run(
            ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
             "-p", str(port), f"root@{host}", f"bash -c '{cmd}'"],
            capture_output=True, text=True, timeout=15
        )
        return (wid, "ok")
    except Exception as e:
        return (wid, f"error: {str(e)[:50]}")


def main():
    instances = get_instances()
    print(f"Rolling restart of {len(instances)} instances in batches of {BATCH_SIZE}")

    total_ok = 0
    total_fail = 0

    for batch_idx in range(0, len(instances), BATCH_SIZE):
        batch = instances[batch_idx:batch_idx + BATCH_SIZE]
        batch_num = batch_idx // BATCH_SIZE + 1
        total_batches = (len(instances) + BATCH_SIZE - 1) // BATCH_SIZE
        print(f"\n--- Batch {batch_num}/{total_batches} ({len(batch)} instances) ---")

        results = []
        threads = []
        sem = threading.Semaphore(15)

        def worker(inst):
            with sem:
                results.append(restart_instance(inst))

        for inst in batch:
            t = threading.Thread(target=worker, args=(inst,))
            t.start()
            threads.append(t)

        for t in threads:
            t.join(timeout=60)

        ok = sum(1 for _, s in results if s == "ok")
        fail = sum(1 for _, s in results if s != "ok")
        total_ok += ok
        total_fail += fail
        print(f"  Batch done: {ok} ok, {fail} failed")

        if batch_idx + BATCH_SIZE < len(instances):
            print(f"  Waiting {WAIT_BETWEEN_BATCHES}s for connections to settle...")
            time.sleep(WAIT_BETWEEN_BATCHES)

    print(f"\n=== ROLLING RESTART COMPLETE: {total_ok} ok, {total_fail} failed ===")


if __name__ == "__main__":
    main()
