"""
Multi-Engine vLLM Wrapper - Transparent round-robin across N engine instances.

OPTIMIZATION: A 0.5B model uses ~1.3GB weights + ~5-10GB KV cache per engine.
On an 80GB A100, we can comfortably run 3 engines at ~12GB each (36GB total).
This triples effective prefill capacity, directly cutting streaming TTFB under
concurrent load by ~3x.

The wrapper is transparent to pipeline code: it exposes the same `.engine.generate()`
interface, but routes calls to different underlying engines via round-robin.

Usage:
    # Creates N engines, each with their own vLLM subprocess
    model = create_multi_engine_model(model_path, num_engines=3, gpu_memory_per_engine=0.12)
    
    # Pipeline uses it identically to a single engine:
    model.engine.generate(prompt, sampling_params, request_id)  # routed to engine N
"""

import logging
import threading
from typing import Any, AsyncGenerator

logger = logging.getLogger(__name__)


class RoundRobinEngine:
    """
    Wraps multiple AsyncLLMEngine instances with round-robin dispatch.
    
    Each call to generate() goes to the next engine in rotation.
    Thread-safe via atomic counter.
    """
    
    def __init__(self, engines: list):
        self._engines = engines
        self._counter = 0
        self._lock = threading.Lock()
    
    def _next_engine(self):
        with self._lock:
            engine = self._engines[self._counter % len(self._engines)]
            self._counter += 1
            return engine
    
    def generate(self, *args, **kwargs) -> AsyncGenerator:
        """Route generate() to next engine via round-robin."""
        engine = self._next_engine()
        return engine.generate(*args, **kwargs)
    
    def __getattr__(self, name):
        """Forward any other attribute access to the first engine."""
        return getattr(self._engines[0], name)


class MultiEngineModel:
    """
    Drop-in replacement for SparkTTSModel that wraps N engines.
    
    Exposes:
        .engine    -> RoundRobinEngine (dispatches generate() calls)
        .tokenizer -> shared tokenizer (same across all engines)
        .model_path, .model_type, etc.
    """
    
    def __init__(self, models: list):
        """
        Args:
            models: List of SparkTTSModel instances (each has its own engine)
        """
        self._models = models
        self._primary = models[0]
        
        # Shared tokenizer (identical across all engines)
        self.tokenizer = self._primary.tokenizer
        self.model_path = self._primary.model_path
        self.model_type = self._primary.model_type
        self.hf_token = self._primary.hf_token
        
        # Round-robin engine wrapper
        engines = [m.engine for m in models]
        self.engine = RoundRobinEngine(engines)
    
    def get_model_type(self):
        return self._primary.get_model_type()
    
    def get_tokenizer(self):
        return self.tokenizer
    
    def get_engine(self):
        return self.engine


def create_multi_engine_model(
    model_path: str,
    num_engines: int = 2,
    hf_token: str = None,
    gpu_memory_per_engine: float = 0.12,
) -> MultiEngineModel:
    """
    Create a multi-engine model with N vLLM instances on the same GPU.
    
    Memory budget per engine (A100-80GB):
        - Model weights: ~1.3GB (shared via CUDA, but each engine loads independently)
        - KV cache at 0.12 utilization: ~8GB
        - Total per engine: ~10GB
        - 3 engines: ~30GB, leaves 50GB free
    
    Args:
        model_path: Path to model directory
        num_engines: Number of vLLM engine instances (default: 2)
        hf_token: HuggingFace token
        gpu_memory_per_engine: GPU memory fraction per engine
    
    Returns:
        MultiEngineModel with round-robin dispatch
    """
    from veena3modal.core.model_loader import SparkTTSModel
    
    logger.info(f"Creating multi-engine model: {num_engines} engines x {gpu_memory_per_engine:.0%} GPU each")
    
    models = []
    for i in range(num_engines):
        logger.info(f"  Initializing engine {i+1}/{num_engines}...")
        model = SparkTTSModel(
            model_path=model_path,
            hf_token=hf_token,
            gpu_memory_utilization=gpu_memory_per_engine,
        )
        models.append(model)
        logger.info(f"  Engine {i+1} ready")
    
    multi = MultiEngineModel(models)
    logger.info(f"Multi-engine model ready: {num_engines} engines, round-robin dispatch")
    return multi
