#!/usr/bin/env python3
"""
MASSIVE COMPUTE - 8x H100 Parallel Processing

Designed for maximum throughput on multi-GPU clusters.
Each GPU runs an independent pipeline worker processing videos in parallel.

Usage:
    # Process videos from file (one URL per line)
    python massive_process.py --input videos.txt

    # Process from YouTube playlist or channel
    python massive_process.py --playlist "PLAYLIST_URL"

    # Benchmark mode with test videos
    python massive_process.py --benchmark --num-videos 16

    # Custom GPU selection
    python massive_process.py --input videos.txt --gpus 0,1,2,3

Architecture:
    Main Process
        │
        ├── GPU Worker 0 (CUDA_VISIBLE_DEVICES=0)
        │     └── pipeline.py (full model stack)
        │
        ├── GPU Worker 1 (CUDA_VISIBLE_DEVICES=1)
        │     └── pipeline.py (full model stack)
        │
        ... (8 workers total)
        │
        └── GPU Worker 7 (CUDA_VISIBLE_DEVICES=7)
              └── pipeline.py (full model stack)

Result: 8 videos processed in parallel → 8x throughput
"""

# CRITICAL: Suppress NNPACK warnings BEFORE any imports
# These warnings are harmless (NNPACK is CPU-only, we use GPU) but spam logs
import os
import warnings
os.environ['NNPACK_DISABLE'] = '1'  # Disable NNPACK entirely
os.environ['TORCH_CPP_LOG_LEVEL'] = 'ERROR'  # Suppress C++ warnings
warnings.filterwarnings('ignore', message='.*NNPACK.*')
warnings.filterwarnings('ignore', category=FutureWarning, module='pynvml')
warnings.filterwarnings('ignore', category=UserWarning, module='torchaudio')

import sys
import time
import json
import logging
import argparse
import signal
import traceback
from pathlib import Path
from datetime import datetime
from multiprocessing import Process, Queue, Value, Manager
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass, asdict, field

# Supabase client for distributed queue
try:
    from src.supabase_client import SupabaseClient, STATUS_COMPLETED, STATUS_FAILED
    SUPABASE_AVAILABLE = True
except ImportError:
    SUPABASE_AVAILABLE = False

# Optional: load .env for HF/R2/Supabase credentials
try:
    from dotenv import load_dotenv  # type: ignore
except Exception:
    load_dotenv = None

# Clean logging utilities
from src.logger import silence_verbose_loggers

# Analytics reporter for distributed monitoring
from src.analytics import create_reporter, AnalyticsReporter

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s | %(name)-12s | %(levelname)-8s | %(message)s',
    datefmt='%H:%M:%S'
)
logger = logging.getLogger("MassiveCompute")


# Global stop signal for graceful shutdown
_GLOBAL_STOP_SIGNAL = None


def validate_environment() -> bool:
    """
    Validate all required environment variables before starting workers.
    Fails fast on startup instead of runtime errors.
    """
    required = {
        'HF_TOKEN': 'Hugging Face token for model downloads',
        'URL': 'Supabase project URL',
        'SUPABASE_ADMIN': 'Supabase service role key',
    }

    r2_vars = {
        'R2_ENDPOINT_URL': 'Cloudflare R2 endpoint',
        'R2_ACCESS_KEY_ID': 'R2 access key',
        'R2_SECRET_ACCESS_KEY': 'R2 secret key',
    }

    missing = []
    for var, desc in required.items():
        if not os.environ.get(var):
            missing.append(f"  {var}: {desc}")

    # R2 vars only required if --r2-upload is used
    # Will be validated later if needed

    if missing:
        logger.error("Missing required environment variables:")
        for m in missing:
            logger.error(m)
        return False

    return True


def setup_signal_handlers(stop_signal):
    """
    Setup graceful shutdown handlers for SIGTERM and SIGINT.

    This ensures workers complete their current video before exiting,
    preventing lost work on container shutdown.
    """
    global _GLOBAL_STOP_SIGNAL
    _GLOBAL_STOP_SIGNAL = stop_signal

    def graceful_shutdown(signum, frame):
        sig_name = 'SIGTERM' if signum == signal.SIGTERM else 'SIGINT'
        logger.warning(f"Received {sig_name}, initiating graceful shutdown...")
        logger.warning("Workers will complete current videos then exit")

        if _GLOBAL_STOP_SIGNAL is not None:
            _GLOBAL_STOP_SIGNAL.value = True

    signal.signal(signal.SIGTERM, graceful_shutdown)
    signal.signal(signal.SIGINT, graceful_shutdown)


@dataclass
class VideoResult:
    """Result from processing a single video."""
    video_url: str
    video_id: str = ""
    success: bool = False
    error: str = ""
    duration_sec: float = 0.0
    process_time_sec: float = 0.0
    pipeline_time_sec: float = 0.0
    rtf: float = 0.0
    num_speakers: int = 0
    usable_percentage: float = 0.0
    output_path: str = ""
    stage_timings: Dict[str, float] = field(default_factory=dict)
    source_type: str = ""
    r2_upload_key: str = ""
    gpu_id: int = -1
    worker_id: int = -1


@dataclass
class BenchmarkStats:
    """Aggregate statistics from benchmark run."""
    total_videos: int = 0
    successful: int = 0
    failed: int = 0
    total_audio_duration_sec: float = 0.0
    total_audio_hours: float = 0.0
    total_process_time_sec: float = 0.0
    throughput_videos_per_min: float = 0.0
    throughput_hours_audio_per_hour: float = 0.0  # Hours of audio processed per hour of wall time
    avg_process_time_sec: float = 0.0
    avg_rtf: float = 0.0
    parallel_efficiency: float = 0.0
    gpus_used: int = 0
    gpu_count: int = 0
    workers_per_gpu: int = 1
    start_time: str = ""
    end_time: str = ""
    start_timestamp: str = ""
    end_timestamp: str = ""
    wall_time_sec: float = 0.0


