"""
R2/S3 client: download videoID.tar, extract segments, pack results tar, upload.
Supports mock mode with local filesystem.
"""
from __future__ import annotations

import json
import logging
import os
import shutil
import tarfile
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

logger = logging.getLogger(__name__)


@dataclass
class ExtractedVideo:
    """Represents an extracted tar with metadata and segment paths."""
    video_id: str
    work_dir: Path
    metadata: dict
    segment_paths: list[Path]
    language: str = "en"


class R2Client:
    """S3-compatible client for Cloudflare R2."""

    def __init__(self, config):
        self.config = config
        self.mock_mode = config.mock_mode
        if not self.mock_mode:
            import boto3
            self.s3 = boto3.client(
                "s3",
                endpoint_url=config.r2_endpoint_url,
                aws_access_key_id=config.r2_access_key_id,
                aws_secret_access_key=config.r2_secret_access_key,
                region_name="auto",
            )
        self.bucket = config.r2_bucket
        self.output_bucket = config.r2_output_bucket
        self.output_prefix = config.r2_output_prefix

    def download_tar(self, video_id: str, work_dir: Path, r2_tar_key: str = "") -> Path:
        tar_path = work_dir / f"{video_id}.tar"
        if self.mock_mode:
            self._create_mock_tar(video_id, tar_path)
            return tar_path

        if r2_tar_key:
            key = r2_tar_key
        else:
            # Try root first, then cleaned/trail/ prefix
            key = f"{video_id}.tar"
            try:
                self.s3.head_object(Bucket=self.bucket, Key=key)
            except Exception:
                alt_key = f"cleaned/trail/{video_id}.tar"
                try:
                    self.s3.head_object(Bucket=self.bucket, Key=alt_key)
                    key = alt_key
                except Exception:
                    pass  # fall through, download_file will raise

        logger.info(f"Downloading s3://{self.bucket}/{key} -> {tar_path}")
        self.s3.download_file(self.bucket, key, str(tar_path))
        return tar_path

    def extract_tar(self, tar_path: Path, video_id: str) -> ExtractedVideo:
        work_dir = tar_path.parent / video_id
        work_dir.mkdir(exist_ok=True)

        with tarfile.open(tar_path, "r:*") as tf:
            tf.extractall(work_dir, filter="data")

        # Find metadata.json (may be at root or nested)
        metadata_path = self._find_file(work_dir, "metadata.json")
        metadata = {}
        if metadata_path:
            metadata = json.loads(metadata_path.read_text())

        # Find all .flac segments
        segments_dir = self._find_dir(work_dir, "segments")
        segment_paths = []
        if segments_dir:
            segment_paths = sorted(segments_dir.glob("*.flac"))

        language = metadata.get("language", "en")

        return ExtractedVideo(
            video_id=video_id,
            work_dir=work_dir,
            metadata=metadata,
            segment_paths=segment_paths,
            language=language,
        )

    def pack_results_tar(self, video_id: str, work_dir: Path,
                         segment_paths: list[Path],
                         transcription_jsons: dict[str, dict],
                         metadata: dict) -> Path:
        """Pack polished segments + transcription JSONs into output tar."""
        out_dir = work_dir / "output"
        out_dir.mkdir(exist_ok=True)
        segments_out = out_dir / "segments"
        segments_out.mkdir(exist_ok=True)
        transcriptions_out = out_dir / "transcriptions"
        transcriptions_out.mkdir(exist_ok=True)

        for seg_path in segment_paths:
            shutil.copy2(seg_path, segments_out / seg_path.name)

        for seg_name, result_json in transcription_jsons.items():
            json_name = Path(seg_name).stem + ".json"
            (transcriptions_out / json_name).write_text(json.dumps(result_json, ensure_ascii=False, indent=2))

        metadata_out = out_dir / "metadata.json"
        metadata_out.write_text(json.dumps(metadata, ensure_ascii=False, indent=2))

        tar_path = work_dir / f"{video_id}_transcribed.tar"
        with tarfile.open(tar_path, "w") as tf:
            tf.add(out_dir, arcname=video_id)
        return tar_path

    def upload_tar(self, tar_path: Path, video_id: str):
        key = f"{self.output_prefix}{video_id}_transcribed.tar"
        if self.mock_mode:
            logger.info(f"[MOCK] Would upload {tar_path} -> s3://{self.output_bucket}/{key}")
            return

        logger.info(f"Uploading {tar_path} -> s3://{self.output_bucket}/{key}")
        self.s3.upload_file(str(tar_path), self.output_bucket, key)

    def cleanup(self, work_dir: Path):
        if work_dir.exists():
            shutil.rmtree(work_dir)

    # === Internal helpers ===

    def _find_file(self, root: Path, name: str) -> Optional[Path]:
        for p in root.rglob(name):
            return p
        return None

    def _find_dir(self, root: Path, name: str) -> Optional[Path]:
        for p in root.rglob(name):
            if p.is_dir():
                return p
        return None

    def _create_mock_tar(self, video_id: str, tar_path: Path):
        """Create a mock tar with fake segments for testing."""
        import numpy as np
        import soundfile as sf

        tmp_dir = tar_path.parent / f"_mock_{video_id}"
        tmp_dir.mkdir(exist_ok=True)
        segments_dir = tmp_dir / "segments"
        segments_dir.mkdir(exist_ok=True)

        sr = 16000
        # Generate 5 mock segments of varying lengths
        durations = [3.0, 5.0, 8.0, 12.0, 4.0]
        segment_files = []
        for i, dur in enumerate(durations):
            samples = int(dur * sr)
            # Sine wave with some noise to simulate speech-like audio
            t = np.linspace(0, dur, samples)
            freq = 200 + i * 50
            audio = 0.3 * np.sin(2 * np.pi * freq * t) + 0.05 * np.random.randn(samples)
            # Add silence regions for split testing
            silence_start = int(0.7 * samples)
            silence_end = min(silence_start + int(0.3 * sr), samples)
            audio[silence_start:silence_end] *= 0.01

            seg_name = f"speaker0_{i * 1000:05d}_{(i * 1000 + int(dur * 1000)):05d}.flac"
            seg_path = segments_dir / seg_name
            sf.write(str(seg_path), audio.astype(np.float32), sr)
            segment_files.append(seg_name)

        metadata = {
            "video_id": video_id,
            "language": "te",
            "total_segments": len(durations),
            "segment_files": segment_files,
        }
        (tmp_dir / "metadata.json").write_text(json.dumps(metadata))

        with tarfile.open(tar_path, "w") as tf:
            tf.add(tmp_dir, arcname=video_id)
        shutil.rmtree(tmp_dir)
