"""
Triton ASR Server Configuration

Maps Triton Inference Server concepts (model config, dynamic batching,
instance groups) to Python dataclasses for local serving.

Structured so config.pbtxt generation is trivial when migrating to full Triton.

Environment variables override defaults (12-factor app pattern).
"""

import os
from dataclasses import dataclass, field
from typing import List, Optional


@dataclass
class ModelConfig:
    """
    Model loading configuration.
    Equivalent to Triton's model repository entry.
    """
    # Local path or HuggingFace download target
    model_path: str = os.environ.get("MODEL_PATH", "")
    hf_repo: str = os.environ.get("HF_MODEL_REPO", "BayAreaBoys/nemotron-hindi")
    hf_file: str = os.environ.get("HF_MODEL_FILE", "final_model.nemo")
    hf_cache_dir: str = os.environ.get("HF_CACHE_DIR", "/tmp/triton_asr_model")

    # Compute
    device: str = "cuda"
    device_id: int = 0
    compute_dtype: str = "bfloat16"
    use_amp: bool = True

    # Streaming pipeline params
    att_context_size: List[int] = field(default_factory=lambda: [70, 6])
    num_slots: int = 1024
    pipeline_batch_size: int = 512  # NeMo's internal max batch (frames across streams)
    sample_rate: int = 16000


@dataclass
class DynamicBatchingConfig:
    """
    Dynamic batching configuration.
    Mirrors Triton's dynamic_batching { } block in config.pbtxt.

    Algorithm:
      - Collect frames from ready streams into a batch
      - Dispatch when EITHER:
        (a) batch reaches max_batch_size, OR
        (b) max_batch_latency_ms has elapsed since first queued frame
      - preferred_batch_sizes: GPU runs most efficiently at these sizes
        (Triton pads up to the next preferred size for optimal CUDA kernel launch)
    """
    max_batch_size: int = 128
    max_batch_latency_ms: float = 50.0  # Max ms to wait before dispatching
    preferred_batch_sizes: List[int] = field(default_factory=lambda: [8, 16, 32, 64, 128])
    # WHY 50ms: At 80ms chunk size, 50ms wait means we fire within ~1 chunk arrival.
    # Under high load, the event signal fires immediately (0ms wait).


@dataclass
class InstanceGroupConfig:
    """
    Instance group configuration.
    Mirrors Triton's instance_group { } block.

    For now: single instance per GPU (safest for stateful RNNT).
    Multi-instance only if profiling shows idle GPU between batches.
    """
    count: int = 1  # Number of model instances per GPU
    kind: str = "GPU"
    gpus: List[int] = field(default_factory=lambda: [0])


@dataclass
class ServerConfig:
    """
    Server configuration for the gateway + engine.
    """
    # Network
    http_port: int = int(os.environ.get("SERVER_PORT", "8000"))
    grpc_port: int = int(os.environ.get("GRPC_PORT", "8001"))
    metrics_port: int = int(os.environ.get("METRICS_PORT", "8002"))

    # Capacity
    max_active_streams: int = int(os.environ.get("MAX_ACTIVE_STREAMS", "200"))

    # VAD
    enable_vad: bool = os.environ.get("ENABLE_VAD", "true").lower() == "true"
    vad_threshold: float = float(os.environ.get("VAD_THRESHOLD", "0.5"))

    # Warmup
    warmup_iterations: int = 3

    # Logging
    log_level: str = os.environ.get("LOG_LEVEL", "INFO")
    log_file: str = os.environ.get("LOG_FILE", "/tmp/triton_asr.log")

    # Timeouts
    ws_idle_timeout_s: float = 300.0  # Close WS after 5min of no audio
    http_timeout_s: float = 120.0  # Max HTTP request duration


@dataclass
class TritonASRConfig:
    """Top-level config aggregating all sub-configs."""
    model: ModelConfig = field(default_factory=ModelConfig)
    batching: DynamicBatchingConfig = field(default_factory=DynamicBatchingConfig)
    instance_group: InstanceGroupConfig = field(default_factory=InstanceGroupConfig)
    server: ServerConfig = field(default_factory=ServerConfig)

    @classmethod
    def from_env(cls) -> "TritonASRConfig":
        """Build config from environment variables (Vast.ai / Docker pattern)."""
        cfg = cls()
        # Override max_active_streams from env
        if os.environ.get("MAX_ACTIVE_STREAMS"):
            cfg.server.max_active_streams = int(os.environ["MAX_ACTIVE_STREAMS"])
        if os.environ.get("MAX_BATCH_SIZE"):
            cfg.batching.max_batch_size = int(os.environ["MAX_BATCH_SIZE"])
        if os.environ.get("MAX_BATCH_LATENCY_MS"):
            cfg.batching.max_batch_latency_ms = float(os.environ["MAX_BATCH_LATENCY_MS"])
        return cfg
