"""
Parquet shard packer: accumulates segment results from N videos,
writes a parquet file, uploads to R2.
Each GPU produces parquet shards of PARQUET_SHARD_SIZE videos.
"""
from __future__ import annotations

import json
import logging
import time
from dataclasses import asdict
from pathlib import Path
from typing import Optional

import pyarrow as pa
import pyarrow.parquet as pq

from .config import ValidationConfig, PARQUET_SHARD_SIZE
from .pipeline import SegmentResult

logger = logging.getLogger(__name__)

# Parquet schema — all columns explicitly typed
PARQUET_SCHEMA = pa.schema([
    ("video_id", pa.string()),
    ("segment_file", pa.string()),
    ("duration_s", pa.float32()),
    # Gemini
    ("gemini_lang", pa.string()),
    ("gemini_transcription", pa.string()),
    ("gemini_tagged", pa.string()),
    ("gemini_quality_score", pa.float32()),
    ("speaker_info", pa.string()),
    # MMS LID
    ("mms_lang_iso3", pa.string()),
    ("mms_lang_iso1", pa.string()),
    ("mms_confidence", pa.float32()),
    ("mms_top3", pa.string()),
    # VoxLingua
    ("vox_lang", pa.string()),
    ("vox_lang_iso1", pa.string()),
    ("vox_confidence", pa.float32()),
    ("vox_top3", pa.string()),
    ("vox_speaker_embedding", pa.binary()),
    # IndicConformer multi
    ("conformer_multi_transcription", pa.string()),
    ("conformer_multi_ctc_raw", pa.float32()),
    ("conformer_multi_ctc_normalized", pa.float32()),
    # IndicWav2Vec
    ("wav2vec_transcription", pa.string()),
    ("wav2vec_ctc_raw", pa.float32()),
    ("wav2vec_ctc_normalized", pa.float32()),
    ("wav2vec_model_used", pa.string()),
    # Consensus
    ("lid_consensus", pa.bool_()),
    ("lid_agree_count", pa.int32()),
    ("consensus_lang", pa.string()),
])


