"""Shard packer: accumulates encoded tokens from multiple videos, packs into
compressed shards, and uploads to R2.

Storage format per shard: tar.zst containing:
  - manifest.json: shard metadata (video_ids, languages, codec info, segment count)
  - segments/<video_id>/<segment_idx>.npz: per-segment token data (uint16)

Designed for efficient sequential reads during LM training.
"""

from __future__ import annotations

import io
import json
import logging
import tarfile
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path

import numpy as np
import zstandard as zstd

from codecbench.pipeline.config import PipelineConfig
from codecbench.pipeline.encoder import EncodedSegment

logger = logging.getLogger(__name__)


@dataclass
class VideoTokens:
    """All encoded segments for one video."""
    video_id: str
    language: str
    duration_s: float
    segments: list[EncodedSegment]
    usable_audio_s: float = 0.0


@dataclass
class ShardBuffer:
    """Accumulates video tokens until shard_pack_count is reached."""
    videos: list[VideoTokens] = field(default_factory=list)
    total_segments: int = 0
    total_audio_s: float = 0.0

    def add(self, vt: VideoTokens) -> None:
        self.videos.append(vt)
        self.total_segments += len(vt.segments)
        self.total_audio_s += vt.usable_audio_s

    @property
    def video_count(self) -> int:
        return len(self.videos)

    def clear(self) -> None:
        self.videos.clear()
        self.total_segments = 0
        self.total_audio_s = 0.0


def pack_shard(
    buffer: ShardBuffer,
    output_dir: Path,
    codecs: list[str] | None = None,
) -> tuple[str, Path, int]:
    """Pack accumulated video tokens into a compressed tar shard.

    Returns (shard_id, local_path, size_bytes).
    """
    if codecs is None:
        codecs = ["xcodec2_fast"]

    shard_id = f"shard_{int(time.time())}_{uuid.uuid4().hex[:8]}"
    output_dir.mkdir(parents=True, exist_ok=True)
    shard_path = output_dir / f"{shard_id}.tar.zst"

    manifest = {
        "shard_id": shard_id,
        "codecs": codecs,
        "sample_rate": 16_000,
        "video_count": buffer.video_count,
        "total_segments": buffer.total_segments,
        "total_audio_s": buffer.total_audio_s,
        "videos": [],
    }

    cctx = zstd.ZstdCompressor(level=3, threads=-1)
    raw_buf = io.BytesIO()

    with tarfile.open(fileobj=raw_buf, mode="w") as tar:
        for vt in buffer.videos:
            video_meta = {
                "video_id": vt.video_id,
                "language": vt.language,
                "duration_s": vt.duration_s,
                "usable_audio_s": vt.usable_audio_s,
                "num_segments": len(vt.segments),
            }
            manifest["videos"].append(video_meta)

            for seg in vt.segments:
                arrays = {}

                xc = seg.xcodec2_tokens.numpy()
                if xc.min() >= 0 and xc.max() <= 65535:
                    xc = xc.astype(np.uint16)
                arrays["xcodec2"] = xc

                seg_meta = json.dumps({
                    "segment_idx": seg.segment_idx,
                    "start_s": seg.start_s,
                    "end_s": seg.end_s,
                    "duration_s": seg.duration_s,
                    "num_chunks": seg.num_chunks,
                    "xcodec2_token_count": seg.xcodec2_tokens.shape[-1],
                    "video_id": vt.video_id,
                    "language": vt.language,
                }).encode("utf-8")
                arrays["_meta"] = np.frombuffer(seg_meta, dtype=np.uint8)

                # Serialize to npz bytes
                npz_buf = io.BytesIO()
                np.savez(npz_buf, **arrays)
                npz_bytes = npz_buf.getvalue()

                # Add to tar
                member_name = f"segments/{vt.video_id}/{seg.segment_idx:06d}.npz"
                info = tarfile.TarInfo(name=member_name)
                info.size = len(npz_bytes)
                tar.addfile(info, io.BytesIO(npz_bytes))

        # Add manifest
        manifest_bytes = json.dumps(manifest, indent=2).encode("utf-8")
        info = tarfile.TarInfo(name="manifest.json")
        info.size = len(manifest_bytes)
        tar.addfile(info, io.BytesIO(manifest_bytes))

    # Compress the tar with zstd
    raw_bytes = raw_buf.getvalue()
    compressed = cctx.compress(raw_bytes)

    with open(shard_path, "wb") as f:
        f.write(compressed)

    size = shard_path.stat().st_size
    ratio = len(raw_bytes) / max(size, 1)
    logger.info(
        "Packed shard %s: %d videos, %d segments, %.1f s audio, "
        "%.1f MB (%.1fx compression)",
        shard_id, buffer.video_count, buffer.total_segments,
        buffer.total_audio_s, size / 1e6, ratio,
    )

    return shard_id, shard_path, size
