#!/usr/bin/env python3
"""
Real-time Monitoring Dashboard for Maya3 Distributed Pipeline.

Displays:
- Cluster-wide progress
- Per-worker status
- Error rates and common failures
- Queue depth and throughput

Usage:
    python dashboard.py                    # One-time status
    python dashboard.py --watch            # Live updates (every 30s)
    python dashboard.py --watch --interval 10  # Custom interval
"""

import os
import sys
import time
import argparse
from datetime import datetime, timedelta
from typing import Dict, List, Any

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

from dotenv import load_dotenv
load_dotenv()

from supabase import create_client

# ANSI colors
class C:
    RESET = '\033[0m'
    BOLD = '\033[1m'
    RED = '\033[91m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    BLUE = '\033[94m'
    CYAN = '\033[96m'
    DIM = '\033[2m'


def clear_screen():
    """Clear terminal screen."""
    print('\033[2J\033[H', end='')


def format_duration(seconds: float) -> str:
    """Format seconds as human-readable duration."""
    if seconds < 60:
        return f"{seconds:.0f}s"
    elif seconds < 3600:
        return f"{seconds/60:.1f}m"
    else:
        return f"{seconds/3600:.1f}h"


def format_number(n: int) -> str:
    """Format number with commas."""
    return f"{n:,}"


class Dashboard:
    def __init__(self):
        self.sb = create_client(os.environ['URL'], os.environ['SUPABASE_ADMIN'])
        self.start_time = datetime.utcnow()

    def get_queue_stats(self) -> Dict[str, int]:
        """Get video counts by status."""
        stats = {}
        for status in ['PENDING', 'CLAIMED', 'COMPLETED', 'FAILED', 'SKIPPED']:
            result = self.sb.table('videos').select('id', count='exact').eq('status', status).execute()
            stats[status] = result.count or 0
        return stats

    def get_worker_status(self) -> List[Dict]:
        """Get all worker heartbeats."""
        result = self.sb.table('worker_heartbeats').select('*').order('worker_id').execute()
        return result.data or []

    def get_recent_errors(self, limit: int = 10) -> List[Dict]:
        """Get recent error logs."""
        result = self.sb.table('error_logs').select(
            'video_id, stage, error_type, error_message, created_at'
        ).order('created_at', desc=True).limit(limit).execute()
        return result.data or []

    def get_error_summary(self) -> Dict[str, int]:
        """Get error counts by type."""
        result = self.sb.table('error_logs').select('error_type').execute()
        counts = {}
        for row in (result.data or []):
            err_type = row.get('error_type', 'Unknown')
            counts[err_type] = counts.get(err_type, 0) + 1
        return counts

    def get_hourly_throughput(self) -> List[Dict]:
        """Get videos completed per hour (last 24h)."""
        # Get recent processing events
        since = (datetime.utcnow() - timedelta(hours=24)).isoformat()
        result = self.sb.table('processing_events').select(
            'created_at, audio_minutes'
        ).eq('event_type', 'done').gte('created_at', since).execute()

        # Bucket by hour
        hourly = {}
        for row in (result.data or []):
            ts = row.get('created_at', '')[:13]  # YYYY-MM-DDTHH
            if ts not in hourly:
                hourly[ts] = {'count': 0, 'audio_min': 0}
            hourly[ts]['count'] += 1
            hourly[ts]['audio_min'] += row.get('audio_minutes', 0) or 0

        return [{'hour': k, **v} for k, v in sorted(hourly.items())]

    def get_stale_workers(self, threshold_sec: int = 300) -> List[Dict]:
        """Find workers that haven't reported in threshold seconds."""
        threshold = (datetime.utcnow() - timedelta(seconds=threshold_sec)).isoformat()
        result = self.sb.table('worker_heartbeats').select('*').neq(
            'status', 'offline'
        ).lt('last_heartbeat', threshold).execute()
        return result.data or []

    def print_header(self):
        """Print dashboard header."""
        now = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')
        print(f"\n{C.BOLD}{C.CYAN}╔══════════════════════════════════════════════════════════════════════╗{C.RESET}")
        print(f"{C.BOLD}{C.CYAN}║{C.RESET}            {C.BOLD}MAYA3 DISTRIBUTED PIPELINE DASHBOARD{C.RESET}                      {C.CYAN}║{C.RESET}")
        print(f"{C.BOLD}{C.CYAN}║{C.RESET}                       {now}                       {C.CYAN}║{C.RESET}")
        print(f"{C.BOLD}{C.CYAN}╚══════════════════════════════════════════════════════════════════════╝{C.RESET}")

    def print_queue_stats(self, stats: Dict[str, int]):
        """Print queue statistics."""
        total = sum(stats.values())
        completed = stats.get('COMPLETED', 0)
        failed = stats.get('FAILED', 0)
        pending = stats.get('PENDING', 0)
        claimed = stats.get('CLAIMED', 0)

        progress = (completed / total * 100) if total > 0 else 0
        error_rate = (failed / (completed + failed) * 100) if (completed + failed) > 0 else 0

        print(f"\n{C.BOLD}📊 QUEUE STATUS{C.RESET}")
        print(f"{'─'*50}")

        # Progress bar
        bar_width = 40
        filled = int(bar_width * progress / 100)
        bar = '█' * filled + '░' * (bar_width - filled)
        print(f"  Progress: [{C.GREEN}{bar}{C.RESET}] {progress:.1f}%")

        # Stats grid
        print(f"  {C.GREEN}✓ Completed:{C.RESET} {format_number(completed):>8}  │  {C.YELLOW}⏳ Pending:{C.RESET} {format_number(pending):>8}")
        print(f"  {C.RED}✗ Failed:{C.RESET}    {format_number(failed):>8}  │  {C.BLUE}🔄 Claimed:{C.RESET} {format_number(claimed):>8}")
        print(f"  {C.DIM}Total:{C.RESET}       {format_number(total):>8}  │  {C.DIM}Error Rate:{C.RESET} {error_rate:.2f}%")

    def print_worker_status(self, workers: List[Dict]):
        """Print worker status table."""
        print(f"\n{C.BOLD}👷 WORKER STATUS ({len(workers)} workers){C.RESET}")
        print(f"{'─'*70}")

        if not workers:
            print(f"  {C.DIM}No workers registered{C.RESET}")
            return

        # Group by machine
        machines = {}
        for w in workers:
            machine = w.get('machine_id', 'unknown')
            if machine not in machines:
                machines[machine] = []
            machines[machine].append(w)

        for machine, machine_workers in sorted(machines.items()):
            print(f"\n  {C.BOLD}{machine}{C.RESET}")

            for w in sorted(machine_workers, key=lambda x: x.get('gpu_id', 0)):
                status = w.get('status', 'unknown')
                gpu = w.get('gpu_id', '?')
                video = w.get('current_video_id', '')
                stage = w.get('current_stage_name', '')
                done = w.get('session_videos_done', 0)
                failed = w.get('session_videos_failed', 0)
                audio_min = w.get('session_audio_minutes', 0)

                # Status indicator
                if status == 'processing':
                    status_icon = f"{C.GREEN}●{C.RESET}"
                    status_text = f"{C.GREEN}processing{C.RESET}"
                elif status == 'idle':
                    status_icon = f"{C.YELLOW}○{C.RESET}"
                    status_text = f"{C.YELLOW}idle{C.RESET}"
                else:
                    status_icon = f"{C.RED}◌{C.RESET}"
                    status_text = f"{C.RED}{status}{C.RESET}"

                # Last heartbeat age
                last_hb = w.get('last_heartbeat', '')
                if last_hb:
                    try:
                        hb_time = datetime.fromisoformat(last_hb.replace('Z', '+00:00'))
                        age = (datetime.utcnow().replace(tzinfo=hb_time.tzinfo) - hb_time).total_seconds()
                        age_str = format_duration(age)
                        if age > 300:
                            age_str = f"{C.RED}{age_str}{C.RESET}"
                    except:
                        age_str = "?"
                else:
                    age_str = "?"

                video_display = video[:11] if video else '-'
                stage_display = stage[:15] if stage else '-'

                print(f"    {status_icon} GPU{gpu} │ {status_text:20} │ {video_display:11} │ {stage_display:15} │ ✓{done} ✗{failed} │ {audio_min:.0f}min │ {age_str}")

    def print_error_summary(self, errors: Dict[str, int], recent: List[Dict]):
        """Print error summary and recent errors."""
        print(f"\n{C.BOLD}🚨 ERROR ANALYSIS{C.RESET}")
        print(f"{'─'*50}")

        if not errors:
            print(f"  {C.GREEN}No errors recorded{C.RESET}")
            return

        # Error type breakdown
        print(f"  {C.BOLD}By Type:{C.RESET}")
        for err_type, count in sorted(errors.items(), key=lambda x: -x[1])[:5]:
            print(f"    {C.RED}•{C.RESET} {err_type}: {count}")

        # Recent errors
        if recent:
            print(f"\n  {C.BOLD}Recent Errors:{C.RESET}")
            for err in recent[:5]:
                video = err.get('video_id', '?')
                stage = err.get('stage', '?')
                msg = (err.get('error_message', '') or '')[:60]
                print(f"    {C.RED}•{C.RESET} {video} @ {stage}: {msg}")

    def print_throughput(self, hourly: List[Dict]):
        """Print throughput chart."""
        print(f"\n{C.BOLD}📈 THROUGHPUT (last 24h){C.RESET}")
        print(f"{'─'*50}")

        if not hourly:
            print(f"  {C.DIM}No data yet{C.RESET}")
            return

        # Recent hours
        recent = hourly[-8:]
        max_count = max(h['count'] for h in recent) if recent else 1

        for h in recent:
            hour = h['hour'][-5:]  # HH:MM
            count = h['count']
            audio = h['audio_min']

            bar_len = int(30 * count / max_count) if max_count > 0 else 0
            bar = '▓' * bar_len

            print(f"  {hour} │ {C.CYAN}{bar:30}{C.RESET} │ {count:4} videos │ {audio:.0f} min")

        # Total stats
        total_videos = sum(h['count'] for h in hourly)
        total_audio = sum(h['audio_min'] for h in hourly)
        print(f"\n  {C.BOLD}24h Total:{C.RESET} {total_videos} videos, {total_audio/60:.1f} hours audio")

    def print_stale_workers(self, stale: List[Dict]):
        """Print stale worker warnings."""
        if not stale:
            return

        print(f"\n{C.BOLD}{C.YELLOW}⚠️  STALE WORKERS (no heartbeat >5min){C.RESET}")
        print(f"{'─'*50}")

        for w in stale:
            worker_id = w.get('worker_id', '?')
            last_hb = w.get('last_heartbeat', '')
            video = w.get('current_video_id', '')

            print(f"  {C.YELLOW}•{C.RESET} {worker_id} - last seen: {last_hb[:19]} - video: {video}")

    def run(self, watch: bool = False, interval: int = 30):
        """Run the dashboard."""
        try:
            while True:
                if watch:
                    clear_screen()

                # Gather data
                queue_stats = self.get_queue_stats()
                workers = self.get_worker_status()
                errors = self.get_error_summary()
                recent_errors = self.get_recent_errors(5)
                hourly = self.get_hourly_throughput()
                stale = self.get_stale_workers()

                # Print sections
                self.print_header()
                self.print_queue_stats(queue_stats)
                self.print_worker_status(workers)
                self.print_throughput(hourly)
                self.print_error_summary(errors, recent_errors)
                self.print_stale_workers(stale)

                print(f"\n{C.DIM}Press Ctrl+C to exit{C.RESET}")

                if not watch:
                    break

                time.sleep(interval)

        except KeyboardInterrupt:
            print(f"\n{C.DIM}Dashboard stopped{C.RESET}")


def main():
    parser = argparse.ArgumentParser(description="Maya3 Pipeline Dashboard")
    parser.add_argument('--watch', '-w', action='store_true', help='Live updates')
    parser.add_argument('--interval', '-i', type=int, default=30, help='Update interval (seconds)')
    args = parser.parse_args()

    dashboard = Dashboard()
    dashboard.run(watch=args.watch, interval=args.interval)


if __name__ == "__main__":
    main()
