"""
Analytics Reporter for Maya3 Pipeline

Lightweight, non-blocking analytics that reports worker status to Supabase.

Design principles:
- Never blocks main processing thread (background queue + thread)
- Batches writes to reduce Supabase calls (~2-3 per video, not 10+)
- Uses upserts for heartbeats (single row per worker)
- Gracefully handles failures (analytics never crashes worker)
- Local counters remain accurate even if Supabase is down
"""

import threading
import queue
import time
import socket
import hashlib
import traceback
from typing import Optional, Dict, Any
from dataclasses import dataclass, field
from datetime import datetime


@dataclass
class AnalyticsEvent:
    """Event to be sent to Supabase."""
    event_type: str  # 'heartbeat', 'done', 'fail'
    data: Dict[str, Any]
    timestamp: datetime = field(default_factory=datetime.utcnow)


def get_machine_id() -> str:
    """
    Generate a unique, readable machine ID.
    Format: hostname-xxxx (where xxxx is a short hash)
    """
    hostname = socket.gethostname()[:20]  # Truncate long hostnames
    # Create short hash from hostname + MAC-like identifier
    unique_str = f"{hostname}-{socket.getfqdn()}"
    short_hash = hashlib.md5(unique_str.encode()).hexdigest()[:4]
    return f"{hostname}-{short_hash}"


