#!/usr/bin/env python3
"""
Model management - Hot inference with adaptive loading.

Key improvements:
- community-1 FIRST (better per pyannote docs), 3.1 as fallback
- Lazy loading with compute-aware batching
- Better GPU memory management

=== v7.0 OPTIMIZATION: cuDNN Benchmark + TF32 ===
Enables CUDA optimizations for faster inference:
- cudnn.benchmark=True: Auto-tune convolution algorithms
- TF32 precision: 3x faster matmul on Ampere+ GPUs
Expected gain: 5-10% on diarization (conv-heavy model)
"""

import time
import logging
import gc
import torch

# === v7.0 OPTIMIZATION: Enable CUDA optimizations ===
# These flags enable hardware-specific optimizations:
# - cudnn.benchmark: Auto-select fastest conv algorithm for input sizes
# - TF32: Use TensorFloat-32 for faster matrix ops (Ampere+ GPUs)
# Must be set BEFORE model loading for full effect
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True  # Auto-tune convolutions
    torch.backends.cuda.matmul.allow_tf32 = True  # Fast matmul
    torch.backends.cudnn.allow_tf32 = True  # Fast convolutions

# === CRITICAL: Patch torch.load BEFORE any other imports ===
_original_torch_load = torch.load
def _patched_torch_load(*args, **kwargs):
    kwargs['weights_only'] = False
    return _original_torch_load(*args, **kwargs)
torch.load = _patched_torch_load

# Patch PyTorch Lightning too
try:
    import pytorch_lightning.core.saving as pl_saving
    _orig_pl_load = pl_saving.pl_load
    def _patched_pl_load(path_or_url, map_location=None, **kwargs):
        return _original_torch_load(path_or_url, map_location=map_location, weights_only=False)
    pl_saving.pl_load = _patched_pl_load
except ImportError:
    pass

try:
    import lightning_fabric.utilities.cloud_io as cloud_io
    _orig_fabric_load = cloud_io._load
    def _patched_fabric_load(path_or_url, map_location=None):
        return _original_torch_load(path_or_url, map_location=map_location, weights_only=False)
    cloud_io._load = _patched_fabric_load
except (ImportError, AttributeError):
    pass

logger = logging.getLogger("FastPipelineV6.Models")


