#!/usr/bin/env python3
"""
Fast Pipeline v6.0 - ADAPTIVE COMPUTE-AWARE
Main entry point for speaker diarization pipeline.

Usage:
    python main.py URL1 [URL2 ...] [options]

Example:
    python main.py https://youtube.com/watch?v=... --merge-threshold 0.80

Key Improvements (v6.0):
1. ADAPTIVE compute: Auto-detects nproc, vCPUs, GPU vRAM
2. community-1 model FIRST (better performance per pyannote)
3. OSD (Overlap Detection) BEFORE diarization (prevent poison merges)
4. Frame-level segmentation (17ms resolution for 0.4s events)
5. Compute monitoring at each stage (CPU/GPU utilization)
6. Metadata-only output by default (no sample clips)
"""

import os
import sys
import logging
import argparse

# === PERFORMANCE TUNING - BEFORE IMPORTS ===
os.environ['OMP_NUM_THREADS'] = '8'
os.environ['MKL_NUM_THREADS'] = '8'
os.environ['NUMEXPR_NUM_THREADS'] = '8'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['HF_HOME'] = '/ephemeral/hf_cache'
os.environ['TORCH_FORCE_WEIGHTS_ONLY_LOAD'] = '0'

# Prevent PyTorch 2.6+ weights_only warnings
import torch
import warnings
warnings.filterwarnings('ignore')

_original_torch_load = torch.load
def _patched_torch_load(*args, **kwargs):
    kwargs['weights_only'] = False
    return _original_torch_load(*args, **kwargs)
torch.load = _patched_torch_load

# Patch PyTorch Lightning if available
try:
    import pytorch_lightning.core.saving as pl_saving
    _orig_pl_load = pl_saving.pl_load
    def _patched_pl_load(path_or_url, map_location=None, **kwargs):
        return _original_torch_load(path_or_url, map_location=map_location, weights_only=False)
    pl_saving.pl_load = _patched_pl_load
except ImportError:
    pass

# Patch lightning_fabric too
try:
    import lightning_fabric.utilities.cloud_io as cloud_io
    _orig_fabric_load = cloud_io._load
    def _patched_fabric_load(path_or_url, map_location=None):
        return _original_torch_load(path_or_url, map_location=map_location, weights_only=False)
    cloud_io._load = _patched_fabric_load
except (ImportError, AttributeError):
    pass

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s | %(levelname)-8s | %(message)s',
    datefmt='%H:%M:%S'
)
logger = logging.getLogger("FastPipelineV6")

# Import pipeline components
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent / 'src'))

from src.config import Config
from src.compute import COMPUTE
from pipeline import process_batch


