"""BiCodec wrapper: Spark-TTS audio tokenizer, 16 kHz.

Produces two token types:
  - semantic_tokens: 50 TPS, variable-length, captures linguistic content
  - global_tokens:   fixed-length, captures speaker/prosody attributes

Both are required for reconstruction.
"""

from __future__ import annotations

import logging

import torch
import torchaudio

from codecbench.codecs.base import NeuralCodec, TokenBatch
from codecbench.codecs import register_codec

logger = logging.getLogger(__name__)


@register_codec
class BiCodecCodec:
    name: str = "bicodec"
    native_sr: int = 16_000

    def __init__(self, model_dir: str | None = None):
        self._model_dir = model_dir
        self._model = None
        self._device = "cpu"
        self._dtype = torch.float32

    def load(self, device: str = "cuda", dtype: torch.dtype = torch.float32) -> None:
        """Load BiCodec tokenizer from Spark-TTS.

        Expects a local checkout or HF download of the Spark-TTS model directory
        containing the BiCodec checkpoint and config.
        """
        self._device = device
        self._dtype = dtype

        if self._model_dir is None:
            raise RuntimeError(
                "BiCodec requires model_dir pointing to a Spark-TTS checkout. "
                "Download from: https://huggingface.co/SparkAudio/Spark-TTS-0.5B"
            )

        try:
            # Spark-TTS BiCodecTokenizer interface
            import sys
            from pathlib import Path

            spark_root = Path(self._model_dir)
            if (spark_root / "sparktts").exists():
                sys.path.insert(0, str(spark_root))

            from sparktts.models.audio_tokenizer import BiCodecTokenizer

            self._model = BiCodecTokenizer(self._model_dir, device=device)
            logger.info("BiCodec loaded from %s on %s", self._model_dir, device)
        except ImportError as e:
            raise RuntimeError(
                f"Could not import BiCodecTokenizer. Ensure Spark-TTS repo is available. "
                f"Error: {e}"
            ) from e

    def warmup(self, batch_seconds: float = 6.0, batch_size: int = 1) -> None:
        assert self._model is not None, "Call load() first"
        n_samples = int(batch_seconds * self.native_sr)
        dummy = torch.randn(batch_size, 1, n_samples, device=self._device)
        for _ in range(3):
            with torch.inference_mode():
                tb = self.encode(dummy, self.native_sr)
                _ = self.decode(tb)
        torch.cuda.synchronize()
        logger.info("BiCodec warmup complete")

    def _resample_if_needed(self, wav: torch.Tensor, sr: int) -> torch.Tensor:
        if sr != self.native_sr:
            wav = torchaudio.functional.resample(wav, sr, self.native_sr)
        return wav

    @torch.inference_mode()
    def encode(self, wav: torch.Tensor, sr: int) -> TokenBatch:
        """Encode [B, 1, T] -> TokenBatch with semantic + global tokens.

        BiCodecTokenizer.tokenize returns (semantic_tokens, global_tokens).
        We store as dict-based TokenBatch for explicit naming.
        """
        wav = self._resample_if_needed(wav, sr).to(self._device)
        B = wav.shape[0]

        semantic_list = []
        global_list = []

        for i in range(B):
            single_wav = wav[i]  # [1, T]
            result = self._model.tokenize(single_wav)
            if isinstance(result, tuple) and len(result) == 2:
                sem_tok, glob_tok = result
            elif hasattr(result, "semantic_tokens"):
                sem_tok = result.semantic_tokens
                glob_tok = result.global_tokens
            else:
                raise ValueError(f"Unexpected tokenize output type: {type(result)}")

            if isinstance(sem_tok, torch.Tensor):
                semantic_list.append(sem_tok)
            else:
                semantic_list.append(torch.tensor(sem_tok, device=self._device))

            if isinstance(glob_tok, torch.Tensor):
                global_list.append(glob_tok)
            else:
                global_list.append(torch.tensor(glob_tok, device=self._device))

        # Pad semantic tokens to same length across batch
        max_sem_len = max(s.shape[-1] for s in semantic_list)
        padded_sem = torch.zeros(B, max_sem_len, dtype=torch.long, device=self._device)
        for i, s in enumerate(semantic_list):
            padded_sem[i, : s.shape[-1]] = s.flatten()

        # Global tokens are fixed-length, just stack
        global_tokens = torch.stack(
            [g.flatten() for g in global_list], dim=0
        )

        return TokenBatch(
            codec_name=self.name,
            sample_rate=self.native_sr,
            tokens={"semantic": padded_sem, "global": global_tokens},
            aux={
                "semantic_lengths": [s.shape[-1] for s in semantic_list],
                "global_token_dim": global_tokens.shape[-1],
            },
        )

    @torch.inference_mode()
    def decode(self, tb: TokenBatch) -> torch.Tensor:
        """Decode semantic + global tokens -> [B, 1, T].

        Both token types are required for reconstruction.
        """
        semantic = tb.tokens["semantic"].to(self._device)
        global_tok = tb.tokens["global"].to(self._device)
        B = semantic.shape[0]

        audio_list = []
        for i in range(B):
            sem_len = tb.aux.get("semantic_lengths", [semantic.shape[-1]])[i]
            sem_i = semantic[i, :sem_len]
            glob_i = global_tok[i]
            audio = self._model.detokenize(sem_i.unsqueeze(0), glob_i.unsqueeze(0))
            if isinstance(audio, torch.Tensor):
                if audio.ndim == 1:
                    audio = audio.unsqueeze(0).unsqueeze(0)
                elif audio.ndim == 2:
                    audio = audio.unsqueeze(1)
                audio_list.append(audio)
            else:
                audio_list.append(
                    torch.tensor(audio, device=self._device).unsqueeze(0).unsqueeze(0)
                )

        # Pad to same output length
        max_len = max(a.shape[-1] for a in audio_list)
        result = torch.zeros(B, 1, max_len, device=self._device)
        for i, a in enumerate(audio_list):
            result[i, :, : a.shape[-1]] = a.squeeze(0)
        return result

    def flatten_for_lm(self, tb: TokenBatch) -> torch.LongTensor:
        """Concatenate semantic + global tokens for LM sequence length estimate.

        Global tokens use a separate vocab range (offset by max semantic ID + 1).
        """
        semantic = tb.tokens["semantic"]
        global_tok = tb.tokens["global"]

        # Offset global token IDs to avoid collision
        offset = semantic.max().item() + 1
        global_offset = global_tok + offset

        return torch.cat([semantic, global_offset], dim=-1).long()
