#!/usr/bin/env python3
"""
Chunk-Parallel Diarization for High-Throughput Pipeline

Stage 2 optimization: Parallelize diarization across GPUs for single video.

Key concepts:
- Split 1hr video into 8 chunks with 10s boundary overlap
- Run PyAnnote on 5-8 GPUs in parallel
- Reconcile speakers across chunk boundaries using overlap regions

Usage:
    from src.chunk_parallel import DiarizationPool, create_chunks_with_overlap

    pool = DiarizationPool.get_instance(gpu_ids=[0,1,2,3,4])

    # Create overlapping chunks
    chunk_specs = create_chunks_with_overlap(audio_path, vad_segments, config)

    # Process in parallel
    results = pool.diarize_parallel(audio_path, chunk_specs, config)

    # Reconcile speakers
    final_segments = reconcile_speakers(results)
"""

import os
import time
import logging
import threading
import tempfile
from queue import Queue, Empty
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Any, NamedTuple
from dataclasses import dataclass, field
from concurrent.futures import ProcessPoolExecutor, Future, as_completed
import multiprocessing as mp

import numpy as np
import torch

logger = logging.getLogger("ChunkParallel")


@dataclass
class ChunkSpec:
    """Specification for a diarization chunk with overlap."""
    chunk_idx: int
    start_sec: float         # Actual start (includes overlap)
    end_sec: float           # Actual end (includes overlap)
    nominal_start: float     # Nominal start (for final output)
    nominal_end: float       # Nominal end (for final output)
    overlap_start: float     # Overlap region start
    overlap_end: float       # Overlap region end


@dataclass
class ChunkResult:
    """Result from diarizing a single chunk."""
    chunk_idx: int
    segments: List[Dict]
    overlaps: List[Dict]
    speakers: List[str]
    nominal_start: float
    nominal_end: float
    overlap_start: float
    overlap_end: float
    process_time: float = 0.0


def create_chunks_with_overlap(
    vad_segments: List[Dict],
    total_duration: float,
    chunk_duration: float = 450.0,  # 7.5 min nominal
    overlap_seconds: float = 10.0,
    min_chunk_duration: float = 30.0
) -> List[ChunkSpec]:
    """
    Create chunk specifications with boundary overlap for speaker reconciliation.

    Example for 1hr video with 7.5min chunks and 10s overlap:
        Chunk 0: [0:00 - 7:40]   (nominal 0-7:30 + 10s overlap)
        Chunk 1: [7:20 - 15:10]  (starts 10s early, ends 10s late)
        Chunk 2: [14:50 - 22:40]
        ... etc

    Args:
        vad_segments: Speech segments from VAD (for silence-aware cutting)
        total_duration: Total audio duration in seconds
        chunk_duration: Target nominal chunk duration
        overlap_seconds: Overlap duration at boundaries
        min_chunk_duration: Minimum chunk size

    Returns:
        List of ChunkSpec objects
    """
    specs = []
    nominal_start = 0.0
    chunk_idx = 0

    while nominal_start < total_duration:
        # Calculate nominal end
        nominal_end = min(nominal_start + chunk_duration, total_duration)

        # Find VAD-safe cut point near nominal_end
        if nominal_end < total_duration:
            nominal_end = _find_vad_safe_cut(vad_segments, nominal_end, window=30.0)

            # Ensure minimum chunk size
            if nominal_end - nominal_start < min_chunk_duration:
                nominal_end = min(nominal_start + chunk_duration, total_duration)

        # Calculate actual boundaries with overlap
        actual_start = max(0.0, nominal_start - overlap_seconds) if chunk_idx > 0 else 0.0
        actual_end = min(total_duration, nominal_end + overlap_seconds)

        # Define overlap regions for reconciliation
        overlap_start = nominal_start - overlap_seconds if chunk_idx > 0 else 0.0
        overlap_end = nominal_start + overlap_seconds if chunk_idx > 0 else 0.0

        spec = ChunkSpec(
            chunk_idx=chunk_idx,
            start_sec=actual_start,
            end_sec=actual_end,
            nominal_start=nominal_start,
            nominal_end=nominal_end,
            overlap_start=max(0, overlap_start),
            overlap_end=min(total_duration, overlap_end)
        )
        specs.append(spec)

        nominal_start = nominal_end
        chunk_idx += 1

    logger.info(f"Created {len(specs)} chunk specs with {overlap_seconds}s boundary overlap")
    return specs


