#!/usr/bin/env python3
"""Music Detection Module - PANNs CNN14 for TTS quality.

=== v7.0 OPTIMIZATIONS ===

1. Early-Exit Sampling: Sample 10% of chunks first. If all clean (<0.1 music prob),
   skip full analysis and mark all segments as clean. Saves ~20s on clean podcasts.

2. Pre-computed Resampling: Resample entire audio 16kHz→32kHz once at start,
   then slice for chunks. Eliminates per-chunk scipy.signal.resample overhead.
"""

import time
import logging
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional, Any
import numpy as np
import torch

logger = logging.getLogger("FastPipelineV6.MusicDetection")

# === AudioSet Class Indices for TTS-Problematic Sounds ===
# Reference: https://github.com/audioset/ontology

# Speech (keep, but detect for ratio)
SPEECH_CLASS_IDX = 0

# Music-related classes (137-282 in AudioSet ontology)
MUSIC_RELATED_CLASSES = list(range(137, 282))  # Instruments, genres, music elements

# Singing/Vocal music (problematic for TTS - will confuse voice model)
SINGING_CLASSES = [16, 17, 18, 19, 20]  # Singing, choir, humming, chanting

# === NEW: Other problematic sounds for TTS ===
# These should trigger demucs processing or rejection

# Crowd/Audience (background chatter ruins TTS)
CROWD_CLASSES = [22, 23, 24, 25, 26, 27]  # Crowd, applause, cheering, laughter

# Environmental noise
ENVIRONMENTAL_CLASSES = [
    288, 289, 290,  # Rain, thunder, wind
    308, 309, 310, 311, 312,  # Vehicle sounds (car, bus, truck)
    337, 338, 339, 340,  # Mechanical sounds (engine, machinery)
    394, 395, 396,  # Construction, drilling
    420, 421,  # Typing, keyboard
    500, 501, 502,  # Alarm, siren
]

# Room acoustics / reverb indicators (echoey audio = bad TTS)
REVERB_INDICATOR_CLASSES = [
    503, 504,  # Echo, reverberation markers (if present)
]

# Background hum/noise
BACKGROUND_NOISE_CLASSES = [
    132, 133,  # Static, hum, buzz
    505, 506,  # White noise, pink noise
]

# All music classes (original)
ALL_MUSIC_CLASSES = set(MUSIC_RELATED_CLASSES + SINGING_CLASSES)

# === NEW: All problematic classes for TTS (beyond just music) ===
ALL_NOISE_CLASSES = set(
    CROWD_CLASSES + 
    ENVIRONMENTAL_CLASSES + 
    REVERB_INDICATOR_CLASSES + 
    BACKGROUND_NOISE_CLASSES
)

# Combined: Everything that should trigger demucs or rejection
ALL_PROBLEMATIC_CLASSES = ALL_MUSIC_CLASSES | ALL_NOISE_CLASSES


@dataclass
class MusicDetectionConfig:
    chunk_duration: float = 1.5
    panns_sample_rate: int = 32000
    source_sample_rate: int = 16000
    batch_size: int = 64
    
    # === STRICT THRESHOLDS FOR TTS QUALITY ===
    # A chunk "has contamination" if music OR noise prob > this
    music_prob_threshold: float = 0.20  # Lowered from 0.30 for stricter detection
    noise_prob_threshold: float = 0.25  # Threshold for non-music noise
    
    # Per-segment decision thresholds (STRICT but realistic)
    # NOTE: ANY chunk contamination → needs_demucs. Only truly clean audio passes.
    # IMPORTANT: Baseline neural net output is ~0.04-0.06, so mean thresholds must be > 0.05
    music_ratio_clean: float = 0.0     # ANY chunk detection → needs_demucs
    music_ratio_demucs: float = 0.15   # 0-15% → demucs, >15% → unusable
    music_mean_clean: float = 0.10     # ADJUSTED: 0.10 accounts for baseline (was 0.05)
    music_mean_demucs: float = 0.25    # Threshold for heavy_music
    
    # Same thresholds for noise detection
    noise_ratio_clean: float = 0.0     # ANY chunk detection → needs processing
    noise_ratio_demucs: float = 0.20   # 0-20% → demucs, >20% → unusable
    noise_mean_clean: float = 0.10     # ADJUSTED: 0.10 accounts for baseline (was 0.05)
    noise_mean_demucs: float = 0.30
    
    # === v7.0: Early-exit sampling ===
    early_exit_enabled: bool = True
    early_exit_sample_ratio: float = 0.10  # Sample 10% of chunks
    early_exit_threshold: float = 0.05     # STRICTER: Was 0.10. Only skip if truly clean
    
    # === TTS Quality Mode ===
    # When True: Zero tolerance - ANY detection → needs_demucs
    strict_tts_mode: bool = True  # Enable strict mode by default


