"""
Worker lifecycle: register, heartbeat loop, main loop, graceful shutdown.
Each Docker container runs one Worker instance.
"""
from __future__ import annotations

import asyncio
import logging
import signal
import sys
from pathlib import Path
from typing import Optional

from .cache_manager import CacheManager
from .config import (
    EnvConfig, HEARTBEAT_INTERVAL_S, MAX_VIDEOS, PROMPT_VERSION, SCHEMA_VERSION,
    TRIMMER_VERSION, VALIDATOR_VERSION, TEMPERATURE, THINKING_LEVEL,
)
from .db import MockDB, PostgresDB, VideoTask, WorkerStats, get_db, log_db_stats
from .pipeline import Pipeline
from .providers.aistudio import AIStudioProvider
from .providers.base import BaseProvider
from .r2_client import R2Client

logger = logging.getLogger(__name__)


class Worker:

    def __init__(self, config: EnvConfig):
        self.config = config
        self.worker_id = config.worker_id
        self.stats = WorkerStats()
        self._shutdown_event = asyncio.Event()
        self._heartbeat_task: Optional[asyncio.Task] = None

        # Initialize components
        self.db = get_db(config)
        self.r2 = R2Client(config)

        # Multi-key provider pool: primary key + fallback keys (different GCP projects)
        self.primary_provider = AIStudioProvider(
            api_key=config.primary_gemini_key,
            mock_mode=config.mock_mode,
        )
        self.fallback_provider = self._create_fallback()

        key_idx = config.gemini_key_index
        n_keys = len(config.gemini_keys)
        n_fallbacks = len(config.fallback_gemini_keys)
        logger.info(f"Provider pool: key[{key_idx}] primary, {n_fallbacks} fallback(s) ({n_keys} total keys)")

    def _create_fallback(self) -> Optional[BaseProvider]:
        """First fallback key from the pool (different GCP project, same quota)."""
        fallbacks = self.config.fallback_gemini_keys
        if not fallbacks:
            return None
        return AIStudioProvider(
            api_key=fallbacks[0],
            mock_mode=self.config.mock_mode,
        )

    async def start(self):
        """Main entry: connect pool, register, set up cache, start heartbeat, process videos."""
        loop = asyncio.get_running_loop()
        for sig in (signal.SIGTERM, signal.SIGINT):
            loop.add_signal_handler(sig, lambda s=sig: asyncio.create_task(self._handle_shutdown(s)))

        try:
            await self.db.connect()
            await self._register()

            if not self.config.mock_mode:
                await self._setup_caches()

            self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
            await self._main_loop()

        except Exception as e:
            logger.error(f"Worker fatal error: {e}", exc_info=True)
            await self.db.set_worker_error(self.worker_id, str(e))
        finally:
            await self._cleanup()

    async def _setup_caches(self):
        """Set up V2 cache for primary + fallback keys (each GCP project has its own cache)."""
        providers_to_cache = [("primary", self.primary_provider, self.config.primary_gemini_key)]
        if self.fallback_provider and self.config.fallback_gemini_keys:
            providers_to_cache.append(("fallback", self.fallback_provider, self.config.fallback_gemini_keys[0]))

        for label, provider, api_key in providers_to_cache:
            try:
                cm = CacheManager(api_key)
                cache_name = await cm.ensure_cache()
                if isinstance(provider, AIStudioProvider):
                    provider.cached_content_name = cache_name
                info = await cm.get_cache_info()
                tokens = info.get("usageMetadata", {}).get("totalTokenCount", "?") if info else "?"
                logger.info(f"Cache ready ({label}): {cache_name} ({tokens} tokens)")
            except Exception as e:
                logger.warning(f"Cache setup failed for {label}, falling back to V1 uncached: {e}")

    async def _register(self):
        config_json = {
            "prompt_version": PROMPT_VERSION,
            "schema_version": SCHEMA_VERSION,
            "trimmer_version": TRIMMER_VERSION,
            "validator_version": VALIDATOR_VERSION,
            "temperature": TEMPERATURE,
            "thinking_level": THINKING_LEVEL,
            "provider": "aistudio",
            "gemini_key_index": self.config.gemini_key_index,
            "total_keys": len(self.config.gemini_keys),
            "mock_mode": self.config.mock_mode,
        }
        logger.info(f"Registering worker {self.worker_id} (key_index={self.config.gemini_key_index})")
        await self.db.register_worker(
            worker_id=self.worker_id,
            provider="aistudio",
            gpu_type=self.config.gpu_type,
            config_json=config_json,
        )

    async def _heartbeat_loop(self):
        import random
        hb_count = 0
        while not self._shutdown_event.is_set():
            try:
                await self.db.update_heartbeat(self.worker_id, self.stats)
            except Exception as e:
                logger.warning(f"Heartbeat update failed: {e}")

            hb_count += 1
            if hb_count % 10 == 0:
                log_db_stats()

            # Jitter (0-5s) prevents all workers hitting DB at same instant
            jitter = random.uniform(0, 5)
            try:
                await asyncio.wait_for(
                    self._shutdown_event.wait(),
                    timeout=HEARTBEAT_INTERVAL_S + jitter,
                )
                break
            except asyncio.TimeoutError:
                pass

    async def _main_loop(self):
        pipeline = Pipeline(
            config=self.config,
            db=self.db,
            r2=self.r2,
            primary_provider=self.primary_provider,
            fallback_provider=self.fallback_provider,
            worker_id=self.worker_id,
            stats=self.stats,
        )

        consecutive_empty = 0
        videos_processed = 0
        prefetch_task: Optional[asyncio.Task] = None
        prefetched: Optional[VideoTask] = None

        if MAX_VIDEOS > 0:
            logger.info(f"MAX_VIDEOS={MAX_VIDEOS} — will stop after {MAX_VIDEOS} video(s)")

        while not self._shutdown_event.is_set():
            if MAX_VIDEOS > 0 and videos_processed >= MAX_VIDEOS:
                logger.info(f"Reached MAX_VIDEOS={MAX_VIDEOS}, shutting down.")
                break

            if prefetched:
                task = prefetched
                prefetched = None
                logger.info(f"Using prefetched video: {task.video_id}")
            else:
                task = await self.db.claim_video(self.worker_id)

            if task is None:
                consecutive_empty += 1
                if consecutive_empty >= 3:
                    logger.info("No more videos to process. Worker going idle.")
                    break
                logger.info(f"No pending videos, waiting 30s (attempt {consecutive_empty}/3)...")
                try:
                    await asyncio.wait_for(self._shutdown_event.wait(), timeout=30)
                    break
                except asyncio.TimeoutError:
                    continue

            consecutive_empty = 0
            logger.info(f"Claimed video: {task.video_id} (lang={task.language})")

            # Prefetch next only if we haven't hit the limit
            remaining = MAX_VIDEOS - videos_processed - 1 if MAX_VIDEOS > 0 else 1
            if remaining > 0:
                prefetch_task = asyncio.create_task(self._prefetch_next())

            success = await pipeline.process_video(task)
            videos_processed += 1
            if not success:
                logger.error(f"Video {task.video_id} failed, continuing to next...")

            # Collect prefetch result
            if prefetch_task and not prefetch_task.done():
                try:
                    prefetched = await asyncio.wait_for(prefetch_task, timeout=5)
                except (asyncio.TimeoutError, Exception):
                    prefetched = None
            elif prefetch_task and prefetch_task.done():
                try:
                    prefetched = prefetch_task.result()
                except Exception:
                    prefetched = None
            prefetch_task = None

        logger.info(f"Main loop ended: {videos_processed} videos processed")

    async def _prefetch_next(self) -> Optional[VideoTask]:
        """Claim and pre-download the next video while current one processes."""
        try:
            task = await self.db.claim_video(self.worker_id)
            if task is None:
                return None
            logger.info(f"[prefetch] Claimed next: {task.video_id}, downloading tar...")
            import tempfile
            work_dir = Path(tempfile.mkdtemp(prefix=f"prefetch_{self.worker_id}_"))
            self.r2.download_tar(task.video_id, work_dir)
            task.prefetch_dir = work_dir
            logger.info(f"[prefetch] {task.video_id} tar ready at {work_dir}")
            return task
        except Exception as e:
            logger.warning(f"[prefetch] Failed: {e}")
            return None

    async def _handle_shutdown(self, sig):
        logger.info(f"Received {sig.name}, initiating graceful shutdown...")
        self._shutdown_event.set()

    async def _cleanup(self):
        logger.info("Cleaning up worker...")
        if self._heartbeat_task:
            self._heartbeat_task.cancel()
            try:
                await self._heartbeat_task
            except asyncio.CancelledError:
                pass

        if self.stats.current_video_id:
            logger.info(f"Releasing video {self.stats.current_video_id}")
            await self.db.release_video(self.stats.current_video_id)

        try:
            await self.db.update_heartbeat(self.worker_id, self.stats)
        except Exception:
            pass
        await self.db.set_worker_offline(self.worker_id)

        log_db_stats()

        await self.db.close()
        logger.info(f"Worker {self.worker_id} shutdown complete. "
                     f"Stats: {self.stats.segments_completed} completed, "
                     f"{self.stats.segments_failed} failed, "
                     f"{self.stats.batches_completed} batches")
