#!/usr/bin/env python3
"""
Unified GPU pull-worker (Azure Batch / RunPod / anywhere).

High-level flow:
1) Poll a dispatcher endpoint for the next job (video_id + youtube_url)
2) Run the existing pipeline for that single video
3) Package outputs into a per-video TAR
4) Upload TAR to Cloudflare R2 (S3-compatible)
5) Report completion/failure back to the dispatcher
6) Delete ALL local artifacts for the video (disk hygiene)

Design goals:
- Cross-provider: works on Azure + RunPod with the same image.
- No duplicates: dispatcher leases jobs; worker is idempotent by checking R2 object exists.
- OOM safety: on CUDA OOM, reduce settings and retry before failing the job.

Required env vars (names are intentionally explicit):
- DISPATCHER_URL                 e.g. https://dispatcher.example.com
- HF_TOKEN                       HuggingFace token for pyannote gated models
- R2_ENDPOINT_URL                e.g. https://<ACCOUNT_ID>.r2.cloudflarestorage.com
- R2_ACCESS_KEY_ID
- R2_SECRET_ACCESS_KEY
- R2_BUCKET

Optional env vars:
- WORKER_ID                      default: "<hostname>:<pid>"
- WORKER_REGION                  informational ("eastus", "westus2", "runpod", ...)
- WORKER_PROVIDER                informational ("azure-batch", "runpod", ...)
- WORKER_OUTPUT_BASE_DIR         default: /tmp/maya3data_worker
- MAX_UTILIZATION                default: 0.80
- POLL_SECONDS                   default: 2.0
- IDLE_JITTER_SECONDS            default: 2.0
- R2_PREFIX                       default: "dataset/videos"
"""

from __future__ import annotations

import json
import os
import random
import shutil
import socket
import tarfile
import time
import traceback
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
from urllib import error, request

import boto3
from botocore.config import Config as BotoConfig
from botocore.exceptions import ClientError

from src.config import Config
from src.models import MODELS
from pipeline import process_single_video


@dataclass(frozen=True)
class Job:
    job_id: str
    video_id: str
    youtube_url: str
    attempt: int = 1


def _env_required(name: str) -> str:
    val = os.environ.get(name, "").strip()
    if not val:
        raise RuntimeError(f"Missing required environment variable: {name}")
    return val


def _env_float(name: str, default: float) -> float:
    raw = os.environ.get(name, "").strip()
    if not raw:
        return default
    try:
        return float(raw)
    except ValueError as e:
        raise RuntimeError(f"Invalid float for {name}: {raw}") from e


def _env_str(name: str, default: str) -> str:
    raw = os.environ.get(name, "").strip()
    return raw or default


def _post_json(url: str, payload: Dict[str, Any], headers: Dict[str, str], timeout_s: float = 15.0) -> Tuple[int, str]:
    data = json.dumps(payload).encode("utf-8")
    req = request.Request(url, data=data, headers={**headers, "Content-Type": "application/json"}, method="POST")
    try:
        with request.urlopen(req, timeout=timeout_s) as resp:
            body = resp.read().decode("utf-8", errors="replace")
            return int(resp.status), body
    except error.HTTPError as e:
        body = e.read().decode("utf-8", errors="replace") if hasattr(e, "read") else ""
        return int(e.code), body


class DispatcherClient:
    def __init__(self, base_url: str, token: Optional[str] = None) -> None:
        self._base_url = base_url.rstrip("/")
        self._headers: Dict[str, str] = {}
        if token:
            self._headers["Authorization"] = f"Bearer {token}"

    def claim(self, worker_info: Dict[str, Any]) -> Optional[Job]:
        status, body = _post_json(
            url=f"{self._base_url}/v1/claim",
            payload={"worker": worker_info},
            headers=self._headers,
            timeout_s=30.0,
        )
        if status in (204, 404):
            return None
        if status != 200:
            raise RuntimeError(f"Dispatcher claim failed: HTTP {status}: {body[:500]}")
        data = json.loads(body)
        return Job(
            job_id=str(data["job_id"]),
            video_id=str(data["video_id"]),
            youtube_url=str(data["youtube_url"]),
            attempt=int(data.get("attempt", 1)),
        )

    def complete(self, job: Job, r2_key: str, result_summary: Dict[str, Any]) -> None:
        status, body = _post_json(
            url=f"{self._base_url}/v1/complete",
            payload={"job_id": job.job_id, "video_id": job.video_id, "r2_key": r2_key, "result": result_summary},
            headers=self._headers,
            timeout_s=30.0,
        )
        if status not in (200, 204):
            raise RuntimeError(f"Dispatcher complete failed: HTTP {status}: {body[:500]}")

    def fail(self, job: Job, error_type: str, message: str) -> None:
        status, body = _post_json(
            url=f"{self._base_url}/v1/fail",
            payload={"job_id": job.job_id, "video_id": job.video_id, "error_type": error_type, "message": message},
            headers=self._headers,
            timeout_s=30.0,
        )
        if status not in (200, 204):
            raise RuntimeError(f"Dispatcher fail failed: HTTP {status}: {body[:500]}")