class ModelManager:
    """
    Singleton model manager - loads all models once, keeps them hot.
    
    Model priority (per user request):
    1. community-1 (better performance per pyannote docs)
    2. speaker-diarization-3.1 (fallback)
    
    Key optimization: Load once, inference many times.
    """
    _instance = None
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._initialized = False
        return cls._instance
    
    def __init__(self):
        if self._initialized:
            return
        self._initialized = True
        self.device = None
        self.silero_vad = None
        self.silero_utils = None
        self.diarization_pipeline = None
        self.segmentation_model = None
        self.embedding_model = None
        self.panns_model = None  # v6.7: PANNs CNN14 for music detection
        self._loaded = False
        self._panns_loaded = False  # Separate flag for optional PANNs
        self._model_name = None  # Track which model loaded
    
    def get_device(self) -> torch.device:
        """Get CUDA device or fallback to CPU."""
        if self.device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        return self.device
    
    def load_all(self, config):
        """
        Load all models once into GPU memory.
        
        Order: VAD -> Diarization (community-1 first) -> Segmentation -> Embedding
        """
        if self._loaded:
            logger.info(f"✅ Models already loaded: {self._model_name} (hot inference)")
            return
        
        logger.info("=" * 70)
        logger.info("🚀 LOADING ALL MODELS INTO GPU")
        logger.info("=" * 70)
        start = time.time()
        
        device = self.get_device()
        logger.info(f"Device: {device}")
        
        # Log GPU memory before
        if torch.cuda.is_available():
            free = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved(0)
            logger.info(f"GPU Memory before: {free / 1e9:.2f}GB free")
        
        # 1. Silero VAD (CPU - fast)
        logger.info("Loading Silero VAD...")
        model, utils = torch.hub.load(
            repo_or_dir='snakers4/silero-vad',
            model='silero_vad',
            force_reload=False,
            onnx=False
        )
        self.silero_vad = model
        self.silero_utils = utils
        logger.info("✅ Silero VAD loaded")
        
        # 2. PyAnnote Diarization - COMMUNITY-1 FIRST (per user request)
        from pyannote.audio import Pipeline
        from pyannote.audio.pipelines import SpeakerDiarization
        import inspect
        
        logger.info("Loading PyAnnote Diarization...")
        
        # Patch SpeakerDiarization to ignore unsupported params (like 'plda')
        _original_sd_init = SpeakerDiarization.__init__
        _valid_params = set(inspect.signature(_original_sd_init).parameters.keys())
        
        def _patched_sd_init(self, *args, **kwargs):
            filtered_kwargs = {k: v for k, v in kwargs.items() if k in _valid_params}
            removed = set(kwargs.keys()) - set(filtered_kwargs.keys())
            if removed:
                logger.debug(f"Filtered out unsupported params: {removed}")
            return _original_sd_init(self, *args, **filtered_kwargs)
        
        SpeakerDiarization.__init__ = _patched_sd_init
        
        # === PRIORITY: community-1 FIRST (better performance per pyannote docs) ===
        # Repo: https://huggingface.co/pyannote/speaker-diarization-community-1
        # Fallback: speaker-diarization-3.1 if community-1 fails
        models_to_try = [
            ("pyannote/speaker-diarization-community-1", "community-1"),  # FIRST CHOICE
            ("pyannote/speaker-diarization-3.1", "speaker-diarization-3.1"),  # FALLBACK
        ]
        
        for model_name, short_name in models_to_try:
            try:
                logger.info(f"Trying {short_name}...")
                # pyannote.audio 4.x uses 'token' parameter
                self.diarization_pipeline = Pipeline.from_pretrained(
                    model_name,
                    token=config.hf_token
                )
                self._model_name = short_name
                logger.info(f"✅ Loaded {short_name} (preferred model)")
                break
            except Exception as e:
                logger.warning(f"Failed to load {short_name}: {e}")
                continue
        
        if self.diarization_pipeline is None:
            raise RuntimeError("Failed to load any diarization pipeline")
        
        self.diarization_pipeline.to(device)
        logger.info("✅ PyAnnote Diarization loaded on GPU")
        
        # 3. Segmentation model - DISABLED (not used in current pipeline)
        # NOTE: segmentation-3.0 was for frame-level OSD but we use community-1 diarization instead
        # The community-1 model handles both diarization AND overlap detection in one pass
        # Keeping this disabled to avoid noisy "gated model" warnings
        self.segmentation_model = None
        
        # 4. ECAPA-TDNN for embeddings
        logger.info("Loading ECAPA-TDNN...")
        from speechbrain.inference.speaker import EncoderClassifier
        self.embedding_model = EncoderClassifier.from_hparams(
            source="speechbrain/spkrec-ecapa-voxceleb",
            savedir="/ephemeral/models/ecapa",
            run_opts={"device": str(device)}
        )
        logger.info("✅ ECAPA-TDNN loaded")
        
        self._loaded = True
        elapsed = time.time() - start
        
        # Log GPU memory after
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / 1e9
            reserved = torch.cuda.memory_reserved() / 1e9
            logger.info(f"GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
        
        logger.info(f"✅ All models loaded in {elapsed:.1f}s (STAYING HOT)")
        logger.info("=" * 70)
    
    def clear_cache(self, aggressive=False):
        """
        Clear GPU cache between operations.
        
        Args:
            aggressive: If True, also empty reserved memory and run gc
        """
        if torch.cuda.is_available():
            if aggressive:
                gc.collect()
                torch.cuda.empty_cache()
                torch.cuda.reset_peak_memory_stats()
            else:
                torch.cuda.empty_cache()
    
    def get_vram_free_gb(self) -> float:
        """Get current free VRAM in GB."""
        if torch.cuda.is_available():
            total = torch.cuda.get_device_properties(0).total_memory
            reserved = torch.cuda.memory_reserved(0)
            return (total - reserved) / (1024**3)
        return 0.0
    
    def load_panns(self, device: str = 'cuda'):
        """
        Load PANNs CNN14 model for music detection (lazy loading).
        
        === v6.7 FEATURE: Music Detection ===
        PANNs (Pretrained Audio Neural Networks) trained on AudioSet.
        Used to detect music/instruments in audio chunks.
        
        Args:
            device: 'cuda' or 'cpu'
        
        Returns:
            Loaded AudioTagging model
        """
        if self._panns_loaded and self.panns_model is not None:
            logger.info("✅ PANNs CNN14 already loaded (hot inference)")
            return self.panns_model
        
        logger.info("Loading PANNs CNN14 for music detection...")
        start = time.time()
        
        try:
            from panns_inference import AudioTagging
            
            # Determine device
            if device == 'cuda' and not torch.cuda.is_available():
                device = 'cpu'
                logger.warning("CUDA not available, using CPU for PANNs")
            
            # Load model - uses Cnn14 by default
            # PANNs API: AudioTagging(model=None, checkpoint_path=None, device='cuda')
            self.panns_model = AudioTagging(device=device)
            self._panns_loaded = True
            elapsed = time.time() - start
            logger.info(f"✅ PANNs CNN14 loaded on {device} in {elapsed:.1f}s")
            
            return self.panns_model
            
        except ImportError as e:
            logger.error(f"❌ PANNs not installed. Run: pip install panns-inference")
            raise ImportError("panns-inference package required for music detection. "
                            "Install with: pip install panns-inference") from e
        except Exception as e:
            logger.error(f"❌ Failed to load PANNs: {e}")
            raise


# Global singleton instance
MODELS = ModelManager()