class AnalyticsReporter:
    """
    Non-blocking analytics reporter with background flushing.

    Usage:
        reporter = AnalyticsReporter(gpu_id=0, supabase_url=..., supabase_key=...)
        reporter.session_start()

        # For each video:
        reporter.pick(video_id)
        reporter.stage_update(video_id, 3, "Diarization")  # Optional stage updates
        reporter.done(video_id, total_time, audio_min, speakers, usable_pct)
        # or
        reporter.fail(video_id, "Download", "NetworkError", "Connection timeout", traceback_str)

        # On shutdown:
        reporter.shutdown()
    """

    def __init__(
        self,
        gpu_id: int,
        supabase_url: str,
        supabase_key: str,
        machine_id: Optional[str] = None,
        flush_interval: float = 5.0,
        batch_size: int = 10
    ):
        self.gpu_id = gpu_id
        self.machine_id = machine_id or get_machine_id()
        self.worker_id = f"{self.machine_id}_gpu{gpu_id}"

        # Background queue for non-blocking writes
        self._queue: queue.Queue = queue.Queue()
        self._stop_event = threading.Event()

        # Configuration
        self._flush_interval = flush_interval
        self._batch_size = batch_size

        # Local counters (always accurate, even if Supabase fails)
        self.session_videos_done = 0
        self.session_videos_failed = 0
        self.session_audio_minutes = 0.0
        self.session_usable_minutes = 0.0
        self.session_start_time = datetime.utcnow()

        # Supabase client (lazy init in background thread)
        self._supabase_url = supabase_url
        self._supabase_key = supabase_key
        self._client = None
        self._client_lock = threading.Lock()

        # Start background flusher
        self._flush_thread = threading.Thread(target=self._flush_loop, daemon=True)
        self._flush_thread.start()

    def _get_client(self):
        """Lazy init Supabase client in background thread (thread-safe)."""
        if self._client is None:
            with self._client_lock:
                if self._client is None:  # Double-check after acquiring lock
                    try:
                        from supabase import create_client
                        self._client = create_client(self._supabase_url, self._supabase_key)
                    except Exception:
                        pass  # Silently fail - analytics should never crash worker
        return self._client

    def _flush_loop(self):
        """Background thread that flushes events to Supabase."""
        events_buffer = []
        last_flush = time.time()

        while not self._stop_event.is_set():
            try:
                # Collect events (non-blocking with timeout)
                try:
                    event = self._queue.get(timeout=1.0)
                    events_buffer.append(event)
                except queue.Empty:
                    pass

                # Flush conditions: batch_size+ events OR flush_interval passed
                should_flush = (
                    len(events_buffer) >= self._batch_size or
                    (events_buffer and time.time() - last_flush >= self._flush_interval)
                )

                if should_flush and events_buffer:
                    self._flush_events(events_buffer)
                    events_buffer = []
                    last_flush = time.time()

            except Exception:
                # Never crash the flush thread
                pass

        # Final flush on shutdown
        if events_buffer:
            self._flush_events(events_buffer)

    def _flush_events(self, events: list):
        """Batch write events to Supabase."""
        client = self._get_client()
        if not client:
            return

        try:
            # Separate by type
            heartbeats = [e for e in events if e.event_type == 'heartbeat']
            done_events = [e for e in events if e.event_type == 'done']
            fail_events = [e for e in events if e.event_type == 'fail']

            # Upsert latest heartbeat (only 1 per worker, use most recent)
            if heartbeats:
                latest = heartbeats[-1]
                client.table('worker_heartbeats').upsert(
                    latest.data,
                    on_conflict='worker_id'
                ).execute()

            # Batch insert done events
            if done_events:
                rows = [e.data for e in done_events]
                client.table('processing_events').insert(rows).execute()

            # Batch insert fail events + error_logs
            if fail_events:
                event_rows = [e.data for e in fail_events]
                client.table('processing_events').insert(event_rows).execute()

                # Also log to error_logs table with full details
                error_rows = [{
                    'worker_id': e.data['worker_id'],
                    'video_id': e.data['video_id'],
                    'stage': e.data.get('stage_name'),
                    'error_type': e.data.get('error_type'),
                    'error_message': e.data.get('error_message'),
                    'stack_trace': e.data.get('stack_trace')
                } for e in fail_events]
                client.table('error_logs').insert(error_rows).execute()

        except Exception:
            # Silently fail - analytics should never crash worker
            pass

    # =========================================================================
    # Public API (all non-blocking)
    # =========================================================================

    def session_start(self):
        """Register worker on startup."""
        self.session_start_time = datetime.utcnow()
        self._queue.put(AnalyticsEvent('heartbeat', {
            'worker_id': self.worker_id,
            'machine_id': self.machine_id,
            'gpu_id': self.gpu_id,
            'status': 'idle',
            'current_video_id': None,
            'current_stage': None,
            'current_stage_name': None,
            'last_heartbeat': datetime.utcnow().isoformat(),
            'session_start': self.session_start_time.isoformat(),
            'session_videos_done': 0,
            'session_videos_failed': 0,
            'session_audio_minutes': 0,
            'session_usable_minutes': 0
        }))

    def pick(self, video_id: str):
        """Worker claimed a video."""
        self._queue.put(AnalyticsEvent('heartbeat', {
            'worker_id': self.worker_id,
            'machine_id': self.machine_id,
            'gpu_id': self.gpu_id,
            'status': 'processing',
            'current_video_id': video_id,
            'current_stage': 1,
            'current_stage_name': 'Download',
            'last_heartbeat': datetime.utcnow().isoformat(),
            'session_start': self.session_start_time.isoformat(),
            'session_videos_done': self.session_videos_done,
            'session_videos_failed': self.session_videos_failed,
            'session_audio_minutes': self.session_audio_minutes,
            'session_usable_minutes': self.session_usable_minutes
        }))

    def stage_update(self, video_id: str, stage_num: int, stage_name: str):
        """
        Update current stage (lightweight heartbeat).
        Call this optionally for long-running stages.
        """
        self._queue.put(AnalyticsEvent('heartbeat', {
            'worker_id': self.worker_id,
            'machine_id': self.machine_id,
            'gpu_id': self.gpu_id,
            'status': 'processing',
            'current_video_id': video_id,
            'current_stage': stage_num,
            'current_stage_name': stage_name,
            'last_heartbeat': datetime.utcnow().isoformat(),
            'session_start': self.session_start_time.isoformat(),
            'session_videos_done': self.session_videos_done,
            'session_videos_failed': self.session_videos_failed,
            'session_audio_minutes': self.session_audio_minutes,
            'session_usable_minutes': self.session_usable_minutes
        }))

    def done(
        self,
        video_id: str,
        total_time: float,
        audio_min: float,
        speakers: int,
        usable_pct: float
    ):
        """Video processed successfully."""
        # Update local counters
        self.session_videos_done += 1
        self.session_audio_minutes += audio_min
        usable_min = audio_min * (usable_pct / 100.0)
        self.session_usable_minutes += usable_min

        # Queue done event (for processing_events table)
        self._queue.put(AnalyticsEvent('done', {
            'worker_id': self.worker_id,
            'video_id': video_id,
            'event_type': 'done',
            'duration_seconds': total_time,
            'audio_minutes': audio_min,
            'speakers': speakers,
            'usable_pct': usable_pct
        }))

        # Update heartbeat with new counters (status back to idle)
        self._queue.put(AnalyticsEvent('heartbeat', {
            'worker_id': self.worker_id,
            'machine_id': self.machine_id,
            'gpu_id': self.gpu_id,
            'status': 'idle',
            'current_video_id': None,
            'current_stage': None,
            'current_stage_name': None,
            'last_heartbeat': datetime.utcnow().isoformat(),
            'session_start': self.session_start_time.isoformat(),
            'session_videos_done': self.session_videos_done,
            'session_videos_failed': self.session_videos_failed,
            'session_audio_minutes': self.session_audio_minutes,
            'session_usable_minutes': self.session_usable_minutes
        }))

    def fail(
        self,
        video_id: str,
        stage: str,
        error_type: str,
        message: str,
        stack_trace: Optional[str] = None
    ):
        """Video processing failed."""
        # Update local counters
        self.session_videos_failed += 1

        # Queue fail event (for processing_events + error_logs tables)
        self._queue.put(AnalyticsEvent('fail', {
            'worker_id': self.worker_id,
            'video_id': video_id,
            'event_type': 'fail',
            'stage_name': stage,
            'error_type': error_type,
            'error_message': message[:500] if message else None,  # Truncate
            'stack_trace': stack_trace[:2000] if stack_trace else None  # Truncate
        }))

        # Update heartbeat (status back to idle)
        self._queue.put(AnalyticsEvent('heartbeat', {
            'worker_id': self.worker_id,
            'machine_id': self.machine_id,
            'gpu_id': self.gpu_id,
            'status': 'idle',
            'current_video_id': None,
            'current_stage': None,
            'current_stage_name': None,
            'last_heartbeat': datetime.utcnow().isoformat(),
            'session_start': self.session_start_time.isoformat(),
            'session_videos_done': self.session_videos_done,
            'session_videos_failed': self.session_videos_failed,
            'session_audio_minutes': self.session_audio_minutes,
            'session_usable_minutes': self.session_usable_minutes
        }))

    def shutdown(self):
        """Graceful shutdown - flush remaining events."""
        # Send final heartbeat marking as offline
        self._queue.put(AnalyticsEvent('heartbeat', {
            'worker_id': self.worker_id,
            'machine_id': self.machine_id,
            'gpu_id': self.gpu_id,
            'status': 'offline',
            'current_video_id': None,
            'current_stage': None,
            'current_stage_name': None,
            'last_heartbeat': datetime.utcnow().isoformat(),
            'session_start': self.session_start_time.isoformat(),
            'session_videos_done': self.session_videos_done,
            'session_videos_failed': self.session_videos_failed,
            'session_audio_minutes': self.session_audio_minutes,
            'session_usable_minutes': self.session_usable_minutes
        }))

        # Signal stop and wait for flush
        self._stop_event.set()
        self._flush_thread.join(timeout=5.0)

    def get_local_stats(self) -> Dict[str, Any]:
        """Get local session statistics (always accurate)."""
        return {
            'worker_id': self.worker_id,
            'machine_id': self.machine_id,
            'gpu_id': self.gpu_id,
            'session_videos_done': self.session_videos_done,
            'session_videos_failed': self.session_videos_failed,
            'session_audio_minutes': self.session_audio_minutes,
            'session_usable_minutes': self.session_usable_minutes,
            'session_duration_minutes': (datetime.utcnow() - self.session_start_time).total_seconds() / 60.0
        }


