#!/usr/bin/env python3
"""
Precision Surgical Splitter - Hybrid 1.5s Chunk-Based Speaker Reassignment

=== HYBRID APPROACH ===
Combines insights from multiple AI approaches:
- ChatGPT 5.2: VAD-based eligibility, margin-based reassignment, quality metrics
- Gemini Maya: Severe threshold circuit-breaker, 1.5s minimum portion
- Gemini Stabilized: Look-ahead comparison pattern, asymmetry of risk philosophy
- Original Claude: Surgical reassignment, conservative clustering

=== CORE PHILOSOPHY ===
"Asymmetry of Risk" (from Stabilized):
- Risk A (False Negative): Miss a speaker change → Catastrophic TTS poison
- Risk B (False Positive): Over-split same speaker → Negligible/Positive impact
- Risk C (Compute Cost): Extra GPU cycles → Irrelevant

Therefore: Be AGGRESSIVE with detection, CONSERVATIVE with reassignment.

=== ALGORITHM ===
1. PHASE 1: Chunk extraction with VAD-based eligibility filtering
   - Divide segment into 1.5s chunks
   - Compute speech_ratio per chunk using VAD
   - Mark chunks as "eligible" only if speech_ratio >= min_speech_ratio (0.6)
   
2. PHASE 2: GPU-batched embedding extraction
   - Extract embeddings for all eligible chunks in single batch
   
3. PHASE 3: Speaker change detection (Look-ahead pattern from Stabilized)
   - Compare chunk[i] to chunk[i+1] AND chunk[i+2]
   - Circuit-breaker: severe drop (<0.25) triggers immediately (from Maya)
   - Normal: requires look-ahead confirmation
   
4. PHASE 4: Split and reassign with margin-based confidence (from ChatGPT)
   - Split segment at detected change points
   - Apply minimum duration filter (1.5s for quality)
   - Reassign with margin requirement between top candidates

=== OPTIMIZATIONS ===
- GPU batched embedding extraction (all chunks at once)
- Vectorized cosine similarity computation
- Only process segments > 3s (short segments unlikely to have changes)
- Reuse existing speaker centroids from clustering stage
"""

import time
import logging
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass, field
import numpy as np
import torch
from sklearn.metrics.pairwise import cosine_similarity

logger = logging.getLogger("FastPipelineV6.ChunkReassignment")


@dataclass
class ChunkReassignmentConfig:
    """
    Configuration for the Precision Surgical Splitter.
    
    Each parameter is annotated with its source from the analysis.
    """
    # === CHUNKING ===
    chunk_duration_seconds: float = 1.5  # Kimi-Audio paper standard
    min_segment_for_chunking: float = 3.0  # Only analyze segments > 3s
    
    # === ELIGIBILITY (from ChatGPT) ===
    # Chunk must have >= X% speech to be eligible for comparison
    # Filters out silence/breath chunks that cause false positives
    min_speech_ratio: float = 0.6
    
    # === DETECTION THRESHOLDS ===
    # Severe threshold: Single chunk triggers split (from Maya)
    # If similarity drops below this, it's DEFINITELY a speaker change
    severe_threshold: float = 0.25
    
    # Normal threshold: Requires look-ahead confirmation (from Stabilized)
    # With "asymmetry of risk" philosophy, we lean aggressive (0.40)
    normal_threshold: float = 0.40
    
    # === SPLITTING ===
    # Minimum usable portion after split (from Maya/Stabilized)
    # 1.5s ensures quality embeddings; shorter fragments are "burned"
    min_split_portion_seconds: float = 1.5
    
    # === REASSIGNMENT (from ChatGPT) ===
    # Minimum similarity for ANY assignment (low = new speaker or unusable)
    assign_min_similarity: float = 0.55
    
    # Required margin between best and second-best candidate
    # Prevents ambiguous assignments (best=0.62, second=0.61 is NOT confident)
    margin_min: float = 0.10
    
    # If ambiguous, create new speaker ID vs mark unusable
    create_new_on_ambiguous: bool = True


@dataclass
class ChunkInfo:
    """Information about a single 1.5s chunk."""
    idx: int
    start_sample: int
    end_sample: int
    start_time: float
    end_time: float
    speech_ratio: float
    eligible: bool
    embedding: Optional[np.ndarray] = None


@dataclass
class ChunkAnalysisResult:
    """Results from chunk-based analysis of a segment."""
    segment_idx: int
    segment_duration: float
    num_chunks: int
    num_eligible_chunks: int
    has_speaker_change: bool
    change_points: List[float]  # Timestamps where speaker changes detected
    chunk_similarities: List[Tuple[int, int, float]]  # (chunk_i, chunk_j, sim)
    detection_reasons: List[str]  # Why each split was triggered
    time_taken_ms: float


@dataclass
class ReassignmentStats:
    """Statistics from validation/execution run."""
    total_segments: int
    segments_analyzed: int  # Segments > 3s that were chunked
    segments_with_changes: int  # Segments where speaker change detected
    total_chunks_processed: int
    total_eligible_chunks: int
    total_split_points: int
    new_speakers_created: int
    portions_marked_unusable: int
    total_time_sec: float
    avg_time_per_segment_ms: float
    segments_affected_pct: float