def gpu_worker(
    worker_id: int,
    gpu_id: int,
    video_queue: Queue,
    result_queue: Queue,
    stop_signal: Value,
    output_dir: str,
    config_overrides: Dict
):
    """
    GPU Worker Process - Runs complete pipeline on assigned GPU with prefetch.

    Each worker:
    1. Sets CUDA_VISIBLE_DEVICES to its assigned GPU
    2. Loads all models ONCE into GPU memory
    3. Prefetches next video(s) while processing current one
    4. Reports results back to main process

    OPTIMIZATION: Download prefetch hides 15-20s latency per video
    """
    # CRITICAL: Set GPU and suppress warnings before ANY torch imports
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    os.environ['NNPACK_DISABLE'] = '1'
    os.environ['TORCH_CPP_LOG_LEVEL'] = 'ERROR'
    import warnings
    warnings.filterwarnings('ignore', message='.*NNPACK.*')
    warnings.filterwarnings('ignore', category=FutureWarning)
    warnings.filterwarnings('ignore', category=UserWarning)

    # Setup worker logging
    worker_logger = logging.getLogger(f"Worker-{worker_id}")
    worker_logger.setLevel(logging.INFO)

    worker_logger.info(f"Starting on GPU {gpu_id}")

    try:
        # Load .env inside the worker as well (spawned processes may not have env populated)
        if load_dotenv is not None:
            try:
                load_dotenv(dotenv_path=Path("/workspace/maya3_data/.env"), override=False)
            except Exception:
                pass

        # Now import torch and pipeline (after CUDA_VISIBLE_DEVICES is set)
        import torch

        # Verify GPU assignment
        if not torch.cuda.is_available():
            worker_logger.error("CUDA not available!")
            return

        gpu_name = torch.cuda.get_device_name(0)  # Now 0 is our assigned GPU
        vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
        worker_logger.info(f"GPU: {gpu_name} ({vram_gb:.1f}GB)")

        # Aggressively silence speechbrain before any imports
        for name in ['speechbrain', 'speechbrain.utils', 'speechbrain.utils.checkpoints',
                     'speechbrain.utils.fetching', 'speechbrain.utils.parameter_transfer']:
            sb_logger = logging.getLogger(name)
            sb_logger.setLevel(logging.CRITICAL)
            sb_logger.handlers = []  # Remove all handlers
            sb_logger.propagate = False

        from src.logger import silence_verbose_loggers
        silence_verbose_loggers()

        # Import pipeline components
        from src.config import Config
        from src.models import MODELS
        from pipeline import process_single_video

        # Create config with overrides and gpu_id
        config = Config(
            output_dir=output_dir,
            **config_overrides
        )
        config.gpu_id = gpu_id  # For PipelineLogger

        # Load models ONCE for this worker
        MODELS.load_all(config)

        # PREFETCH OPTIMIZATION: Download next video while processing current
        # Uses a background thread to hide download latency (15-20s per video)
        from threading import Thread
        from queue import Queue as ThreadQueue, Empty as ThreadEmpty
        from src.download import download_audio

        prefetch_queue = ThreadQueue(maxsize=2)  # Buffer of 2 pre-downloaded videos
        prefetch_stop = [False]  # Use list for mutability in closure

        def prefetch_worker():
            """Background thread that pre-downloads videos."""
            while not prefetch_stop[0]:
                try:
                    # Get next URL from main queue (non-blocking short timeout)
                    try:
                        url = video_queue.get(timeout=1.0)
                    except:
                        continue

                    if url is None:  # Poison pill
                        prefetch_queue.put((None, None, None))
                        break

                    # Download audio
                    try:
                        audio_path, metadata = download_audio(url, config)
                        prefetch_queue.put((url, audio_path, metadata))
                        worker_logger.debug(f"Prefetched: {url}")
                    except Exception as e:
                        # Put URL with error marker
                        prefetch_queue.put((url, None, str(e)))

                except Exception as e:
                    if not prefetch_stop[0]:
                        worker_logger.debug(f"Prefetch error: {e}")

        # Start prefetch thread
        prefetch_thread = Thread(target=prefetch_worker, daemon=True)
        prefetch_thread.start()
        worker_logger.info("Prefetch thread started")

        # Process videos from prefetch queue
        videos_processed = 0
        while not stop_signal.value:
            try:
                # Get pre-downloaded video from prefetch queue
                try:
                    video_url, audio_path, metadata_or_error = prefetch_queue.get(timeout=30.0)
                except ThreadEmpty:
                    worker_logger.warning("Prefetch queue timeout, checking main queue")
                    continue

                if video_url is None:  # Poison pill
                    worker_logger.info("Received stop signal")
                    break

                # Check if download failed
                if audio_path is None:
                    worker_logger.error(f"Prefetch failed for {video_url}: {metadata_or_error}")
                    result_queue.put(VideoResult(
                        video_url=video_url,
                        success=False,
                        error=f"Download failed: {metadata_or_error}",
                        gpu_id=gpu_id,
                        worker_id=worker_id
                    ))
                    videos_processed += 1
                    continue

                worker_logger.info(f"Processing (prefetched): {video_url}")
                start_time = time.time()

                try:
                    # Process video with pre-downloaded audio (skips download stage)
                    result = process_single_video(
                        video_url,
                        config,
                        prefetched_audio=str(audio_path),
                        prefetched_metadata=metadata_or_error if isinstance(metadata_or_error, dict) else None
                    )
                    process_time = time.time() - start_time

                    stage_timings = result.get('timing', {}) or {}
                    pipeline_time = float(result.get('timing_total', 0.0) or 0.0)
                    audio_dur = float(result.get('processed_duration', 0.0) or 0.0)
                    rtf = (audio_dur / pipeline_time) if pipeline_time > 0 else 0.0

                    video_result = VideoResult(
                        video_url=video_url,
                        video_id=result.get('video_id', ''),
                        success=True,
                        duration_sec=audio_dur,
                        process_time_sec=process_time,
                        pipeline_time_sec=pipeline_time,
                        rtf=rtf,
                        num_speakers=result.get('num_speakers', 0),
                        usable_percentage=result.get('quality_stats', {}).get('usable_percentage', 0),
                        output_path=str(Path(output_dir) / result.get('video_id', '')),
                        stage_timings={k: float(v) for k, v in stage_timings.items()},
                        source_type=result.get('source_type', ''),
                        r2_upload_key=(result.get('r2_upload', {}) or {}).get('remote_key', ''),
                        gpu_id=gpu_id,
                        worker_id=worker_id
                    )

                    worker_logger.info(
                        f"Done: {result.get('video_id', 'unknown')} | "
                        f"{process_time:.1f}s | "
                        f"{result.get('num_speakers', 0)} speakers | "
                        f"{result.get('quality_stats', {}).get('usable_percentage', 0):.1f}% usable"
                    )

                except Exception as e:
                    process_time = time.time() - start_time
                    worker_logger.error(f"Failed: {video_url} - {e}")
                    traceback.print_exc()

                    video_result = VideoResult(
                        video_url=video_url,
                        success=False,
                        error=str(e),
                        process_time_sec=process_time,
                        gpu_id=gpu_id,
                        worker_id=worker_id
                    )

                result_queue.put(video_result)
                videos_processed += 1

                # Clear GPU cache between videos
                MODELS.clear_cache(aggressive=True)

            except Exception as e:
                if "Empty" not in str(type(e).__name__):
                    worker_logger.error(f"Queue error: {e}")
                continue

        # Stop prefetch thread
        prefetch_stop[0] = True
        prefetch_thread.join(timeout=5.0)
        worker_logger.info(f"Shutting down. Processed {videos_processed} videos.")

    except Exception as e:
        worker_logger.error(f"Worker failed: {e}")
        traceback.print_exc()


