#!/usr/bin/env python3
"""
Hybrid Music Detection Module (v7.4)

=== TWO-STAGE HYBRID APPROACH ===

Stage 1: inaSpeechSegmenter (FAST)
- Specialized CNN for speech/music/noise segmentation
- Processes entire audio quickly (~10x realtime)
- Binary decisions: speech / music / noise / noEnergy
- NO fine-grained classification, just screening

Stage 2: PANNs CNN14 (DETAILED, only when needed)
- Only runs on segments flagged as "music" or "noise" by Stage 1
- Provides AudioSet class labels (527 classes)
- Confirms/refines the detection
- Skipped for clean speech segments

=== PERFORMANCE GAINS ===
Typical podcast (90% clean speech):
- Before: PANNs on 100% of chunks → 16s
- After: INA on 100% (fast) + PANNs on 10% (suspicious) → ~4s

=== DECISION LOGIC ===
| INA Result | PANNs Run? | Final Decision |
|------------|------------|----------------|
| speech     | NO         | clean          |
| noEnergy   | NO         | skip (silence) |
| music      | YES        | PANNs decides  |
| noise      | YES        | PANNs decides  |
"""

import time
import logging
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional, Any
import numpy as np
import torch

logger = logging.getLogger("FastPipelineV6.HybridMusicDetection")


@dataclass
class HybridDetectionConfig:
    """Configuration for hybrid music detection."""
    
    # Stage 1: inaSpeechSegmenter settings
    ina_vad_engine: str = 'sm'  # 'sm' (small), 'smn' (small+noise)
    ina_detect_gender: bool = False  # We don't need gender, skip for speed
    
    # Stage 2: PANNs settings (only for suspicious segments)
    panns_chunk_duration: float = 1.5
    panns_batch_size: int = 64
    panns_sample_rate: int = 32000
    
    # Decision thresholds
    panns_music_threshold: float = 0.20  # PANNs music prob threshold
    panns_noise_threshold: float = 0.25  # PANNs noise prob threshold
    
    # Per-segment thresholds (same as before)
    music_ratio_clean: float = 0.0
    music_ratio_demucs: float = 0.15
    music_mean_clean: float = 0.10
    music_mean_demucs: float = 0.25
    
    noise_ratio_clean: float = 0.0
    noise_ratio_demucs: float = 0.20
    noise_mean_clean: float = 0.10
    noise_mean_demucs: float = 0.30
    
    strict_tts_mode: bool = True


@dataclass
class HybridChunkInfo:
    """Per-chunk detection results from hybrid approach."""
    idx: int
    start_time: float
    end_time: float
    
    # Stage 1 (INA) results
    ina_label: str  # 'speech', 'music', 'noise', 'noEnergy'
    ina_confidence: float
    
    # Stage 2 (PANNs) results - only populated if INA flagged suspicious
    panns_music_prob: Optional[float] = None
    panns_noise_prob: Optional[float] = None
    panns_ran: bool = False
    
    # Final decision
    has_music: bool = False
    has_noise: bool = False
    has_contamination: bool = False
    decision: str = 'clean'  # 'clean', 'needs_demucs', 'heavy_contamination'


