"""Worker heartbeat thread: periodically reports status to Supabase.

Runs in a daemon thread so it dies with the main process.
Reports: alive status, current video, RTF, total progress, errors.
"""

from __future__ import annotations

import logging
import threading
import time

from codecbench.pipeline.supabase_client import SupabaseOrchestrator

logger = logging.getLogger(__name__)


class HeartbeatThread:
    def __init__(
        self,
        orchestrator: SupabaseOrchestrator,
        worker_id: str,
        interval_s: int = 30,
    ):
        self._orch = orchestrator
        self._worker_id = worker_id
        self._interval = interval_s
        self._stop = threading.Event()
        self._thread: threading.Thread | None = None
        self._start_time = time.time()

        # Mutable state (updated by main thread, read by heartbeat thread)
        self.current_video: str | None = None
        self.current_stage: str | None = None
        self.rtf: float | None = None
        self.total_audio_s: float = 0.0
        self.total_videos: int = 0
        self.total_failed: int = 0
        self.total_shards: int = 0
        self.shard_buffer_count: int = 0
        self._rtf_samples: list[float] = []

    def record_rtf(self, rtf: float) -> None:
        """Track encode RTF for rolling average."""
        self._rtf_samples.append(rtf)
        if len(self._rtf_samples) > 50:
            self._rtf_samples = self._rtf_samples[-50:]

    @property
    def avg_encode_rtf(self) -> float | None:
        if not self._rtf_samples:
            return None
        return sum(self._rtf_samples) / len(self._rtf_samples)

    def start(self) -> None:
        self._start_time = time.time()
        self._thread = threading.Thread(
            target=self._run, daemon=True, name="heartbeat"
        )
        self._thread.start()
        logger.info("Heartbeat thread started (every %d s)", self._interval)

    def _run(self) -> None:
        while not self._stop.wait(self._interval):
            try:
                self._orch.heartbeat(
                    worker_id=self._worker_id,
                    current_video=self.current_video,
                    current_stage=self.current_stage,
                    rtf=self.rtf,
                    avg_encode_rtf=self.avg_encode_rtf,
                    total_audio_s=self.total_audio_s,
                    total_videos=self.total_videos,
                    total_failed=self.total_failed,
                    total_shards=self.total_shards,
                    shard_buffer_count=self.shard_buffer_count,
                    uptime_s=time.time() - self._start_time,
                )
            except Exception as e:
                logger.error("Heartbeat failed: %s", e)

    def stop(self) -> None:
        self._stop.set()
        if self._thread:
            self._thread.join(timeout=5)
