"""Token packing, saving, and loading for codec outputs.

Designed to produce artifacts compatible with the planned R2 storage layout:
  r2://<bucket>/tokens/codec=<name>/sr=<sr>/shard=<id>.tar

Individual samples stored as .npz with zstd compression at shard level.
"""

from __future__ import annotations

import io
import json
from pathlib import Path
from typing import Any

import numpy as np
import torch

from codecbench.codecs.base import TokenBatch


def token_batch_to_numpy(tb: TokenBatch) -> dict[str, np.ndarray]:
    """Convert TokenBatch tokens to numpy arrays for serialization.

    Returns dict of array_name -> ndarray, using uint16 when possible.
    """
    arrays: dict[str, np.ndarray] = {}

    def _convert(t: torch.Tensor, name: str) -> None:
        arr = t.cpu().numpy()
        vmin, vmax = arr.min(), arr.max()
        if vmin >= 0 and vmax <= 65535:
            arr = arr.astype(np.uint16)
        else:
            arr = arr.astype(np.int32)
        arrays[name] = arr

    if isinstance(tb.tokens, torch.Tensor):
        _convert(tb.tokens, "tokens")
    elif isinstance(tb.tokens, dict):
        for key, tensor in tb.tokens.items():
            _convert(tensor, f"tokens_{key}")
    elif isinstance(tb.tokens, (list, tuple)):
        for i, tensor in enumerate(tb.tokens):
            _convert(tensor, f"tokens_level_{i}")

    return arrays


def save_tokens_npz(tb: TokenBatch, path: str | Path) -> Path:
    """Save TokenBatch to .npz file with metadata."""
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)

    arrays = token_batch_to_numpy(tb)
    meta = {
        "codec_name": tb.codec_name,
        "sample_rate": tb.sample_rate,
        "batch_size": tb.batch_size,
        "shapes_summary": tb.shapes_summary(),
    }
    if tb.aux:
        serializable_aux = {}
        for k, v in tb.aux.items():
            if isinstance(v, (int, float, str, bool)):
                serializable_aux[k] = v
            elif isinstance(v, torch.Tensor):
                serializable_aux[k] = v.tolist()
        meta["aux"] = serializable_aux

    meta_bytes = json.dumps(meta).encode("utf-8")
    arrays["_meta_json"] = np.frombuffer(meta_bytes, dtype=np.uint8)

    np.savez_compressed(str(path), **arrays)
    return path


def load_tokens_npz(path: str | Path) -> tuple[dict[str, np.ndarray], dict]:
    """Load tokens and metadata from .npz file."""
    data = np.load(str(path), allow_pickle=False)

    meta_bytes = data["_meta_json"].tobytes()
    meta = json.loads(meta_bytes.decode("utf-8"))

    arrays = {k: data[k] for k in data.files if k != "_meta_json"}
    return arrays, meta


def compress_zstd(data: bytes, level: int = 3) -> bytes:
    """Compress bytes with zstandard."""
    import zstandard as zstd
    cctx = zstd.ZstdCompressor(level=level)
    return cctx.compress(data)


def decompress_zstd(data: bytes) -> bytes:
    import zstandard as zstd
    dctx = zstd.ZstdDecompressor()
    return dctx.decompress(data)