class HybridMusicDetector:
    """
    Two-stage music detection: Fast INA screening + PANNs verification.
    
    Usage:
        detector = HybridMusicDetector(audio_buffer, device='cuda')
        detector.build()
        
        for segment in segments:
            stats = detector.get_segment_stats(segment['start'], segment['end'])
            if stats['decision'] == 'needs_demucs':
                segment['needs_demucs'] = True
    """
    
    def __init__(
        self,
        audio_buffer,
        panns_model=None,
        config: Optional[HybridDetectionConfig] = None,
        device: Optional[torch.device] = None
    ):
        self.audio_buffer = audio_buffer
        self.panns_model = panns_model
        self.config = config or HybridDetectionConfig()
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self._ina_segmenter = None
        self._ina_results = []  # List of (label, start, end) tuples
        self._chunks: Dict[int, HybridChunkInfo] = {}
        self._time_index: Dict[Tuple[float, float], int] = {}
        self._built = False
        self._stats = {}
    
    def _load_ina_segmenter(self):
        """Lazy load inaSpeechSegmenter."""
        if self._ina_segmenter is not None:
            return
        
        try:
            from inaSpeechSegmenter import Segmenter
            
            logger.info("Loading inaSpeechSegmenter...")
            start = time.time()
            
            # 'smn' model detects speech, music, noise
            # detect_gender=False for speed (we don't need it)
            self._ina_segmenter = Segmenter(
                vad_engine=self.config.ina_vad_engine,
                detect_gender=self.config.ina_detect_gender
            )
            
            logger.info(f"✅ inaSpeechSegmenter loaded in {time.time()-start:.1f}s")
            
        except ImportError:
            logger.error("❌ inaSpeechSegmenter not installed. Run: pip install inaSpeechSegmenter")
            raise
    
    def build(self) -> int:
        """
        Build hybrid detection cache using two-stage approach.
        
        Returns:
            Number of suspicious chunks that required PANNs verification
        """
        if self._built:
            return len([c for c in self._chunks.values() if c.panns_ran])
        
        start_time = time.time()
        logger.info("=" * 70)
        logger.info("⚡ HYBRID MUSIC DETECTION (v7.4)")
        logger.info("   Stage 1: inaSpeechSegmenter (fast screening)")
        logger.info("   Stage 2: PANNs CNN14 (verification, suspicious only)")
        logger.info("=" * 70)
        
        # === STAGE 1: Fast INA Segmentation ===
        t0 = time.time()
        self._run_ina_segmentation()
        ina_time = time.time() - t0
        logger.info(f"   [Stage 1] INA segmentation: {ina_time:.1f}s")
        
        # Convert INA results to chunks matching our standard format
        chunk_duration = self.config.panns_chunk_duration
        sr = self.audio_buffer.sample_rate
        total_duration = self.audio_buffer.duration
        
        # Create uniform chunks for consistent processing
        num_chunks = int(total_duration / chunk_duration)
        
        suspicious_chunks = []
        clean_chunks = 0
        
        for i in range(num_chunks):
            chunk_start = i * chunk_duration
            chunk_end = min((i + 1) * chunk_duration, total_duration)
            
            # Find INA label for this chunk (majority vote if spans multiple)
            ina_label, ina_conf = self._get_ina_label_for_range(chunk_start, chunk_end)
            
            chunk_info = HybridChunkInfo(
                idx=i,
                start_time=chunk_start,
                end_time=chunk_end,
                ina_label=ina_label,
                ina_confidence=ina_conf,
            )
            
            # Decide if PANNs verification needed
            if ina_label in ['music', 'noise']:
                suspicious_chunks.append(chunk_info)
            else:
                # Clean speech or silence - no PANNs needed
                chunk_info.decision = 'clean' if ina_label == 'speech' else 'skip'
                clean_chunks += 1
            
            self._chunks[i] = chunk_info
            self._time_index[(round(chunk_start, 3), round(chunk_end, 3))] = i
        
        logger.info(f"   Stage 1 results: {clean_chunks} clean, {len(suspicious_chunks)} suspicious")
        
        # === STAGE 2: PANNs verification (only suspicious chunks) ===
        panns_verified = 0
        if suspicious_chunks and self.panns_model is not None:
            t0 = time.time()
            panns_verified = self._run_panns_verification(suspicious_chunks)
            panns_time = time.time() - t0
            logger.info(f"   [Stage 2] PANNs verification: {panns_time:.1f}s ({len(suspicious_chunks)} chunks)")
        elif suspicious_chunks:
            logger.warning("   ⚠️ PANNs model not provided, suspicious chunks unverified")
            # Mark all suspicious as needs_demucs (conservative)
            for chunk in suspicious_chunks:
                chunk.has_contamination = True
                chunk.decision = 'needs_demucs'
        
        # Calculate stats
        total_chunks = len(self._chunks)
        chunks_with_contamination = sum(1 for c in self._chunks.values() if c.has_contamination)
        
        self._stats = {
            'total_chunks': total_chunks,
            'clean_chunks': clean_chunks,
            'suspicious_chunks': len(suspicious_chunks),
            'panns_verified': panns_verified,
            'chunks_with_contamination': chunks_with_contamination,
            'contamination_ratio': chunks_with_contamination / total_chunks if total_chunks > 0 else 0,
            'panns_skip_ratio': clean_chunks / total_chunks if total_chunks > 0 else 0,
            'build_time_sec': time.time() - start_time,
            'ina_time_sec': ina_time,
        }
        
        self._built = True
        
        logger.info(f"✅ Hybrid detection complete: {chunks_with_contamination} contaminated "
                   f"({self._stats['contamination_ratio']*100:.1f}%)")
        logger.info(f"   PANNs skipped: {self._stats['panns_skip_ratio']*100:.1f}% of chunks")
        logger.info(f"   Total time: {self._stats['build_time_sec']:.1f}s")
        
        return panns_verified
    
    def _run_ina_segmentation(self):
        """Run inaSpeechSegmenter on full audio."""
        self._load_ina_segmenter()
        
        # INA works best on files, so we'll save temp file
        import tempfile
        import soundfile as sf
        
        with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
            tmp_path = tmp.name
        
        try:
            # Save audio to temp file (INA expects file path)
            waveform = self.audio_buffer.waveform_np
            sr = self.audio_buffer.sample_rate
            sf.write(tmp_path, waveform, sr)
            
            # Run segmentation
            self._ina_results = self._ina_segmenter(tmp_path)
            
            # Log distribution
            labels = {}
            for label, start, end in self._ina_results:
                labels[label] = labels.get(label, 0) + (end - start)
            
            total = sum(labels.values())
            logger.info(f"   INA segments: {len(self._ina_results)}")
            for label, duration in sorted(labels.items(), key=lambda x: -x[1]):
                pct = duration / total * 100 if total > 0 else 0
                logger.info(f"      {label}: {duration:.1f}s ({pct:.1f}%)")
                
        finally:
            import os
            try:
                os.unlink(tmp_path)
            except:
                pass
    
    def _get_ina_label_for_range(self, start: float, end: float) -> Tuple[str, float]:
        """
        Get INA label for a time range (majority vote if spans multiple segments).
        
        Returns:
            (label, confidence) where confidence is the proportion of the range
            covered by the winning label.
        """
        if not self._ina_results:
            return 'speech', 0.5
        
        label_durations = {}
        
        for label, seg_start, seg_end in self._ina_results:
            # Calculate overlap
            overlap_start = max(start, seg_start)
            overlap_end = min(end, seg_end)
            
            if overlap_end > overlap_start:
                overlap_duration = overlap_end - overlap_start
                label_durations[label] = label_durations.get(label, 0) + overlap_duration
        
        if not label_durations:
            return 'noEnergy', 1.0
        
        # Return label with most coverage
        total_duration = end - start
        best_label = max(label_durations.items(), key=lambda x: x[1])
        confidence = best_label[1] / total_duration if total_duration > 0 else 0
        
        return best_label[0], confidence
    
    def _run_panns_verification(self, suspicious_chunks: List[HybridChunkInfo]) -> int:
        """
        Run PANNs on suspicious chunks only.
        
        Returns:
            Number of chunks confirmed as contaminated
        """
        if not suspicious_chunks:
            return 0
        
        from src.music_detection import ALL_MUSIC_CLASSES, ALL_NOISE_CLASSES
        
        # Prepare audio chunks for PANNs (requires 32kHz)
        import scipy.signal as signal
        
        source_sr = self.audio_buffer.sample_rate
        target_sr = self.config.panns_sample_rate
        target_samples = int(self.config.panns_chunk_duration * target_sr)
        
        waveform = self.audio_buffer.waveform_np
        
        # Resampling ratio
        from math import gcd
        g = gcd(target_sr, source_sr)
        up = target_sr // g
        down = source_sr // g
        
        audio_chunks = np.zeros((len(suspicious_chunks), target_samples), dtype=np.float32)
        
        for i, chunk in enumerate(suspicious_chunks):
            start_sample = int(chunk.start_time * source_sr)
            end_sample = min(int(chunk.end_time * source_sr), len(waveform))
            
            if start_sample < len(waveform):
                chunk_16k = waveform[start_sample:end_sample]
                chunk_32k = signal.resample_poly(chunk_16k, up, down).astype(np.float32)
                chunk_len = min(len(chunk_32k), target_samples)
                audio_chunks[i, :chunk_len] = chunk_32k[:chunk_len]
        
        # Run PANNs inference in batches
        batch_size = self.config.panns_batch_size
        contaminated = 0
        
        for batch_start in range(0, len(audio_chunks), batch_size):
            batch_end = min(batch_start + batch_size, len(audio_chunks))
            batch = audio_chunks[batch_start:batch_end]
            
            try:
                clipwise_output, _ = self.panns_model.inference(batch)
                
                for j in range(len(batch)):
                    chunk = suspicious_chunks[batch_start + j]
                    probs = clipwise_output[j]
                    
                    # Music probability (max across music classes)
                    music_probs = [probs[idx] for idx in ALL_MUSIC_CLASSES if idx < len(probs)]
                    music_prob = float(max(music_probs)) if music_probs else 0.0
                    
                    # Noise probability (max across noise classes)
                    noise_probs = [probs[idx] for idx in ALL_NOISE_CLASSES if idx < len(probs)]
                    noise_prob = float(max(noise_probs)) if noise_probs else 0.0
                    
                    # Update chunk info
                    chunk.panns_ran = True
                    chunk.panns_music_prob = music_prob
                    chunk.panns_noise_prob = noise_prob
                    chunk.has_music = music_prob > self.config.panns_music_threshold
                    chunk.has_noise = noise_prob > self.config.panns_noise_threshold
                    chunk.has_contamination = chunk.has_music or chunk.has_noise
                    
                    if chunk.has_contamination:
                        contaminated += 1
                        chunk.decision = 'needs_demucs'
                    else:
                        # PANNs cleared it - INA was false positive
                        chunk.decision = 'clean'
                    
            except Exception as e:
                logger.error(f"PANNs batch failed: {e}")
                # Mark batch as suspicious (conservative)
                for j in range(batch_end - batch_start):
                    chunk = suspicious_chunks[batch_start + j]
                    chunk.has_contamination = True
                    chunk.decision = 'needs_demucs'
                    contaminated += batch_end - batch_start
        
        return contaminated
    
    def get_chunks_in_range(self, start: float, end: float) -> List[HybridChunkInfo]:
        """Get all chunks that overlap with the given time range."""
        return sorted([
            c for c in self._chunks.values()
            if start <= (c.start_time + c.end_time) / 2 <= end
        ], key=lambda c: c.start_time)
    
    def get_segment_stats(self, start: float, end: float) -> Dict[str, Any]:
        """
        Compute segment-level stats for hybrid detection.
        
        Returns dict with decision: 'clean', 'needs_demucs', or 'heavy_contamination'
        """
        chunks = self.get_chunks_in_range(start, end)
        
        if not chunks:
            return {
                'music_mean': 0.0, 'music_max': 0.0, 'music_ratio': 0.0,
                'noise_mean': 0.0, 'noise_max': 0.0, 'noise_ratio': 0.0,
                'has_music': False, 'has_noise': False, 'has_contamination': False,
                'decision': 'clean', 'chunks_analyzed': 0,
                'ina_labels': {}, 'panns_ran_count': 0
            }
        
        # INA label distribution
        ina_labels = {}
        for c in chunks:
            ina_labels[c.ina_label] = ina_labels.get(c.ina_label, 0) + 1
        
        # PANNs results (only from chunks that ran PANNs)
        panns_chunks = [c for c in chunks if c.panns_ran]
        
        if panns_chunks:
            music_probs = [c.panns_music_prob or 0 for c in panns_chunks]
            noise_probs = [c.panns_noise_prob or 0 for c in panns_chunks]
            
            music_mean = float(np.mean(music_probs))
            music_max = float(np.max(music_probs))
            noise_mean = float(np.mean(noise_probs))
            noise_max = float(np.max(noise_probs))
        else:
            music_mean = music_max = noise_mean = noise_max = 0.0
        
        # Calculate ratios
        music_ratio = sum(1 for c in chunks if c.has_music) / len(chunks)
        noise_ratio = sum(1 for c in chunks if c.has_noise) / len(chunks)
        contamination_ratio = sum(1 for c in chunks if c.has_contamination) / len(chunks)
        
        # Decision logic (same as before but using hybrid results)
        if self.config.strict_tts_mode:
            if contamination_ratio == 0 and music_mean < self.config.music_mean_clean:
                decision = 'clean'
            elif (music_ratio < self.config.music_ratio_demucs and
                  noise_ratio < self.config.noise_ratio_demucs and
                  music_mean < self.config.music_mean_demucs):
                decision = 'needs_demucs'
            else:
                decision = 'heavy_contamination'
        else:
            if music_ratio < self.config.music_ratio_clean and music_mean < self.config.music_mean_clean:
                decision = 'clean'
            elif music_ratio < self.config.music_ratio_demucs and music_mean < self.config.music_mean_demucs:
                decision = 'needs_demucs'
            else:
                decision = 'heavy_contamination'
        
        return {
            'music_mean': round(music_mean, 4),
            'music_max': round(music_max, 4),
            'music_ratio': round(music_ratio, 4),
            'noise_mean': round(noise_mean, 4),
            'noise_max': round(noise_max, 4),
            'noise_ratio': round(noise_ratio, 4),
            'contamination_ratio': round(contamination_ratio, 4),
            'has_music': bool(any(c.has_music for c in chunks)),
            'has_noise': bool(any(c.has_noise for c in chunks)),
            'has_contamination': bool(any(c.has_contamination for c in chunks)),
            'decision': decision,
            'chunks_analyzed': len(chunks),
            'panns_ran_count': len(panns_chunks),
            'ina_labels': {k: int(v) for k, v in ina_labels.items()},  # Ensure JSON serializable
        }
    
    def get_stats(self) -> Dict[str, Any]:
        """Get overall detection statistics."""
        return {k: round(v, 4) if isinstance(v, float) else v 
                for k, v in self._stats.items()}


