#!/usr/bin/env python3
"""Continuous fleet monitor — runs indefinitely, recycles dead workers, reclaims shards."""
import json, subprocess, time, psycopg2, os, sys

DATABASE_URL = os.environ["DATABASE_URL"]
CHECK_INTERVAL = 300  # 5 min between checks

def db():
    conn = psycopg2.connect(DATABASE_URL)
    conn.autocommit = True
    return conn

def vast(args):
    r = subprocess.run(["vastai"] + args + ["--raw"], capture_output=True, text=True, timeout=30)
    try: return json.loads(r.stdout)
    except: return None

def start_worker(iid, wid):
    try:
        ssh_url = subprocess.run(["vastai", "ssh-url", str(iid)], capture_output=True, text=True, timeout=10).stdout.strip()
        if not ssh_url: return False
        host = ssh_url.replace("ssh://root@", "").split(":")[0]
        port = ssh_url.replace("ssh://root@", "").split(":")[1]
        subprocess.run(["scp", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=8", "-P", port,
                         "/home/ubuntu/neucodec/start_worker.sh", f"root@{host}:/app/"], capture_output=True, timeout=20)
        r = subprocess.run(["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=8", "-p", port,
                             f"root@{host}", "chmod +x /app/start_worker.sh && /app/start_worker.sh " + wid],
                            capture_output=True, text=True, timeout=30)
        return "Started" in r.stdout
    except: return False

cycle = 0
while True:
    cycle += 1
    conn = db()
    cur = conn.cursor()

    # Get active workers
    cur.execute("SELECT worker_id FROM neucodec_workers WHERE state='working' AND EXTRACT(EPOCH FROM (NOW()-last_heartbeat)) < 180")
    active = set(r[0] for r in cur.fetchall())

    # Get shard counts
    cur.execute("SELECT status, COUNT(*) FROM neucodec_shards GROUP BY status")
    counts = dict(cur.fetchall())
    completed = counts.get('completed', 0)
    pending = counts.get('pending', 0)
    processing = counts.get('processing', 0)
    failed = counts.get('failed', 0)

    # Reclaim failed shards
    if failed > 0:
        cur.execute("UPDATE neucodec_shards SET status='pending', worker_id=NULL, claimed_at=NULL, started_at=NULL, failed_at=NULL, error_message=NULL WHERE status='failed'")
        reclaimed = cur.rowcount
    else:
        reclaimed = 0

    # Reclaim stuck shards
    cur.execute("UPDATE neucodec_shards SET status='pending', worker_id=NULL, claimed_at=NULL WHERE status IN ('claimed','processing') AND claimed_at < NOW() - INTERVAL '600 seconds'")
    stuck = cur.rowcount

    # Avg RTF
    cur.execute("SELECT AVG(progress_rtf) FROM neucodec_workers WHERE state='working' AND progress_rtf > 0")
    avg_rtf = cur.fetchone()[0] or 0

    cur.close(); conn.close()

    # Get Vast instances
    instances = vast(["show", "instances"]) or []
    nc = [i for i in instances if "neucodec" in (i.get("label","") or "")]
    running_instances = [i for i in nc if i.get("actual_status") == "running"]

    # Find dead instances (running but no active worker) and restart or destroy
    restarted = destroyed = 0
    for inst in running_instances:
        wid = inst.get("label","").replace("neucodec-","")
        if wid not in active:
            if start_worker(inst["id"], wid):
                restarted += 1
            else:
                subprocess.run(["vastai", "destroy", "instance", str(inst["id"])], capture_output=True, timeout=10)
                destroyed += 1

    # Check if done
    if pending == 0 and processing == 0:
        print(f"\n{'='*60}")
        print(f"  ALL SHARDS COMPLETED! {completed} done, {failed} failed")
        print(f"{'='*60}")
        # Destroy all instances
        for inst in nc:
            subprocess.run(["vastai", "destroy", "instance", str(inst["id"])], capture_output=True, timeout=10)
        print("All instances destroyed.")
        break

    # ETA
    eta_h = (pending * 25 * 3600 / avg_rtf / len(active)) / 3600 if active and avg_rtf > 0 else 999
    cost_hr = len(active) * 0.30

    ts = time.strftime("%H:%M:%S")
    print(f"[{ts}] #{cycle} | Workers: {len(active)} | RTF: {avg_rtf:.0f}x | Combined: {avg_rtf*len(active):.0f}x | Done: {completed} ({completed/5563*100:.1f}%) | Pending: {pending} | ETA: {eta_h:.1f}h ({eta_h/24:.1f}d) | ${cost_hr:.0f}/hr | reclaimed={reclaimed}+{stuck} restarted={restarted} destroyed={destroyed}")
    sys.stdout.flush()

    time.sleep(CHECK_INTERVAL)