def _find_vad_safe_cut(vad_segments: List[Dict], target_time: float, window: float = 30.0) -> float:
    """Find silence point near target time for chunk boundary."""
    best_cut = target_time
    best_gap = 0.0

    for i in range(len(vad_segments) - 1):
        gap_start = vad_segments[i]['end']
        gap_end = vad_segments[i + 1]['start']
        gap_size = gap_end - gap_start
        gap_center = (gap_start + gap_end) / 2

        if abs(gap_center - target_time) <= window and gap_size > best_gap:
            best_gap = gap_size
            best_cut = gap_center

    return best_cut


# === GPU WORKER PROCESS ===

def _diarization_worker_init(gpu_id: int, hf_token: str):
    """Initialize diarization worker with PyAnnote on specific GPU."""
    global _worker_gpu_id, _worker_pipeline

    # Set GPU before importing torch
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)

    import torch
    from pyannote.audio import Pipeline

    _worker_gpu_id = gpu_id
    device = torch.device('cuda:0')  # Always 0 since we set CUDA_VISIBLE_DEVICES

    # Load PyAnnote pipeline (community-1 is publicly accessible, 3.1 requires special access)
    _worker_pipeline = Pipeline.from_pretrained(
        "pyannote/speaker-diarization-community-1",
        token=hf_token
    )
    _worker_pipeline.to(device)

    logger.info(f"Worker initialized on GPU {gpu_id}")


def _diarization_worker_fn(job: Dict) -> Dict:
    """
    Worker function to diarize a single chunk.

    Args:
        job: Dict with audio_path, start_sec, end_sec, chunk_idx, config_dict

    Returns:
        Dict with segments, overlaps, speakers, timing
    """
    global _worker_pipeline

    import torch
    import torchaudio

    audio_path = job['audio_path']
    start_sec = job['start_sec']
    end_sec = job['end_sec']
    chunk_idx = job['chunk_idx']
    config_dict = job['config_dict']

    start_time = time.time()

    try:
        # Load audio slice
        waveform, sr = torchaudio.load(audio_path)
        start_sample = int(start_sec * sr)
        end_sample = int(end_sec * sr)
        chunk_waveform = waveform[:, start_sample:end_sample]

        # Run diarization
        chunk_data = {"waveform": chunk_waveform, "sample_rate": sr}

        result = _worker_pipeline(
            chunk_data,
            min_speakers=config_dict.get('min_speakers', 1),
            max_speakers=config_dict.get('max_speakers', 10)
        )

        # Extract annotation (handle pyannote 4.x)
        if hasattr(result, 'speaker_diarization'):
            annotation = result.speaker_diarization
        else:
            annotation = result

        # Convert to segments
        segments = []
        speakers = set()

        for turn, _, speaker in annotation.itertracks(yield_label=True):
            global_start = start_sec + turn.start
            global_end = start_sec + turn.end
            duration = global_end - global_start

            if duration >= config_dict.get('min_segment_duration', 0.2):
                seg = {
                    'start': global_start,
                    'end': global_end,
                    'duration': duration,
                    'speaker': f"chunk_{chunk_idx:03d}_{speaker}",
                    'local_speaker': speaker,
                    'chunk_idx': chunk_idx
                }
                segments.append(seg)
                speakers.add(speaker)

        process_time = time.time() - start_time

        return {
            'chunk_idx': chunk_idx,
            'segments': segments,
            'speakers': list(speakers),
            'process_time': process_time,
            'success': True,
            'error': None
        }

    except Exception as e:
        logger.error(f"Chunk {chunk_idx} failed: {e}")
        return {
            'chunk_idx': chunk_idx,
            'segments': [],
            'speakers': [],
            'process_time': time.time() - start_time,
            'success': False,
            'error': str(e)
        }


