#!/usr/bin/env python3
"""
Launch and manage neucodec worker fleet on Vast.ai.

Usage:
    python fleet.py launch N         # Launch N new workers
    python fleet.py status           # Show all instances + worker status
    python fleet.py health           # Check for stuck/dead workers, recycle them
    python fleet.py destroy-all      # Destroy all instances
"""

import json
import os
import subprocess
import sys
import time

VAST_KEY = os.environ.get("VAST_KEY", "")
DOCKER_IMAGE = "bharathkumar192/neucodec-worker:latest"

ENV_VARS = {
    "DATABASE_URL": os.environ.get("DATABASE_URL", ""),
    "R2_ENDPOINT_URL": os.environ.get("R2_ENDPOINT_URL", ""),
    "R2_ACCESS_KEY_ID": os.environ.get("R2_ACCESS_KEY_ID", ""),
    "R2_SECRET_ACCESS_KEY": os.environ.get("R2_SECRET_ACCESS_KEY", ""),
    "R2_BUCKET_SFT_DATA": os.environ.get("R2_BUCKET_SFT_DATA", "finalsftdata"),
}

ENV_STRING = " ".join(f"-e {k}={v}" for k, v in ENV_VARS.items() if v)


def vast_cmd(args, raw=False):
    cmd = ["vastai"] + args
    if raw:
        cmd.append("--raw")
    result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
    if raw and result.stdout.strip():
        try:
            return json.loads(result.stdout)
        except json.JSONDecodeError:
            return result.stdout
    return result.stdout


def find_offers(n=100):
    """Find cheapest reliable GPU offers (4090, 3090, L40S, A6000)."""
    all_offers = []
    for gpu in ["RTX_4090", "RTX_3090", "L40S", "RTX_A6000", "L40"]:
        result = vast_cmd([
            "search", "offers",
            f"gpu_name={gpu} num_gpus=1 reliability>0.9 inet_down>100 disk_space>=50 cuda_vers>=12.0 rentable=true",
            "-o", "dph_total",
        ], raw=True)
        if isinstance(result, list):
            all_offers.extend(result)
    # Sort by cost and return cheapest
    all_offers.sort(key=lambda x: x.get("dph_total", 999))
    return all_offers[:n]


def launch_one(offer_id, worker_num):
    """Launch a single worker instance."""
    worker_id = f"nc-{worker_num:03d}"
    result = vast_cmd([
        "create", "instance", str(offer_id),
        "--image", DOCKER_IMAGE,
        "--disk", "80",
        "--ssh", "--direct",
        "--env", ENV_STRING,
        "--label", f"neucodec-{worker_id}",
    ], raw=True)

    if isinstance(result, dict) and result.get("success"):
        instance_id = result["new_contract"]
        print(f"  [{worker_id}] Launched instance {instance_id} from offer {offer_id}")
        return instance_id, worker_id
    else:
        print(f"  [{worker_id}] FAILED to launch from offer {offer_id}: {result}")
        return None, worker_id


def wait_and_start(instance_id, worker_id, timeout=300):
    """Wait for instance to become running, then start the worker."""
    t0 = time.time()
    while time.time() - t0 < timeout:
        try:
            info = vast_cmd(["show", "instance", str(instance_id)], raw=True)
            status = info.get("actual_status", "") if isinstance(info, dict) else ""
            if status == "running":
                break
        except Exception:
            pass
        time.sleep(10)
    else:
        print(f"  [{worker_id}] Timeout waiting for instance {instance_id}")
        return False

    # Get SSH info
    ssh_url = vast_cmd(["ssh-url", str(instance_id)]).strip()
    if not ssh_url:
        print(f"  [{worker_id}] No SSH URL for {instance_id}")
        return False

    host = ssh_url.replace("ssh://root@", "").split(":")[0]
    port = ssh_url.replace("ssh://root@", "").split(":")[1] if ":" in ssh_url.replace("ssh://root@", "") else "22"

    # Wait for SSH
    for _ in range(10):
        r = subprocess.run(
            ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
             "-p", port, f"root@{host}", "echo ok"],
            capture_output=True, text=True, timeout=15
        )
        if "ok" in r.stdout:
            break
        time.sleep(5)

    # SCP latest worker.py (has resample fix + GPU fbank)
    subprocess.run(
        ["scp", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
         "-P", port, "/home/ubuntu/neucodec/worker.py", f"root@{host}:/app/worker.py"],
        capture_output=True, text=True, timeout=30
    )

    # Start worker via SSH
    env_export = "; ".join(f'export {k}="{v}"' for k, v in ENV_VARS.items() if v)
    start_cmd = f"{env_export}; cd /app; nohup python3 worker.py --worker-id {worker_id} > /app/worker.log 2>&1 &"

    r = subprocess.run(
        ["ssh", "-o", "StrictHostKeyChecking=no", "-p", port, f"root@{host}", f"bash -c '{start_cmd}'"],
        capture_output=True, text=True, timeout=30
    )
    print(f"  [{worker_id}] Started on {host}:{port} (instance {instance_id})")
    return True


