#!/usr/bin/env python3
"""
Overlap Speech Detection (OSD) - Mark Unusable Regions.

=== OPTIMIZATION NOTE (v6.1) ===
Previous approach ran diarization TWICE:
1. OSD stage: full file diarization just to find overlaps
2. Diarization stage: per-chunk diarization again

NEW approach: Extract overlaps directly from diarization result using
pyannote's built-in `get_overlap()` and `extrude()` methods.
This saves ~50% GPU compute time.

Per instructions.md:
- Overlap regions are marked as UNUSABLE
- Pad overlap boundaries by ±100-200ms for safety
- Use pyannote's get_overlap() for efficient detection
"""

import time
import logging
from typing import List, Dict, Tuple, Optional
import numpy as np
import torch
import torchaudio

from src.models import MODELS

logger = logging.getLogger("FastPipelineV6.OSD")


def _get_annotation_from_result(result):
    """
    Extract Annotation object from diarization result.
    
    === PYANNOTE 4.x COMPATIBILITY ===
    - pyannote 3.x: Returns Annotation directly
    - pyannote 4.x: Returns DiarizeOutput with .speaker_diarization attribute
    """
    # Check if it's a DiarizeOutput (pyannote 4.x)
    if hasattr(result, 'speaker_diarization'):
        return result.speaker_diarization
    # Otherwise it's already an Annotation (pyannote 3.x)
    return result


def extract_overlaps_from_diarization(diarization, config) -> Tuple[List[Dict], List[Dict]]:
    """
    Extract overlap regions AND clean segments from a diarization result.
    
    === PYANNOTE 4.x COMPATIBILITY ===
    Handles both pyannote 3.x (Annotation) and 4.x (DiarizeOutput).
    
    Uses pyannote's built-in get_overlap() - much faster than O(n²) comparison.
    This is the KEY optimization: no separate diarization run needed!
    
    Args:
        diarization: pyannote Annotation or DiarizeOutput object from diarization
        config: Pipeline configuration
    
    Returns:
        (overlap_segments, clean_diarization_segments)
    """
    # === PYANNOTE 4.x COMPATIBILITY: Extract Annotation from DiarizeOutput ===
    diarization = _get_annotation_from_result(diarization)
    
    try:
        # === USE PYANNOTE'S BUILT-IN OVERLAP DETECTION ===
        # This is much faster than manual O(n²) segment comparison
        overlap_timeline = diarization.get_overlap()
        
        # Convert Timeline to segment list with padding
        padding_sec = config.overlap_padding_ms / 1000.0
        overlap_segments = []
        
        for segment in overlap_timeline:
            duration = segment.duration
            if duration >= config.overlap_min_duration:
                overlap_segments.append({
                    'start': max(0, segment.start - padding_sec),
                    'end': segment.end + padding_sec,
                    'duration': duration + 2 * padding_sec,
                    'type': 'overlap'
                })
        
        # Merge close overlap regions
        if overlap_segments:
            overlap_segments = _merge_close_segments(overlap_segments, max_gap=0.1)
        
        # === EXTRACT CLEAN SEGMENTS (remove overlaps) ===
        # extrude() removes overlap regions from the diarization
        clean_annotation = diarization.extrude(overlap_timeline)
        
        clean_segments = []
        for segment, _, speaker in clean_annotation.itertracks(yield_label=True):
            clean_segments.append({
                'start': segment.start,
                'end': segment.end,
                'duration': segment.duration,
                'speaker': speaker
            })
        
        logger.debug(f"   Extracted {len(overlap_segments)} overlap regions, {len(clean_segments)} clean segments")
        
        return overlap_segments, clean_segments
        
    except Exception as e:
        logger.warning(f"get_overlap() failed, falling back to manual detection: {e}")
        # Fallback to manual detection if pyannote method fails
        overlap_segments = _find_overlapping_regions_manual(diarization, config)
        
        # Extract all segments (without extrude)
        all_segments = []
        for segment, _, speaker in diarization.itertracks(yield_label=True):
            all_segments.append({
                'start': segment.start,
                'end': segment.end,
                'duration': segment.duration,
                'speaker': speaker
            })
        
        return overlap_segments, all_segments


