"""
LeWM TTS v6 — JEPA + VQ: Discrete token prediction for stable AR synthesis.

Key insight: continuous AR prediction accumulates errors every step → noise.
Discrete tokens can't drift — a predicted token is always a valid codebook entry.

Architecture:
  - TextEncoder: byte-level transformer (same as before)
  - AudioEncoder: CNN + Transformer → continuous embeddings (same as before)
  - VectorQuantizer: continuous → discrete codebook indices + quantized embeddings
  - JEPAPredictor: predicts next codebook INDEX (classification, not regression)
  - MelDecoder: codebook embeddings → mel spectrogram
  - Vocos: mel → waveform (external, frozen)
"""

import math
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F


# ─── Vector Quantizer ─────────────────────────────────────────────────────────

class VectorQuantizer(nn.Module):
    """
    Discretizes continuous embeddings into codebook indices.
    Uses EMA codebook updates (no codebook loss needed in optimizer).
    Straight-through estimator for gradients.
    """

    def __init__(self, d_model=256, n_codes=1024, commitment_weight=0.25, ema_decay=0.99):
        super().__init__()
        self.d_model = d_model
        self.n_codes = n_codes
        self.commitment_weight = commitment_weight
        self.ema_decay = ema_decay

        # Codebook: [n_codes, d_model]
        self.register_buffer("codebook", torch.randn(n_codes, d_model))
        self.register_buffer("ema_count", torch.ones(n_codes))
        self.register_buffer("ema_sum", self.codebook.clone())
        self._initialized = False

    def _init_codebook(self, flat_z):
        """Initialize codebook from first batch using k-means++ style."""
        if self._initialized:
            return
        n = min(flat_z.shape[0], self.n_codes)
        indices = torch.randperm(flat_z.shape[0], device=flat_z.device)[:n]
        self.codebook[:n] = flat_z[indices].detach()
        self.ema_sum.copy_(self.codebook)
        self._initialized = True

    def forward(self, z):
        """
        Args:
            z: [B, T, d_model] continuous embeddings
        Returns:
            z_q: [B, T, d_model] quantized embeddings (straight-through)
            indices: [B, T] codebook indices
            commit_loss: scalar commitment loss
        """
        B, T, D = z.shape
        flat_z = z.reshape(-1, D)  # [B*T, D]

        # Initialize codebook from first batch
        if self.training and not self._initialized:
            self._init_codebook(flat_z)

        # Find nearest codebook entry: ||z - e||^2 = ||z||^2 - 2<z,e> + ||e||^2
        dists = (
            flat_z.pow(2).sum(dim=1, keepdim=True)
            - 2 * flat_z @ self.codebook.t()
            + self.codebook.pow(2).sum(dim=1, keepdim=True).t()
        )  # [B*T, n_codes]

        indices = dists.argmin(dim=1)  # [B*T]
        z_q = self.codebook[indices]   # [B*T, D]

        # EMA codebook update (only during training)
        if self.training:
            with torch.no_grad():
                one_hot = F.one_hot(indices, self.n_codes).float()  # [B*T, n_codes]
                counts = one_hot.sum(dim=0)  # [n_codes]
                sums = one_hot.t() @ flat_z  # [n_codes, D]

                self.ema_count.mul_(self.ema_decay).add_(counts, alpha=1 - self.ema_decay)
                self.ema_sum.mul_(self.ema_decay).add_(sums, alpha=1 - self.ema_decay)

                # Laplace smoothing
                n = self.ema_count.sum()
                count_smooth = (self.ema_count + 1e-5) / (n + self.n_codes * 1e-5) * n
                self.codebook.copy_(self.ema_sum / count_smooth.unsqueeze(1))

                # Reset dead codes: replace with random encoder outputs
                dead = counts == 0
                n_dead = dead.sum().item()
                if n_dead > 0:
                    # Pick random live embeddings to replace dead codes
                    rand_idx = torch.randint(0, flat_z.shape[0], (n_dead,), device=flat_z.device)
                    self.codebook[dead] = flat_z[rand_idx].detach()
                    self.ema_sum[dead] = flat_z[rand_idx].detach()
                    self.ema_count[dead] = 1.0

        # Commitment loss: push encoder output toward codebook
        commit_loss = F.mse_loss(flat_z, z_q.detach())

        # Straight-through estimator: z_q gets gradients of z (both are flat here)
        z_q = flat_z + (z_q - flat_z).detach()

        z_q = z_q.reshape(B, T, D)
        indices = indices.reshape(B, T)

        return z_q, indices, self.commitment_weight * commit_loss


