#!/usr/bin/env python3
"""
Universal GPU Worker - Main Entry Point

This is the main worker process that:
1. Auto-detects GPU and configures batch sizes
2. Connects to coordinator API (or runs standalone)
3. Processes videos in a loop
4. Uploads results to R2
5. Handles OOM errors with automatic recovery

Can run on:
- Azure NC4as_T4_v3 (T4 16GB)
- Runpod (A100, A40, etc.)
- Any machine with NVIDIA GPU

Environment Variables:
- COORDINATOR_URL: URL of the coordinator API (optional, standalone mode if not set)
- R2_ACCESS_ID, R2_SECRET_KEY, R2_ACCOUNT_ID, R2_BUCKET: R2 storage credentials
- HF_TOKEN: HuggingFace token for model access
- WORKER_ID: Override worker ID (optional)
- MAX_RETRIES: Max OOM retries (default: 3)
- LOG_LEVEL: Logging level (default: INFO)
"""

import os
import sys
import gc
import time
import signal
import asyncio
import logging
import argparse
import traceback
from pathlib import Path
from datetime import datetime
from typing import Optional, Dict, Any

# Add parent to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))

from worker.gpu_config import (
    auto_configure_gpu,
    reduce_batch_sizes,
    clear_gpu_memory,
    GPUConfig,
)
from worker.storage import R2Client, R2Config
from worker.coordinator import (
    CoordinatorClient,
    LocalCoordinator,
    WorkerInfo,
    Job,
)

# Configure logging
log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
logging.basicConfig(
    level=getattr(logging, log_level, logging.INFO),
    format='%(asctime)s | %(name)s | %(levelname)s | %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger("Worker")

# Global state
shutdown_requested = False


def signal_handler(signum, frame):
    """Handle shutdown signals gracefully."""
    global shutdown_requested
    logger.info("🛑 Shutdown signal received, finishing current job...")
    shutdown_requested = True


def setup_signals():
    """Setup signal handlers for graceful shutdown."""
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)