def gpu_worker_supabase(
    worker_id: int,
    gpu_id: int,
    result_queue,
    stop_signal,
    output_dir: str,
    config_overrides: dict,
    lease_duration_sec: int = 900,
    poll_interval_sec: int = 10,
    max_videos: int = 0  # 0 = unlimited
):
    """
    GPU worker that claims videos from Supabase queue.

    Instead of reading from a file queue, this worker:
    1. Claims the next PENDING video from Supabase
    2. Processes it
    3. Updates Supabase with results (COMPLETED/FAILED)
    4. Repeats until stop signal or no more videos
    """
    # Set GPU for this process
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)

    # Worker-specific logging
    worker_logger = logging.getLogger(f"Worker-{worker_id}")
    worker_logger.setLevel(logging.INFO)

    # Get pod ID for unique worker identification
    pod_id = os.environ.get('RUNPOD_POD_ID', os.environ.get('HOSTNAME', 'local'))
    supabase_worker_id = f"gpu-{gpu_id}-{pod_id}"

    worker_logger.info(f"Starting worker {supabase_worker_id} on GPU {gpu_id}")

    # Initialize analytics to None (will be set in try block)
    analytics = None

    try:
        # Aggressively silence speechbrain before any imports
        for name in ['speechbrain', 'speechbrain.utils', 'speechbrain.utils.checkpoints',
                     'speechbrain.utils.fetching', 'speechbrain.utils.parameter_transfer']:
            sb_logger = logging.getLogger(name)
            sb_logger.setLevel(logging.CRITICAL)
            sb_logger.handlers = []  # Remove all handlers
            sb_logger.propagate = False

        from src.logger import silence_verbose_loggers
        silence_verbose_loggers()

        # Import pipeline (loads models)
        from pipeline import process_single_video
        from src.config import Config
        from src.models import MODELS
        from src.download import download_audio

        # Initialize Supabase client
        supabase = SupabaseClient()

        # Initialize analytics reporter (uses same Supabase credentials)
        # Background thread handles async reporting - zero overhead to main processing
        analytics = create_reporter(
            gpu_id=gpu_id,
            supabase_url=os.environ.get('URL') or os.environ.get('SUPABASE_URL'),
            supabase_key=os.environ.get('SUPABASE_ADMIN'),
            machine_id=pod_id,
            enabled=True
        )
        analytics.session_start()
        worker_logger.info(f"Analytics initialized: {analytics.worker_id}")

        # Build config with gpu_id for pipeline logging
        config = Config(output_dir=output_dir)
        for key, value in config_overrides.items():
            if hasattr(config, key):
                setattr(config, key, value)
        config.gpu_id = gpu_id  # For PipelineLogger

        # Load models once
        MODELS.load_all(config)

        # Process loop
        videos_processed = 0
        consecutive_empty = 0
        max_consecutive_empty = 6  # Stop after 1 minute of no videos

        while not stop_signal.value:
            # Check max videos limit
            if max_videos > 0 and videos_processed >= max_videos:
                worker_logger.info(f"Reached max videos limit ({max_videos})")
                break

            # Claim next video from Supabase
            video = supabase.claim_next_video(
                worker_id=supabase_worker_id,
                lease_duration_sec=lease_duration_sec
            )

            if not video:
                consecutive_empty += 1
                if consecutive_empty >= max_consecutive_empty:
                    worker_logger.info("No more pending videos, stopping worker")
                    break
                worker_logger.debug(f"No pending videos, sleeping {poll_interval_sec}s...")
                time.sleep(poll_interval_sec)
                continue

            consecutive_empty = 0
            video_id = video['youtube_id']
            video_url = video['youtube_url']

            # Report video pick to analytics (non-blocking)
            analytics.pick(video_id)

            # Fetch full video record to get chapters
            full_video = supabase.get_video(video_id)
            if not full_video:
                full_video = video  # Fallback to claim result

            worker_logger.info(f"Processing: {video_id} - {full_video.get('title', '')[:50]}")
            has_chapters = full_video.get('has_chapters', False)
            if has_chapters:
                worker_logger.info(f"   Chapters: {full_video.get('chapter_count', 0)} from Supabase")
            start_time = time.time()

            try:
                # Process video with Supabase metadata (chapters, title, etc.)
                result = process_single_video(video_url, config, supabase_metadata=full_video)
                process_time = time.time() - start_time

                # Extract results for DB update
                quality_stats = result.get('quality_stats', {}) or {}
                segment_summary = result.get('segment_summary', {}) or {}
                download_meta = result.get('download_meta', {}) or {}

                # Update Supabase with success
                update_data = {
                    'num_speakers': result.get('num_speakers', 0),
                    'total_segments': result.get('total_segments', 0),
                    'usable_percentage': quality_stats.get('usable_percentage', 0),
                    'usable_segments': quality_stats.get('usable_segments', 0),
                    'usable_duration_sec': quality_stats.get('usable_duration', 0),
                    'r2_tar_key': result.get('r2_upload_key', ''),
                    'pipeline_version': result.get('pipeline_version', ''),
                    'audio_native_sample_rate': download_meta.get('native_sample_rate'),
                    'audio_channels': download_meta.get('channels'),
                    'audio_duration_sec': result.get('processed_duration', 0),
                    'needs_demucs_count': quality_stats.get('music_detection', {}).get('segments_needs_demucs', 0),
                    'heavy_music_count': quality_stats.get('music_detection', {}).get('segments_heavy_contamination', 0),
                    'processing_meta': {
                        'worker_id': supabase_worker_id,
                        'gpu_id': gpu_id,
                        'timing_total': result.get('timing_total', 0),
                        'completed_at': datetime.now().isoformat(),
                        'pipeline_version': result.get('pipeline_version', ''),
                    },
                    'quality_stats': quality_stats,
                    'segment_summary': segment_summary,
                    'download_meta': download_meta,
                }

                supabase.update_status(video_id, STATUS_COMPLETED, update_data)

                worker_logger.info(
                    f"Completed: {video_id} | "
                    f"{process_time:.1f}s | "
                    f"{result.get('num_speakers', 0)} speakers | "
                    f"{quality_stats.get('usable_percentage', 0):.1f}% usable"
                )

                # Report success to analytics (non-blocking)
                analytics.done(
                    video_id=video_id,
                    total_time=result.get('timing_total', process_time),
                    audio_min=result.get('processed_duration', 0) / 60.0,
                    speakers=result.get('num_speakers', 0),
                    usable_pct=quality_stats.get('usable_percentage', 0)
                )

                # Put result in queue for stats collection
                video_result = VideoResult(
                    video_url=video_url,
                    video_id=video_id,
                    success=True,
                    duration_sec=result.get('processed_duration', 0),
                    process_time_sec=process_time,
                    pipeline_time_sec=result.get('timing_total', 0),
                    rtf=(result.get('processed_duration', 0) / result.get('timing_total', 1)) if result.get('timing_total') else 0,
                    num_speakers=result.get('num_speakers', 0),
                    usable_percentage=quality_stats.get('usable_percentage', 0),
                    output_path=str(Path(output_dir) / video_id),
                    stage_timings={k: float(v) for k, v in (result.get('timing', {}) or {}).items()},
                    source_type=result.get('source_type', ''),
                    r2_upload_key=result.get('r2_upload_key', ''),
                    gpu_id=gpu_id,
                    worker_id=worker_id
                )

            except Exception as e:
                process_time = time.time() - start_time
                error_msg = str(e)
                error_type = type(e).__name__
                worker_logger.error(f"Failed: {video_id} - {error_msg}")
                traceback.print_exc()

                # Report failure to analytics (non-blocking)
                analytics.fail(
                    video_id=video_id,
                    stage="processing",
                    error_type=error_type,
                    message=error_msg,
                    stack_trace=traceback.format_exc()
                )

                # Update Supabase with failure
                supabase.update_status(video_id, STATUS_FAILED, {
                    'last_error': error_msg[:1000],  # Truncate long errors
                    'last_error_type': type(e).__name__,
                    'processing_meta': {
                        'worker_id': supabase_worker_id,
                        'gpu_id': gpu_id,
                        'failed_at': datetime.now().isoformat(),
                    },
                })

                video_result = VideoResult(
                    video_url=video_url,
                    video_id=video_id,
                    success=False,
                    error=error_msg,
                    process_time_sec=process_time,
                    gpu_id=gpu_id,
                    worker_id=worker_id
                )

            result_queue.put(video_result)
            videos_processed += 1

            # Clear GPU cache between videos
            MODELS.clear_cache(aggressive=True)

        # Graceful shutdown - flush remaining analytics events
        analytics.shutdown()
        worker_logger.info(f"Shutting down. Processed {videos_processed} videos.")

    except Exception as e:
        worker_logger.error(f"Worker failed: {e}")
        traceback.print_exc()
        # Try to shutdown analytics even on error (may not be initialized)
        try:
            if analytics is not None:
                analytics.shutdown()
        except Exception:
            pass