class ParquetPacker:
    """
    Accumulates segment results, writes parquet shards.
    Thread-safe accumulation via simple list append + flush.
    """

    def __init__(self, config: ValidationConfig, output_dir: Path):
        self.config = config
        self.output_dir = output_dir
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self._buffer: list[SegmentResult] = []
        self._video_ids: list[str] = []
        self._shard_count = 0
        self._total_segments = 0
        self._total_mb = 0.0
        self._s3 = None

    def _get_s3(self):
        if self._s3 is None and not self.config.mock_mode:
            import boto3
            self._s3 = boto3.client(
                "s3",
                endpoint_url=self.config.r2_endpoint_url,
                aws_access_key_id=self.config.r2_access_key_id,
                aws_secret_access_key=self.config.r2_secret_access_key,
                region_name="auto",
            )
        return self._s3

    def add_video_results(self, video_id: str, results: list[SegmentResult]):
        """Add a video's results to the buffer. Flushes when shard size reached."""
        self._buffer.extend(results)
        self._video_ids.append(video_id)
        self._total_segments += len(results)

        if len(self._video_ids) >= PARQUET_SHARD_SIZE:
            self.flush()

    def flush(self) -> Optional[Path]:
        """Write current buffer to a parquet shard and optionally upload."""
        if not self._buffer:
            return None

        self._shard_count += 1
        worker = self.config.worker_id
        shard_name = f"validation_{worker}_shard_{self._shard_count:04d}.parquet"
        shard_path = self.output_dir / shard_name

        t0 = time.time()
        table = self._to_arrow_table(self._buffer)
        pq.write_table(table, shard_path, compression="zstd")
        write_time = time.time() - t0

        n_videos = len(self._video_ids)
        n_segs = len(self._buffer)
        size_mb = shard_path.stat().st_size / 1e6

        self._total_mb += size_mb
        logger.info(
            f"Wrote shard {shard_name}: {n_videos} videos, {n_segs} segments, "
            f"{size_mb:.1f}MB, {write_time:.1f}s"
        )

        # Upload to R2
        self._upload_shard(shard_path, shard_name)

        # Write a manifest for this shard (video IDs it contains)
        manifest_path = shard_path.with_suffix(".manifest.json")
        manifest_path.write_text(json.dumps({
            "shard": shard_name,
            "worker_id": worker,
            "shard_index": self._shard_count,
            "video_ids": self._video_ids.copy(),
            "total_segments": n_segs,
            "size_mb": round(size_mb, 2),
        }, indent=2))

        self._buffer.clear()
        self._video_ids.clear()

        return shard_path

    def _to_arrow_table(self, results: list[SegmentResult]) -> pa.Table:
        """Convert list of SegmentResult to Arrow table with explicit schema."""
        columns = {field.name: [] for field in PARQUET_SCHEMA}

        for r in results:
            columns["video_id"].append(r.video_id)
            columns["segment_file"].append(r.segment_file)
            columns["duration_s"].append(r.duration_s)
            columns["gemini_lang"].append(r.gemini_lang)
            columns["gemini_transcription"].append(r.gemini_transcription)
            columns["gemini_tagged"].append(r.gemini_tagged)
            columns["gemini_quality_score"].append(r.gemini_quality_score)
            columns["speaker_info"].append(r.speaker_info or "")
            columns["mms_lang_iso3"].append(r.mms_lang_iso3)
            columns["mms_lang_iso1"].append(r.mms_lang_iso1)
            columns["mms_confidence"].append(r.mms_confidence)
            columns["mms_top3"].append(r.mms_top3)
            columns["vox_lang"].append(r.vox_lang)
            columns["vox_lang_iso1"].append(r.vox_lang_iso1)
            columns["vox_confidence"].append(r.vox_confidence)
            columns["vox_top3"].append(r.vox_top3)
            columns["vox_speaker_embedding"].append(r.vox_speaker_embedding or b"")
            columns["conformer_multi_transcription"].append(r.conformer_multi_transcription)
            columns["conformer_multi_ctc_raw"].append(r.conformer_multi_ctc_raw)
            columns["conformer_multi_ctc_normalized"].append(r.conformer_multi_ctc_normalized)
            columns["wav2vec_transcription"].append(r.wav2vec_transcription)
            columns["wav2vec_ctc_raw"].append(r.wav2vec_ctc_raw)
            columns["wav2vec_ctc_normalized"].append(r.wav2vec_ctc_normalized)
            columns["wav2vec_model_used"].append(r.wav2vec_model_used)
            columns["lid_consensus"].append(r.lid_consensus)
            columns["lid_agree_count"].append(r.lid_agree_count)
            columns["consensus_lang"].append(r.consensus_lang)

        arrays = []
        for fld in PARQUET_SCHEMA:
            col = columns[fld.name]
            arrays.append(pa.array(col, type=fld.type))

        return pa.table(arrays, schema=PARQUET_SCHEMA)

    def _upload_shard(self, shard_path: Path, shard_name: str):
        """Upload parquet shard to R2."""
        if self.config.mock_mode:
            logger.info(f"[MOCK] Would upload {shard_name}")
            return

        try:
            s3 = self._get_s3()
            if s3:
                key = f"shards/{self.config.worker_id}/{shard_name}"
                s3.upload_file(str(shard_path), self.config.r2_bucket_output, key)
                logger.info(f"Uploaded {shard_name} → s3://{self.config.r2_bucket_output}/{key}")
        except Exception as e:
            logger.error(f"Shard upload failed (kept locally): {e}")

    @property
    def stats(self) -> dict:
        return {
            "shards_written": self._shard_count,
            "total_segments": self._total_segments,
            "buffered_videos": len(self._video_ids),
            "buffered_segments": len(self._buffer),
            "total_mb": self._total_mb,
        }
