#!/usr/bin/env python3
"""
Debug script to analyze why audio is being marked as heavy_contamination.

Run this directly on the audio file to see:
1. Raw PANNs output per chunk
2. Which chunks are being flagged and why
3. Segment-level decision breakdown
"""

import sys
import os
import numpy as np
import torch
import json

# Add src to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from src.music_detection import (
    MusicDetectionConfig,
    ALL_MUSIC_CLASSES,
    ALL_NOISE_CLASSES,
    SPEECH_CLASS_IDX,
)


def load_audio(audio_path: str):
    """Load and prepare audio."""
    import torchaudio
    import scipy.signal as signal
    from math import gcd
    
    print(f"\n[1] Loading audio: {audio_path}")
    waveform, sr = torchaudio.load(audio_path)
    
    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        print(f"    Converting stereo ({waveform.shape[0]} channels) to mono...")
        waveform = waveform.mean(dim=0)
    else:
        waveform = waveform.squeeze(0)
    
    waveform = waveform.numpy()
    print(f"    Original: {sr}Hz, {len(waveform)/sr:.1f}s")
    
    # Resample to 16kHz if needed
    if sr != 16000:
        g = gcd(16000, sr)
        up = 16000 // g
        down = sr // g
        print(f"    Resampling {sr}Hz → 16kHz...")
        waveform = signal.resample_poly(waveform, up, down).astype(np.float32)
        sr = 16000
    
    duration = len(waveform) / sr
    print(f"    Final: {sr}Hz, {duration:.1f}s")
    
    return waveform, sr, duration


def analyze_panns_chunks(waveform, sr, device="cuda"):
    """Run PANNs on each chunk and return detailed stats."""
    import scipy.signal as signal
    from math import gcd
    from panns_inference import AudioTagging
    
    config = MusicDetectionConfig()
    
    print(f"\n[2] Loading PANNs CNN14 model...")
    at = AudioTagging(checkpoint_path=None, device=device)
    print(f"    Model loaded on {device}")
    
    # Chunk parameters
    chunk_duration = config.chunk_duration  # 1.5s
    chunk_samples_16k = int(chunk_duration * sr)
    target_sr = 32000
    chunk_samples_32k = int(chunk_duration * target_sr)
    
    # Resample params
    g = gcd(target_sr, sr)
    up = target_sr // g
    down = sr // g
    
    num_chunks = int(len(waveform) / chunk_samples_16k)
    
    print(f"\n[3] Analyzing {num_chunks} chunks ({chunk_duration}s each)...")
    print(f"    Thresholds: music_prob > {config.music_prob_threshold}, noise_prob > {config.noise_prob_threshold}")
    
    # Prepare all chunks for batch processing
    all_chunks_32k = np.zeros((num_chunks, chunk_samples_32k), dtype=np.float32)
    
    for i in range(num_chunks):
        start = i * chunk_samples_16k
        end = min(start + chunk_samples_16k, len(waveform))
        chunk_16k = waveform[start:end]
        chunk_32k = signal.resample_poly(chunk_16k, up, down).astype(np.float32)
        chunk_len = min(len(chunk_32k), chunk_samples_32k)
        all_chunks_32k[i, :chunk_len] = chunk_32k[:chunk_len]
    
    # Batch inference
    batch_size = 128
    music_probs = []
    noise_probs = []
    speech_probs = []
    chunk_details = []
    
    for batch_start in range(0, num_chunks, batch_size):
        batch_end = min(batch_start + batch_size, num_chunks)
        batch = all_chunks_32k[batch_start:batch_end]
        
        clipwise_output, _ = at.inference(batch)
        
        for j in range(len(batch)):
            probs = clipwise_output[j]
            
            # Music probability
            music_class_probs = [probs[idx] for idx in ALL_MUSIC_CLASSES if idx < len(probs)]
            music_prob = float(max(music_class_probs)) if music_class_probs else 0.0
            music_probs.append(music_prob)
            
            # Noise probability
            noise_class_probs = [probs[idx] for idx in ALL_NOISE_CLASSES if idx < len(probs)]
            noise_prob = float(max(noise_class_probs)) if noise_class_probs else 0.0
            noise_probs.append(noise_prob)
            
            # Speech
            speech_prob = float(probs[SPEECH_CLASS_IDX]) if SPEECH_CLASS_IDX < len(probs) else 0.0
            speech_probs.append(speech_prob)
            
            # Flag
            has_music = music_prob > config.music_prob_threshold
            has_noise = noise_prob > config.noise_prob_threshold
            has_contamination = has_music or has_noise
            
            chunk_idx = batch_start + j
            chunk_details.append({
                'idx': chunk_idx,
                'start': chunk_idx * chunk_duration,
                'end': (chunk_idx + 1) * chunk_duration,
                'music_prob': music_prob,
                'noise_prob': noise_prob,
                'speech_prob': speech_prob,
                'has_music': has_music,
                'has_noise': has_noise,
                'has_contamination': has_contamination,
            })
    
    return chunk_details, config


