#!/usr/bin/env python3
"""
Gradually ramp up the neucodec encoding fleet.

Steps: 1 → 5 → 25 → 50 → 100 → 200 → 250
After each step, waits for workers to come online and validates encoding.
If not enough offers, waits 30 minutes and retries.
"""

import json
import os
import subprocess
import sys
import time
import threading

# Load env
from dotenv import load_dotenv
load_dotenv("/home/ubuntu/neucodec/.env")

VAST_KEY = os.environ.get("VAST_KEY", "")
DOCKER_IMAGE = "bharathkumar192/neucodec-worker:latest"
WORKER_FILE = "/home/ubuntu/neucodec/worker.py"
RETRY_WAIT = 1800  # 30 min

ENV_VARS = {
    "DATABASE_URL": os.environ.get("DATABASE_URL", ""),
    "R2_ENDPOINT_URL": os.environ.get("R2_ENDPOINT_URL", ""),
    "R2_ACCESS_KEY_ID": os.environ.get("R2_ACCESS_KEY_ID", ""),
    "R2_SECRET_ACCESS_KEY": os.environ.get("R2_SECRET_ACCESS_KEY", ""),
    "R2_BUCKET_SFT_DATA": os.environ.get("R2_BUCKET_SFT_DATA", "finalsftdata"),
}

RAMP_STEPS = [1, 5, 25, 50, 100, 200, 250]


def log(msg):
    ts = time.strftime("%Y-%m-%d %H:%M:%S")
    print(f"[{ts}] {msg}", flush=True)


def vast_cmd(args, raw=False):
    cmd = ["vastai"] + args
    if raw:
        cmd.append("--raw")
    result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
    if raw and result.stdout.strip():
        try:
            return json.loads(result.stdout)
        except json.JSONDecodeError:
            return []
    return result.stdout


def find_offers(n=300):
    all_offers = []
    for gpu in ["RTX_4090", "RTX_3090", "L40S", "RTX_A6000", "L40"]:
        try:
            result = vast_cmd([
                "search", "offers",
                f"gpu_name={gpu} num_gpus=1 reliability>0.9 inet_down>100 disk_space>=50 cuda_vers>=12.0 rentable=true",
                "-o", "dph_total",
            ], raw=True)
            if isinstance(result, list):
                all_offers.extend(result)
        except Exception as e:
            log(f"  Warning: search for {gpu} failed: {e}")
    all_offers.sort(key=lambda x: x.get("dph_total", 999))
    return all_offers[:n]


def get_instances():
    result = vast_cmd(["show", "instances"], raw=True)
    if not isinstance(result, list):
        return []
    return [i for i in result if isinstance(i, dict) and "neucodec" in (i.get("label", "") or "")]


def get_next_worker_num():
    instances = get_instances()
    nums = set()
    for inst in instances:
        label = inst.get("label", "")
        if "nc-" in label:
            try:
                num = int(label.split("nc-")[1][:3])
                nums.add(num)
            except (ValueError, IndexError):
                pass
    return max(nums) + 1 if nums else 1


def launch_instance(offer_id, worker_num):
    worker_id = f"nc-{worker_num:03d}"
    env_string = " ".join(f"-e {k}={v}" for k, v in ENV_VARS.items() if v)
    result = vast_cmd([
        "create", "instance", str(offer_id),
        "--image", DOCKER_IMAGE,
        "--disk", "80",
        "--ssh", "--direct",
        "--env", env_string,
        "--label", f"neucodec-{worker_id}",
    ], raw=True)

    if isinstance(result, dict) and result.get("success"):
        return result["new_contract"], worker_id
    return None, worker_id


