#!/usr/bin/env python3
"""
Spectral Analysis for Audio Quality Assessment (v7.6 - Simplified)

=== PURPOSE ===
Detect the REAL sample rate of audio to identify upsampled content.
Many YouTube videos have 48kHz containers but were recorded at 44.1kHz or lower.

=== v7.6 SIMPLIFICATION ===
Previous approach (1% rolloff) was too aggressive for speech content:
- Speech has most energy 100Hz-8kHz, little above 10kHz
- Using 1% threshold found 779Hz rolloff (wrong!)

New approach:
1. Use 95th percentile of spectral energy to find where content actually ends
2. Look for "brick wall" cutoffs at common sample rate boundaries
3. Focus on detecting: 16kHz, 22kHz (44.1k), 24kHz (48k) cutoffs

Typical findings:
- True 48kHz: Content visible up to ~23-24kHz
- Upsampled 44.1→48: Sharp cliff at ~22kHz  
- MP3 source: Cutoff at 16-20kHz depending on bitrate
- Very low quality: Cutoff at ~8-11kHz
"""

import logging
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
import numpy as np
from scipy import signal

logger = logging.getLogger("FastPipelineV6.SpectralAnalysis")


@dataclass
class SpectralQuality:
    """Audio spectral quality assessment results."""
    
    claimed_sample_rate: int        # What the file claims to be (e.g., 48000)
    detected_nyquist: float         # Where content actually ends (Hz)
    effective_sample_rate: int      # Estimated true sample rate
    is_upsampled: bool              # True if upsampled from lower rate
    original_format_guess: str      # "native_48k", "upsampled_44.1k", "mp3_source", etc.
    confidence: float               # 0-1 confidence in detection
    
    # Detailed metrics
    high_freq_energy_ratio: float   # Energy above 20kHz / total energy
    rolloff_frequency: float        # Frequency where energy drops to 1% of peak
    spectral_flatness: float        # How "noisy" vs "tonal" the signal is
    
    # For TTS quality decisions
    is_suitable_for_tts: bool       # True if quality sufficient for TTS training
    quality_issues: list            # List of detected issues