@dataclass
class QualityMetrics:
    """
    Quality metrics for monitoring pipeline health (from ChatGPT).
    
    These can be computed automatically after processing to verify
    the chunk reassignment is improving speaker purity.
    """
    within_speaker_variance: float = 0.0  # Should decrease after reassignment
    cross_speaker_separation: float = 0.0  # Should increase after reassignment
    ambiguous_fraction: float = 0.0  # Fraction of segments that were ambiguous
    severe_changes_count: int = 0  # Number of circuit-breaker triggers
    lookahead_changes_count: int = 0  # Number of look-ahead confirmed changes


class PrecisionSurgicalSplitter:
    """
    Detects within-segment speaker changes and surgically splits/reassigns.
    
    This is the hybrid implementation combining insights from:
    - ChatGPT 5.2: VAD eligibility, margin-based reassignment, quality metrics
    - Gemini Maya: Severe threshold, 1.5s minimum
    - Gemini Stabilized: Look-ahead pattern
    - Original approach: Surgical reassignment
    """
    
    def __init__(
        self,
        config: ChunkReassignmentConfig,
        embedding_model,
        vad_model=None,
        vad_utils=None,
        device: torch.device = None
    ):
        """
        Initialize the splitter.
        
        Args:
            config: ChunkReassignmentConfig with all parameters
            embedding_model: ECAPA-TDNN model for embeddings
            vad_model: Silero VAD model (optional, for speech ratio)
            vad_utils: Silero VAD utilities
            device: torch device
        """
        self.config = config
        self.embedding_model = embedding_model
        self.vad_model = vad_model
        self.vad_utils = vad_utils
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Track quality metrics across all processed segments
        self.quality_metrics = QualityMetrics()
        self._all_embeddings = []  # For computing quality metrics
        self._speaker_assignments = []  # Track (embedding, speaker) pairs
    
    def compute_speech_ratio(self, audio_chunk: np.ndarray, sample_rate: int) -> float:
        """
        Compute fraction of chunk that contains speech using VAD.
        
        From ChatGPT: Use VAD-based eligibility instead of position-based exclusion.
        This correctly handles the REAL problem (low speech content) rather than
        assuming "last chunk is silence".
        
        Args:
            audio_chunk: Audio samples as numpy array
            sample_rate: Sample rate
            
        Returns:
            Speech ratio between 0.0 and 1.0
        """
        if self.vad_model is None or self.vad_utils is None:
            # Fallback: assume all speech if VAD not available
            return 1.0
        
        try:
            # Get speech timestamps function
            get_speech_timestamps = self.vad_utils[0]
            
            # Convert to tensor
            wav_tensor = torch.from_numpy(audio_chunk).float()
            
            # Run VAD
            speech_timestamps = get_speech_timestamps(
                wav_tensor,
                self.vad_model,
                threshold=0.5,
                min_speech_duration_ms=100,
                min_silence_duration_ms=100,
                return_seconds=True,
                sampling_rate=sample_rate
            )
            
            # Compute total speech duration
            chunk_duration = len(audio_chunk) / sample_rate
            speech_duration = sum(ts['end'] - ts['start'] for ts in speech_timestamps)
            
            return min(1.0, speech_duration / chunk_duration) if chunk_duration > 0 else 0.0
            
        except Exception as e:
            logger.debug(f"VAD speech ratio computation failed: {e}")
            return 1.0  # Assume speech on error
    
    def extract_chunks_with_eligibility(
        self,
        segment: Dict,
        audio_buffer,
        sample_rate: int
    ) -> List[ChunkInfo]:
        """
        Extract 1.5s chunks and compute speech ratio for each.
        
        Phase 1 of the algorithm:
        - Divide segment into chunks
        - Use VAD to compute speech ratio per chunk
        - Mark chunks as eligible only if speech_ratio >= threshold
        
        Args:
            segment: Segment dict with start/end/duration
            audio_buffer: Full audio waveform as numpy array
            sample_rate: Audio sample rate
            
        Returns:
            List of ChunkInfo objects
        """
        chunks = []
        chunk_samples = int(self.config.chunk_duration_seconds * sample_rate)
        
        start_sample = int(segment['start'] * sample_rate)
        end_sample = int(segment['end'] * sample_rate)
        segment_audio = audio_buffer[start_sample:end_sample]
        
        chunk_idx = 0
        for i in range(0, len(segment_audio), chunk_samples):
            chunk_end = min(i + chunk_samples, len(segment_audio))
            chunk_audio = segment_audio[i:chunk_end]
            
            # Skip very short final chunks (< 0.5s)
            if len(chunk_audio) < sample_rate * 0.5:
                continue
            
            # Compute speech ratio using VAD (from ChatGPT)
            speech_ratio = self.compute_speech_ratio(chunk_audio, sample_rate)
            
            chunk_start_time = segment['start'] + i / sample_rate
            chunk_end_time = segment['start'] + chunk_end / sample_rate
            
            chunks.append(ChunkInfo(
                idx=chunk_idx,
                start_sample=start_sample + i,
                end_sample=start_sample + chunk_end,
                start_time=chunk_start_time,
                end_time=chunk_end_time,
                speech_ratio=speech_ratio,
                eligible=speech_ratio >= self.config.min_speech_ratio
            ))
            chunk_idx += 1
        
        return chunks
    
    def batch_extract_embeddings(
        self,
        chunks: List[ChunkInfo],
        audio_buffer: np.ndarray,
        sample_rate: int,
        batch_size: int = 32
    ) -> Dict[int, np.ndarray]:
        """
        Extract embeddings for all eligible chunks in a single GPU batch.
        
        Phase 2 of the algorithm:
        - Only extract for eligible chunks (speech_ratio >= threshold)
        - Use batched inference for efficiency
        
        Args:
            chunks: List of ChunkInfo objects
            audio_buffer: Full audio waveform
            sample_rate: Sample rate
            batch_size: GPU batch size
            
        Returns:
            Dict mapping chunk_idx to embedding array
        """
        # Filter to eligible chunks only
        eligible_chunks = [c for c in chunks if c.eligible]
        
        if not eligible_chunks:
            return {}
        
        embeddings = {}
        
        # Process in batches
        for batch_start in range(0, len(eligible_chunks), batch_size):
            batch_end = min(batch_start + batch_size, len(eligible_chunks))
            batch_chunks = eligible_chunks[batch_start:batch_end]
            
            # Extract audio for each chunk
            audio_chunks = []
            for chunk in batch_chunks:
                chunk_audio = audio_buffer[chunk.start_sample:chunk.end_sample]
                audio_chunks.append(chunk_audio)
            
            # Pad to same length for batching
            max_len = max(len(c) for c in audio_chunks)
            
            padded = torch.zeros(len(audio_chunks), max_len)
            wav_lens = torch.zeros(len(audio_chunks))
            
            for i, chunk_audio in enumerate(audio_chunks):
                padded[i, :len(chunk_audio)] = torch.from_numpy(chunk_audio).float()
                wav_lens[i] = len(chunk_audio) / max_len
            
            padded = padded.to(self.device)
            wav_lens = wav_lens.to(self.device)
            
            try:
                with torch.no_grad():
                    embs = self.embedding_model.encode_batch(padded, wav_lens).cpu().numpy()
                    if embs.ndim == 3 and embs.shape[1] == 1:
                        embs = embs.squeeze(1)
                    elif embs.ndim == 3:
                        embs = embs.reshape(embs.shape[0], -1)
                
                # Map embeddings to chunk indices
                for i, chunk in enumerate(batch_chunks):
                    embeddings[chunk.idx] = embs[i]
                    chunk.embedding = embs[i]
                    
            except Exception as e:
                logger.error(f"Batch embedding extraction failed: {e}")
            
            # Cleanup
            del padded, wav_lens
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        
        return embeddings
    
    def detect_speaker_changes(
        self,
        chunks: List[ChunkInfo],
        embeddings: Dict[int, np.ndarray]
    ) -> Tuple[List[float], List[str], List[Tuple[int, int, float]]]:
        """
        Detect speaker changes using look-ahead pattern.
        
        Phase 3 of the algorithm (from Stabilized + Maya):
        
        Logic:
        - Compare chunk[i] to chunk[i+1] AND chunk[i+2] (look-ahead)
        - Circuit-breaker: if similarity is SEVERELY low (<0.25), single chunk triggers
        - Normal: requires both look-ahead comparisons to be low
        
        This is MORE ROBUST than pure adjacent comparison because:
        - Handles noisy single chunks (coughs, etc.)
        - Asks "Is chunk[i] different from EVERYTHING after it?"
        
        Args:
            chunks: List of ChunkInfo objects
            embeddings: Dict mapping chunk_idx to embedding
            
        Returns:
            (split_points, detection_reasons, all_similarities)
        """
        split_points = []
        detection_reasons = []
        all_similarities = []
        
        # Get indices of eligible chunks only
        eligible_indices = [c.idx for c in chunks if c.eligible and c.idx in embeddings]
        
        if len(eligible_indices) < 2:
            return [], [], []
        
        for i in range(len(eligible_indices) - 1):
            idx = eligible_indices[i]
            next_idx = eligible_indices[i + 1]
            
            # Get embeddings
            emb_current = embeddings[idx]
            emb_next = embeddings[next_idx]
            
            # Compute similarity to next chunk
            sim_next = float(cosine_similarity([emb_current], [emb_next])[0, 0])
            all_similarities.append((idx, next_idx, sim_next))
            
            # === CIRCUIT-BREAKER (from Maya) ===
            # Severe drop triggers immediately - this is DEFINITELY a speaker change
            if sim_next < self.config.severe_threshold:
                # Find the chunk object to get end_time
                chunk = next(c for c in chunks if c.idx == idx)
                split_points.append(chunk.end_time)
                detection_reasons.append(f"SEVERE: sim={sim_next:.3f} < {self.config.severe_threshold}")
                self.quality_metrics.severe_changes_count += 1
                continue
            
            # === LOOK-AHEAD PATTERN (from Stabilized) ===
            # Need to check chunk[i+2] as well for confirmation
            if i + 2 < len(eligible_indices):
                after_idx = eligible_indices[i + 2]
                emb_after = embeddings[after_idx]
                
                sim_after = float(cosine_similarity([emb_current], [emb_after])[0, 0])
                
                # Both comparisons must be below threshold
                if sim_next < self.config.normal_threshold and sim_after < self.config.normal_threshold:
                    chunk = next(c for c in chunks if c.idx == idx)
                    split_points.append(chunk.end_time)
                    detection_reasons.append(
                        f"LOOKAHEAD: sim_next={sim_next:.3f}, sim_after={sim_after:.3f} < {self.config.normal_threshold}"
                    )
                    self.quality_metrics.lookahead_changes_count += 1
        
        return split_points, detection_reasons, all_similarities
    
    def reassign_with_margin(
        self,
        embedding: np.ndarray,
        speaker_centroids: Dict[str, np.ndarray]
    ) -> Tuple[Optional[str], float, str]:
        """
        Reassign embedding to speaker with margin-based confidence.
        
        Phase 4b: Reassignment logic (from ChatGPT)
        
        Requirements for confident assignment:
        1. best_similarity >= assign_min_similarity (0.55)
        2. margin (best - second_best) >= margin_min (0.10)
        
        If ambiguous (high similarity but low margin), create new speaker or mark unusable.
        
        Args:
            embedding: Embedding to reassign
            speaker_centroids: Dict mapping speaker_id to centroid embedding
            
        Returns:
            (assigned_speaker, confidence, reason)
        """
        if not speaker_centroids:
            return None, 0.0, "no_centroids"
        
        # Compute similarities to all centroids
        similarities = {}
        for speaker, centroid in speaker_centroids.items():
            sim = float(cosine_similarity([embedding], [centroid])[0, 0])
            similarities[speaker] = sim
        
        # Sort by similarity
        sorted_speakers = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
        best_speaker, best_sim = sorted_speakers[0]
        
        # Check minimum similarity threshold
        if best_sim < self.config.assign_min_similarity:
            if self.config.create_new_on_ambiguous:
                new_speaker_id = f"SPEAKER_NEW_{len(speaker_centroids):02d}"
                self.quality_metrics.ambiguous_fraction += 1
                return new_speaker_id, best_sim, f"low_similarity ({best_sim:.3f})"
            return None, best_sim, f"unassignable_low_sim ({best_sim:.3f})"
        
        # Check margin between best and second-best (from ChatGPT)
        if len(sorted_speakers) > 1:
            second_best_sim = sorted_speakers[1][1]
            margin = best_sim - second_best_sim
            
            if margin < self.config.margin_min:
                # Ambiguous - similar to multiple speakers
                if self.config.create_new_on_ambiguous:
                    new_speaker_id = f"SPEAKER_NEW_{len(speaker_centroids):02d}"
                    self.quality_metrics.ambiguous_fraction += 1
                    return new_speaker_id, best_sim, f"ambiguous_margin ({margin:.3f})"
                return None, best_sim, f"ambiguous_margin ({margin:.3f})"
        
        # Confident assignment
        return best_speaker, best_sim, "confident"
    
    def split_and_reassign(
        self,
        original_segment: Dict,
        split_points: List[float],
        audio_buffer: np.ndarray,
        sample_rate: int,
        speaker_centroids: Dict[str, np.ndarray]
    ) -> List[Dict]:
        """
        Split segment at change points and reassign portions to speakers.
        
        Phase 4 of the algorithm:
        - Create portions from split points
        - Apply minimum duration filter (1.5s from Maya/Stabilized)
        - Extract embedding for each portion
        - Reassign with margin-based confidence (from ChatGPT)
        
        Args:
            original_segment: Original segment dict
            split_points: List of timestamps to split at
            audio_buffer: Full audio waveform
            sample_rate: Sample rate
            speaker_centroids: Dict mapping speaker_id to centroid
            
        Returns:
            List of resulting segment dicts
        """
        if not split_points:
            return [original_segment]
        
        # Create portions from split points
        all_times = [original_segment['start']] + sorted(split_points) + [original_segment['end']]
        portions = []
        
        for i in range(len(all_times) - 1):
            start_time = all_times[i]
            end_time = all_times[i + 1]
            duration = end_time - start_time
            
            # === MINIMUM DURATION FILTER (from Maya/Stabilized) ===
            # Fragments < 1.5s are "burned" - too short for quality embeddings
            if duration < self.config.min_split_portion_seconds:
                portions.append({
                    'start': start_time,
                    'end': end_time,
                    'duration': duration,
                    'speaker': original_segment.get('speaker', 'UNKNOWN'),
                    'status': 'unusable',
                    'unusable_reason': 'too_short_after_split',
                    'was_split': True,
                    'original_segment_start': original_segment['start']
                })
                continue
            
            # Extract audio for this portion
            start_sample = int(start_time * sample_rate)
            end_sample = int(end_time * sample_rate)
            portion_audio = audio_buffer[start_sample:end_sample]
            
            # Extract embedding for portion
            try:
                padded = torch.tensor(portion_audio).unsqueeze(0).to(self.device)
                wav_lens = torch.tensor([1.0]).to(self.device)
                
                with torch.no_grad():
                    portion_embedding = self.embedding_model.encode_batch(padded, wav_lens).cpu().numpy()
                    if portion_embedding.ndim == 3:
                        portion_embedding = portion_embedding.squeeze()
                
                del padded, wav_lens
                
            except Exception as e:
                logger.error(f"Portion embedding extraction failed: {e}")
                portions.append({
                    'start': start_time,
                    'end': end_time,
                    'duration': duration,
                    'speaker': original_segment.get('speaker', 'UNKNOWN'),
                    'status': 'unusable',
                    'unusable_reason': 'embedding_failed',
                    'was_split': True
                })
                continue
            
            # === MARGIN-BASED REASSIGNMENT (from ChatGPT) ===
            assigned_speaker, confidence, reason = self.reassign_with_margin(
                portion_embedding, speaker_centroids
            )
            
            if assigned_speaker is None:
                portions.append({
                    'start': start_time,
                    'end': end_time,
                    'duration': duration,
                    'speaker': original_segment.get('speaker', 'UNKNOWN'),
                    'status': 'unusable',
                    'unusable_reason': reason,
                    'was_split': True,
                    'reassignment_confidence': confidence
                })
            else:
                portions.append({
                    'start': start_time,
                    'end': end_time,
                    'duration': duration,
                    'speaker': assigned_speaker,
                    'status': 'usable',
                    'was_split': True,
                    'reassignment_confidence': confidence,
                    'reassignment_reason': reason,
                    'original_speaker': original_segment.get('speaker', 'UNKNOWN')
                })
                
                # Track for quality metrics
                self._all_embeddings.append(portion_embedding)
                self._speaker_assignments.append((portion_embedding, assigned_speaker))
        
        return portions
    
    def process_segment(
        self,
        segment: Dict,
        audio_buffer: np.ndarray,
        sample_rate: int,
        speaker_centroids: Dict[str, np.ndarray]
    ) -> Tuple[List[Dict], ChunkAnalysisResult]:
        """
        Process a single segment through the full algorithm.
        
        Full pipeline:
        1. Extract chunks with eligibility
        2. Batch extract embeddings
        3. Detect speaker changes
        4. Split and reassign
        
        Args:
            segment: Segment dict with start/end/duration
            audio_buffer: Full audio waveform
            sample_rate: Sample rate
            speaker_centroids: Dict mapping speaker_id to centroid
            
        Returns:
            (resulting_segments, analysis_result)
        """
        start_time = time.time()
        
        # Phase 1: Extract chunks with eligibility
        chunks = self.extract_chunks_with_eligibility(segment, audio_buffer, sample_rate)
        
        if len(chunks) < 2:
            # Too few chunks to analyze
            return [segment], ChunkAnalysisResult(
                segment_idx=-1,
                segment_duration=segment['duration'],
                num_chunks=len(chunks),
                num_eligible_chunks=sum(1 for c in chunks if c.eligible),
                has_speaker_change=False,
                change_points=[],
                chunk_similarities=[],
                detection_reasons=[],
                time_taken_ms=(time.time() - start_time) * 1000
            )
        
        # Phase 2: Batch extract embeddings
        embeddings = self.batch_extract_embeddings(chunks, audio_buffer, sample_rate)
        
        num_eligible = sum(1 for c in chunks if c.eligible)
        
        if len(embeddings) < 2:
            # Too few eligible chunks
            return [segment], ChunkAnalysisResult(
                segment_idx=-1,
                segment_duration=segment['duration'],
                num_chunks=len(chunks),
                num_eligible_chunks=num_eligible,
                has_speaker_change=False,
                change_points=[],
                chunk_similarities=[],
                detection_reasons=[],
                time_taken_ms=(time.time() - start_time) * 1000
            )
        
        # Phase 3: Detect speaker changes
        split_points, detection_reasons, all_similarities = self.detect_speaker_changes(
            chunks, embeddings
        )
        
        analysis_result = ChunkAnalysisResult(
            segment_idx=-1,
            segment_duration=segment['duration'],
            num_chunks=len(chunks),
            num_eligible_chunks=num_eligible,
            has_speaker_change=len(split_points) > 0,
            change_points=split_points,
            chunk_similarities=all_similarities,
            detection_reasons=detection_reasons,
            time_taken_ms=0  # Will be set at end
        )
        
        if not split_points:
            # No changes detected
            analysis_result.time_taken_ms = (time.time() - start_time) * 1000
            return [segment], analysis_result
        
        # Phase 4: Split and reassign
        result_segments = self.split_and_reassign(
            segment, split_points, audio_buffer, sample_rate, speaker_centroids
        )
        
        analysis_result.time_taken_ms = (time.time() - start_time) * 1000
        
        return result_segments, analysis_result
    
    def compute_quality_metrics(
        self,
        speaker_centroids: Dict[str, np.ndarray]
    ) -> QualityMetrics:
        """
        Compute quality metrics after processing (from ChatGPT).
        
        These metrics help verify that chunk reassignment is improving
        speaker cluster purity:
        
        - within_speaker_variance: Should DECREASE (embeddings more consistent)
        - cross_speaker_separation: Should INCREASE (speakers more distinct)
        - ambiguous_fraction: Lower is better
        
        Returns:
            QualityMetrics dataclass
        """
        if not self._speaker_assignments:
            return self.quality_metrics
        
        # Group embeddings by speaker
        speaker_embeddings = {}
        for emb, speaker in self._speaker_assignments:
            if speaker not in speaker_embeddings:
                speaker_embeddings[speaker] = []
            speaker_embeddings[speaker].append(emb)
        
        # Compute within-speaker variance
        variances = []
        for speaker, embs in speaker_embeddings.items():
            if len(embs) > 1:
                embs_array = np.array(embs)
                variance = np.mean(np.var(embs_array, axis=0))
                variances.append(variance)
        
        if variances:
            self.quality_metrics.within_speaker_variance = float(np.mean(variances))
        
        # Compute cross-speaker separation (if we have centroids)
        if speaker_centroids and len(speaker_centroids) > 1:
            centroid_list = list(speaker_centroids.values())
            separations = []
            for i in range(len(centroid_list)):
                for j in range(i + 1, len(centroid_list)):
                    sep = 1 - float(cosine_similarity([centroid_list[i]], [centroid_list[j]])[0, 0])
                    separations.append(sep)
            
            if separations:
                self.quality_metrics.cross_speaker_separation = float(np.mean(separations))
        
        return self.quality_metrics
    
    def get_quality_metrics(self) -> Dict:
        """Return quality metrics as dict for logging."""
        return {
            'within_speaker_variance': self.quality_metrics.within_speaker_variance,
            'cross_speaker_separation': self.quality_metrics.cross_speaker_separation,
            'ambiguous_fraction': self.quality_metrics.ambiguous_fraction,
            'severe_changes_count': self.quality_metrics.severe_changes_count,
            'lookahead_changes_count': self.quality_metrics.lookahead_changes_count
        }