# ─── Shared components (same as model.py) ─────────────────────────────────────

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=8192):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.shape[1]]


class TextEncoder(nn.Module):
    def __init__(self, vocab_size=256, d_model=256, nhead=4, num_layers=4, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.char_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = SinusoidalPositionalEncoding(d_model, max_len=2048)
        self.dropout = nn.Dropout(dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
            dropout=dropout, batch_first=True, activation="gelu",
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.proj = nn.Linear(d_model, d_model)

    def forward(self, text_tokens, text_mask=None):
        x = self.char_embed(text_tokens) * math.sqrt(self.d_model)
        x = self.pos_embed(x)
        x = self.dropout(x)
        x = self.transformer(x, src_key_padding_mask=text_mask)
        return self.proj(x)


class AudioEncoder(nn.Module):
    def __init__(self, n_mels=100, d_model=256, nhead=4, num_layers=4,
                 downsample_factor=4, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.downsample_factor = downsample_factor

        if downsample_factor == 2:
            self.conv_pre = nn.Sequential(
                nn.Conv1d(n_mels, d_model, kernel_size=7, padding=3), nn.GELU(),
                nn.Conv1d(d_model, d_model, kernel_size=4, stride=2, padding=1), nn.GELU(),
            )
        else:
            self.conv_pre = nn.Sequential(
                nn.Conv1d(n_mels, d_model, kernel_size=7, padding=3), nn.GELU(),
                nn.Conv1d(d_model, d_model, kernel_size=4, stride=2, padding=1), nn.GELU(),
                nn.Conv1d(d_model, d_model, kernel_size=4, stride=2, padding=1), nn.GELU(),
            )

        self.pos_embed = SinusoidalPositionalEncoding(d_model, max_len=4096)
        self.dropout = nn.Dropout(dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
            dropout=dropout, batch_first=True, activation="gelu",
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.proj = nn.Linear(d_model, d_model)

    def forward(self, mel, mel_mask=None):
        x = self.conv_pre(mel)
        x = x.transpose(1, 2)
        if mel_mask is not None:
            T_down = x.shape[1]
            mel_mask = mel_mask[:, ::self.downsample_factor][:, :T_down]
        x = self.pos_embed(x)
        x = self.dropout(x)
        x = self.transformer(x, src_key_padding_mask=mel_mask)
        return self.proj(x)  # [B, T_down, d_model]


# ─── Predictor (now predicts codebook logits) ─────────────────────────────────

class TokenPredictor(nn.Module):
    """
    Causal transformer that predicts next CODEBOOK INDEX.
    Input: quantized embeddings (codebook lookups)
    Output: logits over codebook entries [B, T, n_codes]
    """

    def __init__(self, d_model=256, nhead=4, num_layers=6, n_codes=1024, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_codes = n_codes

        self.input_proj = nn.Linear(d_model, d_model)
        self.pos_embed = SinusoidalPositionalEncoding(d_model, max_len=4096)
        self.dropout = nn.Dropout(dropout)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
            dropout=dropout, batch_first=True, activation="gelu",
        )
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # Output head: predict codebook index (classification, not regression!)
        self.output_head = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, n_codes),
        )

    def forward(self, audio_emb, text_emb, audio_mask=None, text_mask=None):
        x = self.input_proj(audio_emb)
        x = self.pos_embed(x)
        x = self.dropout(x)

        T = x.shape[1]
        causal_mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)

        x = self.transformer(
            x, text_emb,
            tgt_mask=causal_mask,
            tgt_key_padding_mask=audio_mask,
            memory_key_padding_mask=text_mask,
        )

        return self.output_head(x)  # [B, T, n_codes]

    def init_cache(self, num_layers):
        return {
            'self_k': [None] * num_layers,
            'self_v': [None] * num_layers,
            'cross_k': [None] * num_layers,
            'cross_v': [None] * num_layers,
        }

    def inference_step(self, new_emb, text_emb, step_idx, cache, text_mask=None):
        """One-step cached inference. Returns logits [B, 1, n_codes] and updated cache."""
        x = self.input_proj(new_emb)
        x = x + self.pos_embed.pe[:, step_idx:step_idx + 1]

        for i, layer in enumerate(self.transformer.layers):
            # Self-attention with cache
            sa_out, cache['self_k'][i], cache['self_v'][i] = self._cached_mha(
                layer.self_attn, x, x, cache['self_k'][i], cache['self_v'][i]
            )
            x = layer.norm1(x + layer.dropout1(sa_out))

            # Cross-attention with cached text
            ca_out, cache['cross_k'][i], cache['cross_v'][i] = self._cached_cross_attn(
                layer.multihead_attn, x, text_emb,
                cache['cross_k'][i], cache['cross_v'][i], text_mask
            )
            x = layer.norm2(x + layer.dropout2(ca_out))

            # FFN
            ff_out = layer.linear2(layer.dropout(layer.activation(layer.linear1(x))))
            x = layer.norm3(x + layer.dropout3(ff_out))

        return self.output_head(x), cache  # [B, 1, n_codes]

    def _cached_mha(self, mha, q, kv, cache_k, cache_v):
        d, nhead = mha.embed_dim, mha.num_heads
        head_dim = d // nhead
        B = q.shape[0]
        Wq, Wk, Wv = mha.in_proj_weight.chunk(3, dim=0)
        bq, bk, bv = mha.in_proj_bias.chunk(3, dim=0)

        q_proj = F.linear(q, Wq, bq)
        k_new = F.linear(kv, Wk, bk)
        v_new = F.linear(kv, Wv, bv)

        k = torch.cat([cache_k, k_new], dim=1) if cache_k is not None and cache_k.shape[1] > 0 else k_new
        v = torch.cat([cache_v, v_new], dim=1) if cache_v is not None and cache_v.shape[1] > 0 else v_new

        q_proj = q_proj.view(B, 1, nhead, head_dim).transpose(1, 2)
        k_mh = k.view(B, -1, nhead, head_dim).transpose(1, 2)
        v_mh = v.view(B, -1, nhead, head_dim).transpose(1, 2)

        attn = torch.matmul(q_proj, k_mh.transpose(-2, -1)) / (head_dim ** 0.5)
        out = torch.matmul(F.softmax(attn, dim=-1), v_mh)
        out = out.transpose(1, 2).contiguous().view(B, 1, d)
        return F.linear(out, mha.out_proj.weight, mha.out_proj.bias), k, v

    def _cached_cross_attn(self, mha, q, memory, cache_k, cache_v, memory_mask=None):
        d, nhead = mha.embed_dim, mha.num_heads
        head_dim = d // nhead
        B = q.shape[0]
        Wq, Wk, Wv = mha.in_proj_weight.chunk(3, dim=0)
        bq, bk, bv = mha.in_proj_bias.chunk(3, dim=0)

        q_proj = F.linear(q, Wq, bq)
        if cache_k is None:
            k = F.linear(memory, Wk, bk)
            v = F.linear(memory, Wv, bv)
        else:
            k, v = cache_k, cache_v

        q_proj = q_proj.view(B, 1, nhead, head_dim).transpose(1, 2)
        k_mh = k.view(B, -1, nhead, head_dim).transpose(1, 2)
        v_mh = v.view(B, -1, nhead, head_dim).transpose(1, 2)

        attn = torch.matmul(q_proj, k_mh.transpose(-2, -1)) / (head_dim ** 0.5)
        if memory_mask is not None:
            attn = attn.masked_fill(memory_mask.unsqueeze(1).unsqueeze(2), float('-inf'))
        out = torch.matmul(F.softmax(attn, dim=-1), v_mh)
        out = out.transpose(1, 2).contiguous().view(B, 1, d)
        return F.linear(out, mha.out_proj.weight, mha.out_proj.bias), k, v


# ─── Mel Decoder ──────────────────────────────────────────────────────────────

class ResConvBlock(nn.Module):
    def __init__(self, channels, kernel_size=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(channels, channels, kernel_size, padding=kernel_size // 2),
            nn.GELU(),
            nn.Conv1d(channels, channels, kernel_size, padding=kernel_size // 2),
        )

    def forward(self, x):
        return x + self.net(x)


class MelDecoder(nn.Module):
    def __init__(self, d_model=256, n_mels=100, upsample_factor=4):
        super().__init__()
        hidden = d_model * 2

        self.proj = nn.Sequential(nn.Linear(d_model, hidden), nn.GELU())

        if upsample_factor == 2:
            self.upsample = nn.Sequential(
                nn.ConvTranspose1d(hidden, hidden, kernel_size=4, stride=2, padding=1), nn.GELU(),
                ResConvBlock(hidden, kernel_size=3),
            )
        else:
            self.upsample = nn.Sequential(
                nn.ConvTranspose1d(hidden, hidden, kernel_size=4, stride=2, padding=1), nn.GELU(),
                ResConvBlock(hidden, kernel_size=3),
                nn.ConvTranspose1d(hidden, hidden, kernel_size=4, stride=2, padding=1), nn.GELU(),
                ResConvBlock(hidden, kernel_size=3),
            )

        self.out = nn.Sequential(
            nn.Conv1d(hidden, hidden, kernel_size=7, padding=3), nn.GELU(),
            ResConvBlock(hidden, kernel_size=5),
            nn.Conv1d(hidden, n_mels, kernel_size=7, padding=3),
        )

    def forward(self, z):
        x = self.proj(z).transpose(1, 2)
        x = self.upsample(x)
        return self.out(x)


# ─── Multi-Resolution Spectral Loss ──────────────────────────────────────────

class MultiResolutionSpectralLoss(nn.Module):
    def __init__(self, resolutions=((64, 16, 64), (128, 32, 128), (256, 64, 256))):
        super().__init__()
        self.resolutions = resolutions

    def _stft_loss(self, pred, target, n_fft, hop_length, win_length):
        window = torch.hann_window(win_length, device=pred.device)
        B, M, T = pred.shape
        pred_stft = torch.stft(pred.reshape(B*M, T), n_fft, hop_length, win_length, window=window, return_complex=True)
        target_stft = torch.stft(target.reshape(B*M, T), n_fft, hop_length, win_length, window=window, return_complex=True)
        sc = torch.norm(target_stft.abs() - pred_stft.abs(), p="fro") / (torch.norm(target_stft.abs(), p="fro") + 1e-8)
        log_mag = F.l1_loss(torch.log(pred_stft.abs() + 1e-7), torch.log(target_stft.abs() + 1e-7))
        return sc + log_mag

    def forward(self, pred, target):
        loss = torch.tensor(0.0, device=pred.device)
        count = 0
        for n_fft, hop, win in self.resolutions:
            if pred.shape[2] >= n_fft:
                loss = loss + self._stft_loss(pred, target, n_fft, hop, win)
                count += 1
        return loss / max(count, 1)


# ─── Full Model ──────────────────────────────────────────────────────────────

class LeWMTTSvq(nn.Module):
    """
    JEPA + VQ TTS model.

    Training flow:
      mel → AudioEncoder → continuous z → VQ → discrete indices + quantized z_q
      [start, z_q[:-1]] + text → TokenPredictor → logits → cross_entropy(logits, indices)
      z_q → MelDecoder → mel_recon → L1 + spectral loss

    Inference flow:
      text → TextEncoder → text_emb
      start_emb → Predictor → logits → argmax → codebook lookup → next input → repeat
      all codebook embeddings → MelDecoder → mel → Vocos → waveform

    No continuous drift possible — every predicted embedding is a valid codebook entry.
    """

    def __init__(self, config):
        super().__init__()
        d_model = config.get("d_model", 256)
        nhead = config.get("nhead", 4)
        n_mels = config.get("n_mels", 100)
        text_vocab = config.get("text_vocab_size", 256)
        text_layers = config.get("text_encoder_layers", 4)
        audio_layers = config.get("audio_encoder_layers", 4)
        predictor_layers = config.get("predictor_layers", 6)
        dropout = config.get("dropout", 0.1)
        n_codes = config.get("n_codes", 1024)
        downsample_factor = config.get("downsample_factor", 4)

        self.text_encoder = TextEncoder(
            vocab_size=text_vocab, d_model=d_model, nhead=nhead,
            num_layers=text_layers, dropout=dropout,
        )
        self.audio_encoder = AudioEncoder(
            n_mels=n_mels, d_model=d_model, nhead=nhead,
            num_layers=audio_layers, downsample_factor=downsample_factor,
            dropout=dropout,
        )
        self.vq = VectorQuantizer(
            d_model=d_model, n_codes=n_codes,
            commitment_weight=config.get("commitment_weight", 0.25),
        )
        self.predictor = TokenPredictor(
            d_model=d_model, nhead=nhead,
            num_layers=predictor_layers, n_codes=n_codes, dropout=dropout,
        )
        self.mel_decoder = MelDecoder(
            d_model=d_model, n_mels=n_mels, upsample_factor=downsample_factor,
        )

        # Learnable start embedding
        self.start_emb = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)

        # Loss weights
        self.pred_weight = config.get("pred_weight", 1.0)
        self.recon_weight = config.get("recon_weight", 1.0)
        self.spectral_weight = config.get("spectral_weight", 0.5)

        # Spectral loss
        self.spectral_loss_fn = MultiResolutionSpectralLoss()

        # Label smoothing for token prediction (prevents overconfidence)
        self.label_smoothing = config.get("label_smoothing", 0.1)

        self.config = config
        self.downsample_factor = downsample_factor

    def forward(self, mel, text_tokens, mel_mask=None, text_mask=None, **kwargs):
        B = mel.shape[0]

        # Encode text
        text_emb = self.text_encoder(text_tokens, text_mask)

        # Encode audio → continuous → quantize to discrete
        z_continuous = self.audio_encoder(mel, mel_mask)  # [B, T_down, d]
        z_q, indices, commit_loss = self.vq(z_continuous)  # [B, T_down, d], [B, T_down]

        T_down = z_q.shape[1]

        # Predictor input: [start, z_q[0], z_q[1], ..., z_q[T-2]]
        # Predictor target: [idx[0], idx[1], ..., idx[T-1]]
        start = self.start_emb.expand(B, -1, -1)
        input_emb = torch.cat([start, z_q[:, :-1]], dim=1)  # [B, T_down, d]

        # Masks
        if mel_mask is not None:
            ds_mask = mel_mask[:, ::self.downsample_factor][:, :T_down]
            start_mask = torch.zeros(B, 1, dtype=torch.bool, device=mel.device)
            pred_mask = torch.cat([start_mask, ds_mask[:, :-1]], dim=1)
            loss_mask = ds_mask
        else:
            pred_mask = None
            loss_mask = None

        # Predict next token logits
        logits = self.predictor(input_emb, text_emb, pred_mask, text_mask)  # [B, T_down, n_codes]

        # ─── Loss 1: Token prediction (cross-entropy with label smoothing) ───
        if loss_mask is not None:
            # Flatten valid positions only
            valid = ~loss_mask  # [B, T_down]
            logits_flat = logits[valid]  # [N_valid, n_codes]
            targets_flat = indices[valid]  # [N_valid]
            if logits_flat.shape[0] > 0:
                token_loss = F.cross_entropy(logits_flat, targets_flat,
                                              label_smoothing=self.label_smoothing)
            else:
                token_loss = torch.tensor(0.0, device=mel.device)
        else:
            token_loss = F.cross_entropy(
                logits.reshape(-1, logits.shape[-1]), indices.reshape(-1),
                label_smoothing=self.label_smoothing,
            )

        # ─── Loss 2: VQ commitment loss ───
        # (already computed in VQ forward)

        # ─── Loss 3: Mel reconstruction from quantized embeddings ───
        mel_recon = self.mel_decoder(z_q)
        T_mel = mel.shape[2]
        T_recon = mel_recon.shape[2]
        T_min = min(T_mel, T_recon)
        mel = mel[:, :, :T_min]
        mel_recon = mel_recon[:, :, :T_min]
        if mel_mask is not None:
            mel_mask = mel_mask[:, :T_min]

        if mel_mask is not None:
            valid_mel = (~mel_mask).unsqueeze(1).float()
            recon_loss = F.l1_loss(
                mel_recon * valid_mel, mel * valid_mel, reduction="sum"
            ) / (valid_mel.sum() * mel.shape[1] + 1e-8)
        else:
            recon_loss = F.l1_loss(mel_recon, mel)

        # ─── Loss 4: Spectral loss ───
        if mel_mask is not None:
            spectral_loss = self.spectral_loss_fn(mel_recon * valid_mel, mel * valid_mel)
        else:
            spectral_loss = self.spectral_loss_fn(mel_recon, mel)

        # ─── Token accuracy (for monitoring) ───
        with torch.no_grad():
            if loss_mask is not None and logits_flat.shape[0] > 0:
                token_acc = (logits_flat.argmax(dim=-1) == targets_flat).float().mean()
            else:
                token_acc = (logits.argmax(dim=-1) == indices).float().mean()

        total_loss = (self.pred_weight * token_loss
                      + commit_loss
                      + self.recon_weight * recon_loss
                      + self.spectral_weight * spectral_loss)

        return {
            "total_loss": total_loss,
            "token_loss": token_loss,
            "commit_loss": commit_loss,
            "recon_loss": recon_loss,
            "spectral_loss": spectral_loss,
            "token_accuracy": token_acc,
        }

    # ─── Inference ────────────────────────────────────────────────────────────

    def encode_to_tokens(self, mel):
        """Encode mel → discrete token indices."""
        z = self.audio_encoder(mel)
        _, indices, _ = self.vq(z)
        return indices

    @torch.no_grad()
    def synthesize_tokens(self, text_tokens, max_steps=300, temperature=1.0,
                          top_k=50, text_mask=None):
        """AR token generation: text → token indices."""
        text_emb = self.text_encoder(text_tokens, text_mask)

        cache = self.predictor.init_cache(len(self.predictor.transformer.layers))
        emb = self.start_emb
        all_indices = []

        for step in range(max_steps):
            logits, cache = self.predictor.inference_step(
                emb, text_emb, step, cache, text_mask
            )  # [B, 1, n_codes]

            logits = logits.squeeze(1)  # [B, n_codes]

            # Temperature + top-k sampling
            logits = logits / max(temperature, 1e-8)
            if top_k > 0:
                topk_vals, _ = logits.topk(top_k, dim=-1)
                logits[logits < topk_vals[:, -1:]] = float('-inf')

            probs = F.softmax(logits, dim=-1)
            idx = torch.multinomial(probs, 1)  # [B, 1]
            all_indices.append(idx)

            # Lookup codebook for next input
            emb = self.vq.codebook[idx]  # [B, 1, d]

        return torch.cat(all_indices, dim=1)  # [B, max_steps]

    @torch.no_grad()
    def tokens_to_mel(self, indices):
        """Convert token indices → mel spectrogram."""
        z_q = self.vq.codebook[indices]  # [B, T, d]
        return self.mel_decoder(z_q)  # [B, n_mels, T_up]

    @torch.no_grad()
    def synthesize_mel(self, text_tokens, max_steps=300, temperature=1.0,
                       top_k=50, text_mask=None):
        """Full pipeline: text → tokens → mel."""
        indices = self.synthesize_tokens(
            text_tokens, max_steps, temperature, top_k, text_mask
        )
        return self.tokens_to_mel(indices), indices


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def build_model_vq(config=None):
    if config is None:
        config = {
            "d_model": 256,
            "nhead": 4,
            "n_mels": 100,
            "text_vocab_size": 256,
            "text_encoder_layers": 4,
            "audio_encoder_layers": 4,
            "predictor_layers": 6,
            "n_codes": 1024,
            "dropout": 0.1,
            "pred_weight": 1.0,
            "recon_weight": 1.0,
            "spectral_weight": 0.5,
            "commitment_weight": 0.25,
            "label_smoothing": 0.1,
            "downsample_factor": 4,
        }
    model = LeWMTTSvq(config)
    print(f"LeWM TTS VQ model: {count_parameters(model)/1e6:.2f}M parameters")
    return model, config
