#!/usr/bin/env python3
"""
Unified Chunk Embedding Cache - Compute ONCE, use EVERYWHERE.

=== OPTIMIZATION (v6.5 TURBO) ===
This module eliminates BOTH the "Double Compute" AND "VAD Overhead" problems:

BEFORE (v6.3):
- Stage 5 (Embeddings): Load model → extract embeddings for each segment
- Stage 6b (ChunkReassignment): Load SAME model → extract embeddings for 1.5s chunks → extract AGAIN for split portions
- Result: ~82s for a 75-min video, GPU called ~3600 times

v6.4 FIX:
- Single pass: Extract ALL 1.5s chunk embeddings in ONE mega-batch
- BUT: Still ran VAD on EVERY chunk (2783 VAD calls = ~55s wasted!)
- Result: 65s (2.2% GPU utilization - starving!)

v6.5 TURBO FIX:
- SKIP per-chunk VAD: Chunks from VAD segments ALREADY contain speech!
- Use fast energy-based eligibility (0.001s vs 0.02s per chunk)
- Pinned memory for faster CPU→GPU transfer
- Vectorized tensor creation (no Python for-loops)
- Batched GPU→CPU transfer (accumulate on GPU, transfer once)
- Result: ~8-10s predicted (80%+ GPU utilization)

=== HOW IT WORKS ===
1. Split ALL usable audio into 1.5s chunks (from VAD segments - KNOWN speech!)
2. Fast energy-based eligibility (skip silent chunks at edges)
3. Vectorized tensor creation with pinned memory
4. Extract ALL chunk embeddings in mega-batches (GPU saturated)
5. Cache embeddings in memory, indexed by time

Usage:
    cache = UnifiedChunkEmbeddingCache(audio_buffer, vad_segments, embedding_model)
    num_chunks = cache.build()  # One-time GPU work
    
    # For clustering (replaces Stage 5)
    seg_embedding = cache.get_segment_embedding(seg['start'], seg['end'])
    
    # For speaker change detection (replaces Stage 6b chunk processing)
    split_points = cache.detect_speaker_changes(seg['start'], seg['end'])
    
    # For reassignment (replaces portion embedding extraction)
    portion_emb = cache.get_segment_embedding(portion_start, portion_end)
"""

import time
import logging
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional, Any
import numpy as np
import torch
from sklearn.metrics.pairwise import cosine_similarity

logger = logging.getLogger("FastPipelineV6.UnifiedEmbeddings")


@dataclass
class ChunkInfo:
    """Information about a single 1.5s chunk."""
    idx: int
    start_time: float
    end_time: float
    start_sample: int
    end_sample: int
    speech_ratio: float
    embedding: Optional[np.ndarray] = None


@dataclass
class UnifiedEmbeddingConfig:
    """Configuration for unified embedding extraction."""
    chunk_duration: float = 1.5  # 1.5s chunks (Kimi-Audio standard)
    min_speech_ratio: float = 0.6  # From ChatGPT VAD eligibility
    batch_size: int = 128  # Aggressive batching for GPU saturation
    vram_headroom_gb: float = 1.0  # Less conservative headroom
    max_batch_size: int = 256  # Upper limit
    min_batch_size: int = 32  # Lower limit
    
    # Speaker change detection thresholds (from chunk_reassignment)
    severe_threshold: float = 0.25  # Circuit-breaker (from Maya)
    normal_threshold: float = 0.40  # Look-ahead confirmed (from Stabilized)


@dataclass
class CacheBuildStats:
    """Statistics from cache building."""
    total_chunks: int = 0
    eligible_chunks: int = 0
    embeddings_extracted: int = 0
    build_time_sec: float = 0.0
    gpu_batch_count: int = 0
    avg_batch_size: float = 0.0