class NoOpAnalyticsReporter:
    """
    No-op analytics reporter for when analytics is disabled.
    Same interface as AnalyticsReporter but does nothing.
    """

    def __init__(self, gpu_id: int = 0, **kwargs):
        self.gpu_id = gpu_id
        self.worker_id = f"local_gpu{gpu_id}"
        self.session_videos_done = 0
        self.session_videos_failed = 0
        self.session_audio_minutes = 0.0
        self.session_usable_minutes = 0.0

    def session_start(self): pass
    def pick(self, video_id: str): pass
    def stage_update(self, video_id: str, stage_num: int, stage_name: str): pass

    def done(self, video_id: str, total_time: float, audio_min: float,
             speakers: int, usable_pct: float):
        self.session_videos_done += 1
        self.session_audio_minutes += audio_min
        self.session_usable_minutes += audio_min * (usable_pct / 100.0)

    def fail(self, video_id: str, stage: str, error_type: str,
             message: str, stack_trace: str = None):
        self.session_videos_failed += 1

    def shutdown(self): pass

    def get_local_stats(self) -> Dict[str, Any]:
        return {
            'worker_id': self.worker_id,
            'gpu_id': self.gpu_id,
            'session_videos_done': self.session_videos_done,
            'session_videos_failed': self.session_videos_failed,
            'session_audio_minutes': self.session_audio_minutes,
            'session_usable_minutes': self.session_usable_minutes
        }


def create_reporter(
    gpu_id: int,
    supabase_url: Optional[str] = None,
    supabase_key: Optional[str] = None,
    machine_id: Optional[str] = None,
    enabled: bool = True
) -> AnalyticsReporter:
    """
    Factory function to create appropriate reporter.
    Returns NoOpAnalyticsReporter if disabled or credentials missing.
    """
    if not enabled or not supabase_url or not supabase_key:
        return NoOpAnalyticsReporter(gpu_id=gpu_id)

    return AnalyticsReporter(
        gpu_id=gpu_id,
        supabase_url=supabase_url,
        supabase_key=supabase_key,
        machine_id=machine_id
    )