# === CONVENIENCE FUNCTIONS (backward compatible with existing code) ===

def extract_chunk_embeddings_batched(
    audio_chunks: List[np.ndarray],
    sample_rate: int,
    embedding_model,
    device: torch.device,
    batch_size: int = 32
) -> np.ndarray:
    """
    Extract embeddings for all chunks in batched GPU inference.
    
    Backward-compatible wrapper for existing code.
    """
    if not audio_chunks:
        return np.array([])
    
    all_embeddings = []
    
    for batch_start in range(0, len(audio_chunks), batch_size):
        batch_end = min(batch_start + batch_size, len(audio_chunks))
        batch_chunks = audio_chunks[batch_start:batch_end]
        
        max_len = max(len(c) for c in batch_chunks)
        
        padded = torch.zeros(len(batch_chunks), max_len)
        wav_lens = torch.zeros(len(batch_chunks))
        
        for i, chunk in enumerate(batch_chunks):
            padded[i, :len(chunk)] = torch.from_numpy(chunk).float()
            wav_lens[i] = len(chunk) / max_len
        
        padded = padded.to(device)
        wav_lens = wav_lens.to(device)
        
        with torch.no_grad():
            embs = embedding_model.encode_batch(padded, wav_lens).cpu().numpy()
            if embs.ndim == 3 and embs.shape[1] == 1:
                embs = embs.squeeze(1)
            elif embs.ndim == 3:
                embs = embs.reshape(embs.shape[0], -1)
        
        all_embeddings.append(embs)
        del padded, wav_lens
    
    return np.vstack(all_embeddings) if all_embeddings else np.array([])