def detect_overlaps_framelevel(
    audio_path: str,
    config
) -> Tuple[List[Dict], np.ndarray]:
    """
    DEPRECATED: This function runs a SEPARATE diarization pass.
    
    For efficiency, use extract_overlaps_from_diarization() instead,
    which extracts overlaps from an existing diarization result.
    
    Kept for backwards compatibility but logs a warning.
    """
    logger.warning("⚠️ detect_overlaps_framelevel() runs separate diarization - consider using extract_overlaps_from_diarization()")
    logger.info("🔍 Detecting overlapping speech...")
    start = time.time()
    
    if not MODELS.diarization_pipeline:
        logger.warning("⚠️ Diarization pipeline not loaded, skipping OSD")
        return [], np.array([])
    
    try:
        # Run diarization - this is now REDUNDANT if pipeline uses unified approach
        diarization = MODELS.diarization_pipeline(audio_path)
        
        # Use optimized overlap extraction
        overlap_segments, _ = extract_overlaps_from_diarization(diarization, config)
        
        elapsed = time.time() - start
        total_overlap = sum(s['duration'] for s in overlap_segments)
        
        logger.info(f"✅ OSD: {len(overlap_segments)} overlap regions, {total_overlap:.1f}s total ({elapsed:.1f}s)")
        
        return overlap_segments, np.array([])
        
    except Exception as e:
        logger.error(f"OSD failed: {e}")
        import traceback
        traceback.print_exc()
        return [], np.array([])


def _find_overlapping_regions_manual(diarization, config) -> List[Dict]:
    """
    Manual O(n²) fallback for finding overlapping regions.
    
    NOTE: Prefer using extract_overlaps_from_diarization() which uses
    pyannote's optimized get_overlap() method.
    
    Args:
        diarization: pyannote Annotation or DiarizeOutput object
        config: Pipeline configuration
    
    Returns:
        List of overlap segments
    """
    # === PYANNOTE 4.x COMPATIBILITY ===
    diarization = _get_annotation_from_result(diarization)
    
    # Get all segments with their speakers
    segments = []
    for segment, _, speaker in diarization.itertracks(yield_label=True):
        segments.append({
            'start': segment.start,
            'end': segment.end,
            'speaker': speaker
        })
    
    if not segments:
        return []
    
    # Sort by start time
    segments.sort(key=lambda x: x['start'])
    
    # Find overlapping regions - O(n²) but with early termination
    overlap_regions = []
    
    for i, seg1 in enumerate(segments):
        for seg2 in segments[i+1:]:
            # If seg2 starts after seg1 ends, no more overlaps possible
            if seg2['start'] >= seg1['end']:
                break
            
            # Check if different speakers overlap
            if seg1['speaker'] != seg2['speaker']:
                overlap_start = max(seg1['start'], seg2['start'])
                overlap_end = min(seg1['end'], seg2['end'])
                
                if overlap_end > overlap_start:
                    duration = overlap_end - overlap_start
                    if duration >= config.overlap_min_duration:
                        overlap_regions.append({
                            'start': overlap_start,
                            'end': overlap_end,
                            'duration': duration,
                            'type': 'overlap',
                            'speakers': [seg1['speaker'], seg2['speaker']]
                        })
    
    # Merge overlapping overlap regions
    if overlap_regions:
        overlap_regions = _merge_close_segments(overlap_regions, max_gap=0.1)
    
    # Pad boundaries for safety
    padding_sec = config.overlap_padding_ms / 1000.0
    overlap_regions = _pad_segments(overlap_regions, padding_sec)
    
    logger.debug(f"   Found {len(overlap_regions)} overlap regions from {len(segments)} segments (manual)")
    
    return overlap_regions