class DiarizationWorker:
    """
    Main worker class that processes videos.
    
    Handles:
    - GPU configuration and auto-tuning
    - Job coordination (API or standalone)
    - Video processing pipeline
    - R2 upload
    - Error recovery (OOM, network, etc.)
    """
    
    def __init__(
        self,
        coordinator_url: Optional[str] = None,
        csv_path: Optional[str] = None,
        r2_config: Optional[R2Config] = None,
        max_retries: int = 3,
    ):
        self.coordinator_url = coordinator_url
        self.csv_path = csv_path
        self.r2_config = r2_config
        self.max_retries = max_retries
        
        # Initialize later
        self.gpu_config: Optional[GPUConfig] = None
        self.coordinator: Optional[CoordinatorClient] = None
        self.local_coordinator: Optional[LocalCoordinator] = None
        self.r2_client: Optional[R2Client] = None
        self.worker_info: Optional[WorkerInfo] = None
        
        # Stats
        self.jobs_completed = 0
        self.jobs_failed = 0
        self.total_processing_time = 0
        self.start_time = time.time()
    
    async def initialize(self):
        """Initialize worker components."""
        logger.info("=" * 70)
        logger.info("🚀 INITIALIZING DIARIZATION WORKER")
        logger.info("=" * 70)
        
        # 1. Detect GPU and configure
        logger.info("🔍 Detecting GPU...")
        self.gpu_config = auto_configure_gpu(target_utilization=0.80)
        
        # 2. Create worker info
        self.worker_info = WorkerInfo(
            worker_id=os.environ.get("WORKER_ID", f"worker-{os.getpid()}"),
            gpu_name=self.gpu_config.gpu_name,
            gpu_vram_gb=self.gpu_config.gpu_vram_gb,
            region=os.environ.get("AZURE_REGION", "local"),
        )
        logger.info(f"👷 Worker ID: {self.worker_info.worker_id}")
        
        # 3. Setup coordinator
        if self.coordinator_url:
            logger.info(f"🔗 Connecting to coordinator: {self.coordinator_url}")
            self.coordinator = CoordinatorClient(
                base_url=self.coordinator_url,
                api_key=os.environ.get("COORDINATOR_API_KEY"),
            )
            await self.coordinator.register(self.worker_info)
        elif self.csv_path:
            logger.info(f"📄 Using local coordinator with CSV: {self.csv_path}")
            self.local_coordinator = LocalCoordinator(self.csv_path)
            logger.info(f"   Pending: {self.local_coordinator.stats['pending']}")
        else:
            logger.warning("⚠️ No coordinator configured - will need manual job input")
        
        # 4. Setup R2 storage
        if self.r2_config:
            logger.info(f"☁️ Connecting to R2: {self.r2_config.bucket}")
            self.r2_client = R2Client(self.r2_config)
            self.r2_client.ensure_bucket()
        else:
            logger.warning("⚠️ No R2 config - results will be stored locally only")
        
        # 5. Apply GPU config to pipeline
        self._configure_pipeline()
        
        logger.info("=" * 70)
        logger.info("✅ Worker initialized and ready!")
        logger.info("=" * 70)
    
    def _configure_pipeline(self):
        """Apply GPU config to the diarization pipeline."""
        try:
            from src.config import Config
            
            # Create config with our GPU-optimized settings
            config = Config()
            config.embedding_batch_size = self.gpu_config.embedding_batch_size
            config.music_batch_size = self.gpu_config.music_batch_size
            config.vad_workers = self.gpu_config.vad_workers
            config.chunk_workers = self.gpu_config.chunk_workers
            
            # Store for later use
            self._pipeline_config = config
            
            logger.info(f"📊 Pipeline config applied: embed={config.embedding_batch_size}, "
                       f"music={config.music_batch_size}, vad={config.vad_workers}")
            
        except ImportError as e:
            logger.warning(f"Could not import pipeline config: {e}")
            self._pipeline_config = None
    
    async def get_next_job(self) -> Optional[Job]:
        """Get the next job to process."""
        if self.coordinator:
            return await self.coordinator.request_job()
        elif self.local_coordinator:
            video_id = self.local_coordinator.get_next()
            if video_id:
                return Job(
                    video_id=video_id,
                    youtube_url=f"https://www.youtube.com/watch?v={video_id}",
                )
        return None
    
    async def mark_job_complete(self, video_id: str, result: Dict[str, Any], r2_path: str):
        """Mark a job as completed."""
        if self.coordinator:
            await self.coordinator.complete_job(video_id, result, r2_path)
        elif self.local_coordinator:
            self.local_coordinator.mark_complete(video_id)
    
    async def mark_job_failed(self, video_id: str, error: str, error_type: str = "unknown"):
        """Mark a job as failed."""
        if self.coordinator:
            await self.coordinator.fail_job(video_id, error, error_type)
        elif self.local_coordinator:
            self.local_coordinator.mark_failed(video_id)
    
    def process_video(self, job: Job) -> Dict[str, Any]:
        """
        Process a single video through the diarization pipeline.
        
        This is where the actual processing happens.
        """
        from pipeline import process_single_video
        from src.config import Config
        
        video_id = job.video_id
        url = job.youtube_url
        
        logger.info(f"🎬 Processing: {video_id}")
        start_time = time.time()
        
        # Create config with our GPU settings
        config = Config()
        config.embedding_batch_size = self.gpu_config.embedding_batch_size
        config.music_batch_size = self.gpu_config.music_batch_size
        config.vad_workers = self.gpu_config.vad_workers
        config.chunk_workers = self.gpu_config.chunk_workers
        
        # Use a unique output directory for this video
        config.output_dir = f"/tmp/diarization/{video_id}"
        os.makedirs(config.output_dir, exist_ok=True)
        
        # Process
        result = process_single_video(url, config)
        
        # Add processing metadata
        result['_processing_time'] = time.time() - start_time
        result['_worker_id'] = self.worker_info.worker_id if self.worker_info else "unknown"
        result['_gpu'] = self.gpu_config.gpu_name
        result['_batch_sizes'] = {
            'embedding': self.gpu_config.embedding_batch_size,
            'music': self.gpu_config.music_batch_size,
        }
        
        logger.info(f"✅ Processed {video_id} in {result['_processing_time']:.1f}s")
        
        return result
    
    async def upload_result(self, video_id: str, result: Dict[str, Any]) -> str:
        """Upload processing result to R2."""
        if not self.r2_client:
            return ""
        
        # Upload metadata
        r2_path = self.r2_client.upload_metadata(video_id, result)
        
        # Upload original audio if exists
        output_dir = f"/tmp/diarization/{video_id}"
        original_audio = Path(output_dir) / f"{video_id}_original.wav"
        if original_audio.exists():
            self.r2_client.upload_file(video_id, str(original_audio), "original.wav")
        
        return r2_path
    
    def cleanup_local(self, video_id: str):
        """Clean up local files after upload."""
        import shutil
        output_dir = f"/tmp/diarization/{video_id}"
        try:
            shutil.rmtree(output_dir, ignore_errors=True)
        except Exception:
            pass
    
    async def process_job_with_retry(self, job: Job) -> bool:
        """
        Process a job with OOM retry logic.
        
        Returns:
            True if successful, False if failed
        """
        video_id = job.video_id
        retry_count = 0
        original_config = GPUConfig(
            embedding_batch_size=self.gpu_config.embedding_batch_size,
            music_batch_size=self.gpu_config.music_batch_size,
        )
        
        while retry_count < self.max_retries:
            try:
                # Clear GPU memory before processing
                clear_gpu_memory()
                
                # Process
                result = self.process_video(job)
                
                # Upload to R2
                r2_path = await self.upload_result(video_id, result)
                
                # Mark complete
                await self.mark_job_complete(video_id, result, r2_path)
                
                # Cleanup
                self.cleanup_local(video_id)
                
                # Restore original batch sizes (in case we reduced them)
                self.gpu_config.embedding_batch_size = original_config.embedding_batch_size
                self.gpu_config.music_batch_size = original_config.music_batch_size
                
                # Stats
                self.jobs_completed += 1
                self.total_processing_time += result.get('_processing_time', 0)
                
                return True
                
            except RuntimeError as e:
                if "CUDA out of memory" in str(e) or "out of memory" in str(e).lower():
                    retry_count += 1
                    logger.warning(f"⚠️ OOM on {video_id} (attempt {retry_count}/{self.max_retries})")
                    
                    # Clear memory
                    clear_gpu_memory()
                    gc.collect()
                    
                    # Reduce batch sizes
                    self.gpu_config = reduce_batch_sizes(self.gpu_config, factor=0.7)
                    self._configure_pipeline()
                    
                    if retry_count >= self.max_retries:
                        await self.mark_job_failed(video_id, str(e), "oom")
                        self.jobs_failed += 1
                        return False
                    
                    # Wait a bit before retry
                    await asyncio.sleep(2)
                else:
                    raise
                    
            except Exception as e:
                error_type = "processing"
                if "download" in str(e).lower() or "yt-dlp" in str(e).lower():
                    error_type = "download"
                elif "network" in str(e).lower() or "connection" in str(e).lower():
                    error_type = "network"
                
                logger.error(f"❌ Error processing {video_id}: {e}")
                logger.debug(traceback.format_exc())
                
                await self.mark_job_failed(video_id, str(e), error_type)
                self.jobs_failed += 1
                self.cleanup_local(video_id)
                return False
        
        return False
    
    async def run(self):
        """Main worker loop."""
        global shutdown_requested
        
        logger.info("🔄 Starting main worker loop...")
        
        while not shutdown_requested:
            try:
                # Get next job
                job = await self.get_next_job()
                
                if not job:
                    logger.info("💤 No jobs available, waiting...")
                    await asyncio.sleep(10)
                    continue
                
                # Check if already processed (R2)
                if self.r2_client and self.r2_client.check_exists(job.video_id):
                    logger.info(f"⏭️ Skipping {job.video_id} (already in R2)")
                    await self.mark_job_complete(job.video_id, {}, f"{job.video_id}/metadata.json")
                    continue
                
                # Process with retry
                await self.process_job_with_retry(job)
                
                # Brief pause between jobs
                await asyncio.sleep(1)
                
            except Exception as e:
                logger.error(f"❌ Worker loop error: {e}")
                logger.debug(traceback.format_exc())
                await asyncio.sleep(5)
        
        # Print final stats
        self._print_stats()
    
    def _print_stats(self):
        """Print worker statistics."""
        elapsed = time.time() - self.start_time
        
        logger.info("=" * 70)
        logger.info("📊 WORKER STATISTICS")
        logger.info("=" * 70)
        logger.info(f"   Jobs completed: {self.jobs_completed}")
        logger.info(f"   Jobs failed: {self.jobs_failed}")
        logger.info(f"   Total time: {elapsed/60:.1f} minutes")
        if self.jobs_completed > 0:
            logger.info(f"   Avg time/job: {self.total_processing_time/self.jobs_completed:.1f}s")
        logger.info("=" * 70)
    
    async def cleanup(self):
        """Cleanup worker resources."""
        if self.coordinator:
            await self.coordinator.close()


