#!/usr/bin/env python3
"""
Compute Resource Detection & Monitoring - Adaptive Pipeline Configuration

Auto-detects: nproc, vCPUs, GPU vRAM, CPU cores
Monitors: CPU/GPU utilization at each stage
Adapts: Worker counts, batch sizes, parallelism based on available resources

Key principle: Use ALL available compute at every stage.
- CPU-bound: Maximize workers up to core count
- GPU-bound: Keep GPU saturated, prepare CPU work in parallel
- Memory-bound: Adapt batch sizes to available RAM/VRAM
"""

import os
import time
import logging
import threading
from dataclasses import dataclass, field
from typing import Dict, Optional, Callable, Any
from contextlib import contextmanager

logger = logging.getLogger("FastPipelineV5.Compute")


@dataclass
class ComputeResources:
    """Detected compute resources for the current system."""
    # CPU
    cpu_count: int = 1           # Physical cores
    cpu_count_logical: int = 1    # Logical cores (with hyperthreading)
    nproc: int = 1               # nproc output (usable processors)
    
    # Memory
    ram_total_gb: float = 0.0
    ram_available_gb: float = 0.0
    
    # GPU
    gpu_available: bool = False
    gpu_name: str = ""
    gpu_vram_total_gb: float = 0.0
    gpu_vram_free_gb: float = 0.0
    gpu_compute_capability: str = ""
    
    # === UTILIZATION CAP (v6.9) ===
    # Target max utilization - leave headroom for spikes/system stability
    max_utilization: float = 0.80  # 80% cap
    
    # Derived optimal settings (auto-computed based on hardware)
    optimal_vad_workers: int = 8
    optimal_embedding_batch: int = 8
    optimal_music_batch: int = 64
    optimal_download_workers: int = 4
    optimal_cpu_workers: int = 8
    optimal_chunk_workers: int = 1  # GPU diarization parallelism
    optimal_chunk_duration: float = 300.0  # seconds; smaller on low-VRAM GPUs to avoid OOM


@dataclass
class StageMetrics:
    """Metrics for a single pipeline stage."""
    stage_name: str
    start_time: float = 0.0
    end_time: float = 0.0
    duration: float = 0.0
    cpu_percent_avg: float = 0.0
    cpu_percent_max: float = 0.0
    gpu_percent_avg: float = 0.0
    gpu_percent_max: float = 0.0
    gpu_memory_used_gb: float = 0.0
    items_processed: int = 0
    throughput: float = 0.0  # items/sec