def validate_reassignment_impact(
    segments: List[Dict],
    audio_buffer,
    embedding_model,
    device: torch.device,
    config: Optional[ChunkReassignmentConfig] = None,
    vad_model=None,
    vad_utils=None
) -> Tuple[ReassignmentStats, List[ChunkAnalysisResult]]:
    """
    VALIDATION MODE: Analyze segments without making changes.
    
    Uses the Precision Surgical Splitter to detect speaker changes
    and report statistics on how many segments would be affected.
    
    Args:
        segments: List of segment dicts from diarization
        audio_buffer: Pre-loaded AudioBuffer
        embedding_model: ECAPA-TDNN model
        device: torch device
        config: Optional ChunkReassignmentConfig (uses defaults if None)
        vad_model: Optional Silero VAD model for speech ratio
        vad_utils: Optional Silero VAD utilities
    
    Returns:
        (stats, detailed_results): Overall stats and per-segment analysis
    """
    config = config or ChunkReassignmentConfig()
    
    logger.info("=" * 70)
    logger.info("🔍 PRECISION SURGICAL SPLITTER - VALIDATION MODE")
    logger.info(f"   Hybrid approach: ChatGPT + Maya + Stabilized")
    logger.info(f"   Normal threshold: {config.normal_threshold} (look-ahead)")
    logger.info(f"   Severe threshold: {config.severe_threshold} (circuit-breaker)")
    logger.info(f"   Min speech ratio: {config.min_speech_ratio} (VAD eligibility)")
    logger.info(f"   Min split portion: {config.min_split_portion_seconds}s")
    logger.info("=" * 70)
    
    start_time = time.time()
    
    # Initialize splitter
    splitter = PrecisionSurgicalSplitter(
        config=config,
        embedding_model=embedding_model,
        vad_model=vad_model,
        vad_utils=vad_utils,
        device=device
    )
    
    # Get audio data
    waveform_np = audio_buffer.waveform_np
    sr = audio_buffer.sample_rate
    
    # Filter to analyzable segments
    analyzable = [
        (i, seg) for i, seg in enumerate(segments)
        if seg.get('speaker') not in ['OVERLAP', 'NON_SPEECH']
        and seg['duration'] >= config.min_segment_for_chunking
    ]
    
    logger.info(f"   Segments to analyze: {len(analyzable)}/{len(segments)} (duration >= {config.min_segment_for_chunking}s)")
    
    results = []
    total_chunks = 0
    total_eligible = 0
    segments_with_changes = 0
    total_split_points = 0
    
    for idx, (seg_idx, segment) in enumerate(analyzable):
        # Use empty centroids for validation (we just want to detect changes)
        _, result = splitter.process_segment(
            segment, waveform_np, sr, speaker_centroids={}
        )
        result.segment_idx = seg_idx
        results.append(result)
        
        total_chunks += result.num_chunks
        total_eligible += result.num_eligible_chunks
        
        if result.has_speaker_change:
            segments_with_changes += 1
            total_split_points += len(result.change_points)
        
        # Progress logging
        if (idx + 1) % 50 == 0 or idx == len(analyzable) - 1:
            logger.info(f"   Progress: {idx + 1}/{len(analyzable)} segments analyzed")
    
    total_time = time.time() - start_time
    avg_time_per_seg = (total_time * 1000 / len(analyzable)) if analyzable else 0
    affected_pct = (segments_with_changes / len(analyzable) * 100) if analyzable else 0
    
    stats = ReassignmentStats(
        total_segments=len(segments),
        segments_analyzed=len(analyzable),
        segments_with_changes=segments_with_changes,
        total_chunks_processed=total_chunks,
        total_eligible_chunks=total_eligible,
        total_split_points=total_split_points,
        new_speakers_created=0,  # Not applicable in validation mode
        portions_marked_unusable=0,
        total_time_sec=total_time,
        avg_time_per_segment_ms=avg_time_per_seg,
        segments_affected_pct=affected_pct
    )
    
    # Print summary
    logger.info("=" * 70)
    logger.info("📊 VALIDATION RESULTS (Precision Surgical Splitter)")
    logger.info("=" * 70)
    logger.info(f"   Total segments: {stats.total_segments}")
    logger.info(f"   Segments analyzed (>{config.min_segment_for_chunking}s): {stats.segments_analyzed}")
    logger.info(f"   Total chunks: {stats.total_chunks_processed}")
    logger.info(f"   Eligible chunks (≥{config.min_speech_ratio} speech): {stats.total_eligible_chunks}")
    logger.info(f"   Segments with detected changes: {stats.segments_with_changes} ({stats.segments_affected_pct:.1f}%)")
    logger.info(f"   Total split points: {stats.total_split_points}")
    logger.info(f"   Severe (circuit-breaker): {splitter.quality_metrics.severe_changes_count}")
    logger.info(f"   Look-ahead confirmed: {splitter.quality_metrics.lookahead_changes_count}")
    logger.info(f"   Total time: {stats.total_time_sec:.2f}s")
    logger.info(f"   Avg time per segment: {stats.avg_time_per_segment_ms:.1f}ms")
    logger.info("=" * 70)
    
    # Detailed breakdown of affected segments
    if segments_with_changes > 0 and segments_with_changes <= 20:
        logger.info("\n🔴 SEGMENTS WITH DETECTED SPEAKER CHANGES:")
        for result in results:
            if result.has_speaker_change:
                seg = segments[result.segment_idx]
                logger.info(f"   Segment {result.segment_idx}: {seg['start']:.1f}-{seg['end']:.1f}s "
                           f"({result.segment_duration:.1f}s) speaker={seg.get('speaker', '?')}")
                logger.info(f"      Chunks: {result.num_chunks} (eligible: {result.num_eligible_chunks})")
                logger.info(f"      Change points: {result.change_points}")
                logger.info(f"      Reasons: {result.detection_reasons}")
                sims = [f"{s[2]:.2f}" for s in result.chunk_similarities[:5]]
                logger.info(f"      Similarities (first 5): {sims}")
    
    return stats, results