class R2Client:
    def __init__(self, endpoint_url: str, bucket: str, access_key_id: str, secret_access_key: str) -> None:
        self.bucket = bucket
        self.s3 = boto3.client(
            "s3",
            endpoint_url=endpoint_url,
            aws_access_key_id=access_key_id,
            aws_secret_access_key=secret_access_key,
            region_name="auto",
            config=BotoConfig(signature_version="s3v4"),
        )

    def object_exists(self, key: str) -> bool:
        try:
            self.s3.head_object(Bucket=self.bucket, Key=key)
            return True
        except ClientError as e:
            code = str(e.response.get("Error", {}).get("Code", ""))
            if code in {"404", "NoSuchKey", "NotFound"}:
                return False
            raise

    def upload_file(self, local_path: str, key: str) -> None:
        self.s3.upload_file(local_path, self.bucket, key, ExtraArgs={"ContentType": "application/x-tar"})


def _detect_hardware() -> Dict[str, Any]:
    info: Dict[str, Any] = {
        "hostname": socket.gethostname(),
        "cpu_count": os.cpu_count() or 1,
    }
    try:
        import torch

        info["cuda_available"] = bool(torch.cuda.is_available())
        if torch.cuda.is_available():
            props = torch.cuda.get_device_properties(0)
            info["gpu_name"] = torch.cuda.get_device_name(0)
            info["gpu_vram_total_gb"] = round(props.total_memory / (1024**3), 2)
            info["gpu_sm"] = f"{props.major}.{props.minor}"
    except Exception:
        # Keep worker alive even if torch isn't importable (dispatcher can choose what to do).
        info["cuda_available"] = False
    return info


def _is_cuda_oom(exc: BaseException) -> bool:
    msg = str(exc).lower()
    return ("out of memory" in msg) and ("cuda" in msg or "cublas" in msg or "torch" in msg)


