"""Pod lifecycle manager - automatic recycling after N completed sessions.

Prevents memory leaks from accumulating by gracefully shutting down the pod
after a configurable number of completed sessions. Kubernetes automatically
restarts a fresh pod.

Flow:
    1. Each completed session increments the counter
    2. When counter hits MAX_POD_SESSIONS → set draining=True
    3. Readiness probe returns 503 → K8s stops routing new traffic
    4. Wait for active sessions to drain (up to DRAIN_TIMEOUT_SECONDS)
    5. os._exit(0) → K8s restarts fresh pod

Usage:
    from utils.pod_lifecycle import on_session_complete, is_draining, get_lifecycle_stats
"""

import asyncio
import os
import time
from loguru import logger


# ============================================
# Configuration (via env vars)
# ============================================
MAX_POD_SESSIONS = int(os.environ.get("MAX_POD_SESSIONS", "1000"))
DRAIN_TIMEOUT_SECONDS = int(os.environ.get("POD_DRAIN_TIMEOUT", "300"))  # 5 min max drain

# ============================================
# State
# ============================================
_completed_sessions: int = 0
_draining: bool = False
_drain_started_at: float = 0.0
_pod_start_time: float = time.time()


def on_session_complete(active_sessions_dict: dict) -> None:
    """Called when a session task completes (success, error, or cancel).
    
    Args:
        active_sessions_dict: Reference to the active_sessions dict from bot.py
            so we can check how many sessions are still running.
    """
    global _completed_sessions, _draining, _drain_started_at

    _completed_sessions += 1

    if _completed_sessions % 100 == 0:
        logger.info(
            f"POD_LIFECYCLE: completed={_completed_sessions}/{MAX_POD_SESSIONS} "
            f"active={len(active_sessions_dict)} "
            f"uptime={int(time.time() - _pod_start_time)}s"
        )

    if _completed_sessions >= MAX_POD_SESSIONS and not _draining:
        _draining = True
        _drain_started_at = time.time()
        logger.warning(
            f"POD_DRAINING: completed={_completed_sessions} threshold={MAX_POD_SESSIONS} "
            f"active={len(active_sessions_dict)} — stopping new traffic"
        )
        # Start drain watcher in background
        asyncio.ensure_future(_drain_and_exit(active_sessions_dict))


async def _drain_and_exit(active_sessions_dict: dict) -> None:
    """Wait for active sessions to finish, then exit the process."""
    start = time.time()

    while len(active_sessions_dict) > 0:
        elapsed = time.time() - start
        remaining = len(active_sessions_dict)

        if elapsed > DRAIN_TIMEOUT_SECONDS:
            logger.error(
                f"POD_DRAIN_TIMEOUT: {remaining} sessions still active after "
                f"{DRAIN_TIMEOUT_SECONDS}s — force exiting"
            )
            break

        logger.info(
            f"POD_DRAINING: waiting for {remaining} active sessions "
            f"({int(elapsed)}s/{DRAIN_TIMEOUT_SECONDS}s)"
        )
        await asyncio.sleep(5)

    elapsed = time.time() - start
    logger.warning(
        f"POD_RECYCLING: completed={_completed_sessions} sessions "
        f"drain_time={int(elapsed)}s uptime={int(time.time() - _pod_start_time)}s — exiting"
    )

    # Give logs a moment to flush
    await asyncio.sleep(1)

    # Use os._exit to ensure clean shutdown even if native threads are stuck
    os._exit(0)


def is_draining() -> bool:
    """Check if the pod is draining (readiness probe should return 503)."""
    return _draining


def get_lifecycle_stats() -> dict:
    """Get lifecycle stats for health/ready endpoints."""
    return {
        "completed_sessions": _completed_sessions,
        "max_sessions": MAX_POD_SESSIONS,
        "draining": _draining,
        "uptime_seconds": int(time.time() - _pod_start_time),
    }
