#!/usr/bin/env python3
"""
Frame-Level Speaker Segmentation for Micro-Level Speaker Detection.

Key insight from instructions.md:
- PyAnnote segmentation-3.0 outputs probabilities every ~17ms frame
- Standard 1.5s chunking is "POISON" for 0.4s events
- We need raw frame-level activations for crisp speaker changes

This module provides:
1. Frame-level speaker probabilities (17ms resolution)
2. Micro-turn detection (0.2-0.5s events)
3. Crisp speaker change boundaries
"""

import time
import logging
from typing import List, Dict, Tuple, Optional
import numpy as np
import torch
import torchaudio
from pyannote.audio import Inference
from pyannote.core import Segment

from src.models import MODELS

logger = logging.getLogger("FastPipelineV6.Segmentation")


def get_frame_level_activations(
    audio_path: str,
    config,
    step: float = 2.5,
    duration: float = 5.0
) -> Tuple[np.ndarray, float]:
    """
    Get raw frame-level speaker activations from segmentation model.
    
    PyAnnote segmentation-3.0 outputs:
    - Probabilities for up to 3 local speakers per frame
    - Frame rate: ~17ms (59 frames/second)
    
    Args:
        audio_path: Path to audio file
        config: Pipeline configuration
        step: Sliding window step (seconds)
        duration: Window duration (seconds)
    
    Returns:
        (activations, frame_rate): 
            activations: (num_frames, num_speakers) array
            frame_rate: Frames per second
    """
    if not MODELS.segmentation_model:
        logger.warning("⚠️ Segmentation model not loaded")
        return np.array([]), 0.0
    
    logger.info(f"🎯 Getting frame-level activations (17ms resolution)...")
    start = time.time()
    
    device = MODELS.get_device()
    
    # Create inference with sliding window
    inference = Inference(
        MODELS.segmentation_model,
        device=device,
        duration=duration,
        step=step,
    )
    
    try:
        # Run inference - returns SlidingWindowFeature
        segmentation = inference(audio_path)
        
        # Extract raw activations
        # Shape: (num_frames, num_local_speakers)
        activations = segmentation.data
        frame_rate = 1.0 / segmentation.sliding_window.step
        
        elapsed = time.time() - start
        logger.info(f"✅ Frame activations: {activations.shape[0]} frames @ {frame_rate:.1f} fps ({elapsed:.1f}s)")
        
        return activations, frame_rate
        
    except Exception as e:
        logger.error(f"Frame activation extraction failed: {e}")
        return np.array([]), 0.0


def detect_speaker_changes_framelevel(
    audio_path: str,
    vad_segments: List[Dict],
    config
) -> List[Dict]:
    """
    Detect crisp speaker changes at frame level (17ms resolution).
    
    Per instructions.md:
    - Use raw segmentation model output
    - Don't use 1.5s chunking (poison for 0.4s events)
    - Detect micro-turns (0.2-0.5s acknowledgments)
    
    Args:
        audio_path: Path to audio
        vad_segments: Speech segments from VAD
        config: Pipeline configuration
    
    Returns:
        List of speaker segments with crisp boundaries
    """
    logger.info("🎯 Detecting speaker changes (frame-level, 17ms resolution)...")
    start = time.time()
    
    if not MODELS.segmentation_model:
        logger.warning("⚠️ No segmentation model, falling back to diarization pipeline")
        return []
    
    # Get frame-level activations
    activations, frame_rate = get_frame_level_activations(
        audio_path, config,
        step=config.segmentation_step,
        duration=config.segmentation_duration
    )
    
    if len(activations) == 0:
        return []
    
    # Get audio duration
    waveform, sr = torchaudio.load(audio_path)
    total_duration = waveform.shape[1] / sr
    
    # Detect speaker turns from activations
    segments = _activations_to_segments(
        activations, 
        frame_rate, 
        total_duration,
        config.overlap_threshold,
        config.min_segment_duration
    )
    
    # Filter to only speech regions (from VAD)
    segments = _filter_to_vad_regions(segments, vad_segments)
    
    elapsed = time.time() - start
    logger.info(f"✅ Frame-level segmentation: {len(segments)} segments ({elapsed:.1f}s)")
    
    return segments


def _activations_to_segments(
    activations: np.ndarray,
    frame_rate: float,
    total_duration: float,
    threshold: float = 0.5,
    min_duration: float = 0.2
) -> List[Dict]:
    """
    Convert frame-level activations to speaker segments.
    
    Strategy:
    1. For each frame, find active speakers (prob > threshold)
    2. Track speaker continuity to form segments
    3. Handle overlaps (multiple speakers active)
    
    Args:
        activations: (num_frames, num_speakers) array
        frame_rate: Frames per second
        total_duration: Audio duration
        threshold: Activation threshold
        min_duration: Minimum segment duration
    
    Returns:
        List of segments with start/end/speaker
    """
    if len(activations) == 0:
        return []
    
    num_frames, num_speakers = activations.shape
    frame_duration = 1.0 / frame_rate
    
    segments = []
    
    # Track state for each speaker
    speaker_states = {}  # speaker_id -> {'start': float, 'active': bool}
    
    for frame_idx in range(num_frames):
        frame_time = frame_idx * frame_duration
        probs = activations[frame_idx]
        
        # Count active speakers (for overlap detection)
        active_count = sum(1 for p in probs if p > threshold)
        
        for spk_idx, prob in enumerate(probs):
            spk_id = f"LOCAL_{spk_idx}"
            is_active = prob > threshold
            
            if spk_id not in speaker_states:
                speaker_states[spk_id] = {'start': None, 'active': False}
            
            state = speaker_states[spk_id]
            
            # Speaker becomes active
            if is_active and not state['active']:
                state['start'] = frame_time
                state['active'] = True
            
            # Speaker becomes inactive
            elif not is_active and state['active']:
                if state['start'] is not None:
                    duration = frame_time - state['start']
                    if duration >= min_duration:
                        segments.append({
                            'start': state['start'],
                            'end': frame_time,
                            'duration': duration,
                            'speaker': spk_id,
                            'is_overlap': False  # Will be marked separately
                        })
                state['active'] = False
                state['start'] = None
    
    # Close any open segments at end
    end_time = min(num_frames * frame_duration, total_duration)
    for spk_id, state in speaker_states.items():
        if state['active'] and state['start'] is not None:
            duration = end_time - state['start']
            if duration >= min_duration:
                segments.append({
                    'start': state['start'],
                    'end': end_time,
                    'duration': duration,
                    'speaker': spk_id,
                    'is_overlap': False
                })
    
    # Sort by start time
    segments.sort(key=lambda x: x['start'])
    
    return segments


