#!/usr/bin/env python3
"""
Speaker diarization with GPU optimization and compute monitoring.

=== OPTIMIZATION (v6.2) ===
Key improvements:
- In-memory chunk processing: Pass waveform tensors directly to pyannote
- No disk writes: Eliminates chunk WAV file I/O
- VAD-aware chunking (cut at silence)
- Compute utilization tracking
- Memory-efficient processing

=== EDGE CASE HANDLING (v6.7) ===
- Retry logic for transient GPU failures
- OOM recovery with cache clearing
- Graceful fallback on persistent failures
"""

import time
import logging
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Union
import numpy as np
import torch
import torchaudio
from src.models import MODELS

logger = logging.getLogger("FastPipelineV6.Diarization")


def find_vad_safe_cut(vad_segments: List[Dict], target_time: float, window: float = 30.0) -> float:
    """
    Find silence point near target time for chunk boundary.
    
    Strategy: Find largest gap within ±window of target_time.
    This ensures we don't cut in the middle of speech.
    """
    best_cut = target_time
    best_gap = 0.0
    
    for i in range(len(vad_segments) - 1):
        gap_start = vad_segments[i]['end']
        gap_end = vad_segments[i + 1]['start']
        gap_size = gap_end - gap_start
        gap_center = (gap_start + gap_end) / 2
        
        if abs(gap_center - target_time) <= window and gap_size > best_gap:
            best_gap = gap_size
            best_cut = gap_center
    
    return best_cut


def create_chunks(
    audio_path: str, 
    vad_segments: List[Dict], 
    config,
    audio_buffer=None,
    in_memory: bool = True
) -> List[Tuple[Union[str, Dict], float, float]]:
    """
    Create VAD-aware audio chunks for diarization.
    
    === OPTIMIZATION (v6.2) ===
    - in_memory=True (default): Return waveform tensors directly (no disk I/O!)
    - in_memory=False: Legacy disk-based WAV files
    
    Strategy:
    1. Target chunk size from config
    2. Find optimal cut points at silence boundaries
    3. Never cut in the middle of speech
    
    Args:
        audio_path: Path to audio file
        vad_segments: Speech segments from VAD
        config: Pipeline configuration
        audio_buffer: Optional AudioBuffer (avoids file re-read)
        in_memory: If True, return dicts with waveform tensors instead of file paths
    
    Returns:
        List of (chunk_data, start_time, duration) where chunk_data is either:
        - str: file path (in_memory=False)
        - dict: {"waveform": tensor, "sample_rate": int} (in_memory=True)
    """
    logger.info(f"✂️ Creating VAD-aware chunks ({'in-memory' if in_memory else 'disk'})...")
    
    # Use audio buffer if provided, otherwise load from file
    if audio_buffer is not None:
        waveform_np = audio_buffer.waveform_np
        sr = audio_buffer.sample_rate
        waveform = torch.from_numpy(waveform_np).unsqueeze(0)  # Add channel dim
    else:
        waveform, sr = torchaudio.load(audio_path)
        waveform_np = waveform.squeeze(0).numpy()
    
    total_duration = waveform.shape[1] / sr
    
    # Only create chunk directory if not in-memory mode
    if not in_memory:
        chunk_dir = Path(audio_path).parent / "chunks"
        chunk_dir.mkdir(exist_ok=True)
        # Register for cleanup on crash
        from src.audio_buffer import TEMP_MANAGER
        TEMP_MANAGER.register(chunk_dir)
    
    chunks = []
    start = 0.0
    idx = 0
    
    while start < total_duration:
        target_end = min(start + config.chunk_duration, total_duration)
        
        # Find good cut point in silence
        if target_end < total_duration:
            end = find_vad_safe_cut(vad_segments, target_end)
            # Ensure minimum chunk size
            if end - start < config.min_chunk_duration:
                end = min(start + config.chunk_duration, total_duration)
        else:
            end = total_duration
        
        # Extract chunk
        start_sample = int(start * sr)
        end_sample = int(end * sr)
        chunk_waveform = waveform[:, start_sample:end_sample]
        
        if in_memory:
            # === IN-MEMORY: Return dict that pyannote can process directly ===
            # pyannote Pipeline accepts {"waveform": tensor, "sample_rate": int}
            chunk_data = {
                "waveform": chunk_waveform,
                "sample_rate": sr
            }
            chunks.append((chunk_data, start, end - start))
        else:
            # === LEGACY: Write WAV file ===
            chunk_path = chunk_dir / f"chunk_{idx:03d}.wav"
            torchaudio.save(str(chunk_path), chunk_waveform, sr)
            chunks.append((str(chunk_path), start, end - start))
        
        start = end
        idx += 1
    
    mode_str = "in-memory" if in_memory else f"to disk"
    logger.info(f"✅ Created {len(chunks)} chunks ({mode_str}, VAD-aware boundaries)")
    return chunks