def simulate_segment_decisions(chunk_details, config, segment_durations=None):
    """
    Simulate the segment-level decision logic.
    
    If no segment_durations provided, treat entire audio as one segment.
    """
    print(f"\n[4] Simulating segment-level decisions...")
    
    chunk_duration = config.chunk_duration
    
    # Create segments - use full audio as one segment for analysis
    if segment_durations is None:
        # Treat as one big segment
        max_time = max(c['end'] for c in chunk_details)
        segments = [{'start': 0, 'end': max_time}]
    else:
        segments = segment_durations
    
    results = []
    
    for seg in segments:
        start, end = seg['start'], seg['end']
        
        # Get chunks in range (matching get_chunks_in_range logic)
        chunks = [c for c in chunk_details 
                 if start <= (c['start'] + c['end']) / 2 <= end]
        
        if not chunks:
            results.append({
                'segment': seg,
                'decision': 'clean',
                'reason': 'no_chunks',
                'stats': {}
            })
            continue
        
        # Calculate stats (exactly as in music_detection.py)
        music_probs = [c['music_prob'] for c in chunks]
        noise_probs = [c['noise_prob'] for c in chunks]
        
        music_mean = float(np.mean(music_probs))
        music_max = float(np.max(music_probs))
        music_ratio = sum(1 for c in chunks if c['has_music']) / len(chunks)
        
        noise_mean = float(np.mean(noise_probs))
        noise_max = float(np.max(noise_probs))
        noise_ratio = sum(1 for c in chunks if c['has_noise']) / len(chunks)
        
        contamination_ratio = sum(1 for c in chunks if c['has_contamination']) / len(chunks)
        
        # Decision logic (strict_tts_mode=True)
        if contamination_ratio == 0 and music_mean < config.music_mean_clean and noise_mean < config.noise_mean_clean:
            decision = 'clean'
            reason = f"contamination_ratio=0, music_mean={music_mean:.3f}<{config.music_mean_clean}, noise_mean={noise_mean:.3f}<{config.noise_mean_clean}"
        elif (music_ratio < config.music_ratio_demucs and 
              noise_ratio < config.noise_ratio_demucs and
              music_mean < config.music_mean_demucs and 
              noise_mean < config.noise_mean_demucs):
            decision = 'needs_demucs'
            reason = f"music_ratio={music_ratio:.3f}<{config.music_ratio_demucs}, noise_ratio={noise_ratio:.3f}<{config.noise_ratio_demucs}"
        else:
            decision = 'heavy_contamination'
            # Figure out why
            reasons = []
            if music_ratio >= config.music_ratio_demucs:
                reasons.append(f"music_ratio={music_ratio:.3f}>={config.music_ratio_demucs}")
            if noise_ratio >= config.noise_ratio_demucs:
                reasons.append(f"noise_ratio={noise_ratio:.3f}>={config.noise_ratio_demucs}")
            if music_mean >= config.music_mean_demucs:
                reasons.append(f"music_mean={music_mean:.3f}>={config.music_mean_demucs}")
            if noise_mean >= config.noise_mean_demucs:
                reasons.append(f"noise_mean={noise_mean:.3f}>={config.noise_mean_demucs}")
            reason = " | ".join(reasons) if reasons else "unknown"
        
        results.append({
            'segment': seg,
            'decision': decision,
            'reason': reason,
            'stats': {
                'chunks_analyzed': len(chunks),
                'music_mean': music_mean,
                'music_max': music_max,
                'music_ratio': music_ratio,
                'noise_mean': noise_mean,
                'noise_max': noise_max,
                'noise_ratio': noise_ratio,
                'contamination_ratio': contamination_ratio,
                'chunks_with_music': sum(1 for c in chunks if c['has_music']),
                'chunks_with_noise': sum(1 for c in chunks if c['has_noise']),
            }
        })
    
    return results