def _frames_to_segments(
    frame_indices: List[int],
    frame_duration: float,
    min_duration: float = 0.2
) -> List[Dict]:
    """Convert list of frame indices to contiguous segments."""
    if not frame_indices:
        return []
    
    segments = []
    start_frame = frame_indices[0]
    prev_frame = frame_indices[0]
    
    for frame in frame_indices[1:]:
        # Check for gap
        if frame > prev_frame + 1:
            # Close current segment
            seg_start = start_frame * frame_duration
            seg_end = (prev_frame + 1) * frame_duration
            duration = seg_end - seg_start
            
            if duration >= min_duration:
                segments.append({
                    'start': seg_start,
                    'end': seg_end,
                    'duration': duration,
                    'type': 'overlap'
                })
            
            # Start new segment
            start_frame = frame
        
        prev_frame = frame
    
    # Close final segment
    seg_start = start_frame * frame_duration
    seg_end = (prev_frame + 1) * frame_duration
    duration = seg_end - seg_start
    
    if duration >= min_duration:
        segments.append({
            'start': seg_start,
            'end': seg_end,
            'duration': duration,
            'type': 'overlap'
        })
    
    return segments


def _pad_segments(segments: List[Dict], padding: float) -> List[Dict]:
    """Add padding to segment boundaries (safety margin)."""
    padded = []
    for seg in segments:
        padded.append({
            'start': max(0, seg['start'] - padding),
            'end': seg['end'] + padding,
            'duration': seg['duration'] + 2 * padding,
            'type': seg.get('type', 'overlap')
        })
    return padded


def _merge_close_segments(segments: List[Dict], max_gap: float = 0.2) -> List[Dict]:
    """Merge segments that are very close together."""
    if not segments:
        return []
    
    sorted_segs = sorted(segments, key=lambda x: x['start'])
    merged = [sorted_segs[0].copy()]
    
    for seg in sorted_segs[1:]:
        last = merged[-1]
        gap = seg['start'] - last['end']
        
        if gap <= max_gap:
            # Merge
            last['end'] = seg['end']
            last['duration'] = last['end'] - last['start']
        else:
            merged.append(seg.copy())
    
    return merged


def detect_nonspeech_regions(
    vad_segments: List[Dict],
    total_duration: float,
    min_silence: float = 0.3
) -> List[Dict]:
    """
    Detect non-speech regions (gaps between VAD segments).
    
    These are silence/noise regions that should be marked as unusable.
    
    Args:
        vad_segments: Speech segments from VAD
        total_duration: Total audio duration
        min_silence: Minimum silence duration to report
    
    Returns:
        List of non-speech segments
    """
    nonspeech = []
    
    # Check start of audio
    if vad_segments and vad_segments[0]['start'] >= min_silence:
        nonspeech.append({
            'start': 0.0,
            'end': vad_segments[0]['start'],
            'duration': vad_segments[0]['start'],
            'type': 'non_speech'
        })
    
    # Gaps between segments
    for i in range(len(vad_segments) - 1):
        gap_start = vad_segments[i]['end']
        gap_end = vad_segments[i + 1]['start']
        gap_duration = gap_end - gap_start
        
        if gap_duration >= min_silence:
            nonspeech.append({
                'start': gap_start,
                'end': gap_end,
                'duration': gap_duration,
                'type': 'non_speech'
            })
    
    # Check end of audio
    if vad_segments and total_duration - vad_segments[-1]['end'] >= min_silence:
        nonspeech.append({
            'start': vad_segments[-1]['end'],
            'end': total_duration,
            'duration': total_duration - vad_segments[-1]['end'],
            'type': 'non_speech'
        })
    
    return nonspeech


