"""
Multi-Engine vLLM Wrapper - transparent dispatch across N engine instances.

OPTIMIZATION: A 0.5B model uses ~1.3GB weights + ~5-10GB KV cache per engine.
On an 80GB A100, we can comfortably run 3 engines at ~12GB each (36GB total).
This triples effective prefill capacity, directly cutting streaming TTFB under
concurrent load by ~3x.

The wrapper is transparent to pipeline code: it exposes the same `.engine.generate()`
interface, but routes calls to different underlying engines.

Usage:
    # Creates N engines, each with their own vLLM subprocess
    model = create_multi_engine_model(model_path, num_engines=3, gpu_memory_per_engine=0.12)
    
    # Pipeline uses it identically to a single engine:
    model.engine.generate(prompt, sampling_params, request_id)  # routed to engine N
"""

import logging
import threading
from typing import Any, AsyncGenerator, List

logger = logging.getLogger(__name__)


class LoadAwareEngine:
    """
    Wraps multiple AsyncLLMEngine instances with least-inflight dispatch.

    This reduces tail latency under uneven request completion times versus strict
    round-robin, while preserving fair spread with a rotating cursor on ties.
    """
    
    def __init__(self, engines: list):
        self._engines = engines
        self._inflight: List[int] = [0 for _ in engines]
        self._cursor = 0
        self._lock = threading.Lock()

    def _acquire_engine(self) -> tuple[int, Any]:
        with self._lock:
            n = len(self._engines)
            if n == 0:
                raise RuntimeError("No engines configured")
            min_load = min(self._inflight)
            start = self._cursor % n
            selected = start
            for offset in range(n):
                idx = (start + offset) % n
                if self._inflight[idx] == min_load:
                    selected = idx
                    break
            self._inflight[selected] += 1
            self._cursor = (selected + 1) % n
            return selected, self._engines[selected]

    def _release_engine(self, idx: int) -> None:
        with self._lock:
            if 0 <= idx < len(self._inflight):
                self._inflight[idx] = max(0, self._inflight[idx] - 1)
    
    def generate(self, *args, **kwargs) -> AsyncGenerator:
        """Route generate() to the least-loaded engine and release on completion."""
        engine_idx, engine = self._acquire_engine()
        try:
            base_generator = engine.generate(*args, **kwargs)
        except Exception:
            self._release_engine(engine_idx)
            raise

        async def _wrapped() -> AsyncGenerator:
            try:
                async for item in base_generator:
                    yield item
            finally:
                self._release_engine(engine_idx)

        return _wrapped()

    def inflight_snapshot(self) -> List[int]:
        """Current in-flight request count per engine index."""
        with self._lock:
            return list(self._inflight)
    
    def __getattr__(self, name):
        """Forward any other attribute access to the first engine."""
        return getattr(self._engines[0], name)


class MultiEngineModel:
    """
    Drop-in replacement for SparkTTSModel that wraps N engines.
    
    Exposes:
        .engine    -> LoadAwareEngine (dispatches generate() calls)
        .tokenizer -> shared tokenizer (same across all engines)
        .model_path, .model_type, etc.
    """
    
    def __init__(self, models: list):
        """
        Args:
            models: List of SparkTTSModel instances (each has its own engine)
        """
        self._models = models
        self._primary = models[0]
        
        # Shared tokenizer (identical across all engines)
        self.tokenizer = self._primary.tokenizer
        self.model_path = self._primary.model_path
        self.model_type = self._primary.model_type
        self.hf_token = self._primary.hf_token
        
        # Load-aware engine wrapper
        engines = [m.engine for m in models]
        self.engine = LoadAwareEngine(engines)
    
    def get_model_type(self):
        return self._primary.get_model_type()
    
    def get_tokenizer(self):
        return self.tokenizer
    
    def get_engine(self):
        return self.engine


def create_multi_engine_model(
    model_path: str,
    num_engines: int = 2,
    hf_token: str = None,
    gpu_memory_per_engine: float = 0.12,
    **engine_kwargs,
) -> MultiEngineModel:
    """
    Create a multi-engine model with N vLLM instances on the same GPU.
    
    Memory budget per engine (A100-80GB):
        - Model weights: ~1.3GB (shared via CUDA, but each engine loads independently)
        - KV cache at 0.12 utilization: ~8GB
        - Total per engine: ~10GB
        - 3 engines: ~30GB, leaves 50GB free
    
    Args:
        model_path: Path to model directory
        num_engines: Number of vLLM engine instances (default: 2)
        hf_token: HuggingFace token
        gpu_memory_per_engine: GPU memory fraction per engine
        **engine_kwargs: Additional AsyncEngineArgs overrides passed to each engine
    
    Returns:
        MultiEngineModel with round-robin dispatch
    """
    from veena3modal.core.model_loader import SparkTTSModel
    
    logger.info(f"Creating multi-engine model: {num_engines} engines x {gpu_memory_per_engine:.0%} GPU each")
    
    models = []
    for i in range(num_engines):
        logger.info(f"  Initializing engine {i+1}/{num_engines}...")
        model = SparkTTSModel(
            model_path=model_path,
            hf_token=hf_token,
            gpu_memory_utilization=gpu_memory_per_engine,
            **engine_kwargs,
        )
        models.append(model)
        logger.info(f"  Engine {i+1} ready")
    
    multi = MultiEngineModel(models)
    logger.info(f"Multi-engine model ready: {num_engines} engines, round-robin dispatch")
    return multi