def print_analysis(chunk_details, segment_results, config):
    """Print detailed analysis."""
    
    print("\n" + "="*80)
    print("CHUNK-LEVEL ANALYSIS")
    print("="*80)
    
    # Summary stats
    total_chunks = len(chunk_details)
    flagged_music = sum(1 for c in chunk_details if c['has_music'])
    flagged_noise = sum(1 for c in chunk_details if c['has_noise'])
    flagged_any = sum(1 for c in chunk_details if c['has_contamination'])
    
    music_probs = [c['music_prob'] for c in chunk_details]
    noise_probs = [c['noise_prob'] for c in chunk_details]
    speech_probs = [c['speech_prob'] for c in chunk_details]
    
    print(f"\n  Total chunks: {total_chunks}")
    print(f"\n  Music detection:")
    print(f"    Threshold: {config.music_prob_threshold}")
    print(f"    Flagged: {flagged_music}/{total_chunks} ({flagged_music/total_chunks*100:.1f}%)")
    print(f"    Prob range: [{min(music_probs):.4f}, {max(music_probs):.4f}]")
    print(f"    Prob mean: {np.mean(music_probs):.4f}")
    
    print(f"\n  Noise detection:")
    print(f"    Threshold: {config.noise_prob_threshold}")
    print(f"    Flagged: {flagged_noise}/{total_chunks} ({flagged_noise/total_chunks*100:.1f}%)")
    print(f"    Prob range: [{min(noise_probs):.4f}, {max(noise_probs):.4f}]")
    print(f"    Prob mean: {np.mean(noise_probs):.4f}")
    
    print(f"\n  Combined contamination:")
    print(f"    Flagged: {flagged_any}/{total_chunks} ({flagged_any/total_chunks*100:.1f}%)")
    
    print(f"\n  Speech confidence:")
    print(f"    Mean: {np.mean(speech_probs):.4f}")
    
    # Show flagged chunks
    flagged = [c for c in chunk_details if c['has_contamination']]
    if flagged:
        print(f"\n  Flagged chunks (showing first 20):")
        for c in flagged[:20]:
            flags = []
            if c['has_music']:
                flags.append(f"music={c['music_prob']:.3f}")
            if c['has_noise']:
                flags.append(f"noise={c['noise_prob']:.3f}")
            print(f"    [{c['idx']:3d}] {c['start']:6.1f}s - {c['end']:6.1f}s | {', '.join(flags)}")
        if len(flagged) > 20:
            print(f"    ... and {len(flagged)-20} more")
    else:
        print(f"\n  No chunks flagged! Audio should be CLEAN.")
    
    print("\n" + "="*80)
    print("SEGMENT-LEVEL ANALYSIS (why heavy_contamination?)")
    print("="*80)
    
    for res in segment_results:
        seg = res['segment']
        stats = res['stats']
        print(f"\n  Segment: {seg['start']:.1f}s - {seg['end']:.1f}s")
        print(f"  Decision: {res['decision'].upper()}")
        print(f"  Reason: {res['reason']}")
        if stats:
            print(f"  Stats:")
            print(f"    chunks_analyzed: {stats['chunks_analyzed']}")
            print(f"    music_ratio: {stats['music_ratio']:.4f} (threshold for heavy: >={config.music_ratio_demucs})")
            print(f"    noise_ratio: {stats['noise_ratio']:.4f} (threshold for heavy: >={config.noise_ratio_demucs})")
            print(f"    music_mean: {stats['music_mean']:.4f} (threshold for heavy: >={config.music_mean_demucs})")
            print(f"    noise_mean: {stats['noise_mean']:.4f} (threshold for heavy: >={config.noise_mean_demucs})")


def main():
    if len(sys.argv) < 2:
        print("Usage: python debug_music_issue.py <audio_path>")
        print("Example: python debug_music_issue.py /tmp/-5fJwfA0I5E.wav")
        sys.exit(1)
    
    audio_path = sys.argv[1]
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Device: {device}")
    
    # Load audio
    waveform, sr, duration = load_audio(audio_path)
    
    # Analyze chunks
    chunk_details, config = analyze_panns_chunks(waveform, sr, device)
    
    # Simulate segment decisions
    segment_results = simulate_segment_decisions(chunk_details, config)
    
    # Print analysis
    print_analysis(chunk_details, segment_results, config)
    
    # Save raw data
    output_file = audio_path.replace('.wav', '_debug.json')
    with open(output_file, 'w') as f:
        json.dump({
            'audio_path': audio_path,
            'duration': duration,
            'config': {
                'music_prob_threshold': config.music_prob_threshold,
                'noise_prob_threshold': config.noise_prob_threshold,
                'music_ratio_demucs': config.music_ratio_demucs,
                'noise_ratio_demucs': config.noise_ratio_demucs,
                'music_mean_demucs': config.music_mean_demucs,
                'noise_mean_demucs': config.noise_mean_demucs,
            },
            'chunk_details': chunk_details,
            'segment_results': [{
                'segment': r['segment'],
                'decision': r['decision'],
                'reason': r['reason'],
                'stats': r['stats'],
            } for r in segment_results],
        }, f, indent=2)
    print(f"\n[5] Raw data saved to: {output_file}")


if __name__ == "__main__":
    main()









