#!/usr/bin/env python3
"""Monitor neucodec encoding progress across all workers."""

import os
import sys
import psycopg2
from datetime import datetime, timezone

DATABASE_URL = os.environ.get(
    "DATABASE_URL",
    "postgresql://postgres.exlkkfpymkpqlxulurel:Chibhakaku%402001@aws-0-us-west-2.pooler.supabase.com:6543/postgres",
)

conn = psycopg2.connect(DATABASE_URL)
cur = conn.cursor()

# ── Overall progress ──
cur.execute("""
    SELECT status, COUNT(*),
           COALESCE(SUM(total_audio_seconds), 0),
           COALESCE(SUM(total_tokens), 0)
    FROM neucodec_shards GROUP BY status ORDER BY status
""")
rows = cur.fetchall()
total_shards = sum(r[1] for r in rows)

print(f"{'='*72}")
print(f"  NEUCODEC ENCODING — {total_shards} total shards")
print(f"{'='*72}")
print(f"  {'Status':>12s} | {'Shards':>7s} | {'Pct':>6s} | {'Audio (hrs)':>11s} | {'Tokens':>14s}")
print(f"  {'─'*12}─┼─{'─'*7}─┼─{'─'*6}─┼─{'─'*11}─┼─{'─'*14}")
for status, count, audio_s, tokens in rows:
    pct = count / total_shards * 100
    print(f"  {status:>12s} | {count:>7d} | {pct:>5.1f}% | {audio_s/3600:>11.1f} | {tokens:>14,.0f}")

# ── Workers ──
cur.execute("""
    SELECT worker_id, hostname, gpu_name, state, progress_phase,
           current_shard_prefix, progress_pct, progress_segments, progress_total_segments,
           progress_rtf, progress_vram_gb,
           shards_completed, shards_failed, total_audio_hours,
           last_heartbeat, last_error
    FROM neucodec_workers
    ORDER BY last_heartbeat DESC
""")
workers = cur.fetchall()

if workers:
    now = datetime.now(timezone.utc)
    print(f"\n{'='*72}")
    print(f"  WORKERS ({len(workers)})")
    print(f"{'='*72}")

    for (wid, host, gpu, state, phase, shard, pct, seg, total_seg,
         rtf, vram, done, fail, hours, hb, err) in workers:
        age = (now - hb).total_seconds() if hb else 9999
        stale = " ⚠ STALE" if age > 120 and state != "exited" else ""
        status_icon = {"working": "●", "idle": "○", "exited": "✕"}.get(state, "?")

        print(f"\n  {status_icon} {wid}  [{gpu}]{stale}")
        print(f"    State: {state}/{phase}  |  Heartbeat: {int(age)}s ago  |  Done: {done} shards  |  Failed: {fail}")

        if state == "working" and shard:
            bar_len = 30
            filled = int(bar_len * pct / 100) if pct else 0
            bar = "█" * filled + "░" * (bar_len - filled)
            print(f"    Shard: {shard}")
            print(f"    [{bar}] {pct:>5.1f}%  ({seg}/{total_seg})  RTF={rtf:.1f}x  VRAM={vram:.1f}GB")

        if err and state != "working":
            print(f"    Last error: {err[:80]}")

# ── ETA ──
cur.execute("SELECT COUNT(*) FROM neucodec_shards WHERE status = 'pending'")
pending = cur.fetchone()[0]
cur.execute("SELECT AVG(encode_wall_seconds) FROM neucodec_shards WHERE status = 'completed' AND encode_wall_seconds > 0")
avg_wall_row = cur.fetchone()
avg_wall = avg_wall_row[0] if avg_wall_row[0] else None
active = sum(1 for w in workers if w[3] == "working") if workers else 0

if pending > 0:
    print(f"\n{'='*72}")
    if avg_wall and active > 0:
        eta_h = (pending * avg_wall / active) / 3600
        print(f"  ETA: {pending} pending / {active} active workers → ~{eta_h:.1f} hrs ({eta_h/24:.1f} days)")
    else:
        print(f"  {pending} shards pending, {active} active workers")
    print(f"{'='*72}")

cur.close()
conn.close()
