import torch
import torch.nn as nn
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
from einops import rearrange


class VoxCPMLocEnc(nn.Module):
    def __init__(self, config: MiniCPM4Config, input_dim: int = 64):
        super().__init__()
        self.config = config
        self.special_token = nn.Parameter(torch.randn(1, 1, 1, config.hidden_size))
        self.in_proj = nn.Linear(input_dim, config.hidden_size, bias=True)

        assert config.vocab_size == 0, "vocab_size must be 0 for local encoder"
        self.encoder = MiniCPMModel(config)

    def forward(self, x):
        """
        x: [B, T, P, D]
        """
        B, T, P, D = x.shape

        x = self.in_proj(x)
        special_tokens = self.special_token.expand(B, T, 1, -1)
        x = torch.cat([special_tokens, x], dim=2)
        x = rearrange(x, "b t p c -> (b t) p c")
        outputs, _ = self.encoder(x, is_causal=False)
        cls_output = outputs[:, 0, :]

        return rearrange(cls_output, "(b t) c -> b t c", b=B)
