"""
R2/S3 client: download videoID.tar, extract segments, pack results tar, upload.
Supports mock mode with local filesystem.
"""
from __future__ import annotations

import json
import logging
import os
import shutil
import tarfile
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

from botocore.config import Config as BotoConfig

logger = logging.getLogger(__name__)


@dataclass
class ExtractedVideo:
    """Represents an extracted tar with metadata and segment paths."""
    video_id: str
    work_dir: Path
    metadata: dict
    segment_paths: list[Path]
    language: str = "en"


class R2Client:
    """S3-compatible client for Cloudflare R2."""

    def __init__(self, config):
        self.config = config
        self.mock_mode = config.mock_mode
        if not self.mock_mode:
            import boto3
            self.s3 = boto3.client(
                "s3",
                endpoint_url=config.r2_endpoint_url,
                aws_access_key_id=config.r2_access_key_id,
                aws_secret_access_key=config.r2_secret_access_key,
                region_name="auto",
                config=BotoConfig(max_pool_connections=_s3_pool_size()),
            )
        self.bucket = config.r2_bucket
        self.output_bucket = config.r2_output_bucket
        self.output_prefix = config.r2_output_prefix

    def download_tar(self, video_id: str, work_dir: Path, r2_tar_key: str = "") -> Path:
        tar_path = work_dir / f"{video_id}.tar"
        if self.mock_mode:
            self._create_mock_tar(video_id, tar_path)
            return tar_path

        if r2_tar_key:
            key = r2_tar_key
        else:
            # Try root first, then cleaned/trail/ prefix
            key = f"{video_id}.tar"
            try:
                self.s3.head_object(Bucket=self.bucket, Key=key)
            except Exception:
                alt_key = f"cleaned/trail/{video_id}.tar"
                try:
                    self.s3.head_object(Bucket=self.bucket, Key=alt_key)
                    key = alt_key
                except Exception:
                    pass  # fall through, download_file will raise

        logger.info(f"Downloading s3://{self.bucket}/{key} -> {tar_path}")
        self.s3.download_file(self.bucket, key, str(tar_path))
        return tar_path

    def extract_tar(self, tar_path: Path, video_id: str) -> ExtractedVideo:
        work_dir = tar_path.parent / video_id
        work_dir.mkdir(exist_ok=True)

        with tarfile.open(tar_path, "r:*") as tf:
            tf.extractall(work_dir, filter="data")

        # Find metadata.json (may be at root or nested)
        metadata_path = self._find_file(work_dir, "metadata.json")
        metadata = {}
        if metadata_path:
            metadata = json.loads(metadata_path.read_text())

        # Find all .flac segments
        segments_dir = self._find_dir(work_dir, "segments")
        segment_paths = []
        if segments_dir:
            segment_paths = sorted(segments_dir.glob("*.flac"))

        language = metadata.get("language", "en")

        return ExtractedVideo(
            video_id=video_id,
            work_dir=work_dir,
            metadata=metadata,
            segment_paths=segment_paths,
            language=language,
        )

    def pack_results_tar(self, video_id: str, work_dir: Path,
                         segment_paths: list[Path],
                         transcription_jsons: dict[str, dict],
                         metadata: dict) -> Path:
        """Pack polished segments + transcription JSONs into output tar."""
        out_dir = work_dir / "output"
        out_dir.mkdir(exist_ok=True)
        segments_out = out_dir / "segments"
        segments_out.mkdir(exist_ok=True)
        transcriptions_out = out_dir / "transcriptions"
        transcriptions_out.mkdir(exist_ok=True)

        for seg_path in segment_paths:
            shutil.copy2(seg_path, segments_out / seg_path.name)

        for seg_name, result_json in transcription_jsons.items():
            json_name = Path(seg_name).stem + ".json"
            (transcriptions_out / json_name).write_text(json.dumps(result_json, ensure_ascii=False, indent=2))

        metadata_out = out_dir / "metadata.json"
        metadata_out.write_text(json.dumps(metadata, ensure_ascii=False, indent=2))

        tar_path = work_dir / f"{video_id}_transcribed.tar"
        with tarfile.open(tar_path, "w") as tf:
            tf.add(out_dir, arcname=video_id)
        return tar_path

    def upload_tar(self, tar_path: Path, video_id: str):
        key = f"{self.output_prefix}{video_id}_transcribed.tar"
        if self.mock_mode:
            logger.info(f"[MOCK] Would upload {tar_path} -> s3://{self.output_bucket}/{key}")
            return

        logger.info(f"Uploading {tar_path} -> s3://{self.output_bucket}/{key}")
        self.s3.upload_file(str(tar_path), self.output_bucket, key)

    def cleanup(self, work_dir: Path):
        if work_dir.exists():
            shutil.rmtree(work_dir)

    # === Internal helpers ===

    def _find_file(self, root: Path, name: str) -> Optional[Path]:
        for p in root.rglob(name):
            return p
        return None

    def _find_dir(self, root: Path, name: str) -> Optional[Path]:
        for p in root.rglob(name):
            if p.is_dir():
                return p
        return None

    def _create_mock_tar(self, video_id: str, tar_path: Path):
        """Create a mock tar with fake segments for testing."""
        import numpy as np
        import soundfile as sf

        tmp_dir = tar_path.parent / f"_mock_{video_id}"
        tmp_dir.mkdir(exist_ok=True)
        segments_dir = tmp_dir / "segments"
        segments_dir.mkdir(exist_ok=True)

        sr = 16000
        # Generate 5 mock segments of varying lengths
        durations = [3.0, 5.0, 8.0, 12.0, 4.0]
        segment_files = []
        for i, dur in enumerate(durations):
            samples = int(dur * sr)
            # Sine wave with some noise to simulate speech-like audio
            t = np.linspace(0, dur, samples)
            freq = 200 + i * 50
            audio = 0.3 * np.sin(2 * np.pi * freq * t) + 0.05 * np.random.randn(samples)
            # Add silence regions for split testing
            silence_start = int(0.7 * samples)
            silence_end = min(silence_start + int(0.3 * sr), samples)
            audio[silence_start:silence_end] *= 0.01

            seg_name = f"speaker0_{i * 1000:05d}_{(i * 1000 + int(dur * 1000)):05d}.flac"
            seg_path = segments_dir / seg_name
            sf.write(str(seg_path), audio.astype(np.float32), sr)
            segment_files.append(seg_name)

        metadata = {
            "video_id": video_id,
            "language": "te",
            "total_segments": len(durations),
            "segment_files": segment_files,
        }
        (tmp_dir / "metadata.json").write_text(json.dumps(metadata))

        with tarfile.open(tar_path, "w") as tf:
            tf.add(tmp_dir, arcname=video_id)
        shutil.rmtree(tmp_dir)


def _s3_pool_size() -> int:
    # Raw tar download + prefetch can open many parallel S3 connections.
    # The default pool size of 10 is too small and causes noisy discard warnings.
    return 32