@dataclass
class MusicChunkInfo:
    """Per-chunk detection results for music AND noise."""
    idx: int
    start_time: float
    end_time: float
    music_prob: float      # Max probability across music classes
    noise_prob: float      # Max probability across noise classes (NEW)
    speech_prob: float     # Speech probability
    has_music: bool        # True if music_prob > threshold
    has_noise: bool        # True if noise_prob > threshold (NEW)
    has_contamination: bool  # True if either music OR noise detected (NEW)
    top_music_class: Optional[str] = None
    top_noise_class: Optional[str] = None  # NEW: What type of noise detected


class MusicDetectionCache:
    """
    Music detection cache with v7.0 optimizations:
    1. Pre-computed 32kHz resampling (once, not per-chunk)
    2. Early-exit sampling (check 10% first, skip if clean)
    """
    
    def __init__(self, audio_buffer, vad_segments, panns_model, config=None, device=None):
        self.audio_buffer = audio_buffer
        self.vad_segments = vad_segments
        self.panns_model = panns_model
        self.config = config or MusicDetectionConfig()
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.chunks = {}
        self._time_index = {}
        self._built = False
        self._resampled_audio = None  # v7.0: Pre-computed 32kHz audio
        self._early_exit_triggered = False  # v7.0: Track if we skipped full analysis
        self._stats = {'total_chunks': 0, 'chunks_with_music': 0, 'build_time_sec': 0.0}

    def build(self):
        if self._built:
            return len(self.chunks)
        start_time = time.time()
        logger.info("=" * 70)
        logger.info("⚡ MUSIC DETECTION CACHE (PANNs CNN14) - v7.0")
        logger.info(f"   Chunk: {self.config.chunk_duration}s, Threshold: {self.config.music_prob_threshold}")
        logger.info(f"   Early-exit: {'ENABLED' if self.config.early_exit_enabled else 'DISABLED'}")
        logger.info("=" * 70)
        
        chunk_positions = self._create_chunk_positions()
        self._stats['total_chunks'] = len(chunk_positions)
        if not chunk_positions:
            self._built = True
            return 0
        
        # Store source audio reference for chunked resampling
        self._precompute_resampled_audio()
        
        # === v7.0 OPTIMIZATION: Early-exit sampling ===
        if self.config.early_exit_enabled and len(chunk_positions) > 20:
            t0 = time.time()
            is_clean = self._try_early_exit(chunk_positions)
            early_exit_time = time.time() - t0
            logger.info(f"   [TIMING] Early-exit check: {early_exit_time*1000:.0f}ms")
            
            if is_clean:
                # All samples clean - mark everything as clean and skip full analysis
                self._early_exit_triggered = True
                logger.info(f"   ⚡ EARLY EXIT: All samples clean, skipping full analysis")
                
                # Create chunks with zero contamination
                for i, (start, end) in enumerate(chunk_positions):
                    self.chunks[i] = MusicChunkInfo(
                        idx=i, start_time=start, end_time=end,
                        music_prob=0.0, noise_prob=0.0, speech_prob=1.0, 
                        has_music=False, has_noise=False, has_contamination=False
                    )
                    self._time_index[(round(start, 3), round(end, 3))] = i
                
                self._stats['chunks_with_music'] = 0
                self._stats['chunks_with_noise'] = 0
                self._stats['chunks_with_contamination'] = 0
                self._stats['music_ratio'] = 0.0
                self._stats['noise_ratio'] = 0.0
                self._stats['contamination_ratio'] = 0.0
                self._stats['avg_music_prob'] = 0.0
                self._stats['avg_noise_prob'] = 0.0
                self._stats['max_music_prob'] = 0.0
                self._stats['max_noise_prob'] = 0.0
                self._stats['early_exit'] = True
                self._stats['build_time_sec'] = time.time() - start_time
                self._built = True
                
                logger.info(f"✅ Detection cache (early-exit): {len(self.chunks)} chunks marked clean in {self._stats['build_time_sec']:.2f}s")
                return len(self.chunks)
        
        # Full analysis with fast per-batch resampling
        t0 = time.time()
        audio_chunks = self._prepare_audio_chunks_fast(chunk_positions)
        logger.info(f"   [TIMING] Chunk preparation + resample: {(time.time()-t0)*1000:.0f}ms")
        
        t0 = time.time()
        music_probs, noise_probs, speech_probs = self._batch_inference(audio_chunks)
        logger.info(f"   [TIMING] PANNs inference: {(time.time()-t0)*1000:.0f}ms")
        
        chunks_with_music = 0
        chunks_with_noise = 0
        chunks_with_contamination = 0
        
        for i, (start, end) in enumerate(chunk_positions):
            has_music = music_probs[i] > self.config.music_prob_threshold
            has_noise = noise_probs[i] > self.config.noise_prob_threshold
            has_contamination = has_music or has_noise
            
            if has_music:
                chunks_with_music += 1
            if has_noise:
                chunks_with_noise += 1
            if has_contamination:
                chunks_with_contamination += 1
            
            self.chunks[i] = MusicChunkInfo(
                idx=i, start_time=start, end_time=end,
                music_prob=music_probs[i], 
                noise_prob=noise_probs[i],
                speech_prob=speech_probs[i], 
                has_music=has_music,
                has_noise=has_noise,
                has_contamination=has_contamination
            )
            self._time_index[(round(start, 3), round(end, 3))] = i
        
        # Stats
        self._stats['chunks_with_music'] = chunks_with_music
        self._stats['chunks_with_noise'] = chunks_with_noise
        self._stats['chunks_with_contamination'] = chunks_with_contamination
        self._stats['music_ratio'] = chunks_with_music / len(chunk_positions) if chunk_positions else 0
        self._stats['noise_ratio'] = chunks_with_noise / len(chunk_positions) if chunk_positions else 0
        self._stats['contamination_ratio'] = chunks_with_contamination / len(chunk_positions) if chunk_positions else 0
        self._stats['avg_music_prob'] = float(np.mean(music_probs)) if music_probs else 0
        self._stats['avg_noise_prob'] = float(np.mean(noise_probs)) if noise_probs else 0
        self._stats['max_music_prob'] = float(np.max(music_probs)) if music_probs else 0
        self._stats['max_noise_prob'] = float(np.max(noise_probs)) if noise_probs else 0
        self._stats['early_exit'] = False
        self._stats['build_time_sec'] = time.time() - start_time
        self._built = True
        
        logger.info(f"✅ Detection cache: {len(self.chunks)} chunks")
        logger.info(f"   Music: {chunks_with_music} ({self._stats['music_ratio']*100:.1f}%)")
        logger.info(f"   Noise: {chunks_with_noise} ({self._stats['noise_ratio']*100:.1f}%)")
        logger.info(f"   Total contaminated: {chunks_with_contamination} ({self._stats['contamination_ratio']*100:.1f}%)")
        logger.info(f"   Build time: {self._stats['build_time_sec']:.2f}s")
        return len(self.chunks)
    
    def _precompute_resampled_audio(self):
        """
        Prepare audio buffer reference for chunked resampling.
        
        NOTE: Full pre-resampling was REMOVED because scipy.signal.resample
        on very long audio (>30min) is actually SLOWER than per-batch resampling
        due to FFT size overhead. For 157min audio, it took 122s vs ~5s batched.
        
        We now just store the reference and do fast per-batch resampling.
        """
        # Just store reference - actual resampling done per-batch in _prepare_audio_chunks_fast
        self._resampled_audio = None  # Signal to use fast per-batch resampling
        self._source_audio = self.audio_buffer.waveform_np
    
    def _try_early_exit(self, chunk_positions: List[Tuple[float, float]]) -> bool:
        """
        === v7.0 OPTIMIZATION: Early-exit sampling ===
        
        Sample 10% of chunks evenly distributed. If ALL samples have music_prob < threshold,
        we can skip full analysis (entire audio is clean).
        
        Returns:
            True if audio is clean (can skip full analysis), False otherwise
        """
        sample_ratio = self.config.early_exit_sample_ratio
        threshold = self.config.early_exit_threshold
        
        # Select evenly distributed sample indices
        num_chunks = len(chunk_positions)
        num_samples = max(10, int(num_chunks * sample_ratio))  # At least 10 samples
        step = max(1, num_chunks // num_samples)
        sample_indices = list(range(0, num_chunks, step))[:num_samples]
        
        logger.info(f"   ⚡ Early-exit: Sampling {len(sample_indices)}/{num_chunks} chunks")
        
        # Prepare sample chunks with fast resampling
        sample_positions = [chunk_positions[i] for i in sample_indices]
        sample_chunks = self._prepare_audio_chunks_fast(sample_positions)
        
        # Run inference on samples
        music_probs, noise_probs, _ = self._batch_inference(sample_chunks)
        
        # Check if all samples are clean (both music AND noise must be below threshold)
        max_music = max(music_probs) if music_probs else 0
        max_noise = max(noise_probs) if noise_probs else 0
        avg_music = sum(music_probs) / len(music_probs) if music_probs else 0
        avg_noise = sum(noise_probs) / len(noise_probs) if noise_probs else 0
        
        logger.info(f"   ⚡ Sample results: music_max={max_music:.3f}, noise_max={max_noise:.3f} (threshold={threshold})")
        
        # Conservative: ALL samples must be below threshold for BOTH music and noise
        music_clean = all(prob < threshold for prob in music_probs)
        noise_clean = all(prob < threshold for prob in noise_probs)
        is_clean = music_clean and noise_clean
        
        return is_clean
    
    def _prepare_audio_chunks_fast(self, chunk_positions: List[Tuple[float, float]]) -> np.ndarray:
        """
        Prepare audio chunks with fast polyphase resampling.
        
        Uses scipy.signal.resample_poly which is much faster than resample()
        for integer ratio resampling (16kHz → 32kHz = 2x = trivial).
        """
        import scipy.signal as signal
        
        source_sr = self.config.source_sample_rate
        target_sr = self.config.panns_sample_rate
        target_samples = int(self.config.chunk_duration * target_sr)
        source_samples = int(self.config.chunk_duration * source_sr)
        
        all_chunks = np.zeros((len(chunk_positions), target_samples), dtype=np.float32)
        waveform = self._source_audio
        
        # Compute resampling ratio (for 16k→32k, up=2, down=1)
        from math import gcd
        g = gcd(target_sr, source_sr)
        up = target_sr // g
        down = source_sr // g
        
        for i, (start, end) in enumerate(chunk_positions):
            start_sample = int(start * source_sr)
            end_sample = min(int(end * source_sr), len(waveform))
            
            if start_sample < len(waveform):
                chunk_16k = waveform[start_sample:end_sample]
                # resample_poly is MUCH faster than resample for integer ratios
                chunk_32k = signal.resample_poly(chunk_16k, up, down).astype(np.float32)
                chunk_len = min(len(chunk_32k), target_samples)
                all_chunks[i, :chunk_len] = chunk_32k[:chunk_len]
        
        return all_chunks

    def _create_chunk_positions(self):
        """
        Create chunk positions covering the ENTIRE audio, not just VAD segments.
        
        === FIX (v6.7.1): Previously only created chunks from VAD speech segments ===
        This caused instrumental sections in songs to have 0 chunks analyzed,
        defaulting to 'clean' even when music was present.
        
        Now creates uniform chunks across entire audio duration for comprehensive
        music detection coverage.
        """
        sr = self.audio_buffer.sample_rate
        chunk_samples = int(self.config.chunk_duration * sr)
        total_samples = len(self.audio_buffer.waveform_np)
        
        positions = []
        pos = 0
        while pos + chunk_samples <= total_samples:
            positions.append((pos / sr, (pos + chunk_samples) / sr))
            pos += chunk_samples
        
        return positions

    # NOTE: _prepare_audio_chunks removed in v7.0 - replaced by _prepare_audio_chunks_fast
    # which uses pre-computed 32kHz resampled audio for better performance

    def _batch_inference(self, audio_chunks):
        """
        Run PANNs inference and extract music, noise, and speech probabilities.
        
        Returns:
            (music_probs, noise_probs, speech_probs) - Lists of max probabilities per chunk
        """
        batch_size = self.config.batch_size
        music_probs, noise_probs, speech_probs = [], [], []
        
        for batch_start in range(0, len(audio_chunks), batch_size):
            batch_end = min(batch_start + batch_size, len(audio_chunks))
            batch = audio_chunks[batch_start:batch_end]
            try:
                clipwise_output, _ = self.panns_model.inference(batch)
                for j in range(len(batch)):
                    probs = clipwise_output[j]
                    
                    # Music probability (max across music classes)
                    music_class_probs = [probs[idx] for idx in ALL_MUSIC_CLASSES if idx < len(probs)]
                    music_probs.append(float(max(music_class_probs)) if music_class_probs else 0.0)
                    
                    # Noise probability (max across noise classes) - NEW
                    noise_class_probs = [probs[idx] for idx in ALL_NOISE_CLASSES if idx < len(probs)]
                    noise_probs.append(float(max(noise_class_probs)) if noise_class_probs else 0.0)
                    
                    # Speech probability
                    speech_probs.append(float(probs[SPEECH_CLASS_IDX]) if SPEECH_CLASS_IDX < len(probs) else 0.0)
                    
            except Exception as e:
                logger.error(f"PANNs batch failed: {e}")
                for _ in range(batch_end - batch_start):
                    music_probs.append(0.0)
                    noise_probs.append(0.0)
                    speech_probs.append(0.0)
        
        return music_probs, noise_probs, speech_probs

    def get_chunks_in_range(self, start, end):
        return sorted([c for c in self.chunks.values() 
                      if start <= (c.start_time + c.end_time) / 2 <= end],
                     key=lambda c: c.start_time)

    def get_segment_music_stats(self, start, end):
        """
        Compute segment-level stats for music AND noise detection.
        
        Decision logic (STRICT for TTS quality):
        - 'clean': NO music AND NO noise detected (zero tolerance)
        - 'needs_demucs': Some music/noise but below heavy threshold (can be cleaned)
        - 'heavy_contamination': Too much music/noise (unusable for TTS)
        """
        chunks = self.get_chunks_in_range(start, end)
        if not chunks:
            return {
                'music_mean': 0.0, 'music_max': 0.0, 'music_ratio': 0.0,
                'noise_mean': 0.0, 'noise_max': 0.0, 'noise_ratio': 0.0,
                'has_music': False, 'has_noise': False, 'has_contamination': False,
                'decision': 'clean', 'chunks_analyzed': 0
            }
        
        # Music stats
        music_probs = [c.music_prob for c in chunks]
        music_mean = float(np.mean(music_probs))
        music_max = float(np.max(music_probs))
        music_ratio = sum(1 for c in chunks if c.has_music) / len(chunks)
        
        # Noise stats (NEW)
        noise_probs = [c.noise_prob for c in chunks]
        noise_mean = float(np.mean(noise_probs))
        noise_max = float(np.max(noise_probs))
        noise_ratio = sum(1 for c in chunks if c.has_noise) / len(chunks)
        
        # Combined contamination
        contamination_ratio = sum(1 for c in chunks if c.has_contamination) / len(chunks)
        
        # === STRICT DECISION LOGIC FOR TTS ===
        # In strict_tts_mode: ANY contamination → needs_demucs
        if self.config.strict_tts_mode:
            if contamination_ratio == 0 and music_mean < self.config.music_mean_clean and noise_mean < self.config.noise_mean_clean:
                decision = 'clean'
            elif (music_ratio < self.config.music_ratio_demucs and 
                  noise_ratio < self.config.noise_ratio_demucs and
                  music_mean < self.config.music_mean_demucs and 
                  noise_mean < self.config.noise_mean_demucs):
                decision = 'needs_demucs'
            else:
                decision = 'heavy_contamination'
        else:
            # Original (lenient) logic
            if music_ratio < self.config.music_ratio_clean and music_mean < self.config.music_mean_clean:
                decision = 'clean'
            elif music_ratio < self.config.music_ratio_demucs and music_mean < self.config.music_mean_demucs:
                decision = 'needs_demucs'
            else:
                decision = 'heavy_contamination'
        
        return {
            'music_mean': round(music_mean, 4), 'music_max': round(music_max, 4),
            'music_ratio': round(music_ratio, 4), 
            'noise_mean': round(noise_mean, 4), 'noise_max': round(noise_max, 4),
            'noise_ratio': round(noise_ratio, 4),
            'contamination_ratio': round(contamination_ratio, 4),
            'has_music': any(c.has_music for c in chunks),
            'has_noise': any(c.has_noise for c in chunks),
            'has_contamination': any(c.has_contamination for c in chunks),
            'decision': decision, 
            'chunks_analyzed': len(chunks),
            'chunks_with_music': sum(1 for c in chunks if c.has_music),
            'chunks_with_noise': sum(1 for c in chunks if c.has_noise),
            'chunks_with_contamination': sum(1 for c in chunks if c.has_contamination)
        }

    def get_stats(self):
        return {k: round(v, 4) if isinstance(v, float) else v for k, v in self._stats.items()}


def build_segment_music_stats_batch(segments, music_cache):
    """
    Compute segment-level music/noise stats and classify segments.
    
    Classification (for TTS quality):
    - 'clean': No contamination detected → use as-is
    - 'needs_demucs': Some contamination → process with Demucs before TTS
    - 'heavy_contamination': Too much contamination → mark as unusable
    """
    logger.info("Computing segment-level contamination statistics...")
    stats = {
        'segments_clean': 0, 
        'segments_needs_demucs': 0, 
        'segments_heavy_contamination': 0,
        'duration_clean': 0.0, 
        'duration_needs_demucs': 0.0, 
        'duration_heavy_contamination': 0.0
    }
    
    for seg in segments:
        if seg.get('speaker') in ['OVERLAP', 'NON_SPEECH'] or seg.get('status') == 'unusable':
            continue
        
        music_stats = music_cache.get_segment_music_stats(seg['start'], seg['end'])
        seg['music_stats'] = music_stats
        decision = music_stats['decision']
        duration = seg.get('duration', seg['end'] - seg['start'])
        
        if decision == 'clean':
            stats['segments_clean'] += 1
            stats['duration_clean'] += duration
        elif decision == 'needs_demucs':
            stats['segments_needs_demucs'] += 1
            stats['duration_needs_demucs'] += duration
            seg['needs_demucs'] = True
            # Add detail about what was detected
            if music_stats.get('has_music'):
                seg['contamination_type'] = 'music'
            elif music_stats.get('has_noise'):
                seg['contamination_type'] = 'noise'
            else:
                seg['contamination_type'] = 'unknown'
        else:  # heavy_contamination
            stats['segments_heavy_contamination'] += 1
            stats['duration_heavy_contamination'] += duration
            seg['status'] = 'unusable'
            seg['unusable_reason'] = 'heavy_contamination'
            # Add detail
            if music_stats.get('music_ratio', 0) > music_stats.get('noise_ratio', 0):
                seg['contamination_type'] = 'heavy_music'
            else:
                seg['contamination_type'] = 'heavy_noise'
    
    total = stats['duration_clean'] + stats['duration_needs_demucs'] + stats['duration_heavy_contamination']
    if total > 0:
        stats['pct_clean'] = round(stats['duration_clean'] / total * 100, 1)
        stats['pct_needs_demucs'] = round(stats['duration_needs_demucs'] / total * 100, 1)
        stats['pct_heavy_contamination'] = round(stats['duration_heavy_contamination'] / total * 100, 1)
    else:
        stats['pct_clean'] = stats['pct_needs_demucs'] = stats['pct_heavy_contamination'] = 0.0
    
    logger.info(f"   ✅ Clean (no processing needed): {stats['segments_clean']} segments ({stats.get('pct_clean', 0)}%)")
    logger.info(f"   ⚠️ Needs Demucs: {stats['segments_needs_demucs']} segments ({stats.get('pct_needs_demucs', 0)}%)")
    logger.info(f"   ❌ Heavy Contamination (unusable): {stats['segments_heavy_contamination']} segments ({stats.get('pct_heavy_contamination', 0)}%)")
    
    return segments, stats
