"""R2 (S3-compatible) client for video download and encoded shard upload.

Supports concurrent downloads via ThreadPoolExecutor and prefetch queue
so the GPU never starves waiting for I/O.
"""

from __future__ import annotations

import logging
import os
import re
import subprocess
from concurrent.futures import ThreadPoolExecutor, Future
from pathlib import Path
from queue import Queue, Empty
from threading import Event

import boto3
from botocore.config import Config as BotoConfig

from codecbench.pipeline.config import R2Config

logger = logging.getLogger(__name__)


class R2Client:
    def __init__(self, cfg: R2Config):
        self._cfg = cfg
        self._client = boto3.client(
            "s3",
            endpoint_url=cfg.endpoint_url,
            aws_access_key_id=cfg.access_key_id,
            aws_secret_access_key=cfg.secret_access_key,
            region_name="auto",
            config=BotoConfig(
                max_pool_connections=cfg.max_download_workers + 4,
                retries={"max_attempts": 5, "mode": "adaptive"},
            ),
        )

    # Language directories found in pt-indic (video stored as {lang}/{vid}.webm)
    _LANG_DIRS = [
        "indic", "telugu", "tamil", "hindi", "malayalam", "kannada",
        "marathi", "bengali", "gujarati", "punjabi", "assamese", "odia",
    ]

    def find_video_key(self, video_id: str, bucket: str, language: str | None = None) -> str | None:
        """Find the object key for a video_id.

        Handles multiple R2 naming conventions:
          - Flat: {vid}_english_pretrain.webm, {vid}_indic_pretrain.webm
          - Directory: english_pretrain/{vid}.webm, bengali/{vid}.webm
          - Audio: audio/{vid}.webm
        Pass `language` to check the most likely directory first (1 API call).
        """
        # Fast path: root prefix match (flat naming — covers ~10K files per bucket)
        resp = self._client.list_objects_v2(Bucket=bucket, Prefix=video_id, MaxKeys=5)
        for obj in resp.get("Contents", []):
            if video_id in obj["Key"]:
                return obj["Key"]

        # Directory lookup — prioritize known language, then try others
        if bucket == "pt-english":
            dirs_to_try = ["english_pretrain", "audio"]
        else:
            dirs_to_try = []
            if language and language in self._LANG_DIRS:
                dirs_to_try.append(language)
            dirs_to_try.append("indic")
            dirs_to_try.extend(d for d in self._LANG_DIRS if d not in dirs_to_try)

        for d in dirs_to_try:
            try:
                key = f"{d}/{video_id}.webm"
                self._client.head_object(Bucket=bucket, Key=key)
                return key
            except Exception:
                pass

        return None

    def download_video(self, video_id: str, bucket: str, dest_dir: Path, language: str | None = None) -> Path | None:
        """Download a video file from R2 to local filesystem.

        Returns the local path, or None if not found.
        """
        key = self.find_video_key(video_id, bucket, language=language)
        if key is None:
            logger.warning("Video %s not found in bucket %s", video_id, bucket)
            return None

        dest_dir.mkdir(parents=True, exist_ok=True)
        ext = Path(key).suffix or ".webm"
        local_path = dest_dir / f"{video_id}{ext}"

        if local_path.exists():
            logger.debug("Already downloaded: %s", local_path)
            return local_path

        logger.info("Downloading %s/%s → %s", bucket, key, local_path)
        self._client.download_file(bucket, key, str(local_path))
        logger.info("Downloaded %s (%.1f MB)", local_path.name, local_path.stat().st_size / 1e6)
        return local_path

    def upload_shard(self, local_path: Path, shard_key: str) -> None:
        """Upload an encoded shard to the output bucket."""
        bucket = self._cfg.output_bucket
        logger.info("Uploading shard %s → %s/%s", local_path.name, bucket, shard_key)
        self._client.upload_file(
            str(local_path), bucket, shard_key,
            ExtraArgs={"ContentType": "application/octet-stream"},
        )
        logger.info("Uploaded shard %s (%.1f MB)", shard_key, local_path.stat().st_size / 1e6)

    def download_file(self, bucket: str, key: str, local_path: Path) -> Path:
        """Generic file download."""
        local_path.parent.mkdir(parents=True, exist_ok=True)
        self._client.download_file(bucket, key, str(local_path))
        return local_path

    def ensure_bucket(self, bucket: str) -> None:
        """Create bucket if it doesn't exist (R2 specific)."""
        try:
            self._client.head_bucket(Bucket=bucket)
        except Exception:
            try:
                self._client.create_bucket(Bucket=bucket)
                logger.info("Created bucket: %s", bucket)
            except Exception as e:
                if "BucketAlreadyExists" not in str(e):
                    raise