class UnifiedChunkEmbeddingCache:
    """
    Single source of truth for ALL embeddings in the pipeline.
    
    Computes 1.5s chunk embeddings ONCE, provides instant access for:
    - Segment embeddings (clustering)
    - Speaker change detection (chunk reassignment)
    - Portion embeddings (reassignment confidence)
    
    Key insight: All downstream operations become CPU-only (cache lookups + numpy ops)
    """
    
    def __init__(
        self,
        audio_buffer,  # AudioBuffer instance
        vad_segments: List[Dict],  # VAD speech segments
        embedding_model,  # SpeechBrain ECAPA-TDNN
        vad_model=None,  # Silero VAD for speech ratio (optional)
        vad_utils=None,  # Silero utils
        config: Optional[UnifiedEmbeddingConfig] = None,
        device: Optional[torch.device] = None
    ):
        """
        Initialize the cache.
        
        Args:
            audio_buffer: AudioBuffer with waveform_np and sample_rate
            vad_segments: List of {'start': float, 'end': float} from VAD
            embedding_model: ECAPA-TDNN model
            vad_model: Optional Silero VAD for speech ratio computation
            vad_utils: Silero utilities (get_speech_timestamps, etc.)
            config: UnifiedEmbeddingConfig
            device: torch device (auto-detected if None)
        """
        self.audio_buffer = audio_buffer
        self.vad_segments = vad_segments
        self.embedding_model = embedding_model
        self.vad_model = vad_model
        self.vad_utils = vad_utils
        self.config = config or UnifiedEmbeddingConfig()
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Cache storage
        self.chunks: Dict[int, ChunkInfo] = {}  # chunk_idx -> ChunkInfo
        self._time_index: Dict[Tuple[float, float], int] = {}  # (start, end) -> chunk_idx
        self._built = False
        self._stats = CacheBuildStats()
    
    def build(self) -> int:
        """
        Build the cache by extracting ALL chunk embeddings in ONE batch.
        
        This is the ONLY GPU operation - everything else is cache lookups.
        
        Returns:
            Number of chunks with embeddings
        """
        if self._built:
            logger.info(f"   Cache already built ({len(self.chunks)} chunks)")
            return len(self.chunks)
        
        start_time = time.time()
        
        logger.info("=" * 70)
        logger.info("⚡ UNIFIED CHUNK EMBEDDING CACHE (Compute Once, Use Everywhere)")
        logger.info(f"   Chunk duration: {self.config.chunk_duration}s")
        logger.info(f"   Min speech ratio: {self.config.min_speech_ratio}")
        logger.info("=" * 70)
        
        # Step 1: Create ALL chunks from VAD segments
        t0 = time.time()
        chunk_list = self._create_all_chunks()
        self._stats.total_chunks = len(chunk_list)
        logger.info(f"   [TIMING] Chunk creation: {(time.time()-t0)*1000:.0f}ms")
        
        if not chunk_list:
            logger.warning("   No chunks to process")
            self._built = True
            return 0
        
        # Step 2: Filter by speech ratio (VAD eligibility)
        t0 = time.time()
        eligible_chunks = [c for c in chunk_list if c.speech_ratio >= self.config.min_speech_ratio]
        self._stats.eligible_chunks = len(eligible_chunks)
        logger.info(f"   [TIMING] Filtering: {(time.time()-t0)*1000:.0f}ms")
        
        logger.info(f"   Created {len(chunk_list)} chunks, {len(eligible_chunks)} eligible (≥{self.config.min_speech_ratio} speech)")
        
        if not eligible_chunks:
            logger.warning("   No eligible chunks after speech ratio filter")
            self._built = True
            return 0
        
        # Step 3: MEGA-BATCH GPU inference (THE OPTIMIZATION)
        t0 = time.time()
        embeddings = self._batch_extract_embeddings(eligible_chunks)
        self._stats.embeddings_extracted = len(embeddings)
        logger.info(f"   [TIMING] Embedding extraction: {(time.time()-t0)*1000:.0f}ms")
        
        # Step 4: Store in cache
        t0 = time.time()
        for chunk, emb in zip(eligible_chunks, embeddings):
            chunk.embedding = emb
            self.chunks[chunk.idx] = chunk
            # Index by time for fast segment->chunk mapping
            key = (round(chunk.start_time, 3), round(chunk.end_time, 3))
            self._time_index[key] = chunk.idx
        logger.info(f"   [TIMING] Cache storage: {(time.time()-t0)*1000:.0f}ms")
        
        self._built = True
        self._stats.build_time_sec = time.time() - start_time
        
        logger.info(f"✅ Cache built: {len(self.chunks)} embeddings in {self._stats.build_time_sec:.2f}s")
        logger.info(f"   GPU batches: {self._stats.gpu_batch_count}, avg batch size: {self._stats.avg_batch_size:.0f}")
        logger.info("=" * 70)
        
        return len(self.chunks)
    
    def _create_all_chunks(self) -> List[ChunkInfo]:
        """
        Create 1.5s chunks covering ALL usable audio from VAD segments.
        
        === v6.5 TURBO MAX ===
        Key insight: Chunks from VAD segments ALREADY contain speech!
        We trust VAD completely - no energy computation needed.
        
        This saves ~2.5s of Python loop overhead per video.
        """
        chunk_samples = int(self.config.chunk_duration * self.audio_buffer.sample_rate)
        sr = self.audio_buffer.sample_rate
        
        # Collect all chunk positions first, then create objects
        all_chunks = []
        
        for vad_seg in self.vad_segments:
            start_sample = int(vad_seg['start'] * sr)
            end_sample = int(vad_seg['end'] * sr)
            
            # Create chunks within this VAD region
            pos = start_sample
            while pos + chunk_samples <= end_sample:
                all_chunks.append(ChunkInfo(
                    idx=len(all_chunks),
                    start_time=pos / sr,
                    end_time=(pos + chunk_samples) / sr,
                    start_sample=pos,
                    end_sample=pos + chunk_samples,
                    speech_ratio=1.0  # Trust VAD - skip energy check!
                ))
                pos += chunk_samples
        
        return all_chunks
    
    def _compute_speech_ratio_fast(self, audio_chunk: np.ndarray) -> float:
        """
        FAST energy-based speech ratio estimate.
        
        === v6.5 TURBO ===
        NEVER call VAD here - it's too slow (0.02s per chunk = 55s for 2800 chunks!)
        Since chunks come from VAD segments, they ALREADY contain speech.
        We just need to filter out very quiet chunks (rare edge cases).
        """
        # Simple RMS energy - extremely fast
        energy = np.sqrt(np.mean(audio_chunk ** 2))
        # Lower threshold since we know this is from a speech region
        return min(1.0, energy / 0.03)
    
    def _batch_extract_embeddings(self, chunks: List[ChunkInfo]) -> List[np.ndarray]:
        """
        Extract embeddings for ALL chunks in optimized mega-batches.
        
        === v6.6 TURBO ULTRA OPTIMIZATION ===
        Key insight: Model inference is only ~1.6s for 2783 chunks!
        The 37s overhead came from per-batch tensor creation/transfer.
        
        Solution:
        1. Pre-allocate ALL audio on GPU in ONE transfer
        2. Slice from GPU tensor (no per-batch allocation!)
        3. Accumulate results on GPU before final transfer
        4. Use CUDA streams for async operations where possible
        
        Result: 37s → ~3s (12x faster!)
        """
        waveform = self.audio_buffer.waveform_np
        chunk_samples = int(self.config.chunk_duration * self.audio_buffer.sample_rate)
        
        # Use larger batch sizes for better GPU utilization
        batch_size = self._compute_optimal_batch_size(len(chunks))
        batch_size = min(batch_size * 2, 512, len(chunks))
        
        total_batches = (len(chunks) + batch_size - 1) // batch_size
        batch_sizes = []
        
        logger.info(f"   Extracting {len(chunks)} embeddings in {total_batches} batches (size={batch_size})")
        
        # === STEP 1: Pre-allocate ALL audio data (CPU) ===
        t0 = time.time()
        all_audio_data = np.zeros((len(chunks), chunk_samples), dtype=np.float32)
        for i, chunk in enumerate(chunks):
            chunk_len = chunk.end_sample - chunk.start_sample
            all_audio_data[i, :chunk_len] = waveform[chunk.start_sample:chunk.end_sample]
        t_cpu_prep = time.time() - t0
        logger.info(f"   [TIMING] CPU audio prep: {t_cpu_prep*1000:.0f}ms")
        
        # === STEP 2: Transfer ALL to GPU in ONE shot (v6.6 TURBO ULTRA) ===
        t0 = time.time()
        if torch.cuda.is_available():
            # Single transfer: CPU → pinned → GPU (non-blocking)
            gpu_audio = torch.from_numpy(all_audio_data).pin_memory().to(self.device, non_blocking=True)
            gpu_lens = torch.ones(len(chunks), dtype=torch.float32, device=self.device)
            torch.cuda.synchronize()  # Wait for transfer to complete
        else:
            gpu_audio = torch.from_numpy(all_audio_data)
            gpu_lens = torch.ones(len(chunks), dtype=torch.float32)
        t_transfer_up = time.time() - t0
        logger.info(f"   [TIMING] CPU→GPU transfer (ALL): {t_transfer_up*1000:.0f}ms")
        
        # === STEP 3: GPU inference with tensor slicing (NO per-batch allocation!) ===
        t0 = time.time()
        all_embeddings_gpu = []
        
        for batch_num in range(total_batches):
            batch_start = batch_num * batch_size
            batch_end = min(batch_start + batch_size, len(chunks))
            batch_sizes.append(batch_end - batch_start)
            
            # Slice from pre-allocated GPU tensor (zero-copy, no allocation!)
            batch_audio = gpu_audio[batch_start:batch_end]
            batch_lens = gpu_lens[batch_start:batch_end]
            
            try:
                with torch.no_grad():
                    embs = self.embedding_model.encode_batch(batch_audio, batch_lens)
                    
                    # Handle different output shapes
                    if embs.ndim == 3 and embs.shape[1] == 1:
                        embs = embs.squeeze(1)
                    elif embs.ndim == 3:
                        embs = embs.reshape(embs.shape[0], -1)
                    
                    all_embeddings_gpu.append(embs)
                    
            except Exception as e:
                logger.error(f"   Batch {batch_num + 1} failed: {e}")
                # Fallback: create zero embeddings
                all_embeddings_gpu.append(torch.zeros(batch_end - batch_start, 192, device=self.device))
        
        torch.cuda.synchronize()
        t_inference = time.time() - t0
        logger.info(f"   [TIMING] GPU inference ({total_batches} batches): {t_inference*1000:.0f}ms")
        
        # === STEP 4: Concatenate on GPU and transfer back in ONE shot ===
        t0 = time.time()
        all_embs_tensor = torch.cat(all_embeddings_gpu, dim=0)
        all_embeddings = all_embs_tensor.cpu().numpy()
        t_transfer_down = time.time() - t0
        logger.info(f"   [TIMING] GPU→CPU transfer (ALL): {t_transfer_down*1000:.0f}ms")
        
        # Cleanup
        del gpu_audio, gpu_lens, all_embeddings_gpu, all_embs_tensor
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        self._stats.gpu_batch_count = total_batches
        self._stats.avg_batch_size = np.mean(batch_sizes) if batch_sizes else 0
        
        # Log total breakdown
        t_total = t_cpu_prep + t_transfer_up + t_inference + t_transfer_down
        logger.info(f"   [TIMING] Total embedding extraction: {t_total*1000:.0f}ms "
                   f"(prep={t_cpu_prep*1000:.0f}, up={t_transfer_up*1000:.0f}, "
                   f"infer={t_inference*1000:.0f}, down={t_transfer_down*1000:.0f})")
        
        # Return as list for compatibility
        return list(all_embeddings)
    
    def _extract_single_embedding(self, audio: np.ndarray) -> np.ndarray:
        """Extract embedding for a single audio chunk (fallback)."""
        padded = torch.from_numpy(audio).unsqueeze(0).to(self.device)
        wav_lens = torch.tensor([1.0]).to(self.device)
        
        with torch.no_grad():
            emb = self.embedding_model.encode_batch(padded, wav_lens).cpu().numpy()
            if emb.ndim == 3:
                emb = emb.squeeze()
        
        del padded, wav_lens
        return emb
    
    def _compute_optimal_batch_size(self, num_chunks: int) -> int:
        """Compute aggressive batch size based on available VRAM."""
        if not torch.cuda.is_available():
            return min(self.config.min_batch_size, num_chunks)
        
        try:
            free, total = torch.cuda.mem_get_info()
            free_gb = free / (1024**3)
        except Exception:
            free_gb = 8.0  # Conservative fallback
        
        # Available after headroom
        available_gb = max(0.5, free_gb - self.config.vram_headroom_gb)
        
        # ECAPA-TDNN is lightweight: ~2MB per chunk in batch (with activations)
        # 1.5s @ 16kHz = 24000 samples = 96KB audio + ~2MB activations
        gb_per_chunk = 0.002
        
        vram_limited_batch = int(available_gb / gb_per_chunk)
        
        # Clamp to config range
        batch_size = max(self.config.min_batch_size, 
                        min(vram_limited_batch, self.config.max_batch_size, num_chunks))
        
        return batch_size
    
    # =========================================================================
    # PUBLIC API: Methods for downstream stages (NO GPU!)
    # =========================================================================
    
    def get_chunks_in_range(self, start: float, end: float) -> List[ChunkInfo]:
        """
        Get all chunks that fall within a time range.
        
        A chunk is "in range" if its center is within [start, end].
        """
        chunks = []
        for chunk in self.chunks.values():
            chunk_center = (chunk.start_time + chunk.end_time) / 2
            if chunk_center >= start and chunk_center <= end:
                chunks.append(chunk)
        return sorted(chunks, key=lambda c: c.start_time)
    
    def get_segment_embedding(self, start: float, end: float) -> Optional[np.ndarray]:
        """
        Compute segment embedding as mean of contained chunk embeddings.
        
        Used by: Clustering stage (replaces extract_embeddings_batched)
        
        This is INSTANT (CPU only) - no GPU call!
        """
        chunks = self.get_chunks_in_range(start, end)
        if not chunks:
            return None
        
        embeddings = [c.embedding for c in chunks if c.embedding is not None]
        if not embeddings:
            return None
        
        return np.mean(embeddings, axis=0)
    
    def detect_speaker_changes(
        self,
        seg_start: float,
        seg_end: float,
        severe_threshold: Optional[float] = None,
        normal_threshold: Optional[float] = None,
        require_persistence: bool = True,
        use_relative_valley: bool = True,
        persistence_count: int = 2,  # v7.8: Configurable (1=aggressive, 2=conservative)
    ) -> Tuple[List[float], List[str], int, int]:
        """
        Detect speaker changes within a segment with VALIDATION GATES.

        === v7.7: VALIDATION GATES TO PREVENT FALSE POSITIVE SPLITS ===

        Gates (all must pass for a split to be proposed):
        1. Persistence gate: Require low-sim evidence for ≥N consecutive chunks
        2. Relative valley gate: Only split if dip is statistical outlier for this segment
        3. Chunk boundary snapping: Ensure split aligns with actual low-energy points

        The old approach was too aggressive with fixed thresholds (0.40/0.25) that
        triggered on natural intra-speaker embedding variation.

        === v7.8: Configurable persistence_count ===
        - persistence_count=1: More aggressive, catches "yeah" and "ok" interruptions
        - persistence_count=2: Conservative, requires ~3s of evidence (original behavior)

        Args:
            seg_start: Segment start time
            seg_end: Segment end time
            severe_threshold: Override config severe threshold
            normal_threshold: Override config normal threshold
            require_persistence: Enable persistence check (default: True)
            use_relative_valley: Use per-segment adaptive threshold (default: True)
            persistence_count: Min consecutive low-sim chunks needed (v7.8: 1=aggressive, 2=conservative)

        Returns:
            (split_points, reasons, severe_count, lookahead_count)
        """
        severe_th = severe_threshold or self.config.severe_threshold
        normal_th = normal_threshold or self.config.normal_threshold
        
        chunks = self.get_chunks_in_range(seg_start, seg_end)
        
        if len(chunks) < 3:  # GATE: Need at least 3 chunks for meaningful analysis
            return [], [], 0, 0
        
        # Compute ALL adjacent similarities for this segment (for relative valley gate)
        adjacent_sims = []
        for i in range(len(chunks) - 1):
            emb_i = chunks[i].embedding
            emb_next = chunks[i + 1].embedding
            if emb_i is not None and emb_next is not None:
                sim = self._cosine_similarity(emb_i, emb_next)
                adjacent_sims.append((i, sim))
        
        if len(adjacent_sims) < 2:
            return [], [], 0, 0
        
        # === RELATIVE VALLEY GATE: Compute segment-specific adaptive threshold ===
        # Only split if similarity is a clear outlier for THIS segment
        sims_values = [s[1] for s in adjacent_sims]
        median_sim = np.median(sims_values)
        mad = np.median(np.abs(np.array(sims_values) - median_sim))  # Median Absolute Deviation
        
        # Robust threshold: similarity must be below median - 2.5*MAD to be an outlier
        # This prevents splitting in segments where all similarities are a bit lower
        adaptive_threshold = max(severe_th, median_sim - 2.5 * max(mad, 0.05))
        
        if use_relative_valley:
            logger.debug(f"   Segment {seg_start:.1f}-{seg_end:.1f}: "
                        f"median_sim={median_sim:.3f}, MAD={mad:.3f}, "
                        f"adaptive_th={adaptive_threshold:.3f}")
        
        split_points = []
        reasons = []
        severe_count = 0
        lookahead_count = 0
        
        # Track consecutive low-sim chunks for persistence gate
        consecutive_low = 0
        pending_split_idx = None
        
        for idx, (i, sim_next) in enumerate(adjacent_sims):
            emb_i = chunks[i].embedding
            
            # Determine if this is a potential split point
            is_severe = sim_next < severe_th
            
            # Use adaptive threshold if enabled, else fixed normal threshold
            effective_normal_th = adaptive_threshold if use_relative_valley else normal_th
            is_low = sim_next < effective_normal_th
            
            # === PERSISTENCE GATE: Track consecutive low-sim chunks ===
            if require_persistence:
                if is_low or is_severe:
                    consecutive_low += 1
                    if pending_split_idx is None:
                        pending_split_idx = i  # Mark first low-sim position
                else:
                    # Check if we had enough consecutive evidence before resetting
                    # v7.8: Use configurable persistence_count (1=aggressive, 2=conservative)
                    if consecutive_low >= persistence_count and pending_split_idx is not None:
                        # Confirmed split - use the original pending position
                        split_time = chunks[pending_split_idx].end_time
                        split_points.append(split_time)
                        reasons.append(
                            f"PERSISTENT: {consecutive_low} consecutive low-sim chunks, "
                            f"start_sim={adjacent_sims[pending_split_idx - adjacent_sims[0][0]][1]:.3f}"
                        )
                        lookahead_count += 1
                    # Reset tracking
                    consecutive_low = 0
                    pending_split_idx = None
                continue  # Don't process individual triggers in persistence mode
            
            # === CIRCUIT-BREAKER (from Maya): Severe drop triggers immediately ===
            # Only if NOT in persistence mode (persistence mode is stricter)
            if is_severe:
                split_points.append(chunks[i].end_time)
                reasons.append(f"SEVERE: sim={sim_next:.3f} < {severe_th}")
                severe_count += 1
                continue
            
            # === LOOK-AHEAD PATTERN (from Stabilized): Confirm with chunk[i+2] ===
            if i + 2 < len(chunks):
                emb_after = chunks[i + 2].embedding
                if emb_after is not None:
                    sim_after = self._cosine_similarity(emb_i, emb_after)
                    
                    if sim_next < effective_normal_th and sim_after < effective_normal_th:
                        split_points.append(chunks[i].end_time)
                        reasons.append(
                            f"LOOKAHEAD: sim_next={sim_next:.3f}, sim_after={sim_after:.3f} < {effective_normal_th:.3f}"
                        )
                        lookahead_count += 1
        
        # Handle any remaining persistent evidence at end of segment
        # v7.8: Use configurable persistence_count
        if require_persistence and consecutive_low >= persistence_count and pending_split_idx is not None:
            split_time = chunks[pending_split_idx].end_time
            split_points.append(split_time)
            reasons.append(
                f"PERSISTENT_END: {consecutive_low} consecutive low-sim chunks"
            )
            lookahead_count += 1
        
        return split_points, reasons, severe_count, lookahead_count
    
    def get_chunk_similarities(self, seg_start: float, seg_end: float) -> List[Tuple[int, int, float]]:
        """
        Get similarity scores between adjacent chunks in a segment.
        
        Useful for debugging/visualization.
        """
        chunks = self.get_chunks_in_range(seg_start, seg_end)
        similarities = []
        
        for i in range(len(chunks) - 1):
            emb_i = chunks[i].embedding
            emb_next = chunks[i + 1].embedding
            
            if emb_i is not None and emb_next is not None:
                sim = self._cosine_similarity(emb_i, emb_next)
                similarities.append((chunks[i].idx, chunks[i + 1].idx, sim))
        
        return similarities
    
    def get_stats(self) -> Dict[str, Any]:
        """Get cache build statistics."""
        return {
            'total_chunks': self._stats.total_chunks,
            'eligible_chunks': self._stats.eligible_chunks,
            'embeddings_extracted': self._stats.embeddings_extracted,
            'build_time_sec': round(self._stats.build_time_sec, 2),
            'gpu_batch_count': self._stats.gpu_batch_count,
            'avg_batch_size': round(self._stats.avg_batch_size, 1),
        }
    
    @staticmethod
    def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
        """Compute cosine similarity between two vectors."""
        norm_a = np.linalg.norm(a)
        norm_b = np.linalg.norm(b)
        if norm_a == 0 or norm_b == 0:
            return 0.0
        return float(np.dot(a, b) / (norm_a * norm_b))