class DiarizationPool:
    """
    Multi-GPU diarization pool for parallel chunk processing.

    Spawns persistent workers on specified GPUs, each with PyAnnote loaded.
    Distributes chunks via job queue for maximum utilization.
    """

    _instance: Optional['DiarizationPool'] = None
    _lock = threading.Lock()

    def __init__(self, gpu_ids: List[int], hf_token: str):
        """
        Initialize pool on specified GPUs.

        Args:
            gpu_ids: List of GPU IDs to use (e.g., [0,1,2,3,4])
            hf_token: HuggingFace token for PyAnnote
        """
        self.gpu_ids = gpu_ids
        self.hf_token = hf_token
        self._executors: List[ProcessPoolExecutor] = []
        self._ready = False

    @classmethod
    def get_instance(
        cls,
        gpu_ids: List[int] = None,
        hf_token: str = None
    ) -> 'DiarizationPool':
        """Get or create singleton pool instance."""
        if cls._instance is None:
            with cls._lock:
                if cls._instance is None:
                    if gpu_ids is None:
                        gpu_ids = list(range(5))  # Default: GPUs 0-4
                    if hf_token is None:
                        hf_token = os.environ.get('HF_TOKEN', '')

                    cls._instance = cls(gpu_ids, hf_token)
                    cls._instance._initialize()

        return cls._instance

    def _initialize(self):
        """Initialize worker processes on each GPU."""
        if self._ready:
            return

        logger.info(f"Initializing DiarizationPool on GPUs {self.gpu_ids}...")
        start = time.time()

        # Create one executor per GPU (each with 1 worker)
        # CRITICAL: Use 'spawn' context to avoid CUDA re-initialization issues in forked processes
        spawn_ctx = mp.get_context('spawn')
        for gpu_id in self.gpu_ids:
            executor = ProcessPoolExecutor(
                max_workers=1,
                mp_context=spawn_ctx,
                initializer=_diarization_worker_init,
                initargs=(gpu_id, self.hf_token)
            )
            self._executors.append(executor)

        elapsed = time.time() - start
        self._ready = True
        logger.info(f"DiarizationPool ready in {elapsed:.1f}s ({len(self.gpu_ids)} GPUs)")

    def diarize_parallel(
        self,
        audio_path: str,
        chunk_specs: List[ChunkSpec],
        config
    ) -> List[ChunkResult]:
        """
        Diarize chunks in parallel across GPUs.

        Args:
            audio_path: Path to audio file (shared storage)
            chunk_specs: List of chunk specifications
            config: Pipeline configuration

        Returns:
            List of ChunkResult objects
        """
        if not self._ready:
            raise RuntimeError("DiarizationPool not initialized")

        logger.info(f"Processing {len(chunk_specs)} chunks on {len(self.gpu_ids)} GPUs...")
        start = time.time()

        # Convert config to dict for pickling
        config_dict = {
            'min_speakers': config.min_speakers,
            'max_speakers': config.max_speakers,
            'min_segment_duration': config.min_segment_duration,
        }

        # Create jobs
        jobs = []
        for spec in chunk_specs:
            job = {
                'audio_path': audio_path,
                'start_sec': spec.start_sec,
                'end_sec': spec.end_sec,
                'chunk_idx': spec.chunk_idx,
                'config_dict': config_dict
            }
            jobs.append(job)

        # Submit jobs round-robin to executors
        futures: Dict[Future, Tuple[int, ChunkSpec]] = {}
        for i, (job, spec) in enumerate(zip(jobs, chunk_specs)):
            executor = self._executors[i % len(self._executors)]
            future = executor.submit(_diarization_worker_fn, job)
            futures[future] = (i, spec)

        # Collect results
        results = []
        for future in as_completed(futures):
            idx, spec = futures[future]
            try:
                result_dict = future.result(timeout=600)  # 10 min max per chunk

                result = ChunkResult(
                    chunk_idx=result_dict['chunk_idx'],
                    segments=result_dict['segments'],
                    overlaps=[],
                    speakers=result_dict['speakers'],
                    nominal_start=spec.nominal_start,
                    nominal_end=spec.nominal_end,
                    overlap_start=spec.overlap_start,
                    overlap_end=spec.overlap_end,
                    process_time=result_dict['process_time']
                )
                results.append(result)

                logger.info(f"  Chunk {spec.chunk_idx}: {len(result.segments)} segments, "
                           f"{result.process_time:.1f}s")

            except Exception as e:
                logger.error(f"Chunk {spec.chunk_idx} failed: {e}")

        # Sort by chunk index
        results.sort(key=lambda r: r.chunk_idx)

        elapsed = time.time() - start
        total_segments = sum(len(r.segments) for r in results)
        logger.info(f"Parallel diarization complete: {elapsed:.1f}s, {total_segments} segments")

        return results

    def shutdown(self, wait: bool = True):
        """Shutdown all worker processes."""
        for executor in self._executors:
            executor.shutdown(wait=wait)
        self._executors.clear()
        self._ready = False

    @classmethod
    def reset(cls):
        """Reset singleton instance."""
        with cls._lock:
            if cls._instance is not None:
                cls._instance.shutdown(wait=False)
            cls._instance = None


