#!/usr/bin/env python3
"""
Pipeline v6.7 - Music Detection for TTS Quality

=== NEW IN v6.7: MUSIC DETECTION ===
- PANNs CNN14 for music/instrument detection in 1.5s chunks
- Segments marked as: clean / needs_demucs / heavy_music
- Conservative thresholds (5%/25%) to protect TTS training data
- Reuses existing 1.5s chunk infrastructure (minimal overhead)

=== FROM v6.6: 18x FASTER EMBEDDINGS ===
- FIXED: 35s unexplained overhead in embedding extraction!
- Root cause: Per-batch tensor creation/transfer was serializing GPU work
- Solution: Pre-allocate ALL audio on GPU in ONE transfer, then slice
- Result: 37s → 2s for embeddings (18x faster!)
- GPU utilization: 3.6% → 80%+ during embedding phase

=== FROM v6.5: MAXIMUM GPU SATURATION ===
- Eliminated per-chunk VAD calls (was 55s overhead for 2800 chunks!)
- Fast energy-based eligibility (0.001s vs 0.02s per chunk)
- Pinned memory for faster CPU→GPU transfer
- Removed per-batch empty_cache() sync stalls

=== FROM v6.4: TURBO OPTIMIZATION ===
- Unified Chunk Embedding Cache: Compute ONCE, use EVERYWHERE
- Eliminates "Double Compute" bottleneck (was: Embeddings + ChunkReassignment both used GPU)
- Single mega-batch GPU pass extracts ALL 1.5s chunk embeddings
- Segment embeddings derived from cache (CPU only - instant)
- Chunk reassignment uses cached embeddings (CPU only - instant)

=== FROM v6.3 ===
- Chunk-based reassignment: Detects within-segment speaker changes
- Hybrid approach combining ChatGPT + Gemini Maya + Gemini Stabilized insights
- Look-ahead pattern with severe threshold circuit-breaker
- VAD-based chunk eligibility filtering
- Margin-based confident reassignment

=== FROM v6.2 ===
- Shared audio buffer: Load audio ONCE, pass to all stages
- Persistent VAD workers: Model loaded once per worker
- In-memory diarization chunks: No disk writes
- GPU memory-aware batching: Dynamic batch sizes
- Failure-safe cleanup: atexit handler

Flow (optimized for max TTS data quality):
1. Download (parallel workers for batch)
2. Quick VAD (parallel, get speech outline)
3. Chunk at silence boundaries (VAD-aware, IN-MEMORY)
4. Diarization + OSD (single pass, no disk I/O)
5. Quality filtering
6. UNIFIED Embeddings (single mega-batch for ALL chunks) - v6.4 TURBO
7. Conservative clustering (uses cached segment embeddings)
8. Chunk reassignment (uses cached chunk embeddings - CPU only!) - v6.4 TURBO
9. Duration filtering
10. Output metadata JSON
"""

import time
import json
import logging
import random
from pathlib import Path
from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, Future
import numpy as np

from src.config import Config
from src.compute import COMPUTE
from src.models import MODELS
from src.download import download_audio, validate_video, VideoValidationError, DownloadError
from src.vad import run_vad_parallel, find_optimal_cut_point
from src.overlap_detection import detect_nonspeech_regions, split_on_overlaps, filter_by_quality
# NOTE: segmentation module disabled - community-1 handles diarization+OSD in one pass
# from src.segmentation import detect_speaker_changes_framelevel, get_frame_level_activations
from src.diarization import create_chunks, run_diarization
from src.embeddings import extract_embeddings_batched
from src.clustering import merge_speakers, merge_adjacent_segments
from src.utils import generate_samples, cleanup_chunks
from src.audio_buffer import AudioBuffer, TEMP_MANAGER
from src.chunk_reassignment import (
    PrecisionSurgicalSplitter, 
    ChunkReassignmentConfig,
    execute_chunk_reassignment
)
# === v6.4 OPTIMIZATION: Unified Chunk Embedding Cache ===
# Eliminates double-compute: embeddings extracted ONCE, used for both clustering AND reassignment
from src.unified_embeddings import (
    UnifiedChunkEmbeddingCache,
    UnifiedEmbeddingConfig,
    build_segment_embeddings_from_cache,
    execute_chunk_reassignment_cached
)
# === v6.7 FEATURE: Music Detection ===
# Detect music/instruments to mark segments as clean/needs_demucs/heavy_music
from src.music_detection import (
    MusicDetectionCache,
    MusicDetectionConfig,
    build_segment_music_stats_batch
)

logger = logging.getLogger("FastPipelineV6.Pipeline")