def _filter_to_vad_regions(
    segments: List[Dict],
    vad_segments: List[Dict]
) -> List[Dict]:
    """
    Filter speaker segments to only include VAD-detected speech regions.
    
    This removes false positives in silence regions.
    """
    if not vad_segments:
        return segments
    
    filtered = []
    
    for seg in segments:
        # Check if segment overlaps with any VAD region
        for vad in vad_segments:
            overlap_start = max(seg['start'], vad['start'])
            overlap_end = min(seg['end'], vad['end'])
            
            if overlap_end > overlap_start:
                # Segment overlaps with VAD region - keep it
                # Clip to VAD boundaries if needed
                clipped = seg.copy()
                clipped['start'] = max(seg['start'], vad['start'])
                clipped['end'] = min(seg['end'], vad['end'])
                clipped['duration'] = clipped['end'] - clipped['start']
                
                if clipped['duration'] >= 0.2:  # Still long enough
                    filtered.append(clipped)
                break
    
    return filtered


def detect_micro_turns(
    activations: np.ndarray,
    frame_rate: float,
    min_turn_duration: float = 0.2,
    max_turn_duration: float = 1.0,
    threshold: float = 0.5
) -> List[Dict]:
    """
    Specifically detect short micro-turns (0.2-1.0s).
    
    These are typically:
    - Backchannels ("mm-hmm", "yeah", "right")
    - Quick acknowledgments
    - Brief interruptions
    
    Args:
        activations: Frame-level activations
        frame_rate: Frames per second
        min_turn_duration: Minimum turn length (seconds)
        max_turn_duration: Maximum turn length (seconds)
        threshold: Activation threshold
    
    Returns:
        List of micro-turn segments
    """
    all_segments = _activations_to_segments(
        activations, frame_rate, 
        total_duration=len(activations) / frame_rate,
        threshold=threshold,
        min_duration=min_turn_duration
    )
    
    # Filter to only short segments
    micro_turns = [
        seg for seg in all_segments 
        if min_turn_duration <= seg['duration'] <= max_turn_duration
    ]
    
    logger.info(f"   Detected {len(micro_turns)} micro-turns (0.2-1.0s)")
    
    return micro_turns


def refine_segment_boundaries(
    segments: List[Dict],
    activations: np.ndarray,
    frame_rate: float,
    threshold: float = 0.5
) -> List[Dict]:
    """
    Refine segment boundaries using frame-level activations.
    
    Takes coarse segments (e.g., from chunked diarization) and
    sharpens their boundaries using raw frame activations.
    
    This is useful when:
    - Diarization output is coarse (1.5s+ boundaries)
    - Need precise boundaries for 0.4s events
    
    Args:
        segments: Coarse segments to refine
        activations: Frame-level activations
        frame_rate: Frames per second
        threshold: Activation threshold
    
    Returns:
        Segments with refined boundaries
    """
    if len(activations) == 0 or not segments:
        return segments
    
    frame_duration = 1.0 / frame_rate
    refined = []
    
    for seg in segments:
        # Find frame indices for this segment
        start_frame = int(seg['start'] * frame_rate)
        end_frame = int(seg['end'] * frame_rate)
        
        # Clamp to valid range
        start_frame = max(0, min(start_frame, len(activations) - 1))
        end_frame = max(0, min(end_frame, len(activations)))
        
        if start_frame >= end_frame:
            refined.append(seg)
            continue
        
        # Find actual activation boundaries
        # Search backwards from start to find true start
        true_start = start_frame
        for f in range(start_frame, max(0, start_frame - 30), -1):  # Search up to 30 frames back
            if activations[f].max() < threshold:
                true_start = f + 1
                break
        
        # Search forwards from end to find true end
        true_end = end_frame
        for f in range(end_frame, min(len(activations), end_frame + 30)):  # Search up to 30 frames forward
            if activations[f].max() < threshold:
                true_end = f
                break
        
        # Update segment with refined boundaries
        new_seg = seg.copy()
        new_seg['start'] = true_start * frame_duration
        new_seg['end'] = true_end * frame_duration
        new_seg['duration'] = new_seg['end'] - new_seg['start']
        
        if new_seg['duration'] >= 0.1:  # Still valid
            refined.append(new_seg)
    
    return refined