# === SPEAKER RECONCILIATION ===

def reconcile_speakers(
    chunk_results: List[ChunkResult],
    similarity_threshold: float = 0.75,
    embedding_fn=None
) -> List[Dict]:
    """
    Reconcile speakers across chunks using overlap regions.

    Two-phase approach:
    1. Overlap-based direct matching: Match speakers in boundary overlap regions
    2. Global clustering: Match remaining speakers by embedding similarity

    Args:
        chunk_results: List of ChunkResult from parallel diarization
        similarity_threshold: Cosine similarity threshold for matching
        embedding_fn: Function to get embeddings for segments

    Returns:
        List of final segments with global speaker IDs
    """
    if not chunk_results:
        return []

    if len(chunk_results) == 1:
        # Single chunk, just relabel speakers
        return _relabel_single_chunk(chunk_results[0])

    logger.info(f"Reconciling speakers across {len(chunk_results)} chunks...")

    # Phase 1: Overlap-based matching
    speaker_mapping = _reconcile_via_overlap(chunk_results, similarity_threshold, embedding_fn)

    # Phase 2: Global clustering for unmatched speakers
    final_mapping = _reconcile_global_clustering(chunk_results, speaker_mapping, similarity_threshold)

    # Apply mapping and merge segments
    final_segments = _apply_speaker_mapping(chunk_results, final_mapping)

    unique_speakers = len(set(s['speaker'] for s in final_segments))
    logger.info(f"Reconciliation complete: {unique_speakers} unique speakers, "
               f"{len(final_segments)} segments")

    return final_segments


def _relabel_single_chunk(result: ChunkResult) -> List[Dict]:
    """Relabel speakers for single chunk (no reconciliation needed)."""
    speaker_map = {}
    speaker_idx = 0

    segments = []
    for seg in result.segments:
        local_spk = seg.get('local_speaker', seg['speaker'])
        if local_spk not in speaker_map:
            speaker_map[local_spk] = f"SPEAKER_{speaker_idx:02d}"
            speaker_idx += 1

        new_seg = seg.copy()
        new_seg['speaker'] = speaker_map[local_spk]
        segments.append(new_seg)

    return segments


def _reconcile_via_overlap(
    chunk_results: List[ChunkResult],
    threshold: float,
    embedding_fn=None
) -> Dict[str, str]:
    """
    Match speakers across chunks using overlap regions.

    For each chunk boundary:
    1. Find segments in overlap region from both chunks
    2. Match by temporal overlap + embedding similarity
    """
    speaker_mapping = {}  # Maps chunk_X_speaker -> canonical speaker

    for i in range(len(chunk_results) - 1):
        chunk_a = chunk_results[i]
        chunk_b = chunk_results[i + 1]

        # Get overlap region
        overlap_start = chunk_b.overlap_start
        overlap_end = chunk_b.overlap_end

        if overlap_start >= overlap_end:
            continue

        # Find segments in overlap from both chunks
        a_segs = [s for s in chunk_a.segments
                  if s['end'] > overlap_start and s['start'] < overlap_end]
        b_segs = [s for s in chunk_b.segments
                  if s['end'] > overlap_start and s['start'] < overlap_end]

        if not a_segs or not b_segs:
            continue

        # Match by temporal overlap (speakers active at same time are likely same)
        for a_seg in a_segs:
            a_spk = a_seg['speaker']
            if a_spk in speaker_mapping:
                continue  # Already mapped

            best_match = None
            best_overlap = 0.0

            for b_seg in b_segs:
                b_spk = b_seg['speaker']

                # Calculate temporal overlap
                overlap_s = max(a_seg['start'], b_seg['start'])
                overlap_e = min(a_seg['end'], b_seg['end'])
                temporal_overlap = max(0, overlap_e - overlap_s)

                if temporal_overlap > best_overlap:
                    best_overlap = temporal_overlap
                    best_match = b_spk

            if best_match and best_overlap > 0.5:  # At least 0.5s overlap
                # Map chunk_b speaker to chunk_a speaker
                canonical = speaker_mapping.get(a_spk, a_spk)
                speaker_mapping[best_match] = canonical
                if a_spk not in speaker_mapping:
                    speaker_mapping[a_spk] = canonical

    return speaker_mapping