# =============================================================================
# INTEGRATION HELPERS
# =============================================================================

def build_segment_embeddings_from_cache(
    segments: List[Dict],
    cache: UnifiedChunkEmbeddingCache
) -> Dict[int, np.ndarray]:
    """
    Build segment embeddings dict from cache (drop-in replacement for extract_embeddings_batched).
    
    This is INSTANT (CPU only) - no GPU call!
    
    Args:
        segments: List of segment dicts with 'start' and 'end'
        cache: Built UnifiedChunkEmbeddingCache
    
    Returns:
        Dict mapping segment index to embedding (same format as extract_embeddings_batched)
    """
    embeddings = {}
    skipped = 0
    
    for i, seg in enumerate(segments):
        # Skip overlap/non-speech
        if seg.get('speaker') in ['OVERLAP', 'NON_SPEECH']:
            continue
        
        emb = cache.get_segment_embedding(seg['start'], seg['end'])
        if emb is not None:
            embeddings[i] = emb
        else:
            skipped += 1
    
    if skipped > 0:
        logger.debug(f"   Skipped {skipped} segments (no chunks in range)")
    
    return embeddings


def execute_chunk_reassignment_cached(
    segments: List[Dict],
    cache: UnifiedChunkEmbeddingCache,
    speaker_centroids: Dict[str, np.ndarray],
    min_segment_for_analysis: float = 3.0,
    min_split_portion: float = 1.5,
    assign_min_similarity: float = 0.55,
    margin_min: float = 0.10,
    create_new_on_ambiguous: bool = False,  # CHANGED: Default False to prevent speaker explosion
    # === v7.8: Tunable gate parameters for catching small interruptions ===
    persistence_threshold: int = 2,  # Min consecutive low-sim chunks (1 = more aggressive, 2 = conservative)
    centroid_match_threshold: float = 0.60,  # Both portions must be < this to split (lower = more aggressive)
) -> Tuple[List[Dict], Dict[str, Any]]:
    """
    Execute chunk reassignment with VALIDATION GATES using the unified cache.

    === v7.7: VALIDATION GATES TO PREVENT FALSE POSITIVE SPLITS ===

    The old implementation caused 66x speaker fragmentation (6 speakers → 400+)
    because it created new speakers on ANY ambiguous assignment.

    NEW VALIDATION GATES:
    1. Same-speaker centroid gate: Reject split if both sides still match original speaker
    2. No-new-speaker-unless-proven: Keep original speaker on ambiguous (default)
    3. Minimum portion persistence: Don't create 1.5s fragments
    4. Energy boundary gate: Prefer splits near low-energy boundaries

    === v7.8: Tunable gate parameters ===
    - persistence_threshold: Min consecutive low-sim chunks needed (1=aggressive, 2=conservative)
    - centroid_match_threshold: Max similarity to original speaker for both portions (lower=aggressive)

    Args:
        segments: Original segments from diarization
        cache: Built UnifiedChunkEmbeddingCache
        speaker_centroids: Speaker ID -> centroid embedding
        min_segment_for_analysis: Only analyze segments >= this duration
        min_split_portion: Minimum portion duration after split
        assign_min_similarity: Minimum similarity for assignment
        margin_min: Required margin between best and second-best speaker
        create_new_on_ambiguous: Create new speaker ID on ambiguous (DEFAULT: FALSE!)
        persistence_threshold: Min consecutive low-sim chunks (v7.8: configurable)
        centroid_match_threshold: Centroid gate threshold (v7.8: configurable)

    Returns:
        (refined_segments, stats_dict)
    """
    start_time = time.time()
    logger.debug(f"Chunk Reassignment v7.8 (persist={persistence_threshold}, centroid={centroid_match_threshold})")

    # Count eligible segments upfront
    eligible_count = sum(1 for seg in segments
                        if seg.get('speaker') not in ['OVERLAP', 'NON_SPEECH']
                        and seg['duration'] >= min_segment_for_analysis)
    
    refined_segments = []
    stats = {
        'segments_analyzed': 0,
        'segments_with_changes': 0,
        'total_split_points': 0,
        'validated_split_points': 0,  # NEW: After validation gates
        'rejected_by_centroid_gate': 0,  # NEW
        'new_speakers_created': 0,
        'portions_unusable': 0,
        'severe_triggers': 0,
        'lookahead_triggers': 0,
        'kept_original_speaker': 0,  # NEW: Ambiguous assignments kept as original
    }
    
    new_speaker_counter = len(speaker_centroids)
    
    for seg in segments:
        # Pass through non-analyzable segments
        if seg.get('speaker') in ['OVERLAP', 'NON_SPEECH']:
            refined_segments.append(seg)
            continue
        
        if seg['duration'] < min_segment_for_analysis:
            refined_segments.append(seg)
            continue
        
        stats['segments_analyzed'] += 1
        original_speaker = seg.get('speaker', 'UNKNOWN')
        original_centroid = speaker_centroids.get(original_speaker)
        
        # Detect speaker changes using CACHED embeddings (CPU only!)
        # Now with persistence and relative valley gates
        # v7.8: Pass configurable persistence threshold
        split_points, reasons, severe, lookahead = cache.detect_speaker_changes(
            seg['start'], seg['end'],
            require_persistence=persistence_threshold > 0,  # GATE: Enable persistence check
            use_relative_valley=True,  # GATE: Use per-segment adaptive threshold
            persistence_count=persistence_threshold,  # v7.8: Configurable (1=aggressive, 2=conservative)
        )
        
        stats['severe_triggers'] += severe
        stats['lookahead_triggers'] += lookahead
        stats['total_split_points'] += len(split_points)
        
        if not split_points:
            refined_segments.append(seg)
            continue
        
        # === SAME-SPEAKER CENTROID GATE ===
        # Validate each split point: reject if BOTH portions still match original speaker
        validated_splits = []
        for split_time in split_points:
            pre_emb = cache.get_segment_embedding(seg['start'], split_time)
            post_emb = cache.get_segment_embedding(split_time, seg['end'])
            
            if pre_emb is None or post_emb is None:
                continue  # Can't validate, skip this split
            
            # === GATE B: DIRECTIONAL CENTROID SEPARATION ===
            # A valid split requires:
            #   - Pre-portion looks like original speaker
            #   - Post-portion does NOT look like original speaker
            #   - Post-portion looks like SOME SPECIFIC OTHER speaker with margin
            if original_centroid is not None:
                pre_sim_orig = _cosine_sim(pre_emb, original_centroid)
                post_sim_orig = _cosine_sim(post_emb, original_centroid)
                
                # GATE B1: If BOTH portions still match original speaker strongly,
                # this is likely NOT a real speaker change (natural embedding variation)
                # v7.8: Use configurable threshold (0.55=aggressive, 0.60=conservative)
                if pre_sim_orig > centroid_match_threshold and post_sim_orig > centroid_match_threshold:
                    stats['rejected_by_centroid_gate'] += 1
                    logger.debug(f"   Centroid gate B1 rejected at {split_time:.2f}s: "
                               f"both match original (pre={pre_sim_orig:.3f}, post={post_sim_orig:.3f})")
                    continue
                
                # GATE B2: Post-portion must be LESS similar to original (directional)
                # Pre-portion should stay close to original
                if post_sim_orig >= pre_sim_orig - 0.05:  # Post not clearly different
                    stats['rejected_by_centroid_gate'] += 1
                    logger.debug(f"   Centroid gate B2 rejected at {split_time:.2f}s: "
                               f"post not clearly different (pre={pre_sim_orig:.3f}, post={post_sim_orig:.3f})")
                    continue
                
                # GATE B3: Post-portion must match a SPECIFIC DIFFERENT speaker with margin
                # Not just "doesn't match original" - must clearly belong somewhere else
                post_assigned, post_best_sim, assign_reason = _reassign_with_margin(
                    post_emb, speaker_centroids, assign_min_similarity, margin_min
                )
                
                if post_assigned is None:
                    # Can't confidently assign to any speaker - likely not a real change
                    stats['rejected_by_centroid_gate'] += 1
                    logger.debug(f"   Centroid gate B3 rejected at {split_time:.2f}s: "
                               f"post-portion has no confident assignment ({assign_reason})")
                    continue
                
                if post_assigned == original_speaker:
                    # Post-portion still best matches original speaker
                    stats['rejected_by_centroid_gate'] += 1
                    logger.debug(f"   Centroid gate B3 rejected at {split_time:.2f}s: "
                               f"post still matches original (sim={post_best_sim:.3f})")
                    continue
                
                # Valid split: post clearly belongs to a different speaker
                logger.debug(f"   Split VALIDATED at {split_time:.2f}s: "
                           f"pre→{original_speaker}({pre_sim_orig:.2f}), "
                           f"post→{post_assigned}({post_best_sim:.2f})")
            
            validated_splits.append(split_time)
        
        if not validated_splits:
            refined_segments.append(seg)
            continue
        
        # Segment has VALIDATED speaker changes - split it
        stats['segments_with_changes'] += 1
        stats['validated_split_points'] += len(validated_splits)
        
        # === v7.7: ENERGY BOUNDARY SNAPPING ===
        # Snap each split point to local energy minimum to avoid mid-word cuts
        snapped_splits = []
        for split_time in sorted(validated_splits):
            snapped_time = _snap_to_energy_minimum(
                cache.audio_buffer, split_time, 
                search_window_ms=200.0,  # ±200ms search
                frame_size_ms=10.0
            )
            # Ensure snapped time stays within segment bounds (with small margin)
            snapped_time = max(seg['start'] + 0.1, min(snapped_time, seg['end'] - 0.1))
            snapped_splits.append(snapped_time)
        
        # Create portions from snapped split points
        all_times = [seg['start']] + sorted(snapped_splits) + [seg['end']]
        
        for i in range(len(all_times) - 1):
            portion_start = all_times[i]
            portion_end = all_times[i + 1]
            duration = portion_end - portion_start
            
            # === MINIMUM PORTION GATE: Filter out too-short portions ===
            if duration < min_split_portion:
                # CHANGED: Mark as unusable instead of creating tiny fragments
                refined_segments.append({
                    'start': portion_start,
                    'end': portion_end,
                    'duration': duration,
                    'speaker': original_speaker,  # Keep original, don't create new
                    'status': 'unusable',
                    'unusable_reason': 'too_short_after_split',
                    'was_split': True,
                })
                stats['portions_unusable'] += 1
                continue
            
            # Get portion embedding from cache (CPU only!)
            portion_emb = cache.get_segment_embedding(portion_start, portion_end)
            
            if portion_emb is None:
                # No cached chunks for this portion - keep original speaker
                refined_segments.append({
                    'start': portion_start,
                    'end': portion_end,
                    'duration': duration,
                    'speaker': original_speaker,
                    'status': 'usable',
                    'was_split': True,
                    'note': 'no_cached_embedding'
                })
                continue
            
            # Reassign with margin-based confidence
            assigned, confidence, reason = _reassign_with_margin(
                portion_emb,
                speaker_centroids,
                assign_min_similarity,
                margin_min
            )
            
            if assigned is None:
                # === NO-NEW-SPEAKER-UNLESS-PROVEN RULE ===
                # This is the KEY fix that prevents speaker explosion
                if create_new_on_ambiguous:
                    assigned = f"SPEAKER_NEW_{new_speaker_counter:02d}"
                    new_speaker_counter += 1
                    stats['new_speakers_created'] += 1
                else:
                    # CHANGED DEFAULT: Keep original speaker on ambiguous assignment
                    # This prevents 6 speakers → 400+ fragmentation
                    refined_segments.append({
                        'start': portion_start,
                        'end': portion_end,
                        'duration': duration,
                        'speaker': original_speaker,  # Keep original!
                        'status': 'usable',  # Still usable, just ambiguous
                        'was_split': True,
                        'reassignment_reason': f"kept_original_{reason}",
                        'original_speaker': original_speaker,
                    })
                    stats['kept_original_speaker'] += 1
                    continue
            
            refined_segments.append({
                'start': portion_start,
                'end': portion_end,
                'duration': duration,
                'speaker': assigned,
                'status': 'usable',
                'was_split': True,
                'reassignment_confidence': round(confidence, 3),
                'reassignment_reason': reason,
                'original_speaker': original_speaker,
            })
    
    # Sort by start time
    refined_segments.sort(key=lambda x: x['start'])
    
    elapsed = time.time() - start_time
    stats['time_sec'] = round(elapsed, 2)

    logger.debug(f"Reassignment: {stats['segments_with_changes']}/{stats['segments_analyzed']} split, "
                f"{stats['total_split_points']}→{stats['validated_split_points']} splits, {elapsed:.2f}s")
    
    return refined_segments, stats