def start_worker(instance_id, worker_id, timeout=300):
    """Wait for instance, SCP worker.py, start worker."""
    t0 = time.time()
    while time.time() - t0 < timeout:
        try:
            info = vast_cmd(["show", "instance", str(instance_id)], raw=True)
            if isinstance(info, dict) and info.get("actual_status") == "running":
                break
        except Exception:
            pass
        time.sleep(10)
    else:
        return False

    ssh_url = vast_cmd(["ssh-url", str(instance_id)]).strip()
    if not ssh_url:
        return False

    host = ssh_url.replace("ssh://root@", "").split(":")[0]
    port = ssh_url.replace("ssh://root@", "").split(":")[1] if ":" in ssh_url.replace("ssh://root@", "") else "22"

    # Wait for SSH
    for _ in range(12):
        r = subprocess.run(
            ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
             "-p", port, f"root@{host}", "echo ok"],
            capture_output=True, text=True, timeout=15
        )
        if "ok" in r.stdout:
            break
        time.sleep(5)
    else:
        return False

    # SCP latest worker.py
    subprocess.run(
        ["scp", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
         "-P", port, WORKER_FILE, f"root@{host}:/app/worker.py"],
        capture_output=True, text=True, timeout=30
    )

    # Start worker
    env_export = "; ".join(f'export {k}="{v}"' for k, v in ENV_VARS.items() if v)
    cmd = f'{env_export}; cd /app; nohup python3 worker.py --worker-id {worker_id} > /app/worker.log 2>&1 &'
    subprocess.run(
        ["ssh", "-o", "StrictHostKeyChecking=no", "-p", port, f"root@{host}", f"bash -c '{cmd}'"],
        capture_output=True, text=True, timeout=30
    )
    return True


def get_fleet_status():
    """Get status from Supabase."""
    import psycopg2
    try:
        conn = psycopg2.connect(ENV_VARS["DATABASE_URL"])
        cur = conn.cursor()
        cur.execute("SELECT count(*) FROM neucodec_shards WHERE status='completed' AND dataset='final-export'")
        completed = cur.fetchone()[0]
        cur.execute("SELECT count(*) FROM neucodec_shards WHERE status='pending' AND dataset='final-export'")
        pending = cur.fetchone()[0]
        cur.execute("SELECT count(*) FROM neucodec_shards WHERE status='processing' AND dataset='final-export'")
        processing = cur.fetchone()[0]
        cur.execute("""SELECT count(*), coalesce(sum(progress_rtf), 0)
            FROM neucodec_workers WHERE state='working' AND progress_rtf > 0
            AND EXTRACT(EPOCH FROM (NOW()-last_heartbeat)) < 120""")
        active, rtf = cur.fetchone()
        conn.close()
        return {
            "completed": completed, "pending": pending, "processing": processing,
            "active_workers": active or 0, "fleet_rtf": float(rtf or 0)
        }
    except Exception as e:
        log(f"  DB status check failed: {e}")
        return None


def validate_first_worker():
    """Wait for the first worker to report RTF and verify it's ~50 tok/s."""
    log("Waiting for first worker to start encoding and report RTF...")
    for _ in range(30):  # 5 minutes max
        time.sleep(10)
        status = get_fleet_status()
        if status and status["active_workers"] > 0 and status["fleet_rtf"] > 0:
            log(f"  First worker active! RTF={status['fleet_rtf']:.1f}x")
            # Check if any completed shards have correct tok/s
            import psycopg2
            try:
                conn = psycopg2.connect(ENV_VARS["DATABASE_URL"])
                cur = conn.cursor()
                cur.execute("""SELECT total_tokens::float / total_audio_seconds
                    FROM neucodec_shards WHERE dataset='final-export' AND status='completed'
                    AND total_audio_seconds > 0 LIMIT 1""")
                row = cur.fetchone()
                conn.close()
                if row:
                    tok_s = row[0]
                    if 45 < tok_s < 55:
                        log(f"  VALIDATED: {tok_s:.1f} tok/s (correct!)")
                        return True
                    else:
                        log(f"  WARNING: {tok_s:.1f} tok/s (expected ~50)")
                        return False
                else:
                    log("  Worker active but no completed shards yet, continuing...")
            except Exception:
                pass
            return True  # Worker is running, validation will happen soon
    log("  Timeout waiting for first worker")
    return False