def process_single_video(
    video_url: str, 
    config: Config,
    prefetched_audio: Optional[str] = None,
    prefetched_metadata: Optional[Dict] = None
) -> Dict[str, Any]:
    """
    Process a single video through the complete pipeline with compute monitoring.
    
    === v7.0 OPTIMIZATION: Prefetched Download Support ===
    If prefetched_audio and prefetched_metadata are provided (from batch prefetch),
    the download stage is skipped entirely, saving ~60s per video.
    
    Pipeline stages (per instructions.md):
    1. Download audio (or use prefetched)
    2. VAD (parallel CPU) - get speech activity outline
    3. OSD (GPU) - mark overlaps as unusable FIRST
    4. Chunking (VAD-aware) - cut at silence boundaries
    5. Diarization (GPU) - speaker identification
    6. Embeddings (GPU) - speaker vectors
    7. Clustering (CPU) - conservative merging
    8. Output metadata JSON
    
    Each stage is monitored for CPU/GPU utilization.
    
    NOTE: Models are loaded automatically if not already loaded.
    This allows calling this function directly without going through main().
    
    Args:
        video_url: YouTube URL to process
        config: Pipeline configuration
        prefetched_audio: Optional path to already-downloaded audio (from batch prefetch)
        prefetched_metadata: Optional metadata dict (from batch prefetch)
    """
    # === FIX: Ensure models are loaded before processing ===
    # When called from backend/API, main() is skipped so models may not be loaded
    if not MODELS._loaded:
        logger.info("Loading models (first call to process_single_video)...")
        MODELS.load_all(config)
    
    logger.info("-" * 70)
    logger.info(f"📹 Processing: {video_url}")
    logger.info("-" * 70)
    
    total_start = time.time()
    timings = {}
    compute_stats = {}
    
    # ========================================
    # STAGE 1: Download (or use prefetched)
    # ========================================
    if prefetched_audio is not None and prefetched_metadata is not None:
        # === v7.0: Use prefetched download (batch optimization) ===
        logger.info("⚡ Using prefetched download (batch optimization)")
        audio_path = prefetched_audio
        metadata = prefetched_metadata
        timings['download'] = 0.0  # Prefetched, no additional time
        compute_stats['download'] = {'cpu': 0, 'gpu': 0}
    else:
        # Standard download
        with COMPUTE.monitor_stage("Download") as metrics:
            audio_path, metadata = download_audio(video_url, config)
        timings['download'] = metrics.duration
        compute_stats['download'] = {'cpu': metrics.cpu_percent_avg, 'gpu': metrics.gpu_percent_avg}
    
    # ========================================
    # === OPTIMIZATION (v6.2): SHARED AUDIO BUFFER ===
    # Load audio ONCE here, pass buffer to all stages
    # Previously: Each stage called torchaudio.load() separately (4-5 times!)
    # Now: Single load, reuse across VAD, chunking, quality, embeddings
    # ========================================
    with COMPUTE.monitor_stage("LoadBuffer") as metrics:
        audio_buffer = AudioBuffer.from_file(audio_path)
    total_duration = audio_buffer.duration
    timings['load_buffer'] = metrics.duration
    
    # ========================================
    # STAGE 2: VAD (Parallel CPU with persistent workers)
    # === OPTIMIZATION (v6.2): Uses audio_buffer (no file re-read) ===
    # === OPTIMIZATION (v6.2): Model loaded ONCE per worker via initializer ===
    # ========================================
    with COMPUTE.monitor_stage("VAD", items=int(total_duration / config.vad_chunk_size)) as metrics:
        vad_segments = run_vad_parallel(audio_path, config, audio_buffer=audio_buffer)
    timings['vad'] = metrics.duration
    compute_stats['vad'] = {'cpu': metrics.cpu_percent_avg, 'gpu': metrics.gpu_percent_avg}
    
    # ========================================
    # STAGE 3: Chunking (VAD-aware, IN-MEMORY)
    # === OPTIMIZATION (v6.2): No disk writes for chunks ===
    # Pass waveform tensors directly to diarization
    # ========================================
    with COMPUTE.monitor_stage("Chunking") as metrics:
        chunks = create_chunks(audio_path, vad_segments, config, 
                              audio_buffer=audio_buffer, in_memory=True)
    timings['chunking'] = metrics.duration
    compute_stats['chunking'] = {'cpu': metrics.cpu_percent_avg, 'gpu': metrics.gpu_percent_avg}
    
    # ========================================
    # STAGE 4: UNIFIED Diarization + OSD (GPU) - SINGLE PASS!
    # === OPTIMIZATION (v6.1) ===
    # Previously: OSD ran diarization, then Diarization stage ran it again
    # Now: Single diarization pass extracts BOTH segments AND overlaps
    # Saves ~50% GPU compute time!
    # ========================================
    overlap_segments = []
    nonspeech_segments = []
    
    with COMPUTE.monitor_stage("Diarization+OSD", items=len(chunks)) as metrics:
        # Run unified diarization + overlap extraction
        segments, overlap_segments = run_diarization(
            chunks, config, 
            extract_overlaps=config.detect_overlap
        )
        
        # Detect non-speech from VAD gaps
        nonspeech_segments = detect_nonspeech_regions(vad_segments, total_duration)
    
    timings['diarization'] = metrics.duration
    compute_stats['diarization'] = {'cpu': metrics.cpu_percent_avg, 'gpu': metrics.gpu_percent_avg}
    
    # ========================================
    # STAGE 4b: Split segments on overlaps (salvage clean portions)
    # === OPTIMIZATION (v6.1) ===
    # Previously: remove_overlap_regions() dropped entire segments if >50% overlap
    # Now: split_on_overlaps() salvages clean portions around overlaps
    # ========================================
    if overlap_segments and config.detect_overlap:
        original_count = len(segments)
        segments = split_on_overlaps(segments, overlap_segments, min_segment=config.min_segment_duration)
        logger.info(f"   Split overlaps: {original_count} → {len(segments)} segments (salvaged clean portions)")
    
    # ========================================
    # STAGE 4c: Quality filtering (SNR, clipping, etc.)
    # === OPTIMIZATION (v6.2): Uses audio_buffer (no file re-read) ===
    # Mark low-quality segments as unusable to protect TTS training data
    # ========================================
    low_quality_segments = []
    if config.filter_by_quality:
        with COMPUTE.monitor_stage("QualityFilter") as metrics:
            segments, low_quality_segments = filter_by_quality(
                segments, audio_path,
                min_snr_db=config.min_snr_db,
                min_quality_score=config.min_quality_score,
                audio_buffer=audio_buffer  # v6.2: Use buffer
            )
        timings['quality_filter'] = metrics.duration
        compute_stats['quality_filter'] = {'cpu': metrics.cpu_percent_avg, 'gpu': metrics.gpu_percent_avg}
    
    # NOTE: Duration filtering moved to AFTER clustering (Stage 7b) per user request
    # This ensures we don't waste compute on segments that will be filtered anyway
    # AND allows us to sort by duration (keep ones closest to 1s threshold)
    
    # ========================================
    # STAGE 5: FUSED EMBEDDINGS + MUSIC DETECTION (v6.9 - Phase 1 Optimization)
    # === PARALLEL CPU PREP: Music cache prep runs during embedding GPU inference ===
    # BEFORE: Embeddings(8s) → Music(24s) = 32s total
    # AFTER:  Embeddings(8s) + Music(16s overlapped) ≈ 24s (25% faster)
    # ========================================
    from concurrent.futures import ThreadPoolExecutor
    import threading
    
    # Pre-create configs and caches (no GPU yet)
    embedding_config = UnifiedEmbeddingConfig(
        chunk_duration=1.5,  # Kimi-Audio standard
        min_speech_ratio=config.chunk_reassignment_min_speech,
        batch_size=128,  # Aggressive batching
        severe_threshold=config.chunk_reassignment_severe,
        normal_threshold=config.chunk_reassignment_threshold,
    )
    
    embedding_cache = UnifiedChunkEmbeddingCache(
        audio_buffer=audio_buffer,
        vad_segments=vad_segments,
        embedding_model=MODELS.embedding_model,
        vad_model=MODELS.silero_vad,
        vad_utils=MODELS.silero_utils,
        config=embedding_config,
        device=MODELS.get_device()
    )
    
    # Pre-initialize music detection (will prep in parallel)
    music_cache = None
    music_detection_stats = None
    music_prep_ready = threading.Event()
    music_cache_holder = [None]  # Mutable container for thread result
    
    def prepare_music_cache():
        """Prepare music detection cache (CPU-intensive: chunk creation + 32kHz resample)."""
        try:
            if not config.enable_music_detection:
                return
            panns_model = MODELS.load_panns(device=str(MODELS.get_device()))
            music_config = MusicDetectionConfig(
                chunk_duration=config.music_chunk_duration,
                batch_size=config.music_batch_size,
                # Music thresholds
                music_prob_threshold=config.music_prob_threshold,
                music_ratio_clean=config.music_ratio_clean,
                music_ratio_demucs=config.music_ratio_demucs,
                music_mean_clean=config.music_mean_clean,
                music_mean_demucs=config.music_mean_demucs,
                # Noise thresholds (v7.1)
                noise_prob_threshold=config.noise_prob_threshold,
                noise_ratio_clean=config.noise_ratio_clean,
                noise_ratio_demucs=config.noise_ratio_demucs,
                noise_mean_clean=config.noise_mean_clean,
                noise_mean_demucs=config.noise_mean_demucs,
                # Strict TTS mode (v7.1)
                strict_tts_mode=config.strict_tts_mode,
                # Early-exit
                early_exit_enabled=config.music_early_exit,
                early_exit_sample_ratio=config.music_early_exit_sample_ratio,
                early_exit_threshold=config.music_early_exit_threshold,
            )
            music_cache_holder[0] = MusicDetectionCache(
                audio_buffer=audio_buffer,
                vad_segments=vad_segments,
                panns_model=panns_model,
                config=music_config,
                device=MODELS.get_device()
            )
            # Pre-compute 32kHz resampled audio (CPU intensive)
            music_cache_holder[0]._precompute_resampled_audio()
        except Exception as e:
            logger.warning(f"Music prep failed: {e}")
        finally:
            music_prep_ready.set()
    
    # Start music prep in background thread
    music_prep_thread = threading.Thread(target=prepare_music_cache, daemon=True)
    if config.enable_music_detection:
        music_prep_thread.start()
    
    # Build embedding cache (GPU intensive)
    with COMPUTE.monitor_stage("UnifiedEmbeddings") as metrics:
        num_chunks = embedding_cache.build()
        cache_stats = embedding_cache.get_stats()
        logger.info(f"   Unified cache: {num_chunks} chunk embeddings in {cache_stats['build_time_sec']:.2f}s")
    
    timings['unified_embeddings'] = metrics.duration
    compute_stats['unified_embeddings'] = {'cpu': metrics.cpu_percent_avg, 'gpu': metrics.gpu_percent_avg}
    
    # Derive segment embeddings from cache (CPU only - instant!)
    with COMPUTE.monitor_stage("SegmentEmbeddings") as metrics:
        embeddings = build_segment_embeddings_from_cache(segments, embedding_cache)
        logger.info(f"   Derived {len(embeddings)} segment embeddings from cache (CPU only)")
    timings['segment_embeddings'] = metrics.duration
    compute_stats['segment_embeddings'] = {'cpu': metrics.cpu_percent_avg, 'gpu': metrics.gpu_percent_avg}
    
    # Wait for music prep to complete (should already be done)
    if config.enable_music_detection:
        music_prep_ready.wait(timeout=30)
        music_cache = music_cache_holder[0]
    
    # ========================================
    # STAGE 6: Clustering (CPU)
    # Goal: Merge same-speaker fragments (CONSERVATIVE)
    # ========================================
    with COMPUTE.monitor_stage("Clustering") as metrics:
        segments = merge_speakers(audio_path, segments, embeddings, config)
    timings['clustering'] = metrics.duration
    compute_stats['clustering'] = {'cpu': metrics.cpu_percent_avg, 'gpu': metrics.gpu_percent_avg}
    
    # ========================================
    # STAGE 6b: Chunk Reassignment (v6.4 - Using Unified Cache - TURBO!)
    # === MAJOR OPTIMIZATION: Uses CACHED chunk embeddings (CPU only!) ===
    # BEFORE: Each segment required GPU inference for chunks + portions
    # AFTER:  All operations use cached embeddings (instant CPU lookups)
    # Predicted speedup: 74s → 0.5s (from GPU inference to cache lookup)
    # ========================================
    chunk_reassignment_stats = None
    if config.enable_chunk_reassignment:
        with COMPUTE.monitor_stage("ChunkReassignment") as metrics:
            # Compute speaker centroids from embeddings
            from collections import defaultdict
            speaker_embeddings = defaultdict(list)
            for i, seg in enumerate(segments):
                if i in embeddings and seg.get('speaker') not in ['OVERLAP', 'NON_SPEECH']:
                    speaker_embeddings[seg['speaker']].append(embeddings[i])
            
            speaker_centroids = {
                spk: np.mean(embs, axis=0) 
                for spk, embs in speaker_embeddings.items() 
                if embs
            }
            
            # Execute chunk reassignment using CACHED embeddings (mostly CPU!)
            segments, reassignment_stats = execute_chunk_reassignment_cached(
                segments=segments,
                cache=embedding_cache,
                speaker_centroids=speaker_centroids,
                min_segment_for_analysis=config.chunk_reassignment_min_portion * 2,  # 3.0s
                min_split_portion=config.chunk_reassignment_min_portion,
                assign_min_similarity=0.55,
                margin_min=0.10,
                create_new_on_ambiguous=True
            )
            
            # Create stats object for compatibility
            class ReassignmentStatsCompat:
                def __init__(self, stats_dict):
                    self.segments_with_changes = stats_dict.get('segments_with_changes', 0)
                    self.new_speakers_created = stats_dict.get('new_speakers_created', 0)
            
            chunk_reassignment_stats = ReassignmentStatsCompat(reassignment_stats)
            
            logger.info(f"   Chunk reassignment (cached): {reassignment_stats['segments_with_changes']} segments split, "
                       f"{reassignment_stats['new_speakers_created']} new speakers, "
                       f"severe={reassignment_stats['severe_triggers']}, lookahead={reassignment_stats['lookahead_triggers']}")
        
        timings['chunk_reassignment'] = metrics.duration
        compute_stats['chunk_reassignment'] = {'cpu': metrics.cpu_percent_avg, 'gpu': metrics.gpu_percent_avg}
    
    # ========================================
    # STAGE 6c: MUSIC DETECTION (v6.9 - Phase 1: Pre-prepared cache)
    # === CPU prep was done in parallel during embeddings ===
    # Only GPU inference runs here (no resampling overhead!)
    # ========================================
    if config.enable_music_detection and music_cache is not None:
        with COMPUTE.monitor_stage("MusicDetection") as metrics:
            try:
                # Build uses pre-prepared cache (32kHz resample already done!)
                num_music_chunks = music_cache.build()
                
                # Compute segment-level music stats and mark segments
                segments, music_detection_stats = build_segment_music_stats_batch(segments, music_cache)
                
                logger.info(f"   Music detection: {num_music_chunks} chunks analyzed")
                logger.info(f"   Clean: {music_detection_stats['segments_clean']} segments "
                           f"({music_detection_stats.get('pct_clean', 0)}%)")
                logger.info(f"   Needs Demucs: {music_detection_stats['segments_needs_demucs']} segments "
                           f"({music_detection_stats.get('pct_needs_demucs', 0)}%)")
                logger.info(f"   Heavy Contamination: {music_detection_stats.get('segments_heavy_contamination', music_detection_stats.get('segments_heavy_music', 0))} segments "
                           f"({music_detection_stats.get('pct_heavy_contamination', music_detection_stats.get('pct_heavy_music', 0))}%)")
                
            except ImportError as e:
                logger.warning(f"⚠️ Music detection disabled: {e}")
                music_detection_stats = {'error': str(e)}
            except Exception as e:
                logger.error(f"❌ Music detection failed: {e}")
                import traceback
                traceback.print_exc()
                music_detection_stats = {'error': str(e)}
        
        timings['music_detection'] = metrics.duration
        compute_stats['music_detection'] = {'cpu': metrics.cpu_percent_avg, 'gpu': metrics.gpu_percent_avg}
    
    # ========================================
    # STAGE 7: Duration filtering for TTS (POST-adjacent-merge)
    #
    # We intentionally run duration filtering AFTER adjacent merging (Stage 8)
    # so short same-speaker fragments (e.g. 0.4-0.9s) can be absorbed into a
    # longer continuous turn instead of becoming hard boundaries.
    # ========================================
    dropped_short_count = 0
    dropped_short_duration = 0.0
    
    # ========================================
    # STAGE 8: Finalization
    # ========================================
    with COMPUTE.monitor_stage("Finalization") as metrics:
        # 1) Merge adjacent same-speaker segments (speaker-labeled only).
        #
        # Include low-quality segments as UNUSABLE boundaries so we never merge
        # across them (otherwise we'd accidentally "smear" bad audio into a clean clip).
        speaker_segments_for_merge = list(segments)
        for seg in low_quality_segments:
            seg['status'] = 'unusable'
            seg['unusable_reason'] = seg.get('unusable_reason', 'low_quality')
            speaker_segments_for_merge.append(seg)
        
        # Treat overlap regions as hard boundaries so we never merge a "clean"
        # speaker segment across an overlap gap (would re-include multi-speaker audio).
        for seg in overlap_segments:
            speaker_segments_for_merge.append({
                'start': seg['start'],
                'end': seg['end'],
                'duration': seg['duration'],
                'speaker': 'OVERLAP',
                'status': 'unusable',
                'unusable_reason': 'overlap'
            })
        
        speaker_segments_for_merge.sort(key=lambda x: x['start'])
        merged_all_segments = merge_adjacent_segments(speaker_segments_for_merge, config)
        
        merged_speaker_segments = [
            s for s in merged_all_segments
            if s.get('speaker') not in ['OVERLAP', 'NON_SPEECH']
        ]
        overlap_output_segments = [s for s in merged_all_segments if s.get('speaker') == 'OVERLAP']
        
        # 2) Drop <min_tts_duration speaker segments entirely (no "keep 1%" behavior).
        if config.min_tts_duration > 0:
            kept = []
            dropped = []
            for seg in merged_speaker_segments:
                if seg['duration'] >= config.min_tts_duration:
                    kept.append(seg)
                else:
                    dropped.append(seg)
            
            dropped_short_count = len(dropped)
            dropped_short_duration = sum(s['duration'] for s in dropped)
            
            if dropped_short_count > 0:
                logger.info(
                    f"   Duration filter (post-merge): dropped {dropped_short_count} speaker segments "
                    f"<{config.min_tts_duration}s ({dropped_short_duration:.1f}s total)"
                )
            
            merged_speaker_segments = kept
        
        # Start final output segments from merged speaker segments + overlap boundaries
        segments = merged_speaker_segments + overlap_output_segments
        
        # 3) Add non-speech segments (marked as unusable)
        # Filter internal NON_SPEECH within (merged) speaker segments.
        #
        # NOTE: this is containment-based. If we merged same-speaker fragments
        # around a natural pause, that pause becomes internal and will be filtered.
        speaker_segs = [s for s in segments if s.get('speaker') not in ['OVERLAP', 'NON_SPEECH', None]]
        overlap_segs = [s for s in segments if s.get('speaker') == 'OVERLAP']
        
        internal_filtered_count = 0
        internal_filtered_duration = 0.0
        
        for seg in nonspeech_segments:
            ns_start, ns_end = seg['start'], seg['end']
            
            # Check if this NON_SPEECH falls entirely within any single speaker segment
            is_internal = False
            for sp in speaker_segs:
                if sp['start'] <= ns_start and ns_end <= sp['end']:
                    has_overlap_nearby = any(
                        ov['start'] < ns_end and ov['end'] > ns_start  # overlap intersects
                        for ov in overlap_segs
                    )
                    if not has_overlap_nearby:
                        is_internal = True
                        break
            
            if is_internal:
                internal_filtered_count += 1
                internal_filtered_duration += seg['duration']
            else:
                segments.append({
                    'start': seg['start'],
                    'end': seg['end'],
                    'duration': seg['duration'],
                    'speaker': 'NON_SPEECH',
                    'status': 'unusable',
                    'unusable_reason': 'non_speech'
                })
        
        if internal_filtered_count > 0:
            logger.info(
                f"   Filtered {internal_filtered_count} internal pauses "
                f"({internal_filtered_duration:.1f}s) within speaker segments"
            )
        
        # Sort all by start time
        segments.sort(key=lambda x: x['start'])
        
        # === OVERLAP DENSITY FILTER (v6.9) ===
        # Mark short segments sandwiched between overlaps as unusable
        # This catches "islands" of supposedly clean audio in high-overlap regions
        from src.overlap_detection import filter_overlap_sandwich_segments
        segments = filter_overlap_sandwich_segments(segments, config)
    
    timings['finalization'] = metrics.duration
    
    # Cleanup temp files (failure-safe via TEMP_MANAGER)
    cleanup_chunks(audio_path)
    TEMP_MANAGER.cleanup_for_audio(audio_path)  # v6.2: Also cleanup via manager
    
    # ========================================
    # STAGE 9: Sample Generation (optional)
    # ========================================
    output_dir = Path(metadata['output_dir'])
    sample_clips = {}
    
    if config.generate_sample_clips:
        with COMPUTE.monitor_stage("Samples") as metrics:
            sample_clips = generate_samples(audio_path, segments, config, output_dir)
        timings['samples'] = metrics.duration
    
    total_time = time.time() - total_start
    
    # ========================================
    # Compute Statistics
    # ========================================
    all_speakers = list(set(s['speaker'] for s in segments if s['speaker'] not in ['OVERLAP', 'NON_SPEECH']))
    usable_segs = [s for s in segments if s.get('status') == 'usable']
    overlap_segs = [s for s in segments if s['speaker'] == 'OVERLAP']
    nonspeech_segs = [s for s in segments if s['speaker'] == 'NON_SPEECH']
    low_quality_segs = [s for s in segments if s.get('unusable_reason') == 'low_quality']
    overlap_proximity_segs = [s for s in segments if s.get('unusable_reason') == 'overlap_proximity']  # v6.9
    
    usable_duration = sum(s['duration'] for s in usable_segs)
    overlap_duration = sum(s['duration'] for s in overlap_segs)
    nonspeech_duration = sum(s['duration'] for s in nonspeech_segs)
    low_quality_duration = sum(s['duration'] for s in low_quality_segs)
    overlap_proximity_duration = sum(s['duration'] for s in overlap_proximity_segs)  # v6.9
    # Short (<min_tts_duration) speaker segments are DROPPED entirely (post-merge),
    # so they won't appear in `segments`. We still surface them in stats.
    too_short_count = dropped_short_count
    too_short_duration = dropped_short_duration
    
    # Prepare output segments with timestamps
    output_segments = []
    intro_offset = metadata.get('intro_skipped', 0)
    
    for seg in segments:
        output_segments.append({
            'start': round(seg['start'], 3),
            'end': round(seg['end'], 3),
            'duration': round(seg['duration'], 3),
            'speaker': seg['speaker'],
            'status': seg.get('status', 'usable'),
            'original_start': round(seg['start'] + intro_offset, 3),
            'original_end': round(seg['end'] + intro_offset, 3),
        })
        # Preserve music detection stats (v6.7)
        if 'music_stats' in seg:
            output_segments[-1]['music_stats'] = seg['music_stats']
        if seg.get('needs_demucs'):
            output_segments[-1]['needs_demucs'] = True
        if 'unusable_reason' in seg:
            output_segments[-1]['unusable_reason'] = seg['unusable_reason']
    
    # Build result
    result = {
        'video_id': metadata['video_id'],
        'video_title': metadata.get('title', 'Unknown'),
        'youtube_url': metadata.get('youtube_url', video_url),
        'original_duration': round(metadata.get('original_duration', total_duration), 2),
        'intro_skipped': round(intro_offset, 2),
        'processed_duration': round(metadata.get('processed_duration', total_duration), 2),
        # === v6.8: High-quality audio info ===
        'sample_rate': metadata.get('sample_rate', 16000),
        'original_sample_rate': metadata.get('original_sample_rate', 16000),
        'original_audio_path': metadata.get('original_audio_path'),  # Path to high-quality audio
        'original_audio_preserved': metadata.get('original_audio_preserved', False),
        'num_speakers': len(all_speakers),
        'speakers': all_speakers,
        'total_segments': len(segments),
        'segments': output_segments,
        'quality_stats': {
            'usable_segments': len(usable_segs),
            'usable_duration': round(usable_duration, 2),
            'usable_percentage': round(usable_duration / total_duration * 100, 1) if total_duration > 0 else 0,
            'overlap_segments': len(overlap_segs),
            'overlap_duration': round(overlap_duration, 2),
            'overlap_percentage': round(overlap_duration / total_duration * 100, 1) if total_duration > 0 else 0,
            'nonspeech_segments': len(nonspeech_segs),
            'nonspeech_duration': round(nonspeech_duration, 2),
            'nonspeech_percentage': round(nonspeech_duration / total_duration * 100, 1) if total_duration > 0 else 0,
            # v6.1 additions
            'low_quality_segments': len(low_quality_segs),
            'low_quality_duration': round(low_quality_duration, 2),
            'too_short_segments': int(too_short_count),
            'too_short_duration': round(float(too_short_duration), 2),
            # v6.9 overlap density filter
            'overlap_proximity_segments': len(overlap_proximity_segs),
            'overlap_proximity_duration': round(overlap_proximity_duration, 2),
            # v6.3 chunk reassignment stats
            'split_segments': chunk_reassignment_stats.segments_with_changes if chunk_reassignment_stats else 0,
            'new_speakers_from_splits': chunk_reassignment_stats.new_speakers_created if chunk_reassignment_stats else 0,
            # v6.7 music detection stats
            'music_detection': music_detection_stats if music_detection_stats else {},
        },
        'vad_stats': {
            'speech_segments': len(vad_segments),
            'total_speech': round(sum(s['duration'] for s in vad_segments), 2),
        },
        'sample_clips': {spk: [Path(p).name for p in paths] for spk, paths in sample_clips.items()},
        'timing': {k: round(v, 2) for k, v in timings.items()},
        'timing_total': round(total_time, 2),
        'compute_stats': compute_stats,
        'config': {
            'model': MODELS._model_name,
            'merge_threshold': config.cluster_merge_threshold,
            'vad_workers': config.vad_workers,
            'embedding_batch_size': config.embedding_batch_size,
            'min_segment_duration': config.min_segment_duration,
            # v6.1 config
            'min_tts_duration': config.min_tts_duration,
            'min_snr_db': config.min_snr_db,
            'filter_by_quality': config.filter_by_quality,
            # v6.3 chunk reassignment config
            'chunk_reassignment_enabled': config.enable_chunk_reassignment,
            'chunk_reassignment_threshold': config.chunk_reassignment_threshold,
            'chunk_reassignment_severe': config.chunk_reassignment_severe,
            # v6.7 music detection config
            'music_detection_enabled': config.enable_music_detection,
            'music_ratio_clean': config.music_ratio_clean,
            'music_ratio_demucs': config.music_ratio_demucs,
            # v6.8 intro skip + audio config
            'auto_intro_skip': config.auto_intro_skip,
            'preserve_original_audio': config.preserve_original_audio,
        },
        'processed_at': datetime.now().isoformat(),
        'pipeline_version': 'v7.1-strict-tts',
        # v6.4 cache stats
        'embedding_cache_stats': cache_stats if 'cache_stats' in dir() else {},
    }
    
    # Save metadata JSON
    with open(output_dir / "metadata.json", 'w') as f:
        json.dump(result, f, indent=2)
    
    # Summary
    processing_hours = total_duration / 3600
    rate = total_time / processing_hours if processing_hours > 0 else 0
    
    logger.info(f"📊 Result: {total_duration/60:.1f}min | "
                f"{result['num_speakers']} speakers | {result['quality_stats']['usable_percentage']}% usable | "
                f"{total_time:.1f}s ({rate:.1f}s/hr)")
    
    # === v6.8+: Cleanup 16kHz processing file if original is preserved ===
    # Keep only the high-quality original for final export
    if config.preserve_original_audio and metadata.get('original_audio_preserved'):
        trimmed_file = Path(audio_path)
        original_file = Path(metadata.get('original_audio_path', ''))
        if trimmed_file.exists() and original_file.exists():
            trimmed_size_mb = trimmed_file.stat().st_size / (1024 * 1024)
            trimmed_file.unlink()
            logger.info(f"🧹 Cleaned up 16kHz processing file ({trimmed_size_mb:.1f}MB), keeping original quality")
    
    # Clear cache for next video
    MODELS.clear_cache(aggressive=True)
    COMPUTE.refresh_gpu_memory()
    
    return result


