"""Unified codec interface: TokenBatch dataclass + NeuralCodec protocol."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Protocol, runtime_checkable

import torch


@dataclass
class TokenBatch:
    """Container for codec token output, supporting diverse token structures.

    tokens can be:
      - torch.LongTensor [B, T]           single-stream (XCodec2, WavTokenizer)
      - list[torch.LongTensor]            multi-scale (SNAC: different T per level)
      - dict[str, torch.LongTensor]       named streams (BiCodec: semantic + global)
    """

    codec_name: str
    sample_rate: int
    tokens: Any
    aux: dict[str, Any] = field(default_factory=dict)

    @property
    def batch_size(self) -> int:
        if isinstance(self.tokens, torch.Tensor):
            return self.tokens.shape[0]
        if isinstance(self.tokens, dict):
            return next(iter(self.tokens.values())).shape[0]
        if isinstance(self.tokens, (list, tuple)):
            return self.tokens[0].shape[0]
        raise ValueError(f"Cannot infer batch_size from tokens type {type(self.tokens)}")

    @property
    def token_count(self) -> int:
        """Total number of token elements across all streams/levels."""
        if isinstance(self.tokens, torch.Tensor):
            return self.tokens.numel()
        if isinstance(self.tokens, dict):
            return sum(v.numel() for v in self.tokens.values())
        if isinstance(self.tokens, (list, tuple)):
            return sum(t.numel() for t in self.tokens)
        return 0

    def observed_vocab(self) -> tuple[int, int]:
        """Return (min_token, max_token) observed across all token tensors."""
        tensors: list[torch.Tensor] = []
        if isinstance(self.tokens, torch.Tensor):
            tensors = [self.tokens]
        elif isinstance(self.tokens, dict):
            tensors = list(self.tokens.values())
        elif isinstance(self.tokens, (list, tuple)):
            tensors = list(self.tokens)
        all_min = min(t.min().item() for t in tensors)
        all_max = max(t.max().item() for t in tensors)
        return int(all_min), int(all_max)

    def shapes_summary(self) -> str:
        if isinstance(self.tokens, torch.Tensor):
            return str(list(self.tokens.shape))
        if isinstance(self.tokens, dict):
            return str({k: list(v.shape) for k, v in self.tokens.items()})
        if isinstance(self.tokens, (list, tuple)):
            return str([list(t.shape) for t in self.tokens])
        return "unknown"


@runtime_checkable
class NeuralCodec(Protocol):
    """Unified interface every codec wrapper must satisfy."""

    name: str
    native_sr: int

    def load(self, device: str, dtype: torch.dtype) -> None:
        """Load model weights and move to device. Call once."""
        ...

    def warmup(self, batch_seconds: float, batch_size: int) -> None:
        """Run throwaway forward passes so JIT/CUDA caches are warm."""
        ...

    def encode(self, wav: torch.Tensor, sr: int) -> TokenBatch:
        """Encode waveform to tokens.

        Args:
            wav: float32 [-1, 1], shape [B, 1, T]
            sr:  sample rate of wav (wrapper resamples if needed)
        """
        ...

    def decode(self, tb: TokenBatch) -> torch.Tensor:
        """Decode tokens back to waveform. Returns [B, 1, T']."""
        ...

    def flatten_for_lm(self, tb: TokenBatch) -> torch.LongTensor:
        """Flatten token structure to a single 1-D sequence per batch element.

        Returns [B, T_flat]. Used to estimate modeling complexity.
        """
        ...
