"""Codec wrappers with unified NeuralCodec interface."""

from codecbench.codecs.base import NeuralCodec, TokenBatch

CODEC_REGISTRY: dict[str, type] = {}


def register_codec(cls: type) -> type:
    CODEC_REGISTRY[cls.name] = cls
    return cls


def get_codec(name: str) -> "NeuralCodec":
    if name not in CODEC_REGISTRY:
        available = ", ".join(sorted(CODEC_REGISTRY.keys()))
        raise KeyError(f"Unknown codec {name!r}. Available: {available}")
    return CODEC_REGISTRY[name]()


def _discover() -> None:
    """Import all codec modules so @register_codec decorators fire."""
    import importlib
    for mod in ("xcodec2", "xcodec2_fast", "bicodec", "bicodec_fast"):
        try:
            importlib.import_module(f"codecbench.codecs.{mod}")
        except ImportError:
            pass


_discover()

__all__ = ["NeuralCodec", "TokenBatch", "CODEC_REGISTRY", "register_codec", "get_codec"]