def get_instances():
    """Get all our neucodec instances."""
    result = vast_cmd(["show", "instances"], raw=True)
    if not isinstance(result, list):
        return []
    return [i for i in result if isinstance(i, dict) and "neucodec" in (i.get("label", "") or "")]


def cmd_launch(n):
    n = int(n)
    print(f"=== Launching {n} workers ===")

    offers = find_offers(n * 3)  # Get 3x offers for redundancy
    if not offers:
        print("ERROR: No offers found")
        return

    print(f"Found {len(offers)} offers, cheapest: ${offers[0].get('dph_total', '?')}/hr")

    # Get existing instances to determine worker numbering
    existing = get_instances()
    existing_nums = set()
    for inst in existing:
        label = inst.get("label", "")
        if "nc-" in label:
            try:
                num = int(label.split("nc-")[1][:3])
                existing_nums.add(num)
            except (ValueError, IndexError):
                pass

    next_num = max(existing_nums) + 1 if existing_nums else 1

    launched = []
    offer_idx = 0
    for i in range(n):
        if offer_idx >= len(offers):
            print(f"  Ran out of offers after launching {i} workers")
            break

        offer_id = offers[offer_idx]["id"]
        instance_id, worker_id = launch_one(offer_id, next_num + i)
        if instance_id:
            launched.append((instance_id, worker_id))
        offer_idx += 1

    if not launched:
        print("No instances launched")
        return

    print(f"\n=== Waiting for {len(launched)} instances to start ===")
    for instance_id, worker_id in launched:
        wait_and_start(instance_id, worker_id)

    print(f"\n=== Done. {len(launched)} workers launched ===")
    print("Run 'python fleet.py status' to check progress")


def cmd_status():
    instances = get_instances()
    print(f"=== {len(instances)} Vast.ai instances ===")
    for inst in instances:
        iid = inst.get("id", "?")
        label = inst.get("label", "?")
        status = inst.get("actual_status", "?")
        gpu = inst.get("gpu_name", "?")
        cost = inst.get("dph_total", 0)
        print(f"  {iid:>10} | {label:>20} | {status:>10} | {gpu} | ${cost:.3f}/hr")

    # Also show Supabase status
    print()
    os.system("python3 /home/ubuntu/neucodec/status.py")


def cmd_health():
    """Check for stuck workers and recycle them."""
    import psycopg2
    conn = psycopg2.connect(ENV_VARS["DATABASE_URL"])
    conn.autocommit = True
    cur = conn.cursor()

    # Find stale workers (no heartbeat for 5+ min)
    cur.execute("""
        SELECT worker_id, state, progress_phase, last_heartbeat,
               EXTRACT(EPOCH FROM (NOW() - last_heartbeat)) as age_s
        FROM neucodec_workers
        WHERE EXTRACT(EPOCH FROM (NOW() - last_heartbeat)) > 300
          AND state != 'exited'
    """)
    stale = cur.fetchall()

    if stale:
        print(f"=== {len(stale)} STALE WORKERS (>5 min no heartbeat) ===")
        for wid, state, phase, hb, age in stale:
            print(f"  {wid}: state={state}/{phase}, last heartbeat {int(age)}s ago")

        # Reclaim their shards
        cur.execute("""
            UPDATE neucodec_shards
            SET status = 'pending', worker_id = NULL, claimed_at = NULL
            WHERE status IN ('claimed', 'processing')
              AND claimed_at < NOW() - INTERVAL '600 seconds'
            RETURNING id, shard_prefix
        """)
        reclaimed = cur.fetchall()
        if reclaimed:
            print(f"  Reclaimed {len(reclaimed)} stuck shards")
    else:
        print("All workers healthy")

    cur.close()
    conn.close()


def cmd_destroy_all():
    instances = get_instances()
    if not instances:
        print("No instances to destroy")
        return
    print(f"Destroying {len(instances)} instances...")
    for inst in instances:
        iid = inst["id"]
        vast_cmd(["destroy", "instance", str(iid)])
        print(f"  Destroyed {iid} ({inst.get('label', '')})")
    print("Done")


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("Usage: python fleet.py [launch N | status | health | destroy-all]")
        sys.exit(1)

    cmd = sys.argv[1]
    if cmd == "launch":
        cmd_launch(sys.argv[2] if len(sys.argv) > 2 else 1)
    elif cmd == "status":
        cmd_status()
    elif cmd == "health":
        cmd_health()
    elif cmd == "destroy-all":
        cmd_destroy_all()
    else:
        print(f"Unknown command: {cmd}")