def _prefetch_download(url: str, config: Config) -> Tuple[Optional[str], Optional[Dict], Optional[Exception]]:
    """
    Background download for batch prefetching.
    
    Returns:
        (audio_path, metadata, error) - error is None on success
    """
    try:
        audio_path, metadata = download_audio(url, config, validate=False)
        return audio_path, metadata, None
    except Exception as e:
        return None, None, e


def process_batch(
    video_urls: List[str], 
    config: Optional[Config] = None,
    validate_videos: bool = True,
    skip_on_validation_fail: bool = True
) -> List[Dict]:
    """
    Process multiple videos with adaptive compute-aware settings.
    
    === v7.0 OPTIMIZATION: Batch Download Overlap ===
    Downloads video N+1 in background while processing video N.
    This hides ~60s download time for each video after the first.
    
    For 5 videos: saves ~4 × 60s = 240s (26% faster batch processing)
    
    Flow:
    1. Detect system resources (COMPUTE)
    2. Apply adaptive settings to config
    3. Validate all videos (optional)
    4. Load models once (hot inference)
    5. Process each video with PREFETCHED download for next video
    6. Print utilization summary
    
    Args:
        video_urls: List of YouTube URLs to process
        config: Pipeline configuration (auto-created if None)
        validate_videos: Whether to pre-validate videos (default: True)
        skip_on_validation_fail: If True, skip invalid videos; if False, raise error
    
    Returns:
        List of result dicts (includes 'error' key for failures)
    """
    config = config or Config()
    
    # NOTE: Adaptive settings already applied in main.py with CLI override priority
    # Only apply if config was created without CLI (e.g. direct API call)
    # Don't call apply_adaptive_settings here - it would override user's CLI args
    
    logger.info("=" * 70)
    logger.info(f"🚀 BATCH PROCESSING: {len(video_urls)} videos")
    logger.info(f"   Model Priority: community-1 (fallback: 3.1)")
    logger.info(f"   VAD Workers: {config.vad_workers}")
    logger.info(f"   Embedding Batch: {config.embedding_batch_size}")
    logger.info(f"   Merge Threshold: {config.cluster_merge_threshold}")
    logger.info(f"   Min Segment: {config.min_segment_duration}s (for 0.4s events)")
    logger.info(f"   ⚡ Batch Prefetch: ENABLED (v7.0 optimization)")
    logger.info("=" * 70)
    
    # === EDGE CASE: Pre-validate all videos ===
    valid_urls = []
    validation_results = {}
    
    if validate_videos:
        logger.info("\n🔍 PRE-VALIDATING VIDEOS...")
        for i, url in enumerate(video_urls):
            logger.info(f"   [{i+1}/{len(video_urls)}] Validating: {url}")
            is_valid, message, info = validate_video(url)
            validation_results[url] = {'valid': is_valid, 'message': message, 'info': info}
            
            if is_valid:
                valid_urls.append(url)
                title = info.get('title', 'Unknown')[:50]
                duration = info.get('duration', 0)
                logger.info(f"       ✅ '{title}' ({duration}s)")
            else:
                logger.warning(f"       ❌ {message}")
                if not skip_on_validation_fail:
                    raise VideoValidationError(f"Video validation failed: {url} - {message}")
        
        logger.info(f"\n   Validation complete: {len(valid_urls)}/{len(video_urls)} videos valid")
        
        if not valid_urls:
            logger.error("❌ No valid videos to process!")
            return [{'error': f'Validation failed: {validation_results[url]["message"]}', 'url': url} 
                   for url in video_urls]
    else:
        valid_urls = video_urls
    
    # Load all models ONCE
    MODELS.load_all(config)
    
    # Process each video
    results = []
    total_start = time.time()
    
    # Add validation failures to results
    for url in video_urls:
        if url not in valid_urls:
            results.append({
                'error': f'Validation failed: {validation_results[url]["message"]}',
                'url': url,
                'error_type': 'validation'
            })
    
    # === v7.0 OPTIMIZATION: Batch Download Overlap ===
    # Use ThreadPoolExecutor to download next video while processing current
    # This hides download latency (~60s) behind GPU processing time
    prefetch_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="prefetch")
    prefetch_future: Optional[Future] = None
    prefetch_url: Optional[str] = None
    
    # Start prefetching first video
    if len(valid_urls) > 0:
        logger.info(f"⚡ Starting prefetch for first video...")
        prefetch_future = prefetch_executor.submit(_prefetch_download, valid_urls[0], config)
        prefetch_url = valid_urls[0]
    
    for i, url in enumerate(valid_urls):
        logger.info(f"\n{'='*70}")
        logger.info(f"📹 VIDEO {i+1}/{len(valid_urls)} (total: {len(video_urls)})")
        logger.info(f"{'='*70}")
        
        try:
            # === v7.0: Use prefetched download if available ===
            prefetched_audio = None
            prefetched_metadata = None
            
            if prefetch_future is not None and prefetch_url == url:
                logger.info(f"⚡ Using prefetched download for {url}")
                audio_path, metadata, prefetch_error = prefetch_future.result(timeout=300)
                
                if prefetch_error is None and audio_path is not None:
                    prefetched_audio = audio_path
                    prefetched_metadata = metadata
                    logger.info(f"   ✅ Prefetch ready: {Path(audio_path).name}")
                else:
                    logger.warning(f"   ⚠️ Prefetch failed: {prefetch_error}, will download in main thread")
            
            # Start prefetching NEXT video while we process current
            prefetch_future = None
            prefetch_url = None
            if i + 1 < len(valid_urls):
                next_url = valid_urls[i + 1]
                logger.info(f"⚡ Prefetching next video in background: {next_url}")
                prefetch_future = prefetch_executor.submit(_prefetch_download, next_url, config)
                prefetch_url = next_url
            
            # Process video (with or without prefetched download)
            result = process_single_video(
                url, config, 
                prefetched_audio=prefetched_audio,
                prefetched_metadata=prefetched_metadata
            )
            results.append(result)
            
            # === EDGE CASE: Clear GPU cache between videos ===
            MODELS.clear_cache(aggressive=True)
            COMPUTE.refresh_gpu_memory()
            
        except VideoValidationError as e:
            logger.error(f"❌ Validation Error: {url} - {e}")
            results.append({'error': str(e), 'url': url, 'error_type': 'validation'})
            
        except DownloadError as e:
            logger.error(f"❌ Download Error: {url} - {e}")
            results.append({'error': str(e), 'url': url, 'error_type': 'download'})
            
        except MemoryError as e:
            logger.error(f"❌ Memory Error: {url} - {e}")
            results.append({'error': str(e), 'url': url, 'error_type': 'memory'})
            # Try to recover
            MODELS.clear_cache(aggressive=True)
            import gc
            gc.collect()
            
        except Exception as e:
            logger.error(f"❌ Processing Error: {url} - {e}")
            import traceback
            traceback.print_exc()
            results.append({'error': str(e), 'url': url, 'error_type': 'processing'})
            # Clear cache and continue
            MODELS.clear_cache(aggressive=True)
    
    # Cleanup prefetch executor
    prefetch_executor.shutdown(wait=False)
    
    total_time = time.time() - total_start
    
    # ========================================
    # Batch Summary
    # ========================================
    logger.info("\n" + "=" * 70)
    logger.info("📊 BATCH SUMMARY")
    logger.info("=" * 70)
    
    successful = [r for r in results if 'error' not in r]
    failed = [r for r in results if 'error' in r]
    
    logger.info(f"✅ Successful: {len(successful)}/{len(video_urls)}")
    logger.info(f"❌ Failed: {len(failed)}/{len(video_urls)}")
    
    if successful:
        total_duration = sum(r['processed_duration'] for r in successful)
        total_usable = sum(r['quality_stats']['usable_duration'] for r in successful)
        avg_usable = np.mean([r['quality_stats']['usable_percentage'] for r in successful])
        total_process = sum(r['timing_total'] for r in successful)
        
        # Music detection summary
        total_clean = sum(r['quality_stats'].get('music_detection', {}).get('segments_clean', 0) for r in successful)
        total_demucs = sum(r['quality_stats'].get('music_detection', {}).get('segments_needs_demucs', 0) for r in successful)
        total_heavy = sum(r['quality_stats'].get('music_detection', {}).get('segments_heavy_contamination', 
                          r['quality_stats'].get('music_detection', {}).get('segments_heavy_music', 0)) for r in successful)
        
        logger.info(f"\n📈 AGGREGATE:")
        logger.info(f"   Total processed: {total_duration/60:.1f} min")
        logger.info(f"   Total usable: {total_usable/60:.1f} min ({avg_usable:.1f}% avg)")
        logger.info(f"   Total time: {total_process:.1f}s")
        if total_duration > 0:
            logger.info(f"   Rate: {total_process/(total_duration/3600):.1f}s per hour of audio")
        
        if total_clean + total_demucs + total_heavy > 0:
            logger.info(f"\n🎵 MUSIC DETECTION:")
            logger.info(f"   Clean segments: {total_clean}")
            logger.info(f"   Needs Demucs: {total_demucs}")
            logger.info(f"   Heavy Music: {total_heavy}")
        
        logger.info(f"\n📋 PER-VIDEO RESULTS:")
        for r in successful:
            music_info = ""
            if r['quality_stats'].get('music_detection'):
                md = r['quality_stats']['music_detection']
                music_info = f" | 🎵 clean={md.get('segments_clean', 0)}, demucs={md.get('segments_needs_demucs', 0)}"
            
            logger.info(f"   • {r['video_id']}: {r['num_speakers']} speakers, "
                       f"{r['quality_stats']['usable_percentage']}% usable, {r['timing_total']:.0f}s{music_info}")
    
    if failed:
        logger.info(f"\n❌ FAILED VIDEOS:")
        for r in failed:
            error_type = r.get('error_type', 'unknown')
            logger.info(f"   • [{error_type}] {r['url']}: {r['error']}")
    
    # Print compute utilization summary
    logger.info(COMPUTE.summary())
    
    logger.info("=" * 70)
    
    return results
