"""
Dual vLLM Engine Architecture - Design Document and Implementation Scaffold.

OPTIMIZATION RATIONALE:
With gpu_memory_utilization reduced to 0.25, the 0.5B model uses only ~20GB of 80GB VRAM.
This leaves ~60GB free. Running a SECOND vLLM engine instance on the same GPU doubles
effective prefill capacity, directly addressing the #1 streaming bottleneck.

Architecture Options:

OPTION A: Duplicate Engines (simpler)
    - Engine A and Engine B both handle full requests
    - Round-robin or least-loaded routing
    - Each gets gpu_memory_utilization=0.15 (~100 concurrent seqs each)
    - Total: ~200 concurrent sequences, 2x prefill parallelism
    
OPTION B: Prefill/Decode Split (more complex, better for TTFB)
    - Engine A: Handles ONLY prefill (prompt encoding), lower memory, prioritizes latency
    - Engine B: Handles ONLY decode (token generation), higher batch throughput
    - Requires request migration between engines (complex)
    
OPTION C: Multiple Lightweight Instances (best for horizontal scaling)
    - Run 4-8 separate processes, each with gpu_memory_utilization=0.05
    - Each handles its own request queue
    - CUDA MPS (Multi-Process Service) for efficient GPU sharing
    - Simplest to implement, best for scaling

RECOMMENDATION: Start with Option A (duplicate engines), measure impact,
then consider Option C for production horizontal scaling.

Usage (future):
    from veena3modal.core.dual_engine import DualEngineRouter
    
    router = DualEngineRouter(
        model_path=model_path,
        num_engines=2,
        gpu_memory_per_engine=0.15,
    )
    
    # In streaming pipeline:
    engine = router.get_engine()  # Round-robin
    results = engine.generate(prompt, sampling_params, request_id)
"""

from __future__ import annotations

import logging
from typing import Optional, Any
from dataclasses import dataclass, field

logger = logging.getLogger(__name__)


@dataclass
class EngineConfig:
    """Configuration for a single vLLM engine instance."""
    model_path: str
    gpu_memory_utilization: float = 0.15
    max_model_len: int = 4096
    instance_id: int = 0


@dataclass
class DualEngineRouter:
    """
    Routes requests across multiple vLLM engine instances.
    
    NOT YET WIRED INTO PRODUCTION - this is the design scaffold.
    To activate, modify tts_runtime.py to use DualEngineRouter instead of
    a single SparkTTSModel instance.
    
    Memory Budget (A100-80GB):
        - 2 engines x (1.3GB model + 10GB KV cache) = ~22.6GB
        - BiCodec decoder: ~0.6GB
        - Total: ~23.2GB (fits comfortably)
        - Each engine supports ~190 concurrent sequences
        - Combined: ~380 concurrent sequences with 2x prefill throughput
    """
    model_path: str = ""
    num_engines: int = 2
    gpu_memory_per_engine: float = 0.15
    engines: list = field(default_factory=list)
    _counter: int = 0
    
    def initialize(self):
        """Initialize all engine instances."""
        from veena3modal.core.model_loader import SparkTTSModel
        
        for i in range(self.num_engines):
            logger.info(f"Initializing vLLM engine {i+1}/{self.num_engines}...")
            engine = SparkTTSModel(
                model_path=self.model_path,
                gpu_memory_utilization=self.gpu_memory_per_engine,
            )
            self.engines.append(engine)
            logger.info(f"  Engine {i+1} ready")
        
        logger.info(f"DualEngineRouter: {self.num_engines} engines initialized")
    
    def get_engine(self) -> Any:
        """Get the next engine via round-robin."""
        if not self.engines:
            raise RuntimeError("No engines initialized")
        engine = self.engines[self._counter % len(self.engines)]
        self._counter += 1
        return engine
    
    def get_tokenizer(self) -> Any:
        """Get tokenizer from first engine (all share the same tokenizer)."""
        if not self.engines:
            raise RuntimeError("No engines initialized")
        return self.engines[0].tokenizer


# === Activation Instructions ===
# To enable dual engine mode:
#
# 1. In tts_runtime.py, replace:
#     model = SparkTTSModel(model_path=model_path, ...)
#   with:
#     from veena3modal.core.dual_engine import DualEngineRouter
#     router = DualEngineRouter(model_path=model_path, num_engines=2, gpu_memory_per_engine=0.15)
#     router.initialize()
#     model = router.get_engine()  # Use first engine for init
#
# 2. In streaming_pipeline.py, modify generate_speech_stream_indic():
#     Replace: self.model.engine.generate(...)
#     With:    router.get_engine().engine.generate(...)
#
# 3. Pass the router through to the pipeline via __init__
#
# This doubles prefill capacity, cutting streaming TTFB under concurrency by ~50%.
