#!/usr/bin/env python3
"""
GPU Auto-Configuration for Universal Worker

Auto-detects GPU type and configures batch sizes to use 80% of available VRAM.
Includes OOM protection with automatic batch size reduction.

Supports: T4 (16GB), A10 (24GB), RTX 3090 (24GB), A100 (40/80GB), etc.
"""

import os
import gc
import logging
import subprocess
from dataclasses import dataclass
from typing import Optional, Tuple

logger = logging.getLogger("Worker.GPUConfig")


@dataclass
class GPUConfig:
    """GPU-specific configuration with safe defaults."""
    
    # GPU info
    gpu_name: str = "Unknown"
    gpu_vram_gb: float = 16.0
    
    # Batch sizes (computed based on VRAM)
    embedding_batch_size: int = 128
    music_batch_size: int = 32
    chunk_workers: int = 1
    
    # VAD (CPU-bound)
    vad_workers: int = 32
    
    # Safety settings
    target_utilization: float = 0.80  # Use 80% of VRAM max
    min_free_vram_gb: float = 2.0     # Keep 2GB free minimum
    
    # OOM recovery
    current_reduction_factor: float = 1.0  # Starts at 100%, reduced on OOM


# Pre-defined configs for known GPUs (tested values)
# Format: (embedding_batch, music_batch, chunk_workers)
GPU_PROFILES = {
    # Entry-level (16GB)
    "T4": (128, 32, 1),
    "V100-16GB": (128, 32, 1),
    "P100": (96, 24, 1),
    "Tesla T4": (128, 32, 1),
    
    # Mid-range (24GB)
    "RTX 3090": (256, 64, 1),
    "RTX 4090": (384, 96, 1),
    "A10": (256, 64, 1),
    "A10G": (256, 64, 1),
    
    # High-end (40-48GB)
    "A6000": (512, 128, 2),
    "A40": (512, 128, 2),
    "A100-40GB": (512, 128, 2),
    "A100-PCIE-40GB": (512, 128, 2),
    
    # Ultra (80GB+)
    "A100-80GB": (754, 150, 2),
    "A100-SXM4-80GB": (754, 150, 2),  # Specific A100 variant
    "A100": (754, 150, 2),  # Default A100
    "H100": (1024, 200, 2),
}


def detect_gpu() -> Tuple[str, float]:
    """
    Detect GPU name and VRAM using nvidia-smi.
    
    Returns:
        Tuple of (gpu_name, vram_gb)
    """
    try:
        # Get GPU name
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'],
            capture_output=True, text=True, timeout=10
        )
        gpu_name = result.stdout.strip().split('\n')[0] if result.returncode == 0 else "Unknown"
        
        # Get VRAM
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=memory.total', '--format=csv,noheader,nounits'],
            capture_output=True, text=True, timeout=10
        )
        vram_mb = int(result.stdout.strip().split('\n')[0]) if result.returncode == 0 else 16000
        vram_gb = vram_mb / 1024
        
        return gpu_name, vram_gb
        
    except Exception as e:
        logger.warning(f"GPU detection failed: {e}, using defaults")
        return "Unknown", 16.0


def get_gpu_profile(gpu_name: str) -> Optional[Tuple[int, int, int]]:
    """
    Get pre-defined profile for a known GPU.
    
    Returns:
        Tuple of (embedding_batch, music_batch, chunk_workers) or None
    """
    # Try exact match first
    if gpu_name in GPU_PROFILES:
        return GPU_PROFILES[gpu_name]
    
    # Try partial match
    gpu_name_lower = gpu_name.lower()
    for profile_name, config in GPU_PROFILES.items():
        if profile_name.lower() in gpu_name_lower:
            return config
    
    return None


def compute_config_from_vram(vram_gb: float, target_util: float = 0.80) -> Tuple[int, int, int]:
    """
    Compute batch sizes based on available VRAM.
    
    VRAM budget estimation (empirically tested):
    - Diarization model (pyannote): ~8GB fixed
    - ECAPA embeddings: ~0.01GB per batch item
    - PANNs music: ~0.02GB per batch item
    - Buffer for peaks: 2GB
    
    Args:
        vram_gb: Total VRAM in GB
        target_util: Target utilization (0.0-1.0)
    
    Returns:
        Tuple of (embedding_batch, music_batch, chunk_workers)
    """
    usable_vram = vram_gb * target_util
    
    # Reserve for diarization model
    diarization_reserve = 8.0
    remaining = usable_vram - diarization_reserve
    
    if remaining <= 2:
        # Very tight VRAM - minimal config
        return (64, 16, 1)
    
    # Split remaining between embeddings and music
    # Embeddings: ~0.01GB per item
    # Music: ~0.02GB per item
    embed_budget = remaining * 0.5
    music_budget = remaining * 0.3
    
    embedding_batch = max(32, min(int(embed_budget / 0.01), 1024))
    music_batch = max(16, min(int(music_budget / 0.02), 256))
    
    # Chunk workers (parallel diarization streams)
    # Each stream needs ~8GB
    chunk_workers = max(1, int(usable_vram / 12))
    
    return (embedding_batch, music_batch, chunk_workers)