def _tune_down_for_oom(config: Config) -> None:
    """
    Reduce memory pressure before retrying.

    We prioritize stability over speed on small VRAM GPUs.
    """
    config.max_utilization = max(0.50, float(config.max_utilization) - 0.10)
    config.embedding_batch_size = max(64, int(config.embedding_batch_size) // 2)
    config.music_batch_size = max(16, int(config.music_batch_size) // 2)
    config.chunk_duration = max(60.0, float(config.chunk_duration) * 0.66)
    config.clear_cache_every_n_chunks = 1


def _build_video_tar(video_dir: str, tar_path: str) -> None:
    """
    Package the per-video output directory into a single TAR.

    Expected contents in `video_dir`:
    - metadata.json
    - <video_id>_original.wav (or whatever download stage produced)
    """
    if not os.path.isdir(video_dir):
        raise FileNotFoundError(f"Video output dir missing: {video_dir}")

    # Keep archive paths stable: store as "<video_id>/...".
    video_id = os.path.basename(video_dir.rstrip("/"))
    with tarfile.open(tar_path, "w") as tf:
        for filename in os.listdir(video_dir):
            abs_path = os.path.join(video_dir, filename)
            if not os.path.isfile(abs_path):
                continue
            arcname = f"{video_id}/{filename}"
            tf.add(abs_path, arcname=arcname, recursive=False)


def _run_pipeline_once(youtube_url: str, output_dir: str, max_utilization: float) -> Dict[str, Any]:
    # Config auto-tunes in __post_init__ (ComputeMonitor). We still set max_utilization explicitly.
    config = Config(output_dir=output_dir, max_utilization=max_utilization)
    return process_single_video(youtube_url, config)


def _result_summary(result: Dict[str, Any]) -> Dict[str, Any]:
    # Keep dispatcher payload small and stable.
    return {
        "video_id": result.get("video_id"),
        "pipeline_version": result.get("pipeline_version"),
        "timing_total": result.get("timing_total"),
        "num_speakers": result.get("num_speakers"),
        "usable_percentage": (result.get("quality_stats") or {}).get("usable_percentage"),
        "processed_duration": result.get("processed_duration"),
    }


def main() -> None:
    dispatcher_url = _env_required("DISPATCHER_URL")
    dispatcher_token = os.environ.get("DISPATCHER_TOKEN", "").strip() or None

    # Required secrets for R2 are read from env (do not hardcode).
    r2_endpoint = _env_required("R2_ENDPOINT_URL")
    r2_bucket = _env_required("R2_BUCKET")
    # Support both AWS-style names and your existing env naming.
    r2_access_key = (os.environ.get("R2_ACCESS_KEY_ID") or os.environ.get("R2_AccessID") or "").strip()
    r2_secret_key = (os.environ.get("R2_SECRET_ACCESS_KEY") or os.environ.get("R2_Secret_Access_Key") or "").strip()
    if not r2_access_key:
        raise RuntimeError("Missing required environment variable: R2_ACCESS_KEY_ID (or R2_AccessID)")
    if not r2_secret_key:
        raise RuntimeError("Missing required environment variable: R2_SECRET_ACCESS_KEY (or R2_Secret_Access_Key)")

    # Validate HF token is present (models.py will hard-fail otherwise).
    _env_required("HF_TOKEN")

    worker_id = _env_str("WORKER_ID", f"{socket.gethostname()}:{os.getpid()}")
    worker_region = _env_str("WORKER_REGION", "")
    worker_provider = _env_str("WORKER_PROVIDER", "")

    output_base = _env_str("WORKER_OUTPUT_BASE_DIR", "/tmp/maya3data_worker")
    os.makedirs(output_base, exist_ok=True)
    output_dir = os.path.join(output_base, "fast_output_v6")
    os.makedirs(output_dir, exist_ok=True)

    max_utilization = _env_float("MAX_UTILIZATION", 0.80)
    poll_seconds = _env_float("POLL_SECONDS", 2.0)
    idle_jitter = _env_float("IDLE_JITTER_SECONDS", 2.0)
    r2_prefix = _env_str("R2_PREFIX", "dataset/videos").strip("/")

    dispatcher = DispatcherClient(dispatcher_url, dispatcher_token)
    r2 = R2Client(r2_endpoint, r2_bucket, r2_access_key, r2_secret_key)

    worker_info: Dict[str, Any] = {
        "worker_id": worker_id,
        "region": worker_region,
        "provider": worker_provider,
        **_detect_hardware(),
    }

    print(f"[worker] started worker_id={worker_id} region={worker_region or '-'} provider={worker_provider or '-'}")
    print(f"[worker] r2_bucket={r2_bucket} r2_prefix={r2_prefix}")
    print(f"[worker] output_dir={output_dir} max_utilization={max_utilization}")

    while True:
        job: Optional[Job] = None
        try:
            job = dispatcher.claim(worker_info)
        except Exception as e:
            print(f"[worker] dispatcher claim error: {e}")
            time.sleep(min(30.0, poll_seconds + random.random() * idle_jitter))
            continue

        if job is None:
            time.sleep(poll_seconds + random.random() * idle_jitter)
            continue

        r2_key = f"{r2_prefix}/{job.video_id}.tar"
        tar_path = os.path.join(output_base, f"{job.video_id}.tar")
        video_dir = os.path.join(output_dir, job.video_id)

        try:
            # Idempotency guard: if output already exists, mark complete immediately.
            if r2.object_exists(r2_key):
                dispatcher.complete(job, r2_key=r2_key, result_summary={"skipped": True, "reason": "r2_exists"})
                continue

            # Try processing; on OOM tune down and retry.
            attempt_max = 2
            last_exc: Optional[BaseException] = None
            for attempt in range(attempt_max + 1):
                try:
                    result = _run_pipeline_once(job.youtube_url, output_dir=output_dir, max_utilization=max_utilization)
                    _build_video_tar(video_dir=video_dir, tar_path=tar_path)
                    r2.upload_file(tar_path, r2_key)
                    dispatcher.complete(job, r2_key=r2_key, result_summary=_result_summary(result))
                    last_exc = None
                    break
                except Exception as e:
                    last_exc = e
                    if _is_cuda_oom(e) and attempt < attempt_max:
                        print(f"[worker] CUDA OOM for video_id={job.video_id} attempt={attempt+1}/{attempt_max+1} -> tuning down + retry")
                        try:
                            MODELS.clear_cache(aggressive=True)
                        except Exception:
                            pass
                        # Best-effort: shrink future runs by lowering utilization for subsequent attempts.
                        max_utilization = max(0.50, max_utilization - 0.10)
                        time.sleep(2.0)
                        continue
                    raise

            if last_exc is not None:
                raise last_exc

        except Exception as e:
            # Report failure with a bounded message (avoid giant tracebacks in DB).
            msg = f"{type(e).__name__}: {str(e)[:800]}"
            try:
                dispatcher.fail(job, error_type=("oom" if _is_cuda_oom(e) else "processing"), message=msg)
            except Exception as report_err:
                print(f"[worker] dispatcher fail report error: {report_err}")
            print(f"[worker] job failed video_id={job.video_id}: {msg}")
            traceback.print_exc()
        finally:
            # Disk hygiene: always remove local artifacts.
            try:
                if os.path.exists(tar_path):
                    os.remove(tar_path)
            except Exception:
                pass
            try:
                if os.path.isdir(video_dir):
                    shutil.rmtree(video_dir, ignore_errors=True)
            except Exception:
                pass


if __name__ == "__main__":
    main()