def _get_annotation_from_result(result):
    """
    Extract Annotation object from diarization result.
    
    === PYANNOTE 4.x COMPATIBILITY ===
    - pyannote 3.x: Returns Annotation directly
    - pyannote 4.x: Returns DiarizeOutput with .speaker_diarization attribute
    """
    # Check if it's a DiarizeOutput (pyannote 4.x)
    if hasattr(result, 'speaker_diarization'):
        return result.speaker_diarization
    # Otherwise it's already an Annotation (pyannote 3.x)
    return result


def diarize_chunk(
    chunk_data: Union[str, Dict], 
    chunk_start: float, 
    config, 
    extract_overlaps: bool = False,
    chunk_idx: int = 0,
    max_retries: int = 2
) -> Tuple[List[Dict], List[Dict]]:
    """
    Diarize a single chunk using the hot model.
    
    === OPTIMIZATION (v6.2) ===
    Supports both file paths AND in-memory waveform dicts.
    No more disk I/O for chunks!
    
    === PYANNOTE 4.x COMPATIBILITY ===
    Handles both pyannote 3.x (returns Annotation) and 4.x (returns DiarizeOutput).
    
    === EDGE CASE HANDLING (v6.7) ===
    - Retry logic for transient GPU failures
    - OOM recovery with cache clearing
    - Graceful fallback on persistent failures
    
    Uses global MODELS singleton to avoid reloading.
    Returns segments with global timestamps, and optionally overlap regions.
    
    Args:
        chunk_data: Either file path (str) OR dict {"waveform": tensor, "sample_rate": int}
        chunk_start: Start time offset for global timestamps
        config: Pipeline configuration
        extract_overlaps: If True, also extract overlap regions (for unified OSD+diarization)
        chunk_idx: Chunk index for speaker naming (used for in-memory chunks)
        max_retries: Number of retries on failure (default: 2)
    
    Returns:
        (segments, overlaps): segments list, and overlaps list (empty if extract_overlaps=False)
    """
    last_error = None
    
    for attempt in range(max_retries + 1):
        try:
            # === SUPPORT BOTH FILE PATH AND IN-MEMORY DICT ===
            # pyannote Pipeline accepts both str (path) and dict {"waveform": tensor, "sample_rate": int}
            raw_result = MODELS.diarization_pipeline(
                chunk_data,
                min_speakers=config.min_speakers,
                max_speakers=config.max_speakers
            )
            
            # === PYANNOTE 4.x COMPATIBILITY: Extract Annotation from DiarizeOutput ===
            result = _get_annotation_from_result(raw_result)
            
            # Determine chunk identifier for speaker naming
            if isinstance(chunk_data, str):
                chunk_id = Path(chunk_data).stem
            else:
                chunk_id = f"chunk_{chunk_idx:03d}"
            
            overlaps = []
            
            # === UNIFIED OSD+DIARIZATION: Extract overlaps if requested ===
            if extract_overlaps:
                from src.overlap_detection import extract_overlaps_from_diarization
                chunk_overlaps, clean_segments = extract_overlaps_from_diarization(result, config)
                
                # Adjust overlap timestamps to global time
                for ovl in chunk_overlaps:
                    overlaps.append({
                        'start': chunk_start + ovl['start'],
                        'end': chunk_start + ovl['end'],
                        'duration': ovl['duration'],
                        'type': 'overlap'
                    })
                
                # Use clean segments (overlaps already removed)
                segments = []
                for seg in clean_segments:
                    global_start = chunk_start + seg['start']
                    global_end = chunk_start + seg['end']
                    duration = global_end - global_start
                    
                    if duration >= config.min_segment_duration:
                        segments.append({
                            'start': global_start,
                            'end': global_end,
                            'duration': duration,
                            'speaker': f"{chunk_id}_{seg['speaker']}",
                        })
            else:
                # Original behavior: just extract all segments
                segments = []
                for turn, _, speaker in result.itertracks(yield_label=True):
                    global_start = chunk_start + turn.start
                    global_end = chunk_start + turn.end
                    duration = global_end - global_start
                    
                    if duration >= config.min_segment_duration:
                        segments.append({
                            'start': global_start,
                            'end': global_end,
                            'duration': duration,
                            'speaker': f"{chunk_id}_{speaker}",
                        })
            
            return segments, overlaps
            
        except RuntimeError as e:
            # === EDGE CASE: GPU OOM or CUDA errors ===
            last_error = e
            error_str = str(e).lower()
            
            if "out of memory" in error_str or "cuda" in error_str:
                logger.warning(f"   Chunk {chunk_idx} attempt {attempt + 1}: GPU error, clearing cache and retrying")
                MODELS.clear_cache(aggressive=True)
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    torch.cuda.synchronize()
                
                if attempt < max_retries:
                    time.sleep(1)  # Brief pause before retry
                    continue
            else:
                logger.error(f"   Chunk {chunk_idx} RuntimeError: {e}")
                if attempt < max_retries:
                    continue
                    
        except Exception as e:
            last_error = e
            logger.warning(f"   Chunk {chunk_idx} attempt {attempt + 1} failed: {e}")
            if attempt < max_retries:
                MODELS.clear_cache(aggressive=True)
                continue
    
    # All retries failed
    logger.error(f"Chunk {chunk_idx} diarization failed after {max_retries + 1} attempts: {last_error}")
    return [], []