def main():
    """Main entry point with adaptive compute configuration."""
    parser = argparse.ArgumentParser(
        description="Fast Pipeline v6.0 - Adaptive Compute-Aware Speaker Diarization",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Single video with auto-detected settings
  python main.py https://youtube.com/watch?v=...
  
  # Batch processing
  python main.py URL1 URL2 URL3
  
  # Override adaptive settings
  python main.py URL --vad-workers 64 --embedding-batch-size 16

Flow (per instructions.md):
  1. Download (parallel)
  2. VAD (parallel CPU, 200ms min for 0.4s events)
  3. OSD (mark overlaps BEFORE diarization)
  4. Chunk at silence (VAD-aware)
  5. Diarization (community-1 first)
  6. Frame refinement (17ms resolution)
  7. Embeddings + Clustering (conservative)
  8. Output metadata JSON

Compute Monitoring:
  - Auto-detects nproc, CPU cores, GPU VRAM
  - Adapts workers/batch sizes to your system
  - Logs CPU/GPU utilization per stage
  - Suggests optimizations if underutilized
        """
    )
    
    parser.add_argument("urls", nargs='+', help="YouTube URL(s) to process")
    
    # Audio processing
    parser.add_argument("--intro-skip", type=float, default=0.0,
                       help="Skip N seconds from start (default: 0, overrides auto-skip)")
    parser.add_argument("--outro-skip", type=float, default=0.0,
                       help="Skip N seconds from end (default: 0)")
    parser.add_argument("--no-auto-intro", action="store_true",
                       help="Disable automatic intro skip (default: enabled)")
    parser.add_argument("--preserve-original-audio", action="store_true",
                       help="Keep original quality audio alongside 16kHz (for high-quality cutting)")
    
    # Parallelism (override adaptive)
    parser.add_argument("--vad-workers", type=int, default=None,
                       help="VAD parallel workers (default: auto-detected)")
    parser.add_argument("--workers", type=int, default=None,
                       help="General CPU workers (default: auto-detected)")
    
    # Model settings
    parser.add_argument("--merge-threshold", type=float, default=0.80,
                       help="Speaker merge threshold (default: 0.80, higher=fewer speakers)")
    parser.add_argument("--embedding-batch-size", type=int, default=None,
                       help="Embedding batch size (default: auto-detected)")
    parser.add_argument("--no-overlap-detection", action="store_true",
                       help="Disable overlap detection (faster but lower quality)")
    
    # Segmentation
    parser.add_argument("--min-segment", type=float, default=0.2,
                       help="Minimum segment duration in seconds (default: 0.2 for 0.4s events)")
    
    # Output
    parser.add_argument("--output-dir", default="data/fast_output_v6",
                       help="Output directory (default: data/fast_output_v6)")
    parser.add_argument("--with-samples", action="store_true",
                       help="Generate sample audio clips (default: metadata only)")
    
    # Chunk reassignment (v6.3)
    parser.add_argument("--no-chunk-reassignment", action="store_true",
                       help="Disable chunk-based speaker reassignment")
    parser.add_argument("--reassign-threshold", type=float, default=0.40,
                       help="Chunk reassignment normal threshold (default: 0.40)")
    parser.add_argument("--reassign-severe", type=float, default=0.25,
                       help="Chunk reassignment severe threshold (default: 0.25)")
    
    # Advanced
    parser.add_argument("--vad-chunk-size", type=float, default=60.0,
                       help="VAD chunk size in seconds (default: 60)")
    parser.add_argument("--max-embedding-length", type=int, default=160000,
                       help="Max samples per embedding (default: 160000 = 10s)")
    
    args = parser.parse_args()
    
    # Validate inputs
    if not args.urls:
        parser.print_help()
        sys.exit(1)
    
    # Get adaptive settings from compute monitor
    adaptive = COMPUTE.get_optimal_config()
    
    # Create config with adaptive defaults, allow CLI overrides
    config = Config(
        intro_skip_seconds=args.intro_skip,
        outro_skip_seconds=args.outro_skip,
        output_dir=args.output_dir,
        vad_workers=args.vad_workers or adaptive['vad_workers'],
        max_workers=args.workers or adaptive['max_workers'],
        cluster_merge_threshold=args.merge_threshold,
        embedding_batch_size=args.embedding_batch_size or adaptive['embedding_batch_size'],
        generate_sample_clips=args.with_samples,
        vad_chunk_size=args.vad_chunk_size,
        max_embedding_length=args.max_embedding_length,
        detect_overlap=not args.no_overlap_detection,
        min_segment_duration=args.min_segment,
        # v6.3 chunk reassignment
        enable_chunk_reassignment=not args.no_chunk_reassignment,
        chunk_reassignment_threshold=args.reassign_threshold,
        chunk_reassignment_severe=args.reassign_severe,
        # v6.8 dynamic intro skip + high-quality audio
        auto_intro_skip=not args.no_auto_intro,
        preserve_original_audio=args.preserve_original_audio,
    )
    
    # Log configuration
    logger.info("=" * 70)
    logger.info("🚀 FAST PIPELINE V6.8 - ADAPTIVE COMPUTE")
    logger.info("=" * 70)
    logger.info(f"Videos: {len(args.urls)}")
    logger.info(f"VAD Workers: {config.vad_workers} ({'auto' if args.vad_workers is None else 'manual'})")
    logger.info(f"Embedding Batch: {config.embedding_batch_size} ({'auto' if args.embedding_batch_size is None else 'manual'})")
    logger.info(f"Merge Threshold: {config.cluster_merge_threshold}")
    logger.info(f"Min Segment: {config.min_segment_duration}s")
    logger.info(f"Overlap Detection: {'ON' if config.detect_overlap else 'OFF'}")
    logger.info(f"Chunk Reassignment: {'ON' if config.enable_chunk_reassignment else 'OFF'}")
    if config.enable_chunk_reassignment:
        logger.info(f"   Normal/Severe thresholds: {config.chunk_reassignment_threshold}/{config.chunk_reassignment_severe}")
    logger.info(f"Auto Intro Skip: {'ON' if config.auto_intro_skip else 'OFF'}")
    if config.intro_skip_seconds > 0:
        logger.info(f"   Manual intro override: {config.intro_skip_seconds}s")
    logger.info(f"Preserve Original Audio: {'ON' if config.preserve_original_audio else 'OFF'}")
    logger.info(f"Output: {config.output_dir}")
    logger.info("=" * 70)
    
    # Run pipeline
    try:
        results = process_batch(args.urls, config)
        
        # Exit code based on results
        successful = [r for r in results if 'error' not in r]
        if not successful:
            logger.error("❌ All videos failed")
            sys.exit(1)
        elif len(successful) < len(args.urls):
            logger.warning(f"⚠️  Partial success: {len(successful)}/{len(args.urls)}")
            sys.exit(2)
        else:
            logger.info("✅ All videos processed successfully")
            sys.exit(0)
            
    except KeyboardInterrupt:
        logger.info("\n⚠️  Interrupted by user")
        sys.exit(130)
    except Exception as e:
        logger.error(f"❌ Pipeline failed: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)


if __name__ == "__main__":
    main()