def launch_batch(target_total, current_count):
    """Launch instances to reach target_total, return how many actually launched."""
    need = target_total - current_count
    if need <= 0:
        log(f"  Already have {current_count} >= target {target_total}")
        return 0

    offers = find_offers(need * 2)
    if not offers:
        log("  No offers available!")
        return 0

    log(f"  Found {len(offers)} offers, launching {need} instances...")
    next_num = get_next_worker_num()

    launched = []
    for i in range(min(need, len(offers))):
        instance_id, worker_id = launch_instance(offers[i]["id"], next_num + i)
        if instance_id:
            launched.append((instance_id, worker_id))

    if not launched:
        log("  No instances launched!")
        return 0

    log(f"  Launched {len(launched)} instances, waiting for SSH + starting workers...")

    # Start workers in parallel (max 15 concurrent)
    sem = threading.Semaphore(15)
    results = []

    def start_one(iid, wid):
        with sem:
            ok = start_worker(iid, wid)
            results.append((wid, ok))
            if ok:
                log(f"    [{wid}] started")

    threads = []
    for iid, wid in launched:
        t = threading.Thread(target=start_one, args=(iid, wid))
        t.start()
        threads.append(t)

    for t in threads:
        t.join(timeout=360)

    started = sum(1 for _, ok in results if ok)
    log(f"  {started}/{len(launched)} workers started successfully")
    return started


def main():
    log("=== NEUCODEC RE-ENCODING FLEET RAMP-UP ===")
    log(f"Target: re-encode 4,350 final-export shards (109k hrs, 48kHz→16kHz)")
    log(f"Ramp steps: {RAMP_STEPS}")

    for step_target in RAMP_STEPS:
        log(f"\n{'='*60}")
        log(f"RAMP STEP: target {step_target} workers")
        log(f"{'='*60}")

        current_instances = len(get_instances())

        while current_instances < step_target:
            launched = launch_batch(step_target, current_instances)
            current_instances = len(get_instances())

            if current_instances < step_target:
                shortfall = step_target - current_instances
                log(f"  Only {current_instances}/{step_target} instances. Short {shortfall}.")
                log(f"  Waiting {RETRY_WAIT//60} minutes for more offers...")
                time.sleep(RETRY_WAIT)
                current_instances = len(get_instances())

        # After first step (1 worker), validate encoding
        if step_target == 1:
            log("Validating first worker encoding (50 tok/s check)...")
            time.sleep(120)  # Wait 2 min for model loading
            valid = validate_first_worker()
            if not valid:
                log("VALIDATION FAILED! Check worker logs. Stopping ramp-up.")
                sys.exit(1)

        # Status check after each step
        time.sleep(30)
        status = get_fleet_status()
        if status:
            remaining = status["pending"] + status["processing"]
            log(f"  Fleet: {status['active_workers']} workers, RTF={status['fleet_rtf']:.0f}x")
            log(f"  Progress: {status['completed']}/4350 done, {remaining} remaining")
            if remaining == 0:
                log("ALL SHARDS DONE!")
                break

        # Brief pause between steps
        if step_target < RAMP_STEPS[-1]:
            log("  Waiting 60s before next ramp step...")
            time.sleep(60)

    log("\n=== RAMP-UP COMPLETE ===")
    status = get_fleet_status()
    if status:
        log(f"Final fleet: {status['active_workers']} workers, RTF={status['fleet_rtf']:.0f}x")
        if status["fleet_rtf"] > 0:
            remaining = status["pending"] + status["processing"]
            avg_hrs = 109160.0 / 4350  # ~25.1 hrs/shard
            eta = remaining * avg_hrs / status["fleet_rtf"]
            log(f"ETA: {eta:.1f} hours")

    log("\nMonitor running in background. Fleet will auto-recover via monitor.py.")


if __name__ == "__main__":
    main()