class PrefetchItem:
    """A downloaded video ready for processing."""
    __slots__ = ("video_id", "language", "bucket", "path")
    def __init__(self, video_id: str, language: str, bucket: str, path: Path | None):
        self.video_id = video_id
        self.language = language
        self.bucket = bucket
        self.path = path


class VideoPrefetcher:
    """Background prefetcher: downloads videos ahead of GPU processing.

    Maintains a queue of pre-downloaded video paths. When the encoder
    finishes one video, the next is already on disk.
    """

    def __init__(self, r2: R2Client, max_prefetch: int = 2, download_workers: int = 2):
        self._r2 = r2
        self._ready_queue: Queue[PrefetchItem] = Queue(maxsize=max_prefetch + 2)
        self._pool = ThreadPoolExecutor(max_workers=download_workers, thread_name_prefix="prefetch")
        self._stop = Event()
        self._pending_futures: list[Future] = []

    def submit(self, video_id: str, language: str, bucket: str, dest_dir: Path) -> None:
        """Queue a video for background download."""
        if self._stop.is_set():
            return
        fut = self._pool.submit(self._download_and_enqueue, video_id, language, bucket, dest_dir)
        self._pending_futures.append(fut)

    def _download_and_enqueue(self, video_id: str, language: str, bucket: str, dest_dir: Path) -> None:
        if self._stop.is_set():
            return
        try:
            path = self._r2.download_video(video_id, bucket, dest_dir)
            self._ready_queue.put(PrefetchItem(video_id, language, bucket, path))
        except Exception as e:
            logger.error("Prefetch failed for %s: %s", video_id, e)
            self._ready_queue.put(PrefetchItem(video_id, language, bucket, None))

    def get(self, timeout: float = 300.0) -> PrefetchItem:
        """Block until a prefetched video is ready."""
        return self._ready_queue.get(timeout=timeout)

    def stop(self) -> None:
        self._stop.set()
        self._pool.shutdown(wait=False)

    @property
    def queue_size(self) -> int:
        return self._ready_queue.qsize()


def extract_audio_from_video(video_path: Path, output_path: Path, target_sr: int = 16_000) -> Path:
    """Extract mono audio from video using ffmpeg, normalize loudness.

    Returns path to the output .wav file.
    LEGACY: uses loudnorm (2-pass, slow). Prefer extract_audio_pipe() for async pipeline.
    """
    output_path.parent.mkdir(parents=True, exist_ok=True)

    cmd = [
        "ffmpeg", "-y", "-i", str(video_path),
        "-vn",  # drop video
        "-ac", "1",  # mono
        "-ar", str(target_sr),
        "-af", "loudnorm=I=-23:LRA=7:TP=-2",  # EBU R128 loudness norm
        "-c:a", "pcm_s16le",
        str(output_path),
    ]

    result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
    if result.returncode != 0:
        raise RuntimeError(f"ffmpeg failed for {video_path.name}: {result.stderr[-500:]}")

    logger.info("Extracted audio: %s (%.1f MB)", output_path.name, output_path.stat().st_size / 1e6)
    return output_path


def extract_audio_pipe(
    video_path: Path,
    target_sr: int = 16_000,
    ffmpeg_threads: int = 2,
) -> "torch.Tensor":
    """Extract audio from video directly to a tensor via ffmpeg pipe.

    Single-pass decode (no loudnorm), streams raw PCM to stdout.
    ~2x faster than extract_audio_from_video() and avoids disk I/O entirely.
    """
    import numpy as np
    import torch

    cmd = [
        "ffmpeg", "-y",
        "-threads", str(ffmpeg_threads),
        "-i", str(video_path),
        "-vn", "-ac", "1", "-ar", str(target_sr),
        "-f", "s16le", "-acodec", "pcm_s16le",
        "pipe:1",
    ]

    result = subprocess.run(cmd, capture_output=True, timeout=600)
    if result.returncode != 0:
        stderr = result.stderr.decode("utf-8", errors="replace")[-500:]
        raise RuntimeError(f"ffmpeg pipe failed for {video_path.name}: {stderr}")

    pcm = np.frombuffer(result.stdout, dtype=np.int16).astype(np.float32) / 32768.0
    wav = torch.from_numpy(pcm).unsqueeze(0)  # [1, T]
    return wav


def normalize_audio_peak(wav: "torch.Tensor", target_peak: float = 0.95) -> "torch.Tensor":
    """Fast peak normalization. Scales so max absolute amplitude = target_peak.

    Replaces the slow 2-pass EBU R128 loudnorm in ffmpeg.
    ~0.1ms for 40M samples — negligible vs the 20s it saves on ffmpeg.
    """
    peak = wav.abs().max()
    if peak > 0:
        wav = wav * (target_peak / peak)
    return wav