class ComputeMonitor:
    """
    Singleton compute resource monitor.
    
    Features:
    1. Auto-detect system resources on init
    2. Compute optimal worker counts
    3. Track utilization during stages
    4. Suggest optimizations
    """
    _instance = None
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._initialized = False
        return cls._instance
    
    def __init__(self):
        if self._initialized:
            return
        self._initialized = True
        
        self.resources = ComputeResources()
        self.stage_metrics: Dict[str, StageMetrics] = {}
        self._monitoring = False
        self._monitor_thread: Optional[threading.Thread] = None
        self._current_stage: Optional[str] = None
        self._cpu_samples: list = []
        self._gpu_samples: list = []
        
        # Auto-detect resources
        self._detect_resources()
    
    def _detect_resources(self):
        """Detect all available compute resources."""
        logger.info("=" * 70)
        logger.info("🖥️  DETECTING COMPUTE RESOURCES (Adaptive Pipeline)")
        logger.info("=" * 70)
        
        # === CPU Detection ===
        try:
            import multiprocessing
            self.resources.cpu_count_logical = multiprocessing.cpu_count()
            
            # Try to get physical cores (more accurate for worker planning)
            try:
                import psutil
                self.resources.cpu_count = psutil.cpu_count(logical=False) or self.resources.cpu_count_logical
            except ImportError:
                self.resources.cpu_count = self.resources.cpu_count_logical
            
            # Get nproc (what's actually usable)
            try:
                import subprocess
                result = subprocess.run(['nproc'], capture_output=True, text=True, timeout=5)
                self.resources.nproc = int(result.stdout.strip())
            except Exception:
                self.resources.nproc = self.resources.cpu_count_logical
            
        except Exception as e:
            logger.warning(f"CPU detection failed: {e}, using defaults")
            self.resources.cpu_count = 4
            self.resources.cpu_count_logical = 4
            self.resources.nproc = 4
        
        logger.info(f"   CPU: {self.resources.cpu_count} physical / {self.resources.cpu_count_logical} logical cores")
        logger.info(f"   nproc: {self.resources.nproc} usable processors")
        
        # === Memory Detection ===
        try:
            import psutil
            mem = psutil.virtual_memory()
            self.resources.ram_total_gb = mem.total / (1024**3)
            self.resources.ram_available_gb = mem.available / (1024**3)
            logger.info(f"   RAM: {self.resources.ram_available_gb:.1f} / {self.resources.ram_total_gb:.1f} GB available")
        except ImportError:
            logger.warning("psutil not available, memory detection skipped")
        
        # === GPU Detection ===
        try:
            import torch
            if torch.cuda.is_available():
                self.resources.gpu_available = True
                self.resources.gpu_name = torch.cuda.get_device_name(0)
                
                # Get VRAM info
                total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
                allocated = torch.cuda.memory_allocated(0) / (1024**3)
                reserved = torch.cuda.memory_reserved(0) / (1024**3)
                
                self.resources.gpu_vram_total_gb = total
                self.resources.gpu_vram_free_gb = total - reserved
                
                # Compute capability
                props = torch.cuda.get_device_properties(0)
                self.resources.gpu_compute_capability = f"{props.major}.{props.minor}"
                
                logger.info(f"   GPU: {self.resources.gpu_name}")
                logger.info(f"   VRAM: {self.resources.gpu_vram_free_gb:.1f} / {self.resources.gpu_vram_total_gb:.1f} GB free")
                logger.info(f"   Compute: SM {self.resources.gpu_compute_capability}")
            else:
                logger.info("   GPU: None available (CPU-only mode)")
        except Exception as e:
            logger.warning(f"GPU detection failed: {e}")
        
        # === Compute Optimal Settings ===
        self._compute_optimal_settings()
        
        logger.info("-" * 70)
        logger.info(f"🎯 OPTIMAL SETTINGS (auto-computed @ {self.resources.max_utilization*100:.0f}% cap):")
        logger.info(f"   VAD workers: {self.resources.optimal_vad_workers}")
        logger.info(f"   Embedding batch: {self.resources.optimal_embedding_batch}")
        logger.info(f"   Music batch: {self.resources.optimal_music_batch}")
        logger.info(f"   Download workers: {self.resources.optimal_download_workers}")
        logger.info(f"   CPU workers: {self.resources.optimal_cpu_workers}")
        logger.info(f"   Diarization workers: {self.resources.optimal_chunk_workers}")
        logger.info(f"   Chunk duration: {self.resources.optimal_chunk_duration:.0f}s")
        logger.info("=" * 70)
    
    def _compute_optimal_settings(self):
        """
        Compute optimal worker counts based on detected resources.
        
        All settings respect max_utilization cap (default 80%) to leave headroom
        for system stability and unexpected spikes.
        """
        r = self.resources
        cap = r.max_utilization  # 0.80 by default
        
        # === VAD Workers (CPU-bound, I/O-light) ===
        # VAD is embarrassingly parallel - use most CPU cores
        # Each worker loads ~200MB model, so balance with RAM
        
        # Use available processors with utilization cap
        base_workers = int(max(r.cpu_count, r.nproc) * cap)
        r.optimal_vad_workers = max(8, min(base_workers, 64))
        
        # Adjust based on RAM (each VAD worker ~300MB with buffer)
        max_by_ram = int(r.ram_available_gb * cap * 1024 / 300)
        r.optimal_vad_workers = min(r.optimal_vad_workers, max_by_ram)
        r.optimal_vad_workers = max(4, r.optimal_vad_workers)  # Minimum 4
        
        # === Embedding Batch Size (GPU-bound, VRAM-sensitive) ===
        # ECAPA-TDNN uses ~0.3GB per batch of 64
        # Scale with available VRAM, respecting cap
        if r.gpu_available:
            total_vram = r.gpu_vram_total_gb
            usable_vram = r.gpu_vram_free_gb * cap

            # Low-VRAM GPUs (e.g. T4 16GB): be conservative to avoid OOM.
            # The pyannote diarization pipeline + ECAPA model already reserve significant VRAM.
            if total_vram <= 16.5:
                r.optimal_embedding_batch = 128
            else:
                # Embeddings are relatively lightweight, but keep a cap.
                r.optimal_embedding_batch = max(64, min(int(usable_vram * 40), 1024))
        else:
            r.optimal_embedding_batch = 16
        
        # === Music Detection Batch (GPU-bound, same as embeddings) ===
        # PANNs CNN14 is heavier: ~1GB for batch=64
        if r.gpu_available:
            total_vram = r.gpu_vram_total_gb
            usable_vram = r.gpu_vram_free_gb * cap

            if total_vram <= 16.5:
                r.optimal_music_batch = 32
            else:
                r.optimal_music_batch = max(32, min(int(usable_vram * 8), 256))
        else:
            r.optimal_music_batch = 16
        
        # === Download Workers (I/O-bound, network bottleneck) ===
        # Modest parallelism - network is the real limiter
        r.optimal_download_workers = max(2, min(8, int(r.nproc * cap / 8)))
        
        # === General CPU Workers ===
        r.optimal_cpu_workers = max(4, int(r.nproc * cap / 4))
        
        # === Diarization Chunk Workers (GPU VRAM-bound) ===
        # Pyannote diarization uses ~6-8GB VRAM per stream
        # Most systems can only run 1-2 concurrent diarizations
        if r.gpu_available:
            total_vram = r.gpu_vram_total_gb
            usable_vram = total_vram * cap
            vram_per_diarization = 8.0  # GB per concurrent stream (heuristic)
            r.optimal_chunk_workers = max(1, int(usable_vram / vram_per_diarization))

            # On 16GB GPUs we force single-stream settings and smaller chunks for stability.
            if total_vram <= 16.5:
                r.optimal_chunk_workers = 1
                r.optimal_chunk_duration = 180.0
            else:
                r.optimal_chunk_duration = 300.0
        else:
            r.optimal_chunk_workers = 1
            r.optimal_chunk_duration = 300.0
    
    def get_optimal_config(self) -> Dict[str, Any]:
        """Return optimal configuration dict for Config class."""
        r = self.resources
        return {
            'vad_workers': r.optimal_vad_workers,
            'max_workers': r.optimal_cpu_workers,
            'download_workers': r.optimal_download_workers,
            'chunk_workers': r.optimal_chunk_workers,
            'chunk_duration': r.optimal_chunk_duration,
            'embedding_batch_size': r.optimal_embedding_batch,
            'music_batch_size': r.optimal_music_batch,
        }
    
    def set_max_utilization(self, cap: float):
        """
        Set max utilization cap and recompute optimal settings.
        
        Args:
            cap: Utilization cap between 0.1 and 1.0 (default 0.8 = 80%)
        """
        cap = max(0.1, min(1.0, cap))
        self.resources.max_utilization = cap
        self._compute_optimal_settings()
        logger.info(f"🎚️  Utilization cap updated to {cap*100:.0f}%, settings recomputed")
    
    def apply_to_config(self, config) -> None:
        """
        Apply optimal settings to a Config object.
        
        This updates the config's parallelism and batch size settings
        based on auto-detected system resources.
        
        Args:
            config: Config dataclass instance to update
        """
        optimal = self.get_optimal_config()
        for key, value in optimal.items():
            if hasattr(config, key):
                setattr(config, key, value)
    
    @contextmanager
    def monitor_stage(self, stage_name: str, items: int = 0):
        """
        Context manager to monitor a pipeline stage.
        
        Usage:
            with COMPUTE.monitor_stage("VAD", items=100) as metrics:
                # do work
            print(metrics.duration, metrics.cpu_percent_avg)
        """
        metrics = StageMetrics(stage_name=stage_name, items_processed=items)
        metrics.start_time = time.time()
        
        self._current_stage = stage_name
        self._cpu_samples = []
        self._gpu_samples = []
        
        # Start background monitoring
        self._start_monitoring()
        
        try:
            yield metrics
        finally:
            # Stop monitoring
            self._stop_monitoring()
            
            metrics.end_time = time.time()
            metrics.duration = metrics.end_time - metrics.start_time
            
            # Compute averages
            if self._cpu_samples:
                metrics.cpu_percent_avg = sum(self._cpu_samples) / len(self._cpu_samples)
                metrics.cpu_percent_max = max(self._cpu_samples)
            
            if self._gpu_samples:
                gpu_utils = [s[0] for s in self._gpu_samples]
                gpu_mems = [s[1] for s in self._gpu_samples]
                metrics.gpu_percent_avg = sum(gpu_utils) / len(gpu_utils) if gpu_utils else 0
                metrics.gpu_percent_max = max(gpu_utils) if gpu_utils else 0
                metrics.gpu_memory_used_gb = max(gpu_mems) if gpu_mems else 0
            
            if items > 0 and metrics.duration > 0:
                metrics.throughput = items / metrics.duration
            
            self.stage_metrics[stage_name] = metrics
            self._current_stage = None
            
            # Log stage summary
            self._log_stage_summary(metrics)
    
    def _start_monitoring(self):
        """Start background utilization monitoring."""
        self._monitoring = True
        self._monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
        self._monitor_thread.start()
    
    def _stop_monitoring(self):
        """Stop background monitoring."""
        self._monitoring = False
        if self._monitor_thread:
            self._monitor_thread.join(timeout=1.0)
    
    def _monitor_loop(self):
        """Background thread that samples CPU/GPU utilization."""
        try:
            import psutil
            has_psutil = True
        except ImportError:
            has_psutil = False
        
        try:
            import torch
            has_torch = torch.cuda.is_available()
        except ImportError:
            has_torch = False
        
        while self._monitoring:
            try:
                # Sample CPU
                if has_psutil:
                    cpu = psutil.cpu_percent(interval=0.1)
                    self._cpu_samples.append(cpu)
                
                # Sample GPU
                if has_torch:
                    try:
                        import pynvml
                        pynvml.nvmlInit()
                        handle = pynvml.nvmlDeviceGetHandleByIndex(0)
                        util = pynvml.nvmlDeviceGetUtilizationRates(handle)
                        mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
                        self._gpu_samples.append((util.gpu, mem.used / (1024**3)))
                        pynvml.nvmlShutdown()
                    except Exception:
                        # Fallback: just get memory from torch
                        import torch
                        mem = torch.cuda.memory_allocated() / (1024**3)
                        self._gpu_samples.append((0, mem))
                
                time.sleep(0.5)  # Sample every 500ms
            except Exception:
                time.sleep(0.5)
    
    def _log_stage_summary(self, metrics: StageMetrics):
        """Log summary for a completed stage."""
        cpu_status = "✅ HIGH" if metrics.cpu_percent_avg > 70 else "⚠️ LOW" if metrics.cpu_percent_avg < 30 else "➖ OK"
        gpu_status = "✅ HIGH" if metrics.gpu_percent_avg > 70 else "⚠️ LOW" if metrics.gpu_percent_avg < 30 else "➖ OK"
        
        logger.info(f"📊 {metrics.stage_name}: {metrics.duration:.1f}s | "
                   f"CPU: {metrics.cpu_percent_avg:.0f}% ({cpu_status}) | "
                   f"GPU: {metrics.gpu_percent_avg:.0f}% ({gpu_status}) | "
                   f"VRAM: {metrics.gpu_memory_used_gb:.1f}GB")
        
        # Optimization hints
        if metrics.cpu_percent_avg < 30 and metrics.duration > 5:
            logger.info(f"   💡 HINT: CPU underutilized - consider more workers for {metrics.stage_name}")
        if metrics.gpu_percent_avg > 0 and metrics.gpu_percent_avg < 30 and metrics.duration > 5:
            logger.info(f"   💡 HINT: GPU underutilized - could run CPU work in parallel")
    
    def summary(self) -> str:
        """Return full pipeline summary."""
        lines = [
            "\n" + "=" * 70,
            "📊 COMPUTE UTILIZATION SUMMARY",
            "=" * 70
        ]
        
        total_time = 0
        cpu_weighted = 0
        gpu_weighted = 0
        
        for name, m in self.stage_metrics.items():
            lines.append(f"   {name:20s}: {m.duration:6.1f}s | CPU: {m.cpu_percent_avg:5.1f}% | GPU: {m.gpu_percent_avg:5.1f}%")
            total_time += m.duration
            cpu_weighted += m.cpu_percent_avg * m.duration
            gpu_weighted += m.gpu_percent_avg * m.duration
        
        if total_time > 0:
            avg_cpu = cpu_weighted / total_time
            avg_gpu = gpu_weighted / total_time
            lines.append("-" * 70)
            lines.append(f"   {'TOTAL':20s}: {total_time:6.1f}s | CPU: {avg_cpu:5.1f}% | GPU: {avg_gpu:5.1f}%")
            
            # Efficiency rating
            efficiency = (avg_cpu + avg_gpu) / 2
            if efficiency > 60:
                rating = "🟢 EXCELLENT"
            elif efficiency > 40:
                rating = "🟡 GOOD"
            else:
                rating = "🔴 NEEDS OPTIMIZATION"
            lines.append(f"   Efficiency: {rating}")
        
        lines.append("=" * 70)
        return "\n".join(lines)
    
    def refresh_gpu_memory(self):
        """Refresh GPU memory stats (call after clearing cache)."""
        try:
            import torch
            if torch.cuda.is_available():
                allocated = torch.cuda.memory_allocated(0) / (1024**3)
                reserved = torch.cuda.memory_reserved(0) / (1024**3)
                self.resources.gpu_vram_free_gb = self.resources.gpu_vram_total_gb - reserved
        except Exception:
            pass


# Global singleton
COMPUTE = ComputeMonitor()


def get_adaptive_batch_size(segment_lengths: list, base_batch: int, max_memory_mb: int = 500) -> int:
    """
    Compute adaptive batch size based on segment lengths.
    
    Shorter segments = larger batches possible
    Longer segments = smaller batches to avoid OOM
    """
    if not segment_lengths:
        return base_batch
    
    avg_length = sum(segment_lengths) / len(segment_lengths)
    
    # Assume ~50KB per second of audio at 16kHz float32
    avg_memory_per_seg_mb = avg_length * 0.05
    
    if avg_memory_per_seg_mb <= 0:
        return base_batch
    
    # How many can fit in max_memory?
    optimal = int(max_memory_mb / avg_memory_per_seg_mb)
    
    return max(1, min(optimal, base_batch * 2))