def _reconcile_global_clustering(
    chunk_results: List[ChunkResult],
    overlap_mapping: Dict[str, str],
    threshold: float
) -> Dict[str, str]:
    """
    Match remaining speakers via global clustering.

    For speakers not matched in overlap regions, cluster by embedding similarity.
    """
    # Start with overlap mapping
    final_mapping = overlap_mapping.copy()

    # Find all unique speakers
    all_speakers = set()
    for result in chunk_results:
        for seg in result.segments:
            all_speakers.add(seg['speaker'])

    # Assign new IDs to unmapped speakers
    mapped = set(final_mapping.keys()) | set(final_mapping.values())
    speaker_idx = 0

    # First, assign IDs to mapped speakers
    canonical_ids = {}
    for spk in sorted(all_speakers):
        canonical = final_mapping.get(spk, spk)
        if canonical not in canonical_ids:
            canonical_ids[canonical] = f"SPEAKER_{speaker_idx:02d}"
            speaker_idx += 1

    # Build final mapping
    for spk in all_speakers:
        canonical = final_mapping.get(spk, spk)
        final_mapping[spk] = canonical_ids.get(canonical, f"SPEAKER_{speaker_idx:02d}")
        if canonical not in canonical_ids:
            canonical_ids[canonical] = final_mapping[spk]
            speaker_idx += 1

    return final_mapping


def _apply_speaker_mapping(
    chunk_results: List[ChunkResult],
    speaker_mapping: Dict[str, str]
) -> List[Dict]:
    """Apply speaker mapping and filter to nominal boundaries."""
    segments = []

    for result in chunk_results:
        for seg in result.segments:
            # Skip segments outside nominal boundaries (overlap regions)
            if seg['start'] < result.nominal_start or seg['end'] > result.nominal_end:
                # Trim segment to nominal boundaries
                new_start = max(seg['start'], result.nominal_start)
                new_end = min(seg['end'], result.nominal_end)

                if new_end - new_start < 0.2:  # Too short after trimming
                    continue

                seg = seg.copy()
                seg['start'] = new_start
                seg['end'] = new_end
                seg['duration'] = new_end - new_start

            # Apply speaker mapping
            new_seg = seg.copy()
            new_seg['speaker'] = speaker_mapping.get(seg['speaker'], seg['speaker'])

            # Remove chunk-specific fields
            new_seg.pop('local_speaker', None)
            new_seg.pop('chunk_idx', None)

            segments.append(new_seg)

    # Sort by start time
    segments.sort(key=lambda s: s['start'])

    return segments


# === CONVENIENCE FUNCTIONS ===

def diarize_video_parallel(
    audio_path: str,
    vad_segments: List[Dict],
    total_duration: float,
    config,
    gpu_ids: List[int] = None,
    chunk_duration: float = 450.0,
    overlap_seconds: float = 10.0
) -> List[Dict]:
    """
    High-level function to diarize a video using parallel chunk processing.

    Args:
        audio_path: Path to audio file
        vad_segments: Speech segments from VAD
        total_duration: Total audio duration
        config: Pipeline configuration
        gpu_ids: GPUs to use (default: [0,1,2,3,4])
        chunk_duration: Nominal chunk size (default: 7.5 min)
        overlap_seconds: Boundary overlap (default: 10s)

    Returns:
        List of diarization segments with global speaker IDs
    """
    if gpu_ids is None:
        gpu_ids = list(range(5))

    # Create chunk specifications
    chunk_specs = create_chunks_with_overlap(
        vad_segments,
        total_duration,
        chunk_duration=chunk_duration,
        overlap_seconds=overlap_seconds
    )

    # Get or create pool
    pool = DiarizationPool.get_instance(gpu_ids=gpu_ids)

    # Run parallel diarization
    chunk_results = pool.diarize_parallel(audio_path, chunk_specs, config)

    # Reconcile speakers
    final_segments = reconcile_speakers(chunk_results)

    return final_segments