def run_supabase_queue(
    gpu_ids: List[int],
    output_dir: str,
    config_overrides: Optional[Dict] = None,
    max_videos_per_worker: int = 0,
    lease_duration_sec: int = 900,
    poll_interval_sec: int = 10
) -> BenchmarkStats:
    """
    Run distributed processing with Supabase queue.

    Workers claim videos from Supabase, process them, and update status.
    Continues until no more PENDING videos or stop signal.
    """
    config_overrides = config_overrides or {}

    workers_per_gpu = config_overrides.get('workers_per_gpu', 1)
    total_workers = len(gpu_ids) * workers_per_gpu

    # Get initial queue stats
    supabase = SupabaseClient()
    initial_stats = supabase.get_stats()
    pending_count = initial_stats.get('PENDING', 0)

    logger.info("=" * 70)
    logger.info("SUPABASE QUEUE - Distributed Processing")
    logger.info("=" * 70)
    logger.info(f"Pending videos: {pending_count:,}")
    logger.info(f"GPUs: {len(gpu_ids)} ({gpu_ids})")
    logger.info(f"Workers/GPU: {workers_per_gpu} (total: {total_workers})")
    logger.info(f"Output: {output_dir}")
    logger.info(f"Lease duration: {lease_duration_sec}s")
    if max_videos_per_worker > 0:
        logger.info(f"Max videos/worker: {max_videos_per_worker}")
    logger.info("=" * 70)

    # Create output directory
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    # Setup multiprocessing
    manager = Manager()
    result_queue = manager.Queue()
    stop_signal = manager.Value('b', False)

    # Setup graceful shutdown handlers
    setup_signal_handlers(stop_signal)

    # Start workers
    workers = []
    start_time = time.time()

    worker_id = 0
    for gpu_id in gpu_ids:
        for w in range(workers_per_gpu):
            p = Process(
                target=gpu_worker_supabase,
                args=(
                    worker_id, gpu_id, result_queue, stop_signal,
                    output_dir, config_overrides, lease_duration_sec,
                    poll_interval_sec, max_videos_per_worker
                ),
                name=f"GPU{gpu_id}-Supabase-{w}"
            )
            p.start()
            workers.append(p)
            logger.info(f"Started Supabase worker {worker_id} on GPU {gpu_id}")
            worker_id += 1

    # Collect results
    results: List[VideoResult] = []
    active_workers = len(workers)

    try:
        while active_workers > 0:
            # Check for results
            try:
                result = result_queue.get(timeout=30)
                results.append(result)

                status = "OK" if result.success else "FAIL"
                logger.info(
                    f"[{len(results)}] {status}: {result.video_id or result.video_url[:50]} "
                    f"({result.process_time_sec:.1f}s on GPU {result.gpu_id})"
                )

            except Exception:
                pass  # Timeout, check worker status

            # Check if workers are still alive
            active_workers = sum(1 for p in workers if p.is_alive())

    except KeyboardInterrupt:
        logger.warning("Interrupted! Stopping workers...")
        stop_signal.value = True

    # Wait for workers to finish
    for p in workers:
        p.join(timeout=60)
        if p.is_alive():
            logger.warning(f"Worker {p.name} did not stop, terminating...")
            p.terminate()

    end_time = time.time()
    wall_time = end_time - start_time

    # Compute statistics
    successful = [r for r in results if r.success]
    failed = [r for r in results if not r.success]

    total_audio_hours = sum(r.duration_sec for r in successful) / 3600
    total_process_time = sum(r.process_time_sec for r in successful)

    stats = BenchmarkStats(
        total_videos=len(results),
        successful=len(successful),
        failed=len(failed),
        wall_time_sec=wall_time,
        total_audio_hours=total_audio_hours,
        total_process_time_sec=total_process_time,
        avg_process_time_sec=(total_process_time / len(successful)) if successful else 0,
        avg_rtf=sum(r.rtf for r in successful) / len(successful) if successful else 0,
        throughput_hours_audio_per_hour=(total_audio_hours / (wall_time / 3600)) if wall_time > 0 else 0,
        parallel_efficiency=(total_process_time / (wall_time * total_workers)) if wall_time > 0 else 0,
        gpu_count=len(gpu_ids),
        workers_per_gpu=workers_per_gpu,
        start_timestamp=datetime.fromtimestamp(start_time).isoformat(),
        end_timestamp=datetime.fromtimestamp(end_time).isoformat(),
    )

    # Get final queue stats
    final_stats = supabase.get_stats()

    logger.info("=" * 70)
    logger.info("SUPABASE QUEUE COMPLETE")
    logger.info("=" * 70)
    logger.info(f"Processed: {stats.total_videos} ({stats.successful} OK, {stats.failed} failed)")
    logger.info(f"Wall time: {wall_time/60:.1f} min")
    logger.info(f"Audio hours: {total_audio_hours:.1f}h")
    logger.info("-" * 70)
    logger.info("Queue Status:")
    for status, count in sorted(final_stats.items()):
        logger.info(f"   {status}: {count:,}")
    logger.info("=" * 70)

    if failed:
        logger.warning(f"\nFailed videos:")
        for r in failed[:10]:
            logger.warning(f"   {r.video_id or r.video_url}: {r.error[:80]}")
        if len(failed) > 10:
            logger.warning(f"   ... and {len(failed) - 10} more")

    return stats