def execute_chunk_reassignment(
    segments: List[Dict],
    audio_buffer,
    embedding_model,
    speaker_centroids: Dict[str, np.ndarray],
    device: torch.device,
    config: Optional[ChunkReassignmentConfig] = None,
    vad_model=None,
    vad_utils=None
) -> Tuple[List[Dict], ReassignmentStats]:
    """
    EXECUTE MODE: Actually perform chunk-based reassignment.
    
    Uses the Precision Surgical Splitter to:
    1. Detect speaker changes within segments
    2. Split segments at change points
    3. Reassign each resulting portion to best matching speaker
    
    Args:
        segments: Original segment list
        audio_buffer: Pre-loaded AudioBuffer
        embedding_model: ECAPA-TDNN model
        speaker_centroids: Dict mapping speaker_id -> centroid embedding
        device: torch device
        config: Optional ChunkReassignmentConfig
        vad_model: Optional Silero VAD model
        vad_utils: Optional Silero VAD utilities
    
    Returns:
        (refined_segments, stats): New segment list with splits applied, and stats
    """
    config = config or ChunkReassignmentConfig()
    
    logger.info("=" * 70)
    logger.info("⚡ PRECISION SURGICAL SPLITTER - EXECUTE MODE")
    logger.info("=" * 70)
    
    start_time = time.time()
    
    # Initialize splitter
    splitter = PrecisionSurgicalSplitter(
        config=config,
        embedding_model=embedding_model,
        vad_model=vad_model,
        vad_utils=vad_utils,
        device=device
    )
    
    # Get audio data
    waveform_np = audio_buffer.waveform_np
    sr = audio_buffer.sample_rate
    
    refined_segments = []
    segments_analyzed = 0
    segments_with_changes = 0
    total_chunks = 0
    total_eligible = 0
    total_splits = 0
    new_speakers = 0
    unusable_portions = 0
    
    for seg_idx, segment in enumerate(segments):
        # Pass through segments that shouldn't be analyzed
        if segment.get('speaker') in ['OVERLAP', 'NON_SPEECH']:
            refined_segments.append(segment)
            continue
        
        if segment['duration'] < config.min_segment_for_chunking:
            refined_segments.append(segment)
            continue
        
        segments_analyzed += 1
        
        # Process segment
        result_segs, analysis = splitter.process_segment(
            segment, waveform_np, sr, speaker_centroids
        )
        
        total_chunks += analysis.num_chunks
        total_eligible += analysis.num_eligible_chunks
        
        if analysis.has_speaker_change:
            segments_with_changes += 1
            total_splits += len(analysis.change_points)
        
        # Count new speakers and unusable portions
        for seg in result_segs:
            if seg.get('speaker', '').startswith('SPEAKER_NEW_'):
                new_speakers += 1
            if seg.get('status') == 'unusable':
                unusable_portions += 1
        
        refined_segments.extend(result_segs)
    
    # Sort by start time
    refined_segments.sort(key=lambda x: x['start'])
    
    # Compute quality metrics
    quality_metrics = splitter.compute_quality_metrics(speaker_centroids)
    
    total_time = time.time() - start_time
    avg_time = (total_time * 1000 / segments_analyzed) if segments_analyzed > 0 else 0
    affected_pct = (segments_with_changes / segments_analyzed * 100) if segments_analyzed > 0 else 0
    
    stats = ReassignmentStats(
        total_segments=len(segments),
        segments_analyzed=segments_analyzed,
        segments_with_changes=segments_with_changes,
        total_chunks_processed=total_chunks,
        total_eligible_chunks=total_eligible,
        total_split_points=total_splits,
        new_speakers_created=new_speakers,
        portions_marked_unusable=unusable_portions,
        total_time_sec=total_time,
        avg_time_per_segment_ms=avg_time,
        segments_affected_pct=affected_pct
    )
    
    logger.info(f"✅ Reassignment complete:")
    logger.info(f"   Original segments: {len(segments)}")
    logger.info(f"   Refined segments: {len(refined_segments)}")
    logger.info(f"   Segments split: {segments_with_changes}")
    logger.info(f"   New speakers created: {new_speakers}")
    logger.info(f"   Portions marked unusable: {unusable_portions}")
    logger.info(f"   Quality metrics: {splitter.get_quality_metrics()}")
    logger.info(f"   Total time: {total_time:.2f}s")
    
    return refined_segments, stats