def auto_configure_gpu(target_utilization: float = 0.80) -> GPUConfig:
    """
    Auto-detect GPU and compute optimal configuration.
    
    Priority:
    1. Use pre-defined profile if GPU is known
    2. Otherwise compute from VRAM
    
    Args:
        target_utilization: Target VRAM utilization (default 80%)
    
    Returns:
        GPUConfig with optimal settings
    """
    gpu_name, vram_gb = detect_gpu()
    
    logger.info(f"🔍 Detected GPU: {gpu_name} ({vram_gb:.1f} GB)")
    
    config = GPUConfig(
        gpu_name=gpu_name,
        gpu_vram_gb=vram_gb,
        target_utilization=target_utilization
    )
    
    # Try pre-defined profile
    profile = get_gpu_profile(gpu_name)
    if profile:
        config.embedding_batch_size, config.music_batch_size, config.chunk_workers = profile
        logger.info(f"✅ Using pre-defined profile for {gpu_name}")
    else:
        # Compute from VRAM
        embed, music, chunks = compute_config_from_vram(vram_gb, target_utilization)
        config.embedding_batch_size = embed
        config.music_batch_size = music
        config.chunk_workers = chunks
        logger.info(f"📊 Computed config from VRAM: embed={embed}, music={music}, chunks={chunks}")
    
    # Auto-tune VAD workers based on CPU
    try:
        import multiprocessing
        cpu_count = multiprocessing.cpu_count()
        config.vad_workers = max(8, min(int(cpu_count * 0.8), 64))
    except Exception:
        config.vad_workers = 32
    
    logger.info(f"🎯 Final config: embed={config.embedding_batch_size}, "
                f"music={config.music_batch_size}, chunks={config.chunk_workers}, "
                f"vad={config.vad_workers}")
    
    return config


def reduce_batch_sizes(config: GPUConfig, factor: float = 0.7) -> GPUConfig:
    """
    Reduce batch sizes after OOM error.
    
    Args:
        config: Current GPUConfig
        factor: Reduction factor (0.7 = reduce to 70%)
    
    Returns:
        New GPUConfig with reduced batch sizes
    """
    config.current_reduction_factor *= factor
    
    # Apply reduction
    config.embedding_batch_size = max(16, int(config.embedding_batch_size * factor))
    config.music_batch_size = max(8, int(config.music_batch_size * factor))
    
    logger.warning(f"⚠️ OOM Recovery: Reduced batches to embed={config.embedding_batch_size}, "
                   f"music={config.music_batch_size} (reduction factor: {config.current_reduction_factor:.2f})")
    
    return config


def clear_gpu_memory():
    """Aggressively clear GPU memory after OOM or between videos."""
    try:
        import torch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            gc.collect()
    except Exception:
        pass


class OOMProtectedContext:
    """
    Context manager for OOM-protected GPU operations.
    
    Usage:
        config = auto_configure_gpu()
        
        with OOMProtectedContext(config, max_retries=3) as ctx:
            result = process_video(video_id, ctx.config)
    
    On OOM:
    - Clears GPU cache
    - Reduces batch sizes
    - Retries automatically
    """
    
    def __init__(self, config: GPUConfig, max_retries: int = 3):
        self.config = config
        self.max_retries = max_retries
        self.attempts = 0
        self.last_error = None
    
    def __enter__(self):
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is None:
            return False
        
        # Check if OOM
        is_oom = (
            "CUDA out of memory" in str(exc_val) or
            "out of memory" in str(exc_val).lower() or
            isinstance(exc_val, RuntimeError) and "memory" in str(exc_val).lower()
        )
        
        if is_oom and self.attempts < self.max_retries:
            self.attempts += 1
            self.last_error = str(exc_val)
            
            # Clear memory and reduce batch sizes
            clear_gpu_memory()
            self.config = reduce_batch_sizes(self.config)
            
            logger.warning(f"🔄 OOM detected (attempt {self.attempts}/{self.max_retries}), "
                          f"retrying with reduced batch sizes...")
            
            # Return True to suppress exception and allow retry
            return True
        
        return False


# Quick test
if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    config = auto_configure_gpu()
    print(f"\nGPU Config: {config}")