def get_available_gpus() -> List[int]:
    """Detect available NVIDIA GPUs."""
    try:
        import subprocess
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=index', '--format=csv,noheader'],
            capture_output=True, text=True, timeout=10
        )
        gpu_ids = [int(x.strip()) for x in result.stdout.strip().split('\n') if x.strip()]
        return gpu_ids
    except Exception as e:
        logger.error(f"Failed to detect GPUs: {e}")
        return [0]


def load_video_urls(input_path: str) -> List[str]:
    """Load video URLs from file (one per line)."""
    urls = []
    with open(input_path, 'r') as f:
        for line in f:
            line = line.strip()
            if line and not line.startswith('#'):
                urls.append(line)
    return urls


def get_test_videos(num_videos: int = 8) -> List[str]:
    """
    Get test video URLs for benchmarking.
    Uses a mix of short podcasts/interviews for consistent testing.
    """
    # Mix of video lengths for realistic benchmarking
    test_videos = [
        # Short videos (~5-15 min) - good for quick benchmarks
        "https://www.youtube.com/watch?v=dQw4w9WgXcQ",  # 3:32 (classic test)
        "https://www.youtube.com/watch?v=9bZkp7q19f0",  # 4:12
        # Add your test video URLs here
    ]

    # Repeat to fill requested count
    result = []
    while len(result) < num_videos:
        result.extend(test_videos)
    return result[:num_videos]