def remove_overlap_regions(
    speech_segments: List[Dict],
    overlap_segments: List[Dict]
) -> List[Dict]:
    """
    Remove parts of speech segments that overlap with detected overlap regions.
    
    Strategy: If >50% of segment is in overlap, discard entirely.
    Otherwise, keep the non-overlapping parts.
    
    Args:
        speech_segments: Speech segments from VAD/diarization
        overlap_segments: Detected overlap regions
    
    Returns:
        Clean single-speaker segments only
    """
    if not overlap_segments:
        return speech_segments
    
    clean = []
    
    for speech_seg in speech_segments:
        # Calculate total overlap
        total_overlap_amount = 0
        
        for overlap_seg in overlap_segments:
            overlap_start = max(speech_seg['start'], overlap_seg['start'])
            overlap_end = min(speech_seg['end'], overlap_seg['end'])
            
            if overlap_end > overlap_start:
                total_overlap_amount += overlap_end - overlap_start
        
        # Check overlap percentage
        overlap_pct = total_overlap_amount / speech_seg['duration'] if speech_seg['duration'] > 0 else 0
        
        if overlap_pct < 0.5:
            # Keep segment (mostly clean)
            clean.append(speech_seg)
        else:
            # Discard (too much overlap)
            logger.debug(f"   Discarding segment {speech_seg['start']:.2f}-{speech_seg['end']:.2f} ({overlap_pct*100:.0f}% overlap)")
    
    return clean