def run_diarization(
    chunks: List[Tuple[str, float, float]], 
    config,
    extract_overlaps: bool = False
) -> Tuple[List[Dict], List[Dict]]:
    """
    Run diarization on all chunks with progress tracking.
    
    === OPTIMIZATION (v6.1) ===
    When extract_overlaps=True, this function extracts overlap regions
    from the diarization results in a SINGLE PASS, eliminating the need
    for a separate OSD stage that would run diarization again.
    
    Args:
        chunks: List of (chunk_path, chunk_start, chunk_duration)
        config: Pipeline configuration
        extract_overlaps: If True, also extract and return overlap regions
    
    Returns:
        (all_segments, all_overlaps): segments and overlaps (overlaps empty if not extracted)
    """
    mode_str = " + OSD" if extract_overlaps else ""
    logger.info(f"🎙️ Running diarization{mode_str} ({len(chunks)} chunks, model: {MODELS._model_name})...")
    start = time.time()
    
    all_segments = []
    all_overlaps = []
    
    # Process chunks serially (GPU-bound)
    for i, (chunk_data, chunk_start, chunk_dur) in enumerate(chunks):
        t0 = time.time()
        # Pass chunk index for speaker naming (needed for in-memory chunks)
        segments, overlaps = diarize_chunk(
            chunk_data, chunk_start, config, 
            extract_overlaps=extract_overlaps,
            chunk_idx=i
        )
        all_segments.extend(segments)
        all_overlaps.extend(overlaps)
        
        elapsed = time.time() - t0
        overlap_str = f", {len(overlaps)} overlaps" if extract_overlaps else ""
        logger.info(f"   Chunk {i+1}/{len(chunks)}: {len(segments)} segments{overlap_str} ({elapsed:.1f}s, {chunk_dur:.0f}s audio)")
        
        # Clear cache more aggressively to prevent OOM
        if (i + 1) % config.clear_cache_every_n_chunks == 0:
            MODELS.clear_cache(aggressive=True)
    
    # Final cleanup
    MODELS.clear_cache(aggressive=True)
    
    # Sort by start time
    all_segments.sort(key=lambda x: x['start'])
    all_overlaps.sort(key=lambda x: x['start'])
    
    elapsed = time.time() - start
    unique_speakers = len(set(s['speaker'] for s in all_segments))
    
    overlap_str = f" | {len(all_overlaps)} overlaps" if extract_overlaps else ""
    logger.info(f"✅ Diarization{mode_str}: {elapsed:.1f}s | {len(all_segments)} segments{overlap_str} | ~{unique_speakers} fragments")
    
    return all_segments, all_overlaps
