"""XCodec2 wrapper: HKUSTAudio/xcodec2, 16 kHz, single VQ, ~50 TPS."""

from __future__ import annotations

import logging
from typing import Any

import torch
import torchaudio

from codecbench.codecs.base import NeuralCodec, TokenBatch
from codecbench.codecs import register_codec

logger = logging.getLogger(__name__)


@register_codec
class XCodec2Codec:
    name: str = "xcodec2"
    native_sr: int = 16_000

    def __init__(self, model_id: str = "HKUSTAudio/xcodec2"):
        self._model_id = model_id
        self._model = None
        self._device = "cpu"
        self._dtype = torch.float32

    def load(self, device: str = "cuda", dtype: torch.dtype = torch.float32) -> None:
        from xcodec2.modeling_xcodec2 import XCodec2Model

        self._device = device
        self._dtype = dtype
        self._model = XCodec2Model.from_pretrained(self._model_id)
        self._model.eval().to(device)
        logger.info("XCodec2 loaded on %s", device)

    def warmup(self, batch_seconds: float = 6.0, batch_size: int = 1) -> None:
        assert self._model is not None, "Call load() first"
        n_samples = int(batch_seconds * self.native_sr)
        dummy = torch.randn(batch_size, 1, n_samples, device=self._device)
        for _ in range(3):
            with torch.inference_mode():
                tb = self.encode(dummy, self.native_sr)
                _ = self.decode(tb)
        torch.cuda.synchronize()
        logger.info("XCodec2 warmup complete")

    def _resample_if_needed(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
        if sr != self.native_sr:
            wav = torchaudio.functional.resample(wav, sr, self.native_sr)
        return wav

    @torch.inference_mode()
    def encode(self, wav: torch.Tensor, sr: int) -> TokenBatch:
        """Encode [B, 1, T] -> TokenBatch with tokens [B, T_tok].

        XCodec2 expects input_waveform as [B, T] (no channel dim).
        Attempts true batching first; falls back to sequential if it fails.
        """
        wav = self._resample_if_needed(wav, sr).to(self._device)
        # XCodec2 API expects [B, T] — squeeze channel dim
        wav_2d = wav.squeeze(1)

        try:
            codes = self._model.encode_code(input_waveform=wav_2d)
            # codes shape varies by version; ensure [B, T_tok]
            if codes.ndim == 3:
                codes = codes.squeeze(1)
        except Exception:
            logger.warning("XCodec2 batch encode failed, falling back to sequential")
            code_list = []
            for i in range(wav_2d.shape[0]):
                c = self._model.encode_code(input_waveform=wav_2d[i : i + 1])
                if c.ndim == 3:
                    c = c.squeeze(1)
                code_list.append(c)
            codes = torch.cat(code_list, dim=0)

        return TokenBatch(
            codec_name=self.name,
            sample_rate=self.native_sr,
            tokens=codes,
        )

    @torch.inference_mode()
    def decode(self, tb: TokenBatch) -> torch.Tensor:
        """Decode tokens -> [B, 1, T]."""
        codes = tb.tokens.to(self._device)
        # encode_code returns [B, T_tok], decode expects same or [B, n_q, T_tok]
        if codes.ndim == 2:
            codes = codes.unsqueeze(1)
        audio = self._model.decode_code(codes)
        # ensure [B, 1, T]
        if audio.ndim == 2:
            audio = audio.unsqueeze(1)
        return audio

    def flatten_for_lm(self, tb: TokenBatch) -> torch.LongTensor:
        """Already single-stream: [B, T_tok]."""
        t = tb.tokens
        if t.ndim == 3:
            t = t.squeeze(1)
        return t.long()