async def main():
    """Main entry point."""
    parser = argparse.ArgumentParser(description="Diarization Worker")
    parser.add_argument("--coordinator", type=str, help="Coordinator API URL")
    parser.add_argument("--csv", type=str, help="CSV file for standalone mode")
    parser.add_argument("--r2-bucket", type=str, default="diarization-output", help="R2 bucket name")
    parser.add_argument("--max-retries", type=int, default=3, help="Max OOM retries")
    parser.add_argument("--single", type=str, help="Process single video ID (test mode)")
    args = parser.parse_args()
    
    # Setup signals
    setup_signals()
    
    # R2 config from environment
    r2_config = None
    if os.environ.get("R2_ACCESS_ID"):
        r2_config = R2Config(
            access_key_id=os.environ["R2_ACCESS_ID"],
            secret_access_key=os.environ["R2_SECRET_KEY"],
            account_id=os.environ["R2_ACCOUNT_ID"],
            bucket=args.r2_bucket or os.environ.get("R2_BUCKET", "diarization-output"),
        )
    
    # Create worker
    worker = DiarizationWorker(
        coordinator_url=args.coordinator or os.environ.get("COORDINATOR_URL"),
        csv_path=args.csv,
        r2_config=r2_config,
        max_retries=args.max_retries,
    )
    
    try:
        await worker.initialize()
        
        if args.single:
            # Single video test mode
            job = Job(
                video_id=args.single,
                youtube_url=f"https://www.youtube.com/watch?v={args.single}",
            )
            await worker.process_job_with_retry(job)
        else:
            # Normal loop mode
            await worker.run()
            
    finally:
        await worker.cleanup()


if __name__ == "__main__":
    asyncio.run(main())















