#!/usr/bin/env python3
"""
Supabase Client for Distributed Video Processing Queue

Provides:
- Atomic video claiming with lease-based locking
- Status updates after processing
- Bulk upsert for CSV imports
- Queue statistics

Usage:
    from src.supabase_client import SupabaseClient

    client = SupabaseClient()
    video = client.claim_next_video("worker-1")
    if video:
        # Process video...
        client.update_status(video['youtube_id'], 'COMPLETED', results)
"""

import os
import json
import logging
import time
from datetime import datetime, timezone
from typing import Dict, Any, Optional, List
from urllib.request import Request, urlopen
from urllib.error import HTTPError, URLError
from urllib.parse import urlencode

logger = logging.getLogger("SupabaseClient")

# Status constants
STATUS_PENDING = 'PENDING'
STATUS_CLAIMED = 'CLAIMED'
STATUS_COMPLETED = 'COMPLETED'
STATUS_FAILED = 'FAILED'
STATUS_SKIPPED = 'SKIPPED'


class SupabaseClient:
    """
    Supabase REST API client for video processing queue.

    Uses atomic claims via RPC function for distributed workers.
    """

    _instance = None

    def __new__(cls, *args, **kwargs):
        """Singleton pattern - reuse connection."""
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance._initialized = False
        return cls._instance

    def __init__(self, url: str = None, key: str = None):
        """
        Initialize Supabase client.

        Args:
            url: Supabase project URL (loads from URL env var if None)
            key: Service role key (loads from SUPABASE_ADMIN env var if None)
        """
        if self._initialized:
            return

        self.url = url or os.environ.get('URL', '')
        self.key = key or os.environ.get('SUPABASE_ADMIN', '')

        if not self.url or not self.key:
            logger.warning("Supabase credentials not configured. Set URL and SUPABASE_ADMIN env vars.")

        self.rest_url = f"{self.url}/rest/v1"
        self.headers = {
            'apikey': self.key,
            'Authorization': f'Bearer {self.key}',
            'Content-Type': 'application/json',
            'Prefer': 'return=representation'
        }

        self._initialized = True
        logger.info(f"SupabaseClient initialized: {self.url}")

    def _request(
        self,
        method: str,
        endpoint: str,
        data: dict = None,
        params: dict = None,
        headers_override: dict = None
    ) -> Any:
        """Make HTTP request to Supabase REST API."""
        url = f"{self.rest_url}/{endpoint}"
        if params:
            url = f"{url}?{urlencode(params)}"

        headers = {**self.headers, **(headers_override or {})}

        body = json.dumps(data).encode('utf-8') if data else None

        req = Request(url, data=body, headers=headers, method=method)

        try:
            with urlopen(req, timeout=30) as resp:
                content = resp.read().decode('utf-8')
                return json.loads(content) if content else None
        except HTTPError as e:
            error_body = e.read().decode('utf-8')
            logger.error(f"Supabase error {e.code}: {error_body}")
            raise
        except URLError as e:
            logger.error(f"Network error: {e.reason}")
            raise

    def _rpc(self, function_name: str, params: dict = None) -> Any:
        """Call Supabase RPC function."""
        url = f"{self.url}/rest/v1/rpc/{function_name}"

        body = json.dumps(params or {}).encode('utf-8')
        req = Request(url, data=body, headers=self.headers, method='POST')

        try:
            with urlopen(req, timeout=30) as resp:
                content = resp.read().decode('utf-8')
                return json.loads(content) if content else None
        except HTTPError as e:
            error_body = e.read().decode('utf-8')
            logger.error(f"RPC error {e.code}: {error_body}")
            raise

    def claim_next_video(
        self,
        worker_id: str,
        lease_duration_sec: int = 900,
        language: str = None
    ) -> Optional[Dict[str, Any]]:
        """
        Atomically claim the next PENDING video.

        Uses RPC function with FOR UPDATE SKIP LOCKED for safe concurrent access.
        Falls back to REST API claim if RPC not available.

        Args:
            worker_id: Unique identifier for this worker
            lease_duration_sec: How long to hold the claim (default 15 min)
            language: Optional language filter

        Returns:
            Video dict with youtube_id, youtube_url, title, etc. or None if queue empty
        """
        try:
            # Try RPC function first (atomic)
            result = self._rpc('claim_next_video', {
                'p_worker_id': worker_id,
                'p_lease_duration_sec': lease_duration_sec
            })

            if result and len(result) > 0:
                video = result[0]
                logger.info(f"Claimed video: {video.get('youtube_id')} by {worker_id}")
                return video
            return None

        except HTTPError as e:
            if e.code == 404:
                # RPC function not found, fall back to REST API
                logger.warning("claim_next_video RPC not found, using REST fallback")
                return self._claim_video_rest(worker_id, lease_duration_sec, language)
            raise

    def _claim_video_rest(
        self,
        worker_id: str,
        lease_duration_sec: int,
        language: str = None
    ) -> Optional[Dict[str, Any]]:
        """
        Fallback claim using REST API (less atomic but works without RPC).

        Note: This has a small race condition window between SELECT and UPDATE.
        For production, use the RPC function.
        """
        # Find oldest pending video
        params = {
            'status': f'eq.{STATUS_PENDING}',
            'order': 'created_at.asc',
            'limit': '1'
        }
        if language:
            params['language'] = f'eq.{language}'

        videos = self._request('GET', 'videos', params=params)

        if not videos:
            return None

        video = videos[0]
        video_id = video['youtube_id']

        # Attempt to claim (may race with other workers)
        now = datetime.now(timezone.utc).isoformat()
        lease_expires = datetime.fromtimestamp(
            time.time() + lease_duration_sec,
            tz=timezone.utc
        ).isoformat()

        update_data = {
            'status': STATUS_CLAIMED,
            'claimed_by': worker_id,
            'claimed_at': now,
            'lease_expires_at': lease_expires,
            'attempt_count': (video.get('attempt_count') or 0) + 1
        }

        # Only update if still PENDING (optimistic lock)
        params = {
            'youtube_id': f'eq.{video_id}',
            'status': f'eq.{STATUS_PENDING}'
        }

        result = self._request(
            'PATCH',
            'videos',
            data=update_data,
            params=params
        )

        if result and len(result) > 0:
            logger.info(f"Claimed video (REST): {video_id} by {worker_id}")
            return result[0]

        # Another worker got it first
        logger.debug(f"Lost race for video: {video_id}")
        return None

    def update_status(
        self,
        youtube_id: str,
        status: str,
        results: Dict[str, Any] = None
    ) -> bool:
        """
        Update video status and processing results.

        Args:
            youtube_id: YouTube video ID
            status: New status (COMPLETED, FAILED, SKIPPED)
            results: Processing results dict (num_speakers, usable_percentage, etc.)

        Returns:
            True if update succeeded
        """
        now = datetime.now(timezone.utc).isoformat()

        update_data = {
            'status': status,
            'updated_at': now,
        }

        if status == STATUS_COMPLETED:
            update_data['processing_completed_at'] = now

        if results:
            # Map results to column names
            field_mapping = {
                'num_speakers': 'num_speakers',
                'total_segments': 'total_segments',
                'usable_percentage': 'usable_percentage',
                'usable_segments': 'usable_segments',
                'usable_duration_sec': 'usable_duration_sec',
                'r2_tar_key': 'r2_tar_key',
                'r2_tar_size_bytes': 'r2_tar_size_bytes',
                'pipeline_version': 'pipeline_version',
                'processing_meta': 'processing_meta',
                'quality_stats': 'quality_stats',
                'segment_summary': 'segment_summary',
                'download_meta': 'download_meta',
                'last_error': 'last_error',
                'last_error_type': 'last_error_type',
                'audio_native_sample_rate': 'audio_native_sample_rate',
                'audio_channels': 'audio_channels',
                'audio_duration_sec': 'audio_duration_sec',
                'needs_demucs_count': 'needs_demucs_count',
                'heavy_music_count': 'heavy_music_count',
            }

            for src_key, dst_key in field_mapping.items():
                if src_key in results and results[src_key] is not None:
                    update_data[dst_key] = results[src_key]

        params = {'youtube_id': f'eq.{youtube_id}'}

        try:
            result = self._request('PATCH', 'videos', data=update_data, params=params)
            logger.info(f"Updated status: {youtube_id} -> {status}")
            return True
        except Exception as e:
            logger.error(f"Failed to update status for {youtube_id}: {e}")
            return False

    def release_claim(self, youtube_id: str) -> bool:
        """
        Release a claim without completing (e.g., worker shutdown).

        Resets video to PENDING so another worker can pick it up.
        """
        update_data = {
            'status': STATUS_PENDING,
            'claimed_by': None,
            'claimed_at': None,
            'lease_expires_at': None,
            'updated_at': datetime.now(timezone.utc).isoformat()
        }

        params = {'youtube_id': f'eq.{youtube_id}'}

        try:
            self._request('PATCH', 'videos', data=update_data, params=params)
            logger.info(f"Released claim: {youtube_id}")
            return True
        except Exception as e:
            logger.error(f"Failed to release claim for {youtube_id}: {e}")
            return False

    def bulk_upsert(
        self,
        videos: List[Dict[str, Any]],
        batch_size: int = 1000,
        on_conflict: str = 'youtube_id'
    ) -> int:
        """
        Batch upsert videos to database.

        Args:
            videos: List of video dicts
            batch_size: Number of rows per batch
            on_conflict: Column for conflict resolution

        Returns:
            Number of rows upserted
        """
        total = 0

        for i in range(0, len(videos), batch_size):
            batch = videos[i:i + batch_size]

            headers_override = {
                'Prefer': 'resolution=merge-duplicates,return=minimal'
            }

            # Add on_conflict to URL params for upsert
            params = {
                'on_conflict': on_conflict
            }

            try:
                self._request(
                    'POST',
                    'videos',
                    data=batch,
                    params=params,
                    headers_override=headers_override
                )
                total += len(batch)
                logger.info(f"Upserted batch {i//batch_size + 1}: {len(batch)} rows (total: {total})")
            except Exception as e:
                logger.error(f"Batch upsert failed at row {i}: {e}")
                raise

        return total

    def get_video(self, youtube_id: str) -> Optional[Dict[str, Any]]:
        """Get a single video by YouTube ID."""
        params = {
            'youtube_id': f'eq.{youtube_id}',
            'limit': '1'
        }

        result = self._request('GET', 'videos', params=params)
        return result[0] if result else None

    def get_stats(self) -> Dict[str, int]:
        """
        Get queue statistics.

        Returns:
            Dict with counts by status
        """
        # Use Supabase's count feature
        stats = {}

        for status in [STATUS_PENDING, STATUS_CLAIMED, STATUS_COMPLETED, STATUS_FAILED, STATUS_SKIPPED]:
            params = {
                'status': f'eq.{status}',
                'select': 'id'
            }
            headers_override = {'Prefer': 'count=exact'}

            try:
                # Make request and check Content-Range header
                url = f"{self.rest_url}/videos?{urlencode(params)}"
                req = Request(url, headers={**self.headers, **headers_override})

                with urlopen(req, timeout=30) as resp:
                    content_range = resp.headers.get('Content-Range', '0-0/0')
                    # Format: "0-9/100" where 100 is total count
                    total = int(content_range.split('/')[-1])
                    stats[status] = total
            except Exception as e:
                logger.warning(f"Failed to get count for {status}: {e}")
                stats[status] = -1

        return stats

    def get_pending_count(self) -> int:
        """Get count of pending videos."""
        stats = self.get_stats()
        return stats.get(STATUS_PENDING, 0)

    def reclaim_expired_leases(self) -> int:
        """
        Reset videos with expired leases back to PENDING.

        Useful for recovering from worker crashes.

        Returns:
            Number of videos reset
        """
        now = datetime.now(timezone.utc).isoformat()

        # Find claimed videos with expired leases
        params = {
            'status': f'eq.{STATUS_CLAIMED}',
            'lease_expires_at': f'lt.{now}'
        }

        update_data = {
            'status': STATUS_PENDING,
            'claimed_by': None,
            'lease_expires_at': None,
            'updated_at': now
        }

        try:
            result = self._request('PATCH', 'videos', data=update_data, params=params)
            count = len(result) if result else 0
            if count > 0:
                logger.info(f"Reclaimed {count} expired leases")
            return count
        except Exception as e:
            logger.error(f"Failed to reclaim expired leases: {e}")
            return 0


# Convenience function
def get_supabase_client() -> SupabaseClient:
    """Get singleton SupabaseClient instance."""
    return SupabaseClient()