def run_massive_process(
    video_urls: List[str],
    gpu_ids: List[int],
    output_dir: str,
    config_overrides: Optional[Dict] = None
) -> BenchmarkStats:
    """
    Run massive parallel processing across multiple GPUs.

    Args:
        video_urls: List of YouTube URLs to process
        gpu_ids: List of GPU indices to use
        output_dir: Output directory for results
        config_overrides: Optional config parameter overrides

    Returns:
        BenchmarkStats with aggregate metrics
    """
    config_overrides = config_overrides or {}

    workers_per_gpu = config_overrides.get('workers_per_gpu', 1)
    total_workers = len(gpu_ids) * workers_per_gpu

    logger.info("=" * 70)
    logger.info("MASSIVE COMPUTE - Multi-GPU Parallel Processing")
    logger.info("=" * 70)
    logger.info(f"Videos: {len(video_urls)}")
    logger.info(f"GPUs: {len(gpu_ids)} ({gpu_ids})")
    logger.info(f"Workers/GPU: {workers_per_gpu} (total: {total_workers})")
    logger.info(f"Output: {output_dir}")
    logger.info("=" * 70)

    # Create output directory
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    # Setup multiprocessing
    manager = Manager()
    video_queue = manager.Queue()
    result_queue = manager.Queue()
    stop_signal = manager.Value('b', False)

    # Add all videos to queue
    for url in video_urls:
        video_queue.put(url)

    # Support multiple workers per GPU (experimental)
    workers_per_gpu = config_overrides.pop('workers_per_gpu', 1)
    total_workers = len(gpu_ids) * workers_per_gpu

    # Add poison pills (one per worker)
    for _ in range(total_workers):
        video_queue.put(None)

    # Start workers
    workers = []
    start_time = time.time()
    start_timestamp = datetime.now().isoformat()

    worker_id = 0
    for gpu_id in gpu_ids:
        for w in range(workers_per_gpu):
            p = Process(
                target=gpu_worker,
                args=(worker_id, gpu_id, video_queue, result_queue, stop_signal, output_dir, config_overrides),
                name=f"GPU{gpu_id}-Worker-{w}"
            )
            p.start()
            workers.append(p)
            logger.info(f"Started worker {worker_id} on GPU {gpu_id}")
            worker_id += 1

    # Collect results
    results: List[VideoResult] = []
    completed = 0

    try:
        while completed < len(video_urls):
            try:
                result = result_queue.get(timeout=600)  # 10 min timeout per video
                results.append(result)
                completed += 1

                status = "OK" if result.success else "FAIL"
                logger.info(
                    f"[{completed}/{len(video_urls)}] {status}: {result.video_id or result.video_url[:50]} "
                    f"({result.process_time_sec:.1f}s on GPU {result.gpu_id})"
                )

            except Exception as e:
                logger.warning(f"Timeout waiting for result: {e}")
                break

    except KeyboardInterrupt:
        logger.warning("Interrupted! Stopping workers...")
        stop_signal.value = True

    # Wait for workers to finish
    for p in workers:
        p.join(timeout=30)
        if p.is_alive():
            logger.warning(f"Worker {p.name} did not stop, terminating...")
            p.terminate()

    end_time = time.time()
    wall_time = end_time - start_time

    # Compute statistics
    successful = [r for r in results if r.success]
    failed = [r for r in results if not r.success]

    total_audio_duration = sum(r.duration_sec for r in successful)
    total_process_time = sum(r.process_time_sec for r in successful)
    total_pipeline_time = sum(r.pipeline_time_sec for r in successful)

    # Aggregate per-stage timings (average over successful videos)
    stage_sums: Dict[str, float] = {}
    for r in successful:
        for stage, t in (r.stage_timings or {}).items():
            stage_sums[stage] = stage_sums.get(stage, 0.0) + float(t)
    stage_avgs = {k: (v / len(successful)) for k, v in stage_sums.items()} if successful else {}
    stage_total_avg = sum(stage_avgs.values()) if stage_avgs else 0.0
    stage_pcts = {k: (v / stage_total_avg) * 100.0 for k, v in stage_avgs.items()} if stage_total_avg > 0 else {}

    stats = BenchmarkStats(
        total_videos=len(video_urls),
        successful=len(successful),
        failed=len(failed),
        total_audio_duration_sec=total_audio_duration,
        total_process_time_sec=total_process_time,
        avg_process_time_sec=total_process_time / len(successful) if successful else 0,
        throughput_videos_per_min=(len(successful) / wall_time) * 60 if wall_time > 0 else 0,
        throughput_hours_audio_per_hour=(total_audio_duration / wall_time) if wall_time > 0 else 0,
        parallel_efficiency=(total_process_time / (wall_time * len(gpu_ids))) if wall_time > 0 else 0,
        gpus_used=len(gpu_ids),
        start_time=start_timestamp,
        end_time=datetime.now().isoformat(),
        wall_time_sec=wall_time
    )

    # Print summary
    logger.info("\n" + "=" * 70)
    logger.info("BENCHMARK RESULTS")
    logger.info("=" * 70)
    logger.info(f"Videos: {stats.successful}/{stats.total_videos} successful")
    logger.info(f"Failed: {stats.failed}")
    logger.info(f"Wall time: {stats.wall_time_sec:.1f}s ({stats.wall_time_sec/60:.1f} min)")
    logger.info(f"Total audio: {stats.total_audio_duration_sec/3600:.2f} hours")
    logger.info("-" * 70)
    logger.info(f"THROUGHPUT:")
    logger.info(f"   Videos/min: {stats.throughput_videos_per_min:.2f}")
    logger.info(f"   Audio hours/wall hour: {stats.throughput_hours_audio_per_hour:.2f}x realtime")
    logger.info(f"   Parallel efficiency: {stats.parallel_efficiency*100:.1f}%")
    logger.info("-" * 70)
    logger.info(f"Avg process time/video: {stats.avg_process_time_sec:.1f}s")
    if stage_avgs:
        logger.info("-" * 70)
        logger.info("AVG PER-STAGE TIME (successful videos):")
        for stage, avg_t in sorted(stage_avgs.items(), key=lambda x: -x[1]):
            pct = stage_pcts.get(stage, 0.0)
            logger.info(f"   {stage:22s} {avg_t:7.2f}s  ({pct:5.1f}%)")
    logger.info("=" * 70)

    # Save results
    results_file = Path(output_dir) / "benchmark_results.json"
    with open(results_file, 'w') as f:
        json.dump({
            'stats': asdict(stats),
            'stage_avgs_sec': {k: round(v, 4) for k, v in stage_avgs.items()},
            'stage_pcts': {k: round(v, 2) for k, v in stage_pcts.items()},
            'results': [asdict(r) for r in results]
        }, f, indent=2)
    logger.info(f"Results saved to: {results_file}")

    if failed:
        logger.warning(f"\nFailed videos:")
        for r in failed:
            logger.warning(f"   {r.video_url}: {r.error}")

    return stats