def analyze_spectral_quality(
    waveform: np.ndarray,
    sample_rate: int,
    analysis_duration: float = 10.0,
    num_samples: int = 3,
    max_time_budget: float = 5.0,
) -> SpectralQuality:
    """
    Analyze audio to detect effective sample rate and quality issues.
    
    === v7.6: SIMPLIFIED DETECTION ===
    Uses 95th percentile energy threshold instead of 1% rolloff.
    This is more robust for speech content which has low energy at high frequencies.
    
    Detection strategy:
    1. Compute spectrogram from middle portion of audio
    2. Find the frequency band where cumulative energy reaches 95% 
    3. Look for "brick wall" cutoffs at common sample rate boundaries
    4. Map detected cutoff to original sample rate
    
    Args:
        waveform: Audio samples (numpy array, mono, float32)
        sample_rate: Claimed sample rate (from file)
        analysis_duration: Duration to analyze per segment (seconds)
        num_samples: Number of segments to sample from audio
        max_time_budget: Maximum time allowed for analysis (seconds)
    
    Returns:
        SpectralQuality with detection results
    """
    import time
    start_time = time.time()
    nyquist = sample_rate / 2
    
    # Sample from middle portion of audio (avoid intro/outro)
    total_samples = len(waveform)
    analysis_samples = int(analysis_duration * sample_rate)
    
    # Use middle 80% of audio to avoid intro/outro effects
    usable_start = int(total_samples * 0.10)
    usable_end = int(total_samples * 0.90)
    usable_length = usable_end - usable_start
    
    if usable_length < analysis_samples * num_samples:
        segments = [waveform]
    else:
        step = (usable_length - analysis_samples) // max(1, num_samples - 1)
        segments = []
        for i in range(num_samples):
            start = usable_start + i * step
            end = min(start + analysis_samples, total_samples)
            if end - start > sample_rate:
                segments.append(waveform[start:end])
    
    if not segments:
        segments = [waveform[:min(len(waveform), analysis_samples)]]
    
    # Analyze each segment
    rolloff_frequencies = []
    high_freq_ratios = []
    spectral_flatness_values = []
    
    for segment in segments:
        elapsed = time.time() - start_time
        if elapsed > max_time_budget and len(rolloff_frequencies) >= 1:
            logger.debug(f"   ⏱️ Spectral analysis time budget exceeded ({elapsed:.1f}s)")
            break
            
        # Compute spectrogram with high frequency resolution
        nperseg = min(8192, len(segment) // 4)
        f, t, Sxx = signal.spectrogram(
            segment, fs=sample_rate, nperseg=nperseg, noverlap=nperseg//2
        )
        
        # Average power across time
        mean_power = np.mean(Sxx, axis=1)
        mean_power = np.maximum(mean_power, 1e-10)
        
        # === v7.6: Use 95th percentile cumulative energy for rolloff ===
        # This is more robust than 1% threshold for speech
        cumulative_energy = np.cumsum(mean_power)
        total_energy = cumulative_energy[-1]
        
        # Find frequency where we reach 95% of total energy
        threshold_95 = total_energy * 0.95
        rolloff_idx = np.searchsorted(cumulative_energy, threshold_95)
        rolloff_idx = min(rolloff_idx, len(f) - 1)
        rolloff_freq_95 = f[rolloff_idx]
        
        # === v7.6: Also look for brick-wall cutoffs (sharp cliff in energy) ===
        # Compute derivative to find sharp drops in power spectrum
        # Look for the highest frequency where there's still meaningful content
        log_power = np.log10(mean_power + 1e-10)
        
        # Find frequency where log power drops below noise floor
        # Use median as baseline, look for where we're 20dB below
        median_log_power = np.median(log_power[:len(log_power)//2])  # Focus on lower half
        noise_floor = median_log_power - 2.0  # 20dB below median
        
        # Find highest frequency above noise floor
        brick_wall_idx = len(f) - 1
        for i in range(len(log_power) - 1, 0, -1):
            if log_power[i] > noise_floor:
                brick_wall_idx = i
                break
        brick_wall_freq = f[brick_wall_idx]
        
        # Use the HIGHER of the two methods (more conservative)
        rolloff_freq = max(rolloff_freq_95, brick_wall_freq)
        rolloff_frequencies.append(rolloff_freq)
        
        # Calculate high frequency energy ratio (above 20kHz)
        if nyquist > 20000:
            idx_20k = np.searchsorted(f, 20000)
            energy_above_20k = np.sum(mean_power[idx_20k:])
            high_freq_ratio = energy_above_20k / total_energy if total_energy > 0 else 0
        else:
            high_freq_ratio = 0.0
        high_freq_ratios.append(high_freq_ratio)
        
        # Spectral flatness (Wiener entropy)
        geometric_mean = np.exp(np.mean(np.log(mean_power + 1e-10)))
        arithmetic_mean = np.mean(mean_power)
        flatness = geometric_mean / arithmetic_mean if arithmetic_mean > 0 else 0
        spectral_flatness_values.append(flatness)
    
    # Aggregate results (use median for robustness)
    median_rolloff = np.median(rolloff_frequencies)
    median_high_freq = np.median(high_freq_ratios)
    median_flatness = np.median(spectral_flatness_values)
    
    # Determine effective sample rate based on rolloff
    # Standard sample rates and their Nyquist frequencies:
    # 22050 → 11025 Hz, 32000 → 16000 Hz, 44100 → 22050 Hz, 48000 → 24000 Hz
    
    effective_sr, original_format, is_upsampled = _detect_original_format(
        median_rolloff, sample_rate, median_high_freq
    )
    
    # Calculate confidence based on consistency across segments
    rolloff_std = np.std(rolloff_frequencies) if len(rolloff_frequencies) > 1 else 0
    confidence = max(0.0, 1.0 - (rolloff_std / median_rolloff)) if median_rolloff > 0 else 0.5
    
    # Determine quality issues
    quality_issues = []
    is_suitable = True
    
    if is_upsampled and effective_sr < 32000:
        quality_issues.append(f"Very low quality source ({effective_sr}Hz effective)")
        is_suitable = False
    elif is_upsampled and effective_sr < 44100:
        quality_issues.append(f"Medium quality source ({effective_sr}Hz effective)")
    
    if median_flatness > 0.5:
        quality_issues.append("High spectral flatness (possible noise/compression artifacts)")
    
    # Check for MP3-style brickwall filter at common cutoffs
    if 15000 < median_rolloff < 17000:
        quality_issues.append("Likely MP3 source (128kbps cutoff detected)")
        original_format = "mp3_128k"
    elif 17000 < median_rolloff < 19000:
        quality_issues.append("Likely MP3 source (192kbps cutoff detected)")
        original_format = "mp3_192k"
    
    return SpectralQuality(
        claimed_sample_rate=int(sample_rate),
        detected_nyquist=float(median_rolloff),
        effective_sample_rate=int(effective_sr),
        is_upsampled=bool(is_upsampled),
        original_format_guess=str(original_format),
        confidence=float(confidence),
        high_freq_energy_ratio=float(median_high_freq),
        rolloff_frequency=float(median_rolloff),
        spectral_flatness=float(median_flatness),
        is_suitable_for_tts=bool(is_suitable),
        quality_issues=list(quality_issues),
    )


def _detect_original_format(
    rolloff_freq: float, 
    claimed_sr: int,
    high_freq_ratio: float
) -> Tuple[int, str, bool]:
    """
    Determine original sample rate based on spectral rolloff.
    
    === v7.6: More robust detection ===
    Maps detected rolloff frequency to common sample rate boundaries:
    - 24kHz Nyquist → 48kHz native
    - 22kHz Nyquist → 44.1kHz (CD quality) 
    - 16kHz Nyquist → 32kHz
    - 11kHz Nyquist → 22.05kHz
    - 8kHz Nyquist → 16kHz (telephone quality)
    """
    claimed_nyquist = claimed_sr / 2
    rolloff_ratio = rolloff_freq / claimed_nyquist if claimed_nyquist > 0 else 0
    
    # === v7.6: Use absolute frequency ranges for detection ===
    # These ranges are based on standard Nyquist frequencies
    
    # Native 48kHz: content extends past 22.5kHz (Nyquist of 44.1k)
    if rolloff_freq > 22500:
        return claimed_sr, f"native_{claimed_sr//1000}k", False
    
    # Upsampled from 44.1kHz: content ends around 20-22.5kHz
    # This is the most common case for YouTube podcasts
    if rolloff_freq > 19000:
        return 44100, "upsampled_44.1k", True
    
    # MP3 192kbps: typical cutoff around 18-19kHz
    if rolloff_freq > 17000:
        return 44100, "mp3_192k_source", True
    
    # MP3 128kbps or 32kHz source: cutoff around 15-17kHz
    if rolloff_freq > 14000:
        return 32000, "mp3_128k_or_32k", True
    
    # Lower quality: 22.05kHz source, cutoff around 10-11kHz
    if rolloff_freq > 9000:
        return 22050, "low_quality_22k", True
    
    # Very low quality: 16kHz telephone-style, cutoff around 7-8kHz
    if rolloff_freq > 6000:
        return 16000, "telephone_16k", True
    
    # Extremely low quality
    if rolloff_freq > 3000:
        return 8000, "narrowband_8k", True
    
    # Fallback: estimate from rolloff (Nyquist = rolloff, SR = 2*Nyquist)
    effective_sr = int(rolloff_freq * 2)
    # Round to nearest common sample rate
    common_rates = [96000, 48000, 44100, 32000, 22050, 16000, 8000]
    for sr in common_rates:
        if abs(effective_sr - sr) < sr * 0.15:
            return sr, f"estimated_{sr//1000}k", True
    
    return max(effective_sr, 8000), "unknown", rolloff_ratio < 0.85


def add_spectral_info_to_metadata(
    metadata: Dict,
    waveform: np.ndarray,
    sample_rate: int,
) -> Dict:
    """
    Add spectral quality analysis to video metadata.
    
    This should be called during download or processing stage to capture
    the audio quality information before any processing.
    
    Args:
        metadata: Existing metadata dict
        waveform: Original audio samples (before resampling)
        sample_rate: Original sample rate
    
    Returns:
        Updated metadata dict with spectral_quality field
    """
    try:
        quality = analyze_spectral_quality(waveform, sample_rate)
        
        metadata['spectral_quality'] = {
            'claimed_sample_rate': quality.claimed_sample_rate,
            'effective_sample_rate': quality.effective_sample_rate,
            'detected_nyquist': round(quality.detected_nyquist, 0),
            'is_upsampled': quality.is_upsampled,
            'original_format': quality.original_format_guess,
            'high_freq_energy_ratio': round(quality.high_freq_energy_ratio, 4),
            'rolloff_frequency': round(quality.rolloff_frequency, 0),
            'spectral_flatness': round(quality.spectral_flatness, 4),
            'confidence': round(quality.confidence, 2),
            'is_suitable_for_tts': quality.is_suitable_for_tts,
            'quality_issues': quality.quality_issues,
        }
        
        # Also set top-level fields for easy DB queries
        metadata['audio_native_sample_rate'] = quality.effective_sample_rate
        metadata['is_native_quality'] = not quality.is_upsampled
        
        logger.info(f"   📊 Spectral: {quality.original_format_guess} "
                   f"(rolloff={quality.rolloff_frequency:.0f}Hz, "
                   f"confidence={quality.confidence:.0%})")
        
        if quality.quality_issues:
            logger.warning(f"   ⚠️ Quality issues: {', '.join(quality.quality_issues)}")
        
    except Exception as e:
        logger.warning(f"   ⚠️ Spectral analysis failed: {e}")
        metadata['spectral_quality'] = {'error': str(e)}
    
    return metadata


# === Quick test function ===
def test_spectral_detection():
    """Test spectral detection with synthetic signals."""
    import matplotlib
    matplotlib.use('Agg')  # Non-GUI backend
    
    sr = 48000
    duration = 5.0
    t = np.arange(int(sr * duration)) / sr
    
    # Test 1: True 48kHz signal (content up to 23kHz)
    freqs_48k = [440, 1000, 5000, 10000, 15000, 20000, 23000]
    signal_48k = sum(0.1 * np.sin(2 * np.pi * f * t) for f in freqs_48k)
    result = analyze_spectral_quality(signal_48k.astype(np.float32), sr)
    print(f"True 48kHz: detected={result.effective_sample_rate}, format={result.original_format_guess}")
    
    # Test 2: 44.1kHz upsampled to 48kHz (content only up to 22kHz)
    freqs_44k = [440, 1000, 5000, 10000, 15000, 20000, 21500]
    signal_44k = sum(0.1 * np.sin(2 * np.pi * f * t) for f in freqs_44k)
    result = analyze_spectral_quality(signal_44k.astype(np.float32), sr)
    print(f"Upsampled 44.1kHz: detected={result.effective_sample_rate}, format={result.original_format_guess}")
    
    # Test 3: Low quality (content only up to 16kHz - like 32kHz source)
    freqs_32k = [440, 1000, 5000, 10000, 15000]
    signal_32k = sum(0.1 * np.sin(2 * np.pi * f * t) for f in freqs_32k)
    result = analyze_spectral_quality(signal_32k.astype(np.float32), sr)
    print(f"Low quality 32kHz: detected={result.effective_sample_rate}, format={result.original_format_guess}")


if __name__ == "__main__":
    test_spectral_detection()