def _cosine_sim(a: np.ndarray, b: np.ndarray) -> float:
    """Quick cosine similarity helper."""
    norm_a = np.linalg.norm(a)
    norm_b = np.linalg.norm(b)
    if norm_a == 0 or norm_b == 0:
        return 0.0
    return float(np.dot(a, b) / (norm_a * norm_b))


def _snap_to_energy_minimum(
    audio_buffer,
    split_time: float,
    search_window_ms: float = 200.0,
    frame_size_ms: float = 10.0
) -> float:
    """
    Snap a split point to the nearest local energy minimum.
    
    === v7.7: ENERGY BOUNDARY SNAPPING ===
    Avoids mid-word cuts by finding a local silence/pause near the proposed split.
    
    Args:
        audio_buffer: AudioBuffer with waveform_np and sample_rate
        split_time: Proposed split time in seconds
        search_window_ms: ±window to search for minimum (default: 200ms)
        frame_size_ms: Frame size for RMS computation (default: 10ms)
    
    Returns:
        Adjusted split time at local energy minimum (or original if can't improve)
    """
    sr = audio_buffer.sample_rate
    waveform = audio_buffer.waveform_np
    
    # Convert to samples
    window_samples = int(search_window_ms * sr / 1000)
    frame_samples = int(frame_size_ms * sr / 1000)
    center_sample = int(split_time * sr)
    
    # Search window bounds
    start_sample = max(0, center_sample - window_samples)
    end_sample = min(len(waveform), center_sample + window_samples)
    
    if end_sample - start_sample < frame_samples * 2:
        return split_time  # Window too small, keep original
    
    # Compute RMS energy for each frame in the window
    min_energy = float('inf')
    min_energy_sample = center_sample
    
    for pos in range(start_sample, end_sample - frame_samples, frame_samples // 2):
        frame = waveform[pos:pos + frame_samples]
        rms = np.sqrt(np.mean(frame ** 2))
        
        if rms < min_energy:
            min_energy = rms
            min_energy_sample = pos + frame_samples // 2
    
    # Convert back to time
    snapped_time = min_energy_sample / sr
    
    # Log if significant adjustment
    adjustment_ms = (snapped_time - split_time) * 1000
    if abs(adjustment_ms) > 20:
        logger.debug(f"   Boundary snapped: {split_time:.3f}s → {snapped_time:.3f}s "
                    f"(Δ={adjustment_ms:.0f}ms, energy={min_energy:.4f})")
    
    return snapped_time


def _reassign_with_margin(
    embedding: np.ndarray,
    speaker_centroids: Dict[str, np.ndarray],
    min_similarity: float,
    margin_min: float
) -> Tuple[Optional[str], float, str]:
    """
    Reassign embedding to speaker with margin-based confidence.
    
    From ChatGPT: Requires both minimum similarity AND margin between candidates.
    """
    if not speaker_centroids:
        return None, 0.0, "no_centroids"
    
    # Compute similarities
    similarities = {}
    for speaker, centroid in speaker_centroids.items():
        sim = float(np.dot(embedding, centroid) / 
                   (np.linalg.norm(embedding) * np.linalg.norm(centroid)))
        similarities[speaker] = sim
    
    # Sort by similarity
    sorted_speakers = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
    best_speaker, best_sim = sorted_speakers[0]
    
    # Check minimum similarity
    if best_sim < min_similarity:
        return None, best_sim, f"low_similarity_{best_sim:.3f}"
    
    # Check margin between best and second-best
    if len(sorted_speakers) > 1:
        second_sim = sorted_speakers[1][1]
        margin = best_sim - second_sim
        
        if margin < margin_min:
            return None, best_sim, f"ambiguous_margin_{margin:.3f}"
    
    return best_speaker, best_sim, "confident"


# =============================================================================
# v7.8: MICRO-CONTAMINATION DETECTION
# =============================================================================
# Detects brief interjections (0.2-1.0s) that diarization missed.
# Three-stage approach: Coarse suspicion → Fine localization → Burn policy
# =============================================================================

def detect_micro_contamination(
    segments: List[Dict],
    cache: 'UnifiedChunkEmbeddingCache',
    speaker_centroids: Dict[str, np.ndarray],
    embedding_model,
    min_segment_duration: float = 4.0,
    suspicion_threshold: float = 0.12,
    fine_window_sec: float = 0.5,
    fine_hop_sec: float = 0.15,
    contamination_threshold: float = 0.45,
    min_contamination_duration: float = 0.2,
    min_clean_portion: float = 2.0,
) -> Tuple[List[Dict], Dict[str, Any]]:
    """
    === v7.8: MICRO-CONTAMINATION DETECTOR ===
    
    Catches brief speaker interjections (0.2-1.0s "yeah", "true", etc.) that 
    diarization missed. Uses a 3-stage approach to minimize compute overhead.
    
    STAGE 1 (Coarse - Cheap):
        For each segment, check if ANY 1.5s chunk has similarity to speaker 
        centroid significantly below the segment's mean. If so, flag as suspicious.
        Uses existing cached embeddings - NO extra GPU work.
    
    STAGE 2 (Fine - Only on Suspicious):
        For flagged segments, extract embeddings with smaller windows (0.5s) 
        to precisely locate contaminated regions. This IS extra GPU work but
        only runs on ~5-15% of segments.
    
    STAGE 3 (Action - Burn Policy):
        Mark contaminated regions as UNUSABLE (hard boundary).
        Keep clean portions before/after if ≥ min_clean_portion.
        Never merge across the burned region.
    
    Args:
        segments: List of segment dicts with 'start', 'end', 'speaker'
        cache: UnifiedChunkEmbeddingCache with pre-computed 1.5s embeddings
        speaker_centroids: Dict mapping speaker ID to centroid embedding
        embedding_model: Model for fine-grained embedding extraction
        min_segment_duration: Only analyze segments ≥ this duration
        suspicion_threshold: Flag if any chunk < (mean - threshold)
        fine_window_sec: Window size for fine-grained scan (Stage 2)
        fine_hop_sec: Hop size for fine-grained scan
        contamination_threshold: Similarity below this = contaminated
        min_contamination_duration: Ignore contamination shorter than this
        min_clean_portion: Keep clean portions only if ≥ this duration
    
    Returns:
        (refined_segments, stats_dict)
    """
    start_time = time.time()
    logger.debug("Micro-Contamination Detection v7.8 (3-stage)")

    stats = {
        'segments_checked': 0,
        'segments_suspicious': 0,
        'segments_contaminated': 0,
        'contamination_regions_found': 0,
        'segments_split_by_contamination': 0,
        'portions_burned': 0,
        'clean_portions_kept': 0,
        'total_duration_burned': 0.0,
        'stage1_time': 0.0,
        'stage2_time': 0.0,
    }
    
    refined_segments = []
    suspicious_segments = []
    
    # =========================================================================
    # STAGE 1: COARSE SUSPICION (Cheap - uses cached embeddings)
    # =========================================================================
    stage1_start = time.time()
    
    for seg in segments:
        # Skip non-analyzable segments
        if seg.get('speaker') in ['OVERLAP', 'NON_SPEECH']:
            refined_segments.append(seg)
            continue
        
        if seg.get('status') == 'unusable':
            refined_segments.append(seg)
            continue
        
        if seg['duration'] < min_segment_duration:
            refined_segments.append(seg)
            continue
        
        speaker = seg.get('speaker', 'UNKNOWN')
        centroid = speaker_centroids.get(speaker)
        
        if centroid is None:
            refined_segments.append(seg)
            continue
        
        stats['segments_checked'] += 1
        
        # Get chunk embeddings for this segment
        chunks = cache.get_chunks_in_range(seg['start'], seg['end'])
        if len(chunks) < 2:
            refined_segments.append(seg)
            continue
        
        # Compute chunk-to-centroid similarities
        chunk_sims = []
        for chunk in chunks:
            if chunk.embedding is not None:
                sim = _cosine_sim(chunk.embedding, centroid)
                chunk_sims.append((chunk, sim))
        
        if len(chunk_sims) < 2:
            refined_segments.append(seg)
            continue
        
        # Check for suspicious dip
        sims_only = [s for _, s in chunk_sims]
        mean_sim = np.mean(sims_only)
        min_sim = min(sims_only)
        std_sim = np.std(sims_only)
        
        # v7.8.1 FIX: More conservative suspicion criteria
        # REMOVED aggressive absolute threshold (min_sim < 0.50) that caused 53% data loss!
        # Now uses RELATIVE-ONLY detection: significant dip from THIS segment's mean
        # 
        # Criteria: Only flag if there's a clear outlier dip relative to segment's own baseline
        # This accounts for speaker-specific embedding variation (some speakers naturally lower)
        is_suspicious = (
            # Primary: Clear dip below segment mean (adjusted threshold)
            (min_sim < mean_sim - suspicion_threshold) and 
            # Secondary: The dip must be substantial enough (not just noise)
            (mean_sim - min_sim > 0.10) and
            # Tertiary: Mean should be reasonably high (speaker is recognizable)
            (mean_sim > 0.55)
        )
        
        if is_suspicious:
            stats['segments_suspicious'] += 1
            # Find which chunks triggered suspicion
            suspicious_chunks = [
                (chunk, sim) for chunk, sim in chunk_sims 
                if sim < mean_sim - suspicion_threshold * 0.8
            ]
            suspicious_segments.append({
                'segment': seg,
                'chunks': chunk_sims,
                'suspicious_chunks': suspicious_chunks,
                'mean_sim': mean_sim,
                'min_sim': min_sim,
            })
        else:
            refined_segments.append(seg)
    
    stats['stage1_time'] = time.time() - stage1_start

    # =========================================================================
    # STAGE 2 + 3: FINE LOCALIZATION + BURN POLICY (Only on suspicious)
    # v7.8.1: BATCHED GPU processing for 10x+ speedup
    # =========================================================================
    stage2_start = time.time()
    
    # =====================================================================
    # v7.8.1 OPTIMIZATION: Collect ALL regions first, then BATCH process
    # This turns 240 sequential GPU calls into 1-3 batched calls
    # =====================================================================
    
    # Step 1: Collect all suspicious regions with metadata
    all_regions_to_scan = []  # List of (seg_idx, region_start, region_end, speaker, centroid)
    
    for seg_idx, item in enumerate(suspicious_segments):
        seg = item['segment']
        suspicious_chunks = item['suspicious_chunks']
        speaker = seg.get('speaker', 'UNKNOWN')
        centroid = speaker_centroids.get(speaker)
        
        if centroid is None:
            continue
        
        # Identify suspicious time regions (±0.5s around each suspicious chunk)
        # v7.8.1: Reduced from ±1.0s to ±0.5s to reduce burn zone
        suspicious_regions = []
        for chunk, sim in suspicious_chunks:
            region_start = max(seg['start'], chunk.start_time - 0.5)
            region_end = min(seg['end'], chunk.end_time + 0.5)
            suspicious_regions.append((region_start, region_end))
        
        suspicious_regions = _merge_overlapping_regions(suspicious_regions)
        
        for region_start, region_end in suspicious_regions:
            all_regions_to_scan.append({
                'seg_idx': seg_idx,
                'region_start': region_start,
                'region_end': region_end,
                'speaker': speaker,
                'centroid': centroid,
                'segment': seg,
            })
    
    # Step 2: BATCHED fine embedding extraction (GPU-optimized)
    fine_embeddings_by_region = _extract_fine_embeddings_batched(
        cache.audio_buffer,
        embedding_model,
        all_regions_to_scan,
        window_sec=fine_window_sec,
        hop_sec=fine_hop_sec
    )
    
    # Step 3: Process results and find contamination (CPU-only, fast)
    seg_contaminations = {}  # seg_idx -> list of contaminated regions
    
    for region_info, fine_embeddings in zip(all_regions_to_scan, fine_embeddings_by_region):
        seg_idx = region_info['seg_idx']
        speaker = region_info['speaker']
        centroid = region_info['centroid']
        
        if not fine_embeddings:
            continue
        
        # v7.8.1: Now includes "different speaker confirmation" to reduce FPs
        contamination_runs = _find_contamination_runs(
            fine_embeddings,
            centroid,
            all_speaker_centroids=speaker_centroids,
            original_speaker=speaker,
            threshold=contamination_threshold,
            min_duration=min_contamination_duration,
            require_different_speaker=True  # Critical for reducing FPs!
        )
        
        if contamination_runs:
            if seg_idx not in seg_contaminations:
                seg_contaminations[seg_idx] = []
            seg_contaminations[seg_idx].extend(contamination_runs)
    
    # Step 4: Process each segment's contamination results
    for seg_idx, item in enumerate(suspicious_segments):
        seg = item['segment']
        speaker = seg.get('speaker', 'UNKNOWN')
        
        contaminated_regions = seg_contaminations.get(seg_idx, [])
        contaminated_regions = _merge_overlapping_regions(contaminated_regions, gap_tolerance=0.3)
        
        if not contaminated_regions:
            # False alarm - segment is clean
            refined_segments.append(seg)
            continue
        
        # =====================================================================
        # STAGE 3: BURN POLICY - Split around contamination
        # =====================================================================
        stats['segments_contaminated'] += 1
        stats['contamination_regions_found'] += len(contaminated_regions)
        
        # Create split points from contamination boundaries
        all_boundaries = [seg['start']]
        for cont_start, cont_end in contaminated_regions:
            # Snap boundaries to energy minima for cleaner cuts
            snapped_start = _snap_to_energy_minimum(cache.audio_buffer, cont_start)
            snapped_end = _snap_to_energy_minimum(cache.audio_buffer, cont_end)
            all_boundaries.extend([snapped_start, snapped_end])
        all_boundaries.append(seg['end'])
        all_boundaries = sorted(set(all_boundaries))
        
        # Create portions, alternating between clean and contaminated
        portions_created = []
        contamination_set = set()
        for cont_start, cont_end in contaminated_regions:
            contamination_set.add((round(cont_start, 2), round(cont_end, 2)))
        
        for i in range(len(all_boundaries) - 1):
            portion_start = all_boundaries[i]
            portion_end = all_boundaries[i + 1]
            duration = portion_end - portion_start
            
            # Check if this portion overlaps with contamination
            is_contaminated = False
            for cont_start, cont_end in contaminated_regions:
                # Overlap check with tolerance
                if portion_start < cont_end - 0.1 and portion_end > cont_start + 0.1:
                    is_contaminated = True
                    break
            
            if is_contaminated:
                # BURN: Mark as unusable (hard boundary)
                portions_created.append({
                    'start': portion_start,
                    'end': portion_end,
                    'duration': duration,
                    'speaker': speaker,
                    'status': 'unusable',
                    'unusable_reason': 'micro_contamination',
                    'contamination_detected': True,
                    'original_speaker': speaker,
                })
                stats['portions_burned'] += 1
                stats['total_duration_burned'] += duration
            else:
                # Clean portion - keep if long enough
                if duration >= min_clean_portion:
                    portions_created.append({
                        'start': portion_start,
                        'end': portion_end,
                        'duration': duration,
                        'speaker': speaker,
                        'status': 'usable',
                        'was_decontaminated': True,
                        'original_speaker': speaker,
                    })
                    stats['clean_portions_kept'] += 1
                else:
                    # Too short after decontamination - also burn
                    portions_created.append({
                        'start': portion_start,
                        'end': portion_end,
                        'duration': duration,
                        'speaker': speaker,
                        'status': 'unusable',
                        'unusable_reason': 'too_short_after_decontamination',
                        'original_speaker': speaker,
                    })
                    stats['portions_burned'] += 1
                    stats['total_duration_burned'] += duration
        
        if len(portions_created) > 1:
            stats['segments_split_by_contamination'] += 1
        
        refined_segments.extend(portions_created)
    
    stats['stage2_time'] = time.time() - stage2_start

    elapsed = time.time() - start_time
    stats['total_time'] = round(elapsed, 2)

    logger.debug(f"MicroContam: {stats['segments_contaminated']}/{stats['segments_suspicious']} contaminated, "
                f"burned {stats['total_duration_burned']:.1f}s, {elapsed:.2f}s total")

    return refined_segments, stats


def _merge_overlapping_regions(
    regions: List[Tuple[float, float]], 
    gap_tolerance: float = 0.0
) -> List[Tuple[float, float]]:
    """Merge overlapping or adjacent time regions."""
    if not regions:
        return []
    
    sorted_regions = sorted(regions, key=lambda x: x[0])
    merged = [sorted_regions[0]]
    
    for start, end in sorted_regions[1:]:
        last_start, last_end = merged[-1]
        if start <= last_end + gap_tolerance:
            merged[-1] = (last_start, max(last_end, end))
        else:
            merged.append((start, end))
    
    return merged


def _extract_fine_embeddings_batched(
    audio_buffer,
    embedding_model,
    regions_to_scan: List[Dict],
    window_sec: float = 0.5,
    hop_sec: float = 0.15,
) -> List[List[Tuple[float, float, np.ndarray]]]:
    """
    v7.8.1: BATCHED fine embedding extraction for 10x+ speedup.
    
    Instead of processing each suspicious region separately (240 GPU calls),
    this collects ALL audio chunks first, does ONE batched GPU inference,
    then redistributes results to their respective regions.
    
    Args:
        audio_buffer: AudioBuffer with waveform
        embedding_model: SpeechBrain embedding model
        regions_to_scan: List of dicts with 'region_start', 'region_end'
        window_sec: Fine-grained window size (default 0.5s)
        hop_sec: Hop between windows (default 0.15s)
    
    Returns:
        List of embedding lists, one per input region
        Each embedding list contains (start, end, embedding) tuples
    """
    if not regions_to_scan:
        return []
    
    sr = audio_buffer.sample_rate
    waveform = audio_buffer.waveform_np
    window_samples = int(window_sec * sr)
    hop_samples = int(hop_sec * sr)
    
    # Step 1: Collect ALL audio chunks across ALL regions
    all_chunks_audio = []
    all_chunk_times = []
    all_chunk_region_idx = []  # Track which region each chunk belongs to
    
    for region_idx, region in enumerate(regions_to_scan):
        region_start = region['region_start']
        region_end = region['region_end']
        
        start_sample = int(region_start * sr)
        end_sample = int(region_end * sr)
        
        pos = start_sample
        while pos + window_samples <= end_sample:
            chunk_audio = waveform[pos:pos + window_samples]
            
            # Quick energy check - skip if too quiet
            rms = np.sqrt(np.mean(chunk_audio ** 2))
            if rms > 0.005:  # Minimal speech threshold
                all_chunks_audio.append(chunk_audio)
                chunk_start = pos / sr
                chunk_end = (pos + window_samples) / sr
                all_chunk_times.append((chunk_start, chunk_end))
                all_chunk_region_idx.append(region_idx)
            
            pos += hop_samples
    
    if not all_chunks_audio:
        return [[] for _ in regions_to_scan]
    
    # Step 2: SINGLE BATCHED GPU inference (the key optimization!)
    try:
        # Convert to tensor and batch
        batch_tensor = torch.stack([
            torch.from_numpy(chunk).float().unsqueeze(0)
            for chunk in all_chunks_audio
        ])
        
        # Extract embeddings in one batch
        with torch.no_grad():
            # Resample if needed (model expects 16kHz)
            if sr != 16000:
                import torchaudio
                batch_tensor = torchaudio.functional.resample(batch_tensor, sr, 16000)
            
            # Process in sub-batches to avoid OOM (max 256 at a time)
            all_embeddings = []
            batch_size = 256
            for i in range(0, len(batch_tensor), batch_size):
                batch_slice = batch_tensor[i:i+batch_size].squeeze(1)
                embeddings = embedding_model.encode_batch(batch_slice)
                all_embeddings.append(embeddings.cpu().numpy())
            
            all_embeddings = np.concatenate(all_embeddings, axis=0)
        
    except Exception as e:
        logger.warning(f"Batched fine embedding extraction failed: {e}")
        return [[] for _ in regions_to_scan]
    
    # Step 3: Redistribute embeddings back to their respective regions
    results_by_region = [[] for _ in regions_to_scan]
    
    for i, (t_start, t_end) in enumerate(all_chunk_times):
        region_idx = all_chunk_region_idx[i]
        embedding = all_embeddings[i]
        results_by_region[region_idx].append((t_start, t_end, embedding))
    
    return results_by_region


def _extract_fine_embeddings(
    audio_buffer,
    embedding_model,
    start_time: float,
    end_time: float,
    window_sec: float = 0.5,
    hop_sec: float = 0.15,
) -> List[Tuple[float, float, np.ndarray]]:
    """
    Extract fine-grained embeddings with smaller windows.
    
    Returns list of (start, end, embedding) tuples.
    """
    sr = audio_buffer.sample_rate
    waveform = audio_buffer.waveform_np
    
    window_samples = int(window_sec * sr)
    hop_samples = int(hop_sec * sr)
    start_sample = int(start_time * sr)
    end_sample = int(end_time * sr)
    
    results = []
    chunks_audio = []
    chunk_times = []
    
    pos = start_sample
    while pos + window_samples <= end_sample:
        chunk_audio = waveform[pos:pos + window_samples]
        
        # Quick energy check - skip if too quiet
        rms = np.sqrt(np.mean(chunk_audio ** 2))
        if rms > 0.005:  # Minimal speech threshold
            chunks_audio.append(chunk_audio)
            chunk_start = pos / sr
            chunk_end = (pos + window_samples) / sr
            chunk_times.append((chunk_start, chunk_end))
        
        pos += hop_samples
    
    if not chunks_audio:
        return []
    
    # Batch extract embeddings
    try:
        # Convert to tensor
        batch_tensor = torch.stack([
            torch.from_numpy(chunk).float().unsqueeze(0)
            for chunk in chunks_audio
        ])
        
        # Extract embeddings in one batch
        with torch.no_grad():
            # Resample if needed (model expects 16kHz)
            if sr != 16000:
                import torchaudio
                batch_tensor = torchaudio.functional.resample(batch_tensor, sr, 16000)
            
            embeddings = embedding_model.encode_batch(batch_tensor.squeeze(1))
            embeddings = embeddings.cpu().numpy()
        
        # Pair with times
        for i, (t_start, t_end) in enumerate(chunk_times):
            results.append((t_start, t_end, embeddings[i]))
        
    except Exception as e:
        logger.warning(f"Fine embedding extraction failed: {e}")
        return []
    
    return results


def _find_contamination_runs(
    fine_embeddings: List[Tuple[float, float, np.ndarray]],
    speaker_centroid: np.ndarray,
    all_speaker_centroids: Dict[str, np.ndarray],
    original_speaker: str,
    threshold: float = 0.45,
    min_duration: float = 0.2,
    require_different_speaker: bool = True,  # v7.8.1: Critical for reducing FPs
) -> List[Tuple[float, float]]:
    """
    Find contiguous runs where similarity to speaker is below threshold
    AND (optionally) the region matches a DIFFERENT speaker.
    
    v7.8.1 FIX: Added "different speaker confirmation" to reduce false positives.
    Without this, prosody/phoneme variation triggers false contamination.
    
    Returns list of (start, end) tuples for contaminated regions.
    """
    if not fine_embeddings:
        return []
    
    # Get centroids of OTHER speakers (for confirmation)
    other_centroids = {
        spk: centroid for spk, centroid in all_speaker_centroids.items()
        if spk != original_speaker and spk not in ['OVERLAP', 'NON_SPEECH', 'UNKNOWN']
    }
    
    contaminated_regions = []
    current_run_start = None
    current_run_end = None
    current_run_embeddings = []  # Collect embeddings for different-speaker check
    
    for t_start, t_end, embedding in fine_embeddings:
        sim = _cosine_sim(embedding, speaker_centroid)
        
        if sim < threshold:
            # Below threshold - potential contamination
            if current_run_start is None:
                current_run_start = t_start
                current_run_embeddings = []
            current_run_end = t_end
            current_run_embeddings.append(embedding)
        else:
            # Above threshold - end any current run
            if current_run_start is not None:
                run_duration = current_run_end - current_run_start
                if run_duration >= min_duration:
                    # v7.8.1: Confirm this is ACTUALLY a different speaker
                    if require_different_speaker and other_centroids:
                        avg_emb = np.mean(current_run_embeddings, axis=0)
                        
                        # Check if it matches ANY other speaker better
                        best_other_sim = 0.0
                        for other_centroid in other_centroids.values():
                            other_sim = _cosine_sim(avg_emb, other_centroid)
                            best_other_sim = max(best_other_sim, other_sim)
                        
                        orig_sim = _cosine_sim(avg_emb, speaker_centroid)
                        
                        # Only confirm contamination if:
                        # 1. Matches another speaker well (>0.50)
                        # 2. Matches another speaker BETTER than original (with margin)
                        if best_other_sim > 0.50 and best_other_sim > orig_sim + 0.08:
                            contaminated_regions.append((current_run_start, current_run_end))
                    else:
                        # Fallback: No other speakers to compare (rare)
                        contaminated_regions.append((current_run_start, current_run_end))
                        
                current_run_start = None
                current_run_end = None
                current_run_embeddings = []
    
    # Handle run at end
    if current_run_start is not None:
        run_duration = current_run_end - current_run_start
        if run_duration >= min_duration:
            if require_different_speaker and other_centroids:
                avg_emb = np.mean(current_run_embeddings, axis=0)
                best_other_sim = 0.0
                for other_centroid in other_centroids.values():
                    other_sim = _cosine_sim(avg_emb, other_centroid)
                    best_other_sim = max(best_other_sim, other_sim)
                orig_sim = _cosine_sim(avg_emb, speaker_centroid)
                
                if best_other_sim > 0.50 and best_other_sim > orig_sim + 0.08:
                    contaminated_regions.append((current_run_start, current_run_end))
            else:
                contaminated_regions.append((current_run_start, current_run_end))
    
    return contaminated_regions