# === STANDALONE VALIDATION SCRIPT ===
if __name__ == "__main__":
    """
    Run standalone validation on existing processed data.
    
    Usage:
        python -m src.chunk_reassignment VIDEO_ID [--threshold 0.40]
    """
    import sys
    import json
    import argparse
    from pathlib import Path
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s | %(levelname)-8s | %(message)s',
        datefmt='%H:%M:%S'
    )
    
    parser = argparse.ArgumentParser(description="Validate chunk reassignment impact")
    parser.add_argument("video_id", help="Video ID to analyze")
    parser.add_argument("--threshold", type=float, default=0.40,
                       help="Normal threshold (default: 0.40)")
    parser.add_argument("--severe", type=float, default=0.25,
                       help="Severe threshold (default: 0.25)")
    parser.add_argument("--min-speech", type=float, default=0.6,
                       help="Minimum speech ratio (default: 0.6)")
    args = parser.parse_args()
    
    video_id = args.video_id
    data_dir = Path("data/fast_output_v6") / video_id
    
    if not data_dir.exists():
        print(f"Error: Data directory not found: {data_dir}")
        sys.exit(1)
    
    # Load metadata
    metadata_file = data_dir / "metadata.json"
    if not metadata_file.exists():
        print(f"Error: metadata.json not found in {data_dir}")
        sys.exit(1)
    
    with open(metadata_file) as f:
        metadata = json.load(f)
    
    segments = metadata['segments']
    
    # Find audio file
    audio_file = data_dir / f"{video_id}_trimmed.wav"
    if not audio_file.exists():
        audio_file = data_dir / f"{video_id}.wav"
    
    if not audio_file.exists():
        print(f"Error: Audio file not found in {data_dir}")
        sys.exit(1)
    
    print(f"\n📁 Loaded: {video_id}")
    print(f"   Segments: {len(segments)}")
    print(f"   Audio: {audio_file}")
    
    # Load models and audio buffer
    from src.audio_buffer import AudioBuffer
    from src.models import MODELS
    from src.config import Config
    
    config_obj = Config()
    MODELS.load_all(config_obj)
    
    audio_buffer = AudioBuffer.from_file(str(audio_file))
    device = MODELS.get_device()
    
    # Configure splitter
    splitter_config = ChunkReassignmentConfig(
        normal_threshold=args.threshold,
        severe_threshold=args.severe,
        min_speech_ratio=args.min_speech
    )
    
    # Run validation
    stats, results = validate_reassignment_impact(
        segments=segments,
        audio_buffer=audio_buffer,
        embedding_model=MODELS.embedding_model,
        device=device,
        config=splitter_config,
        vad_model=MODELS.silero_vad,
        vad_utils=MODELS.silero_utils
    )
    
    print(f"\n📊 Summary:")
    print(f"   Would affect: {stats.segments_with_changes}/{stats.segments_analyzed} segments ({stats.segments_affected_pct:.1f}%)")
    print(f"   Time cost: {stats.total_time_sec:.2f}s for this video")