def build_segment_music_stats_hybrid(
    segments: List[Dict],
    detector: HybridMusicDetector
) -> Tuple[List[Dict], Dict[str, Any]]:
    """
    Compute segment-level music/noise stats using hybrid detection.
    
    Drop-in replacement for build_segment_music_stats_batch.
    """
    logger.info("Computing segment-level contamination (hybrid approach)...")
    
    stats = {
        'segments_clean': 0,
        'segments_needs_demucs': 0,
        'segments_heavy_contamination': 0,
        'duration_clean': 0.0,
        'duration_needs_demucs': 0.0,
        'duration_heavy_contamination': 0.0,
    }
    
    for seg in segments:
        if seg.get('speaker') in ['OVERLAP', 'NON_SPEECH'] or seg.get('status') == 'unusable':
            continue
        
        music_stats = detector.get_segment_stats(seg['start'], seg['end'])
        seg['music_stats'] = music_stats
        decision = music_stats['decision']
        duration = seg.get('duration', seg['end'] - seg['start'])
        
        if decision == 'clean':
            stats['segments_clean'] += 1
            stats['duration_clean'] += duration
        elif decision == 'needs_demucs':
            stats['segments_needs_demucs'] += 1
            stats['duration_needs_demucs'] += duration
            seg['needs_demucs'] = True
            if music_stats.get('has_music'):
                seg['contamination_type'] = 'music'
            elif music_stats.get('has_noise'):
                seg['contamination_type'] = 'noise'
        else:  # heavy_contamination
            stats['segments_heavy_contamination'] += 1
            stats['duration_heavy_contamination'] += duration
            seg['status'] = 'unusable'
            seg['unusable_reason'] = 'heavy_contamination'
    
    total = stats['duration_clean'] + stats['duration_needs_demucs'] + stats['duration_heavy_contamination']
    if total > 0:
        stats['pct_clean'] = round(stats['duration_clean'] / total * 100, 1)
        stats['pct_needs_demucs'] = round(stats['duration_needs_demucs'] / total * 100, 1)
        stats['pct_heavy_contamination'] = round(stats['duration_heavy_contamination'] / total * 100, 1)
    else:
        stats['pct_clean'] = stats['pct_needs_demucs'] = stats['pct_heavy_contamination'] = 0.0
    
    # Add hybrid-specific stats
    detector_stats = detector.get_stats()
    stats['panns_skip_ratio'] = detector_stats.get('panns_skip_ratio', 0)
    stats['ina_time_sec'] = detector_stats.get('ina_time_sec', 0)
    stats['hybrid_mode'] = True
    
    logger.info(f"   ✅ Clean: {stats['segments_clean']} ({stats['pct_clean']}%)")
    logger.info(f"   ⚠️ Needs Demucs: {stats['segments_needs_demucs']} ({stats['pct_needs_demucs']}%)")
    logger.info(f"   ❌ Heavy Contamination: {stats['segments_heavy_contamination']} ({stats['pct_heavy_contamination']}%)")
    logger.info(f"   ⚡ PANNs skipped: {stats['panns_skip_ratio']*100:.1f}% of chunks")
    
    return segments, stats