def estimate_segment_quality(audio_chunk: np.ndarray, sr: int) -> Dict:
    """
    Estimate audio quality metrics for TTS usability filtering.
    
    === ADDED (v6.1) ===
    Quality metrics help identify segments that are technically single-speaker
    but acoustically unsuitable for TTS training (noisy, clipped, etc.)
    
    Args:
        audio_chunk: Audio samples (numpy array, float32, [-1, 1] range)
        sr: Sample rate
    
    Returns:
        Dict with quality metrics:
        - snr_db: Estimated signal-to-noise ratio
        - has_clipping: True if clipping detected
        - rms_db: RMS level in dB
        - quality_score: Overall quality score (0-1)
        - is_usable: True if meets minimum quality thresholds
    """
    # Ensure we have data
    if len(audio_chunk) < sr * 0.1:  # Less than 100ms
        return {
            'snr_db': 0, 'has_clipping': False, 'rms_db': -60,
            'quality_score': 0, 'is_usable': False
        }
    
    # RMS level
    rms = np.sqrt(np.mean(audio_chunk**2) + 1e-10)
    rms_db = 20 * np.log10(rms + 1e-10)
    
    # Clipping detection (samples near ±1.0)
    clipping_threshold = 0.99
    clipping_ratio = np.mean(np.abs(audio_chunk) > clipping_threshold)
    has_clipping = clipping_ratio > 0.001  # >0.1% clipped samples
    
    # SNR estimation (simplified approach)
    # Use percentile-based noise floor estimation
    sorted_abs = np.sort(np.abs(audio_chunk))
    noise_floor = np.mean(sorted_abs[:len(sorted_abs)//10])  # Bottom 10%
    signal_level = np.mean(sorted_abs[len(sorted_abs)*9//10:])  # Top 10%
    snr_db = 20 * np.log10((signal_level + 1e-10) / (noise_floor + 1e-10))
    
    # Quality score (0-1)
    # Good TTS audio typically has: SNR > 20dB, RMS > -40dB, no clipping
    snr_score = min(1.0, max(0.0, (snr_db - 10) / 30))  # 10-40dB range
    rms_score = min(1.0, max(0.0, (rms_db + 50) / 30))  # -50 to -20dB range
    clip_penalty = 0.5 if has_clipping else 0.0
    
    quality_score = max(0.0, (snr_score * 0.5 + rms_score * 0.5) - clip_penalty)
    
    # Minimum thresholds for TTS usability
    is_usable = (
        snr_db >= 15 and  # At least 15dB SNR
        rms_db >= -45 and  # Not too quiet
        not has_clipping  # No clipping
    )
    
    return {
        'snr_db': round(snr_db, 1),
        'has_clipping': has_clipping,
        'rms_db': round(rms_db, 1),
        'quality_score': round(quality_score, 3),
        'is_usable': is_usable
    }


def filter_by_quality(
    segments: List[Dict],
    audio_path: str,
    min_snr_db: float = 15.0,
    min_quality_score: float = 0.3,
    audio_buffer=None
) -> Tuple[List[Dict], List[Dict]]:
    """
    Filter segments by audio quality, marking low-quality as unusable.
    
    === OPTIMIZATION (v6.2) ===
    Accept audio_buffer to avoid file re-read.
    
    Args:
        segments: List of segments to filter
        audio_path: Path to audio file
        min_snr_db: Minimum SNR threshold
        min_quality_score: Minimum quality score threshold
        audio_buffer: Optional AudioBuffer (avoids file re-read)
    
    Returns:
        (usable_segments, unusable_segments)
    """
    # Use buffer if provided, otherwise load from file
    if audio_buffer is not None:
        waveform_np = audio_buffer.waveform_np
        sr = audio_buffer.sample_rate
    else:
        import torchaudio
        waveform, sr = torchaudio.load(audio_path)
        waveform_np = waveform.squeeze(0).numpy()
    
    usable = []
    unusable = []
    
    for seg in segments:
        # Skip already marked segments
        if seg.get('speaker') in ['OVERLAP', 'NON_SPEECH']:
            unusable.append(seg)
            continue
        
        start_sample = int(seg['start'] * sr)
        end_sample = int(seg['end'] * sr)
        chunk = waveform_np[start_sample:end_sample]
        
        quality = estimate_segment_quality(chunk, sr)
        
        # Add quality info to segment
        seg['quality'] = quality
        
        if quality['snr_db'] >= min_snr_db and quality['quality_score'] >= min_quality_score:
            usable.append(seg)
        else:
            seg['status'] = 'unusable'
            seg['unusable_reason'] = 'low_quality'
            unusable.append(seg)
    
    logger.info(f"   Quality filter: {len(usable)} usable, {len(unusable)} unusable")
    
    return usable, unusable


def split_on_overlaps(
    speech_segments: List[Dict],
    overlap_segments: List[Dict],
    min_segment: float = 0.2
) -> List[Dict]:
    """
    Split speech segments around overlap regions.
    
    More sophisticated than remove_overlap_regions - actually splits
    segments to salvage clean portions.
    
    Args:
        speech_segments: Speech segments to split
        overlap_segments: Overlap regions to split around
        min_segment: Minimum resulting segment length
    
    Returns:
        List of clean segments (overlaps removed/split)
    """
    if not overlap_segments:
        return speech_segments
    
    clean = []
    
    for speech_seg in speech_segments:
        seg_start = speech_seg['start']
        seg_end = speech_seg['end']
        
        # Find all overlaps that intersect this segment
        intersecting = []
        for ovl in overlap_segments:
            if ovl['start'] < seg_end and ovl['end'] > seg_start:
                intersecting.append(ovl)
        
        if not intersecting:
            # No overlaps - keep as is
            clean.append(speech_seg)
            continue
        
        # Sort overlaps by start time
        intersecting.sort(key=lambda x: x['start'])
        
        # Extract clean portions between overlaps
        current_start = seg_start
        
        for ovl in intersecting:
            # Clean portion before this overlap
            if ovl['start'] > current_start:
                duration = ovl['start'] - current_start
                if duration >= min_segment:
                    clean.append({
                        'start': current_start,
                        'end': ovl['start'],
                        'duration': duration,
                        'speaker': speech_seg.get('speaker', 'UNKNOWN'),
                        'is_split': True
                    })
            
            # Move past this overlap
            current_start = max(current_start, ovl['end'])
        
        # Clean portion after last overlap
        if seg_end > current_start:
            duration = seg_end - current_start
            if duration >= min_segment:
                clean.append({
                    'start': current_start,
                    'end': seg_end,
                    'duration': duration,
                    'speaker': speech_seg.get('speaker', 'UNKNOWN'),
                    'is_split': True
                })
    
    return clean


# === OVERLAP DENSITY FILTER (v6.9) ===
def filter_overlap_sandwich_segments(
    segments: List[Dict],
    config
) -> List[Dict]:
    """
    Mark short segments sandwiched between overlaps as unusable.
    
    === OVERLAP DENSITY DETECTION (v6.9) ===
    Problem: pyannote's overlap detection may leave short "clean" segments
    between two overlap regions. These are often:
    - Brief pauses in overlapping speech
    - Model uncertainty at speaker transitions
    - Not truly clean single-speaker audio
    
    Solution: If a short segment is sandwiched between two overlaps,
    mark it as unusable due to "overlap_proximity".
    
    Criteria:
    - Segment is currently marked as usable (not overlap/non_speech)
    - Previous segment is OVERLAP (within max_gap seconds)
    - Next segment is OVERLAP (within max_gap seconds)
    - Segment duration < max_duration
    
    Args:
        segments: List of all segments (sorted by time)
        config: Pipeline config with overlap_density_* settings
    
    Returns:
        Updated segments list with overlap_proximity segments marked unusable
    """
    if not getattr(config, 'overlap_density_filter', True):
        return segments
    
    max_gap = getattr(config, 'overlap_density_max_gap', 1.0)
    max_duration = getattr(config, 'overlap_density_max_duration', 3.0)
    
    sandwiched_count = 0
    sandwiched_duration = 0.0
    
    # Work on a copy to avoid mutation issues
    result = []
    
    for i, seg in enumerate(segments):
        seg_copy = seg.copy()
        
        # Skip first/last segments (can't be sandwiched)
        if i == 0 or i == len(segments) - 1:
            result.append(seg_copy)
            continue
        
        # Skip if already overlap or non_speech
        if seg.get('speaker') in ['OVERLAP', 'NON_SPEECH']:
            result.append(seg_copy)
            continue
        
        # Skip if already marked unusable
        if seg.get('status') == 'unusable':
            result.append(seg_copy)
            continue
        
        prev_seg = segments[i - 1]
        next_seg = segments[i + 1]
        
        # Check if sandwiched between overlaps
        prev_is_overlap = prev_seg.get('speaker') == 'OVERLAP'
        next_is_overlap = next_seg.get('speaker') == 'OVERLAP'
        
        if not (prev_is_overlap and next_is_overlap):
            result.append(seg_copy)
            continue
        
        # Check gaps (should be small/zero for truly sandwiched segments)
        gap_to_prev = seg['start'] - prev_seg['end']
        gap_to_next = next_seg['start'] - seg['end']
        
        prev_close = gap_to_prev <= max_gap
        next_close = gap_to_next <= max_gap
        
        if not (prev_close and next_close):
            result.append(seg_copy)
            continue
        
        # Check segment duration
        is_short = seg['duration'] < max_duration
        
        if is_short:
            # Mark as unusable due to overlap proximity
            seg_copy['status'] = 'unusable'
            seg_copy['unusable_reason'] = 'overlap_proximity'
            sandwiched_count += 1
            sandwiched_duration += seg['duration']
            logger.debug(f"   Overlap sandwich: {seg['start']:.3f}-{seg['end']:.3f} ({seg['duration']:.2f}s) {seg.get('speaker', 'UNKNOWN')}")
        
        result.append(seg_copy)
    
    if sandwiched_count > 0:
        logger.info(f"   🥪 Overlap density filter: {sandwiched_count} segments marked unusable ({sandwiched_duration:.1f}s)")
    
    return result


# === Legacy compatibility function ===
def detect_overlap_and_nonspeech(
    audio_path: str,
    vad_segments: List[Dict],
    config
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
    """
    Combined overlap and non-speech detection (legacy interface).
    
    Returns:
        (clean_segments, overlap_segments, nonspeech_segments)
    """
    # Get audio duration
    waveform, sr = torchaudio.load(audio_path)
    total_duration = waveform.shape[1] / sr
    
    # Detect overlaps
    overlap_segments, _ = detect_overlaps_framelevel(audio_path, config)
    
    # Detect non-speech
    nonspeech_segments = detect_nonspeech_regions(vad_segments, total_duration)
    
    # Remove overlapping portions from VAD segments
    clean_segments = remove_overlap_regions(vad_segments, overlap_segments)
    
    logger.info(f"   Clean speech: {len(clean_segments)} segments (from {len(vad_segments)} VAD)")
    
    return clean_segments, overlap_segments, nonspeech_segments