def main():
    parser = argparse.ArgumentParser(
        description="MASSIVE COMPUTE - 8x H100 Parallel Processing",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Process videos from file
  python massive_process.py --input videos.txt

  # Benchmark with 16 test videos
  python massive_process.py --benchmark --num-videos 16

  # Use specific GPUs
  python massive_process.py --input videos.txt --gpus 0,1,2,3

  # Custom output directory
  python massive_process.py --input videos.txt --output /data/results
        """
    )

    # Input options
    input_group = parser.add_mutually_exclusive_group(required=True)
    input_group.add_argument('--input', '-i', type=str, help='File with video URLs (one per line)')
    input_group.add_argument('--benchmark', action='store_true', help='Run benchmark with test videos')
    input_group.add_argument('--urls', nargs='+', help='Direct URL list')
    input_group.add_argument('--supabase-queue', action='store_true',
                            help='Use Supabase as distributed video queue (claims PENDING videos from DB)')

    # Processing options
    parser.add_argument('--gpus', type=str, default=None,
                       help='Comma-separated GPU IDs (default: all available)')
    parser.add_argument('--workers-per-gpu', type=int, default=1,
                       help='Number of workers per GPU (default: 1, try 2-4 for H100)')
    parser.add_argument('--num-videos', type=int, default=8,
                       help='Number of videos for benchmark mode (default: 8)')
    parser.add_argument('--output', '-o', type=str, default='data/massive_output',
                       help='Output directory (default: data/massive_output)')

    # Pipeline config overrides
    parser.add_argument('--merge-threshold', type=float, default=0.80,
                       help='Speaker merge threshold (default: 0.80)')
    parser.add_argument('--min-segment', type=float, default=0.2,
                       help='Minimum segment duration (default: 0.2s)')
    parser.add_argument('--no-music-detection', action='store_true',
                       help='Disable music detection')

    # R2 upload (processed .tar -> 1-cleaned-data)
    parser.add_argument('--r2-upload', action='store_true',
                       help='Upload per-video .tar into R2 bucket 1-cleaned-data (no deletes)')
    parser.add_argument('--r2-upload-prefix', type=str, default='',
                       help="Optional prefix within 1-cleaned-data for uploads (e.g. 'benchmarks/rtx4090/2026-01-03')")
    parser.add_argument('--r2-upload-skip-if-exists', action='store_true',
                       help='Skip upload if remote key already exists')
    
    # Export options
    parser.add_argument('--segment-audio', action='store_true', default=True,
                       help='Segment original audio into individual speaker clips (default: True)')
    parser.add_argument('--no-segment-audio', action='store_false', dest='segment_audio',
                       help='Disable audio segmentation')
    parser.add_argument('--background-export', action='store_true',
                       help='Run export (segmentation+tar+upload) in background threads')
    parser.add_argument('--export-workers', type=int, default=4,
                       help='Number of background export workers (default: 4)')

    # Supabase queue options
    parser.add_argument('--max-videos', type=int, default=0,
                       help='Max videos per worker in supabase-queue mode (0=unlimited)')
    parser.add_argument('--lease-duration', type=int, default=900,
                       help='Lease duration in seconds for claimed videos (default: 900)')
    parser.add_argument('--poll-interval', type=int, default=10,
                       help='Seconds to wait when queue empty (default: 10)')

    args = parser.parse_args()

    # Load .env once in the parent process so workers inherit env vars when possible.
    if load_dotenv is not None:
        try:
            load_dotenv(dotenv_path=Path("/workspace/maya3_data/.env"), override=False)
        except Exception:
            pass

    # Validate environment before starting workers
    if not validate_environment():
        logger.error("Environment validation failed. Please set required variables.")
        sys.exit(1)

    # Get GPU list
    if args.gpus:
        gpu_ids = [int(x.strip()) for x in args.gpus.split(',')]
    else:
        gpu_ids = get_available_gpus()

    if not gpu_ids:
        logger.error("No GPUs available!")
        sys.exit(1)

    logger.info(f"Detected {len(gpu_ids)} GPUs: {gpu_ids}")

    # Get video URLs (not needed for supabase-queue mode)
    video_urls = None
    supabase_mode = getattr(args, 'supabase_queue', False)

    if supabase_mode:
        if not SUPABASE_AVAILABLE:
            logger.error("Supabase client not available! Install dependencies or check src/supabase_client.py")
            sys.exit(1)
        logger.info("Supabase queue mode: videos will be claimed from database")
    elif args.benchmark:
        video_urls = get_test_videos(args.num_videos)
        logger.info(f"Benchmark mode: {len(video_urls)} test videos")
    elif args.input:
        video_urls = load_video_urls(args.input)
        logger.info(f"Loaded {len(video_urls)} URLs from {args.input}")
    else:
        video_urls = args.urls

    if not supabase_mode and not video_urls:
        logger.error("No videos to process!")
        sys.exit(1)

    # CRITICAL: Capture HF_TOKEN before spawning workers
    # Spawned processes may not inherit environment variables
    hf_token = os.environ.get('HF_TOKEN', '')
    if not hf_token:
        # Try loading from .env file
        env_file = Path('.env')
        if env_file.exists():
            with open(env_file) as f:
                for line in f:
                    if line.startswith('HF_TOKEN='):
                        hf_token = line.strip().split('=', 1)[1].strip('"\'')
                        break

    if not hf_token:
        logger.error("HF_TOKEN not found! Set via environment or .env file")
        sys.exit(1)

    logger.info(f"HF_TOKEN captured (length={len(hf_token)})")

    # Config overrides
    config_overrides = {
        'cluster_merge_threshold': args.merge_threshold,
        'min_segment_duration': args.min_segment,
        'enable_music_detection': not args.no_music_detection,
        'workers_per_gpu': args.workers_per_gpu,
        'hf_token': hf_token,  # Pass token to spawned workers
        # R2 source (always R2, metadata from Supabase)
        'input_source': 'r2',
        'fetch_youtube_metadata_for_r2': False,
        'r2_source_prefix': 'podcasts/',
        'r2_source_extensions': '.webm',
        # R2 output upload
        'r2_upload_enabled': bool(args.r2_upload),
        'r2_upload_bucket_type': 'production',
        'r2_upload_prefix': args.r2_upload_prefix,
        'r2_upload_skip_if_exists': bool(args.r2_upload_skip_if_exists),
        # Export options
        'segment_audio': args.segment_audio,
        'background_export': args.background_export,
        'background_export_workers': args.export_workers,
    }
    
    # Initialize background export manager if enabled
    export_manager = None
    if args.background_export:
        from src.export import init_export_manager
        export_manager = init_export_manager(max_workers=args.export_workers)
        logger.info(f"Background export enabled ({args.export_workers} workers)")

    # Run!
    if supabase_mode:
        # Supabase distributed queue mode
        stats = run_supabase_queue(
            gpu_ids=gpu_ids,
            output_dir=args.output,
            config_overrides=config_overrides,
            max_videos_per_worker=args.max_videos,
            lease_duration_sec=args.lease_duration,
            poll_interval_sec=args.poll_interval
        )
    else:
        # Traditional file-based mode
        stats = run_massive_process(
            video_urls=video_urls,
            gpu_ids=gpu_ids,
            output_dir=args.output,
            config_overrides=config_overrides
        )
    
    # Wait for background exports to complete (if enabled)
    if args.background_export:
        from src.export import shutdown_export_manager
        logger.info("Waiting for background exports to complete...")
        export_results = shutdown_export_manager(wait=True, timeout=600)
        
        export_success = sum(1 for r in export_results.values() if r.success)
        export_failed = sum(1 for r in export_results.values() if not r.success)
        total_segments = sum(r.segments_exported for r in export_results.values())
        total_export_time = sum(r.export_time_sec for r in export_results.values())
        total_upload_time = sum(r.r2_upload_time_sec for r in export_results.values())
        
        logger.info(f"Background exports: {export_success}/{len(export_results)} successful")
        logger.info(f"  Total segments: {total_segments}")
        logger.info(f"  Export time: {total_export_time:.1f}s")
        logger.info(f"  Upload time: {total_upload_time:.1f}s")
        
        if export_failed > 0:
            logger.warning(f"  Failed exports: {export_failed}")
            for vid, r in export_results.items():
                if not r.success:
                    logger.warning(f"    {vid}: {r.error}")

    # Exit with appropriate code
    if stats.failed > 0:
        sys.exit(2 if stats.successful > 0 else 1)
    sys.exit(0)


if __name__ == "__main__":
    main()
