#!/usr/bin/env python3
"""Speaker clustering and merging with temporal constraints."""

import time
import logging
from typing import List, Dict
from collections import defaultdict
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

logger = logging.getLogger("FastPipelineV5.Clustering")


def speakers_overlap_in_time(segments: List[Dict], spk1: str, spk2: str, min_overlap: float = 0.2) -> bool:
    """
    Cannot-link constraint: temporal overlap = different speakers.
    Two speakers can't talk at the same time.
    """
    spk1_times = [(s['start'], s['end']) for s in segments if s['speaker'] == spk1]
    spk2_times = [(s['start'], s['end']) for s in segments if s['speaker'] == spk2]
    
    for s1_start, s1_end in spk1_times:
        for s2_start, s2_end in spk2_times:
            overlap_start = max(s1_start, s2_start)
            overlap_end = min(s1_end, s2_end)
            if overlap_end - overlap_start >= min_overlap:
                return True
    return False


def merge_speakers(
    audio_path: str,
    segments: List[Dict],
    embeddings: Dict[int, np.ndarray],
    config
) -> List[Dict]:
    """
    Conservative speaker merging with cannot-link constraints.
    
    Strategy:
    1. Compute speaker centroids from embeddings
    2. Build similarity matrix
    3. Merge high-similarity speakers (respecting temporal constraints)
    4. Assign final speaker labels
    """
    logger.info(f"🔧 Merging speakers (threshold={config.cluster_merge_threshold})...")
    start = time.time()
    
    if not segments:
        return segments
    
    # Group segments by speaker
    speaker_segments = defaultdict(list)
    for i, seg in enumerate(segments):
        if i in embeddings:
            speaker_segments[seg['speaker']].append((i, seg, embeddings[i]))
    
    # Compute speaker centroids
    speaker_embeddings = {}
    speaker_counts = {}
    
    for speaker, items in speaker_segments.items():
        embs = [item[2] for item in items]
        if embs:
            speaker_embeddings[speaker] = np.mean(embs, axis=0)
            speaker_counts[speaker] = len(items)
    
    speakers = list(speaker_embeddings.keys())
    if len(speakers) <= 1:
        logger.info(f"✅ Only {len(speakers)} speaker(s), no merging needed")
        return segments
    
    # Build similarity matrix
    n = len(speakers)
    sim_matrix = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            if i != j:
                sim_matrix[i, j] = cosine_similarity(
                    [speaker_embeddings[speakers[i]]],
                    [speaker_embeddings[speakers[j]]]
                )[0, 0]
    
    # Find merges using greedy approach
    merge_map = {s: s for s in speakers}
    merged_to = {}
    stats = {
        'attempted': 0,
        'blocked_threshold': 0,
        'blocked_overlap': 0,
        'successful': 0
    }
    
    # Sort pairs by similarity (highest first)
    pairs = []
    for i in range(n):
        for j in range(i + 1, n):
            pairs.append((speakers[i], speakers[j], sim_matrix[i, j]))
    pairs.sort(key=lambda x: x[2], reverse=True)
    
    # Try to merge similar speakers
    for spk1, spk2, sim in pairs:
        stats['attempted'] += 1
        
        # Check similarity threshold
        if sim < config.cluster_merge_threshold:
            stats['blocked_threshold'] += 1
            continue
        
        # Check temporal overlap (cannot-link)
        if speakers_overlap_in_time(segments, spk1, spk2):
            stats['blocked_overlap'] += 1
            continue
        
        # Merge: find root of spk1
        root1 = spk1
        while root1 in merged_to:
            root1 = merged_to[root1]
        
        # Merge spk2 into root1
        merge_map[spk2] = root1
        merged_to[spk2] = root1
        stats['successful'] += 1
    
    # Apply merges to segments
    unique_speakers = sorted(set(merge_map.values()))
    name_map = {s: f"SPEAKER_{i:02d}" for i, s in enumerate(unique_speakers)}
    
    for seg in segments:
        merged = merge_map.get(seg['speaker'], seg['speaker'])
        seg['speaker'] = name_map.get(merged, merged)
    
    elapsed = time.time() - start
    final_speakers = len(set(seg['speaker'] for seg in segments))
    
    logger.info(f"✅ Merged: {len(speakers)} → {final_speakers} speakers ({elapsed:.1f}s)")
    logger.info(f"   Stats: {stats}")
    
    return segments


def merge_adjacent_segments(segments: List[Dict], config) -> List[Dict]:
    """
    Merge adjacent same-speaker segments with small gaps.
    
    Rules:
    1. Same speaker
    2. Gap <= max_silence_gap

    Important semantics:
    - The "gap" is typically a VAD-derived non-speech region (silence/breath).
      If we MERGE, we do NOT "ignore" that non-speech; we *include* it inside the
      merged [start, end] span by extending `end` to the later segment.
    - This is intentionally conservative by default: larger allowed gaps reduce
      fragmentation, but also increase the amount of silence included in a clip
      and can (rarely) hide missed micro-interjections if upstream diarization
      failed to label them as speech.

    Status handling (v6.8 fix):
    - Preserve any existing `status` (especially 'unusable') on incoming segments.
    - If a segment has no `status` key, default it to 'usable' so downstream code
      (stats, sample generation, UI) can reliably filter by status.
    - Only merge two segments when BOTH are usable. Unusable segments act as
      hard boundaries so we never "smear" low-quality/too-short regions into
      an otherwise clean clip.
    """
    if not segments:
        return []
    
    segments = sorted(segments, key=lambda x: x['start'])
    merged = []
    current = None
    
    for seg in segments:
        if current is None:
            current = seg.copy()
            current.setdefault('status', 'usable')
            continue
        
        same_speaker = seg['speaker'] == current['speaker']
        gap = seg['start'] - current['end']
        current_usable = current.get('status', 'usable') == 'usable'
        seg_usable = seg.get('status', 'usable') == 'usable'
        
        if (
            current_usable
            and seg_usable
            and same_speaker
            and gap <= config.max_silence_gap
        ):
            # Merge
            current['end'] = seg['end']
            current['duration'] = current['end'] - current['start']
        else:
            # Save current and start new
            merged.append(current)
            current = seg.copy()
            current.setdefault('status', 'usable')
    
    if current:
        merged.append(current)
    
    logger.info(f"   Adjacent merge: {len(segments)} → {len(merged)} segments")
    
    return merged

