"""
LeWM TTS Model — JEPA-based Text-to-Speech

Architecture:
  - TextEncoder: Character-level transformer for Hindi/Devanagari text
  - AudioEncoder: 1D CNN + Transformer that encodes mel spectrograms to embeddings
    - Returns intermediate layer outputs for multi-scale targets
  - JEPAPredictor: Transformer that predicts next audio embedding given text + audio context
  - SpeakerConditioner: FiLM-based speaker conditioning (scale + shift)
  - MelDecoder: FiLM-conditioned decoder with speaker-aware residual blocks
  - Losses: prediction (MSE + cosine), KL, reconstruction, spectral, speaker consistency
"""

import math
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio


# ─── FiLM Conditioning ────────────────────────────────────────────────────────

class FiLMConditioner(nn.Module):
    """Feature-wise Linear Modulation: x → gamma * x + beta.
    Learns per-speaker scale (gamma) and shift (beta) from a conditioning vector.
    Much more expressive than simple addition — can reshape the entire distribution."""

    def __init__(self, cond_dim, channels):
        super().__init__()
        self.proj = nn.Linear(cond_dim, channels * 2)
        # Initialize near-identity: gamma≈1, beta≈0
        # Small non-zero weights so gradients flow back to conditioning input
        nn.init.normal_(self.proj.weight, std=0.02)
        nn.init.zeros_(self.proj.bias)
        self.proj.bias.data[:channels] = 1.0  # gamma starts at 1

    def forward(self, x, cond):
        """
        Args:
            x: [B, T, d] or [B, d, T] (set channel_last accordingly)
            cond: [B, d_cond] — speaker conditioning vector
        Returns: modulated x, same shape
        """
        params = self.proj(cond)  # [B, channels*2]
        gamma, beta = params.chunk(2, dim=-1)  # each [B, channels]

        if x.dim() == 3:
            if x.shape[-1] == gamma.shape[-1]:
                # [B, T, d] format
                gamma = gamma.unsqueeze(1)  # [B, 1, d]
                beta = beta.unsqueeze(1)
            else:
                # [B, d, T] format (conv layers)
                gamma = gamma.unsqueeze(2)  # [B, d, 1]
                beta = beta.unsqueeze(2)
        return gamma * x + beta


# ─── Text Encoder ───────────────────────────────────────────────────────────

class TextEncoder(nn.Module):
    """Character-level transformer encoder for Hindi text."""

    def __init__(self, vocab_size=256, d_model=256, nhead=4, num_layers=4, dropout=0.1):
        super().__init__()
        self.d_model = d_model

        # Character embedding (covers all Unicode Hindi chars via byte-level encoding)
        self.char_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = SinusoidalPositionalEncoding(d_model, max_len=2048)
        self.dropout = nn.Dropout(dropout)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead,
            dim_feedforward=d_model * 4,
            dropout=dropout, batch_first=True,
            activation="gelu",
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.proj = nn.Linear(d_model, d_model)

    def forward(self, text_tokens, text_mask=None):
        """
        Args:
            text_tokens: [B, T_text] long tensor of byte-level token IDs
            text_mask: [B, T_text] bool, True = padding
        Returns:
            text_emb: [B, T_text, d_model]
        """
        x = self.char_embed(text_tokens) * math.sqrt(self.d_model)
        x = self.pos_embed(x)
        x = self.dropout(x)
        x = self.transformer(x, src_key_padding_mask=text_mask)
        x = self.proj(x)
        return x


# ─── Audio Encoder ──────────────────────────────────────────────────────────

class AudioEncoder(nn.Module):
    """
    Encodes mel spectrograms into a sequence of audio embeddings.
    Uses 1D CNN for downsampling + Transformer for contextualization.
    Returns intermediate layer outputs for multi-scale EMA targets.
    """

    def __init__(self, n_mels=100, d_model=256, nhead=4, num_layers=4,
                 downsample_factor=4, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.downsample_factor = downsample_factor
        self.num_layers = num_layers

        # CNN downsampling: reduces temporal resolution by downsample_factor
        if downsample_factor == 2:
            self.conv_pre = nn.Sequential(
                nn.Conv1d(n_mels, d_model, kernel_size=7, padding=3),
                nn.GELU(),
                nn.Conv1d(d_model, d_model, kernel_size=4, stride=2, padding=1),
                nn.GELU(),
            )  # Total stride = 2
        else:
            self.conv_pre = nn.Sequential(
                nn.Conv1d(n_mels, d_model, kernel_size=7, padding=3),
                nn.GELU(),
                nn.Conv1d(d_model, d_model, kernel_size=4, stride=2, padding=1),
                nn.GELU(),
                nn.Conv1d(d_model, d_model, kernel_size=4, stride=2, padding=1),
                nn.GELU(),
            )  # Total stride = 4

        self.pos_embed = SinusoidalPositionalEncoding(d_model, max_len=4096)
        self.dropout = nn.Dropout(dropout)

        # Build layers individually so we can capture intermediate outputs
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=d_model, nhead=nhead,
                dim_feedforward=d_model * 4,
                dropout=dropout, batch_first=True,
                activation="gelu",
            )
            for _ in range(num_layers)
        ])

        # Layer norm after each layer (for multi-scale targets)
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(d_model) for _ in range(num_layers)
        ])

        # Project to embedding with mu and logvar for Gaussian regularizer
        self.proj_mu = nn.Linear(d_model, d_model)
        self.proj_logvar = nn.Linear(d_model, d_model)

    def forward(self, mel, mel_mask=None, return_intermediates=False):
        """
        Args:
            mel: [B, n_mels, T_mel] — log mel spectrogram
            mel_mask: [B, T_mel] bool, True = padding (before downsampling)
            return_intermediates: if True, return list of per-layer outputs
        Returns:
            z: [B, T_down, d_model] — sampled latent embeddings
            mu: [B, T_down, d_model]
            logvar: [B, T_down, d_model]
            intermediates: list of [B, T_down, d_model] (only if return_intermediates)
        """
        # CNN encoding: [B, n_mels, T] -> [B, d_model, T//4]
        x = self.conv_pre(mel)
        x = x.transpose(1, 2)  # [B, T_down, d_model]

        # Downsample mask if provided
        if mel_mask is not None:
            # Approximate downsampled mask
            T_down = x.shape[1]
            mel_mask = mel_mask[:, ::self.downsample_factor][:, :T_down]

        x = self.pos_embed(x)
        x = self.dropout(x)

        intermediates = []
        for i, (layer, norm) in enumerate(zip(self.layers, self.layer_norms)):
            x = layer(x, src_key_padding_mask=mel_mask)
            if return_intermediates:
                intermediates.append(norm(x))

        mu = self.proj_mu(x)
        logvar = self.proj_logvar(x)

        # Reparameterization trick
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            z = mu + eps * std
        else:
            z = mu

        if return_intermediates:
            return z, mu, logvar, intermediates
        return z, mu, logvar


# ─── JEPA Predictor ─────────────────────────────────────────────────────────

class JEPAPredictor(nn.Module):
    """
    Predicts next audio embedding given:
    - Text context (from TextEncoder)
    - Previous audio embeddings (from AudioEncoder)

    Uses cross-attention to condition on text, and causal self-attention
    over the audio embedding sequence.
    """

    def __init__(self, d_model=256, nhead=4, num_layers=6, dropout=0.1):
        super().__init__()
        self.d_model = d_model

        self.audio_input_proj = nn.Linear(d_model, d_model)
        self.pos_embed = SinusoidalPositionalEncoding(d_model, max_len=4096)
        self.dropout = nn.Dropout(dropout)

        # Transformer decoder with causal self-attention + cross-attention to text
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=nhead,
            dim_feedforward=d_model * 4,
            dropout=dropout, batch_first=True,
            activation="gelu",
        )
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # Output projection to predict next embedding
        self.output_proj = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model),
        )

    def forward(self, audio_emb, text_emb, audio_mask=None, text_mask=None):
        """
        Args:
            audio_emb: [B, T_audio, d_model] — audio embeddings (from encoder)
            text_emb: [B, T_text, d_model] — text embeddings
            audio_mask: [B, T_audio] bool, True = padding
            text_mask: [B, T_text] bool, True = padding
        Returns:
            predicted: [B, T_audio, d_model] — predicted next embeddings
        """
        x = self.audio_input_proj(audio_emb)
        x = self.pos_embed(x)
        x = self.dropout(x)

        # Causal mask: each position can only attend to previous positions
        T = x.shape[1]
        causal_mask = torch.triu(
            torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1
        )

        predicted = self.transformer(
            x, text_emb,
            tgt_mask=causal_mask,
            tgt_key_padding_mask=audio_mask,
            memory_key_padding_mask=text_mask,
        )

        predicted = self.output_proj(predicted)
        return predicted

    def _cached_mha(self, mha, q_input, kv_input, cache_k, cache_v):
        """Run multi-head attention with KV cache.
        q_input: [B, 1, d], kv_input: [B, 1, d] (new token for K,V)
        cache_k, cache_v: [B, T_prev, d] projected keys/values
        Returns: output [B, 1, d], updated cache_k, updated cache_v
        """
        d = mha.embed_dim
        nhead = mha.num_heads
        head_dim = d // nhead
        B = q_input.shape[0]

        # Project Q from q_input, K/V from kv_input
        # in_proj_weight is [3d, d] = [Wq; Wk; Wv]
        Wq, Wk, Wv = mha.in_proj_weight.chunk(3, dim=0)
        bq, bk, bv = mha.in_proj_bias.chunk(3, dim=0)

        q = F.linear(q_input, Wq, bq)        # [B, 1, d]
        k_new = F.linear(kv_input, Wk, bk)    # [B, 1, d]
        v_new = F.linear(kv_input, Wv, bv)    # [B, 1, d]

        # Append to cache
        if cache_k is not None and cache_k.shape[1] > 0:
            k = torch.cat([cache_k, k_new], dim=1)
            v = torch.cat([cache_v, v_new], dim=1)
        else:
            k = k_new
            v = v_new

        # Multi-head attention
        q = q.view(B, 1, nhead, head_dim).transpose(1, 2)      # [B, nhead, 1, head_dim]
        k_mh = k.view(B, -1, nhead, head_dim).transpose(1, 2)  # [B, nhead, T, head_dim]
        v_mh = v.view(B, -1, nhead, head_dim).transpose(1, 2)

        attn = torch.matmul(q, k_mh.transpose(-2, -1)) / (head_dim ** 0.5)
        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v_mh)  # [B, nhead, 1, head_dim]

        out = out.transpose(1, 2).contiguous().view(B, 1, d)
        out = F.linear(out, mha.out_proj.weight, mha.out_proj.bias)

        return out, k, v

    def _cached_cross_attn(self, mha, q_input, memory, cache_k, cache_v, memory_mask=None):
        """Cross-attention with cached memory K,V (computed once).
        memory_mask: [B, T_text] bool, True = padding (applied as -inf)."""
        d = mha.embed_dim
        nhead = mha.num_heads
        head_dim = d // nhead
        B = q_input.shape[0]

        Wq, Wk, Wv = mha.in_proj_weight.chunk(3, dim=0)
        bq, bk, bv = mha.in_proj_bias.chunk(3, dim=0)

        q = F.linear(q_input, Wq, bq)  # [B, 1, d]

        if cache_k is None:
            k = F.linear(memory, Wk, bk)  # [B, T_text, d]
            v = F.linear(memory, Wv, bv)
        else:
            k, v = cache_k, cache_v

        q = q.view(B, 1, nhead, head_dim).transpose(1, 2)
        k_mh = k.view(B, -1, nhead, head_dim).transpose(1, 2)
        v_mh = v.view(B, -1, nhead, head_dim).transpose(1, 2)

        attn = torch.matmul(q, k_mh.transpose(-2, -1)) / (head_dim ** 0.5)
        if memory_mask is not None:
            # memory_mask: [B, T_text] → [B, 1, 1, T_text]
            attn = attn.masked_fill(memory_mask.unsqueeze(1).unsqueeze(2), float('-inf'))
        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v_mh)

        out = out.transpose(1, 2).contiguous().view(B, 1, d)
        out = F.linear(out, mha.out_proj.weight, mha.out_proj.bias)

        return out, k, v

    def init_cache(self, num_layers):
        """Initialize empty KV cache."""
        return {
            'self_k': [None] * num_layers,
            'self_v': [None] * num_layers,
            'cross_k': [None] * num_layers,
            'cross_v': [None] * num_layers,
        }

    def inference_step(self, new_emb, text_emb, step_idx, cache, text_mask=None):
        """Process one token with KV cache.
        new_emb: [B, 1, d_model] — raw audio embedding (not projected)
        text_mask: [B, T_text] bool, True = padding
        Returns: predicted [B, 1, d_model], updated cache
        """
        x = self.audio_input_proj(new_emb)
        # Add positional encoding for this specific position
        x = x + self.pos_embed.pe[:, step_idx:step_idx + 1]

        for i, layer in enumerate(self.transformer.layers):
            # Self-attention with cache (post-norm)
            sa_out, cache['self_k'][i], cache['self_v'][i] = self._cached_mha(
                layer.self_attn, x, x, cache['self_k'][i], cache['self_v'][i]
            )
            x = layer.norm1(x + layer.dropout1(sa_out))

            # Cross-attention with cached text KV
            ca_out, cache['cross_k'][i], cache['cross_v'][i] = self._cached_cross_attn(
                layer.multihead_attn, x, text_emb, cache['cross_k'][i], cache['cross_v'][i],
                memory_mask=text_mask,
            )
            x = layer.norm2(x + layer.dropout2(ca_out))

            # FFN
            ff_out = layer.linear2(layer.dropout(layer.activation(layer.linear1(x))))
            x = layer.norm3(x + layer.dropout3(ff_out))

        return self.output_proj(x), cache


# ─── Positional Encoding ────────────────────────────────────────────────────

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=8192):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))  # [1, max_len, d_model]

    def forward(self, x):
        return x + self.pe[:, :x.shape[1]]


# ─── Mel Decoder ────────────────────────────────────────────────────────────

class FiLMResConvBlock(nn.Module):
    """Residual conv block with FiLM speaker conditioning.
    When no speaker conditioning is provided, behaves like a normal ResConvBlock."""

    def __init__(self, channels, kernel_size=3, cond_dim=None):
        super().__init__()
        self.conv1 = nn.Conv1d(channels, channels, kernel_size, padding=kernel_size // 2)
        self.conv2 = nn.Conv1d(channels, channels, kernel_size, padding=kernel_size // 2)
        self.has_film = cond_dim is not None
        if self.has_film:
            self.film = FiLMConditioner(cond_dim, channels)

    def forward(self, x, spk_cond=None):
        """x: [B, C, T], spk_cond: [B, d_cond] or None"""
        h = F.gelu(self.conv1(x))
        h = self.conv2(h)
        if self.has_film and spk_cond is not None:
            h = self.film(h, spk_cond)  # FiLM on residual branch
        return x + h


class MelDecoder(nn.Module):
    """Decodes latent embeddings back to mel spectrograms.
    Speaker-conditioned via FiLM in residual blocks."""

    def __init__(self, d_model=256, n_mels=100, upsample_factor=4, n_speakers=1):
        super().__init__()
        hidden = d_model * 2  # 512
        cond_dim = d_model if n_speakers > 1 else None

        self.proj = nn.Sequential(
            nn.Linear(d_model, hidden),
            nn.GELU(),
        )

        # FiLM on input projection (before conv)
        self.has_film = n_speakers > 1
        if self.has_film:
            self.input_film = FiLMConditioner(d_model, hidden)

        if upsample_factor == 2:
            self.up1 = nn.ConvTranspose1d(hidden, hidden, kernel_size=4, stride=2, padding=1)
            self.res1 = FiLMResConvBlock(hidden, kernel_size=3, cond_dim=cond_dim)
            self.up2 = None
            self.res2 = None
        else:
            self.up1 = nn.ConvTranspose1d(hidden, hidden, kernel_size=4, stride=2, padding=1)
            self.res1 = FiLMResConvBlock(hidden, kernel_size=3, cond_dim=cond_dim)
            self.up2 = nn.ConvTranspose1d(hidden, hidden, kernel_size=4, stride=2, padding=1)
            self.res2 = FiLMResConvBlock(hidden, kernel_size=3, cond_dim=cond_dim)

        # Refinement
        self.refine_conv = nn.Conv1d(hidden, hidden, kernel_size=7, padding=3)
        self.refine_res = FiLMResConvBlock(hidden, kernel_size=5, cond_dim=cond_dim)
        self.out_conv = nn.Conv1d(hidden, n_mels, kernel_size=7, padding=3)

    def forward(self, z, spk_cond=None):
        """z: [B, T_down, d_model], spk_cond: [B, d_model] or None → mel: [B, n_mels, T_up]"""
        x = self.proj(z)  # [B, T_down, hidden]

        if self.has_film and spk_cond is not None:
            x = self.input_film(x, spk_cond)  # FiLM on [B, T, hidden]

        x = x.transpose(1, 2)  # [B, hidden, T_down]

        x = F.gelu(self.up1(x))
        x = self.res1(x, spk_cond)
        if self.up2 is not None:
            x = F.gelu(self.up2(x))
            x = self.res2(x, spk_cond)

        x = F.gelu(self.refine_conv(x))
        x = self.refine_res(x, spk_cond)
        return self.out_conv(x)  # [B, n_mels, T_up]


# ─── Multi-Resolution Spectral Loss ─────────────────────────────────────────

class MultiResolutionSpectralLoss(nn.Module):
    """Computes spectral convergence + log magnitude loss at multiple STFT resolutions.
    This forces the mel decoder to produce sharp, well-defined spectral structure
    rather than blurry averages that Vocos can't decode properly."""

    def __init__(self, resolutions=((64, 16, 64), (128, 32, 128), (256, 64, 256))):
        super().__init__()
        # Each resolution is (n_fft, hop_length, win_length)
        self.resolutions = resolutions

    def _stft_loss(self, pred, target, n_fft, hop_length, win_length):
        """Compute spectral convergence + log magnitude loss for one resolution."""
        window = torch.hann_window(win_length, device=pred.device)

        # pred/target: [B, n_mels, T] — treat as 1D signal per mel bin
        # Flatten mel bins into batch: [B*n_mels, T]
        B, M, T = pred.shape
        pred_flat = pred.reshape(B * M, T)
        target_flat = target.reshape(B * M, T)

        pred_stft = torch.stft(pred_flat, n_fft, hop_length, win_length,
                               window=window, return_complex=True)
        target_stft = torch.stft(target_flat, n_fft, hop_length, win_length,
                                 window=window, return_complex=True)

        pred_mag = pred_stft.abs()
        target_mag = target_stft.abs()

        # Spectral convergence: Frobenius norm of difference / Frobenius norm of target
        sc_loss = torch.norm(target_mag - pred_mag, p="fro") / (torch.norm(target_mag, p="fro") + 1e-8)

        # Log magnitude loss
        log_mag_loss = F.l1_loss(
            torch.log(pred_mag + 1e-7),
            torch.log(target_mag + 1e-7),
        )

        return sc_loss + log_mag_loss

    def forward(self, pred, target):
        """pred, target: [B, n_mels, T]"""
        loss = torch.tensor(0.0, device=pred.device)
        count = 0
        for n_fft, hop, win in self.resolutions:
            # Skip if T is too short for this resolution
            if pred.shape[2] >= n_fft:
                loss = loss + self._stft_loss(pred, target, n_fft, hop, win)
                count += 1
        if count > 0:
            loss = loss / count
        return loss


# ─── Full LeWM TTS Model ────────────────────────────────────────────────────

class LeWMTTS(nn.Module):
    """
    Complete LeWM TTS model with speaker-preserving JEPA.

    Speaker fidelity is enforced through 4 mechanisms:
      1. FiLM conditioning (scale+shift) on latents and targets
      2. FiLM-conditioned MelDecoder (speaker-aware reconstruction)
      3. Multi-scale EMA targets (retain acoustic detail across layers)
      4. Cosine + MSE prediction loss (preserves direction = speaker identity)
      5. Speaker consistency classifier on raw encoder output
    """

    def __init__(self, config):
        super().__init__()
        d_model = config.get("d_model", 256)
        nhead = config.get("nhead", 4)
        n_mels = config.get("n_mels", 100)
        text_vocab = config.get("text_vocab_size", 256)
        text_layers = config.get("text_encoder_layers", 4)
        audio_layers = config.get("audio_encoder_layers", 4)
        predictor_layers = config.get("predictor_layers", 6)
        dropout = config.get("dropout", 0.1)
        self.kl_weight = config.get("kl_weight", 0.01)
        n_speakers = config.get("n_speakers", 1)

        self.text_encoder = TextEncoder(
            vocab_size=text_vocab, d_model=d_model, nhead=nhead,
            num_layers=text_layers, dropout=dropout,
        )
        downsample_factor = config.get("downsample_factor", 4)
        self.audio_encoder = AudioEncoder(
            n_mels=n_mels, d_model=d_model, nhead=nhead,
            num_layers=audio_layers, downsample_factor=downsample_factor,
            dropout=dropout,
        )
        self.predictor = JEPAPredictor(
            d_model=d_model, nhead=nhead,
            num_layers=predictor_layers, dropout=dropout,
        )

        # Mel decoder: speaker-conditioned reconstruction
        self.mel_decoder = MelDecoder(
            d_model=d_model, n_mels=n_mels,
            upsample_factor=downsample_factor, n_speakers=n_speakers,
        )
        self.recon_weight = config.get("recon_weight", 1.0)

        # EMA target encoder — core JEPA component
        self.ema_audio_encoder = copy.deepcopy(self.audio_encoder)
        for p in self.ema_audio_encoder.parameters():
            p.requires_grad = False
        self.ema_decay = config.get("ema_decay", 0.998)

        # Multi-scale target: which layers to average for EMA targets
        # data2vec-style: average top-K layers instead of just the last
        self.ema_target_layers = config.get("ema_target_layers", None)  # None = all layers
        self.ema_target_proj = nn.Linear(d_model, d_model)  # project averaged intermediates

        # Speaker embedding + FiLM conditioning (for multi-speaker)
        self.n_speakers = n_speakers
        if n_speakers > 1:
            self.speaker_embed = nn.Embedding(n_speakers, d_model)
            # FiLM for latent space (replaces additive injection)
            self.latent_film = FiLMConditioner(d_model, d_model)
            self.target_film = FiLMConditioner(d_model, d_model)
            # Speaker classifier — forces encoder to retain speaker info
            self.speaker_classifier = nn.Sequential(
                nn.Linear(d_model, d_model),
                nn.GELU(),
                nn.Linear(d_model, n_speakers),
            )
        else:
            self.speaker_embed = None
            self.latent_film = None
            self.target_film = None
            self.speaker_classifier = None

        # Learnable start-of-sequence embedding for AR inference
        self.start_emb = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)

        # Prediction weight — primary JEPA loss
        self.pred_weight = config.get("pred_weight", 10.0)

        # Cosine similarity weight in prediction loss
        self.cosine_weight = config.get("cosine_weight", 1.0)

        # Input noise injection to bridge train/inference gap
        self.input_noise = config.get("input_noise", 0.0)

        # Scheduled sampling: probability of using predicted (not real) embeddings as input
        # Ramps up during training to bridge train/inference gap
        self.scheduled_sampling_rate = config.get("scheduled_sampling_rate", 0.0)

        # Free-bits KL: allow KL up to this threshold without penalty
        self.free_bits = config.get("free_bits", 1.0)

        # Multi-resolution spectral loss for sharper mel output
        self.spectral_loss_fn = MultiResolutionSpectralLoss()
        self.spectral_weight = config.get("spectral_weight", 0.5)

        # Speaker consistency loss weight
        self.speaker_weight = config.get("speaker_weight", 1.0)

        self.config = config

    def _get_speaker_cond(self, speaker_id):
        """Get speaker conditioning vector. Returns [B, d_model] or None."""
        if self.speaker_embed is not None and speaker_id is not None:
            return self.speaker_embed(speaker_id)  # [B, d_model]
        return None

    def _apply_speaker_film(self, x, spk_cond, film_module):
        """Apply FiLM conditioning to tensor. No-op if no speaker."""
        if film_module is not None and spk_cond is not None:
            return film_module(x, spk_cond)
        return x

    @torch.no_grad()
    def update_ema(self):
        """Update EMA target encoder weights."""
        for p_online, p_ema in zip(self.audio_encoder.parameters(),
                                   self.ema_audio_encoder.parameters()):
            p_ema.data.mul_(self.ema_decay).add_(p_online.data, alpha=1 - self.ema_decay)

    def _compute_multiscale_targets(self, mel, mel_mask):
        """Compute multi-scale EMA targets by averaging intermediate layer outputs.
        data2vec-style: captures both low-level acoustic detail (early layers)
        and high-level semantic structure (late layers)."""
        with torch.no_grad():
            self.ema_audio_encoder.eval()
            _, mu_ema, _, intermediates = self.ema_audio_encoder(
                mel, mel_mask, return_intermediates=True
            )
            self.ema_audio_encoder.train()

            # Select which layers to average
            if self.ema_target_layers is not None:
                selected = [intermediates[i] for i in self.ema_target_layers]
            else:
                selected = intermediates  # all layers

            # Average across selected layers
            target_emb = torch.stack(selected, dim=0).mean(dim=0)  # [B, T, d]

        # Project through learned projection (allows model to adapt target space)
        target_emb = self.ema_target_proj(target_emb)
        return target_emb

    def forward(self, mel, text_tokens, mel_mask=None, text_mask=None, speaker_id=None):
        """
        Args:
            mel: [B, n_mels, T_mel]
            text_tokens: [B, T_text]
            mel_mask: [B, T_mel] bool
            text_mask: [B, T_text] bool
            speaker_id: [B] long tensor of speaker IDs (optional)
        Returns:
            loss_dict with all individual losses
        """
        # Get speaker conditioning
        spk_cond = self._get_speaker_cond(speaker_id)  # [B, d] or None

        # Encode text + FiLM speaker conditioning on text
        text_emb = self.text_encoder(text_tokens, text_mask)
        if spk_cond is not None:
            # FiLM on text embeddings (replaces simple addition)
            text_emb = self._apply_speaker_film(text_emb, spk_cond, self.latent_film)

        # Online encoder → latent embeddings
        z, mu, logvar = self.audio_encoder(mel, mel_mask)
        B = z.shape[0]
        T_down = z.shape[1]

        # Save raw mu before speaker conditioning — for speaker classifier
        mu_raw = mu

        # Apply FiLM speaker conditioning to encoder output
        z = self._apply_speaker_film(z, spk_cond, self.latent_film)

        # Prediction target = encoder z (FiLM'd), shifted by 1
        # CRITICAL: target must be in the SAME space as predictor input (encoder z),
        # not a different space (EMA targets). Otherwise AR inference diverges because
        # predicted outputs (in target space) get fed back as input (in encoder space).
        # We use EMA encoder to get stable targets, but keep them in encoder scale.
        with torch.no_grad():
            self.ema_audio_encoder.eval()
            _, mu_ema, _ = self.ema_audio_encoder(mel, mel_mask)
            self.ema_audio_encoder.train()
            target_emb = mu_ema
        # Apply same FiLM as encoder output so target space matches input space
        target_emb = self._apply_speaker_film(target_emb, spk_cond, self.target_film)

        # Prepend learnable start embedding to input sequence
        start = self.start_emb.expand(B, -1, -1)  # [B, 1, d]
        input_emb = torch.cat([start, z[:, :-1]], dim=1)  # [B, T_down, d]

        # Input noise: only on positions 1+ (real embeddings), not start_emb
        if self.training and self.input_noise > 0:
            noise = torch.zeros_like(input_emb)
            noise[:, 1:] = torch.randn(B, T_down - 1, input_emb.shape[-1],
                                        device=input_emb.device) * self.input_noise
            input_emb = input_emb + noise

        # Adjust masks for downsampled sequence
        if mel_mask is not None:
            ds_mask = mel_mask[:, ::self.audio_encoder.downsample_factor][:, :T_down]
            start_mask = torch.zeros(B, 1, dtype=torch.bool, device=z.device)
            pred_mask = torch.cat([start_mask, ds_mask[:, :-1]], dim=1)
            loss_mask = ds_mask
        else:
            pred_mask = None
            loss_mask = None

        predicted = self.predictor(input_emb, text_emb, pred_mask, text_mask)

        # Scheduled sampling: re-run predictor with some positions replaced by
        # its own predictions. This teaches it to handle imperfect inputs (like AR inference).
        if self.training and self.scheduled_sampling_rate > 0:
            with torch.no_grad():
                # Randomly select positions to replace with predicted values
                mask_ss = torch.rand(B, T_down, device=z.device) < self.scheduled_sampling_rate
                mask_ss[:, 0] = False  # never replace start_emb
                mask_ss_exp = mask_ss.unsqueeze(-1)  # [B, T_down, 1]
                # Mix: use predicted (detached) where mask is True, real where False
                mixed_input = torch.where(mask_ss_exp, predicted.detach(), input_emb)
            # Re-predict with mixed input
            predicted = self.predictor(mixed_input, text_emb, pred_mask, text_mask)

        # ─── Loss 1: Prediction loss (MSE + cosine) against multi-scale EMA targets ───
        if loss_mask is not None:
            valid = (~loss_mask).unsqueeze(-1).float()  # [B, T_down, 1]
            n_valid = valid.sum() * predicted.shape[-1] + 1e-8

            # MSE component
            mse_loss = (((predicted - target_emb) ** 2) * valid).sum() / n_valid

            # Cosine component — preserves direction (speaker identity)
            pred_flat = (predicted * valid).reshape(-1, predicted.shape[-1])
            tgt_flat = (target_emb * valid).reshape(-1, target_emb.shape[-1])
            # Only compute on non-padding positions
            valid_flat = valid.squeeze(-1).reshape(-1) > 0
            if valid_flat.any():
                cosine_loss = 1.0 - F.cosine_similarity(
                    pred_flat[valid_flat], tgt_flat[valid_flat], dim=-1
                ).mean()
            else:
                cosine_loss = torch.tensor(0.0, device=mel.device)
        else:
            mse_loss = F.mse_loss(predicted, target_emb)
            cosine_loss = 1.0 - F.cosine_similarity(
                predicted.reshape(-1, predicted.shape[-1]),
                target_emb.reshape(-1, target_emb.shape[-1]),
                dim=-1,
            ).mean()

        prediction_loss = mse_loss + self.cosine_weight * cosine_loss

        # ─── Loss 2: Free-bits KL ───
        kl_per_dim = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
        kl_free = torch.clamp(kl_per_dim - self.free_bits, min=0.0)
        if mel_mask is not None:
            T_down_kl = mu.shape[1]
            ds_mask_kl = mel_mask[:, ::self.audio_encoder.downsample_factor][:, :T_down_kl]
            valid_kl = (~ds_mask_kl).unsqueeze(-1).float()
            kl_loss = (kl_free * valid_kl).sum() / (valid_kl.sum() * mu.shape[-1] + 1e-8)
        else:
            kl_loss = kl_free.mean()

        # ─── Loss 3: Mel reconstruction (speaker-conditioned decoder) ───
        # Decode from BOTH real encoder output AND predicted embeddings
        # This teaches the mel decoder to handle predicted inputs (not just perfect encoder outputs)
        mel_recon = self.mel_decoder(z, spk_cond)  # from real encoder
        mel_pred_recon = self.mel_decoder(predicted.detach(), spk_cond)  # from predictor output

        T_mel = mel.shape[2]
        T_recon = mel_recon.shape[2]
        # Align all mel tensors to the shortest length
        T_min = min(T_mel, T_recon, mel_pred_recon.shape[2])
        mel = mel[:, :, :T_min]
        mel_recon = mel_recon[:, :, :T_min]
        mel_pred_recon = mel_pred_recon[:, :, :T_min]
        if mel_mask is not None:
            mel_mask = mel_mask[:, :T_min]

        if mel_mask is not None:
            valid_mel = (~mel_mask).unsqueeze(1).float()
            recon_loss = F.l1_loss(
                mel_recon * valid_mel, mel * valid_mel, reduction="sum"
            ) / (valid_mel.sum() * mel.shape[1] + 1e-8)
            # Predicted recon loss — teaches decoder to handle AR-quality inputs
            pred_recon_loss = F.l1_loss(
                mel_pred_recon * valid_mel, mel * valid_mel, reduction="sum"
            ) / (valid_mel.sum() * mel.shape[1] + 1e-8)
        else:
            recon_loss = F.l1_loss(mel_recon, mel)
            pred_recon_loss = F.l1_loss(mel_pred_recon, mel)

        recon_loss = recon_loss + 0.5 * pred_recon_loss

        # ─── Loss 4: Multi-resolution spectral loss ───
        if mel_mask is not None:
            spectral_loss = self.spectral_loss_fn(mel_recon * valid_mel, mel * valid_mel)
        else:
            spectral_loss = self.spectral_loss_fn(mel_recon, mel)

        # ─── Loss 5: Speaker consistency on raw encoder output ───
        speaker_loss = torch.tensor(0.0, device=mel.device)
        if self.speaker_classifier is not None and speaker_id is not None:
            if mel_mask is not None:
                ds_mask_spk = mel_mask[:, ::self.audio_encoder.downsample_factor][:, :T_down]
                valid_spk = (~ds_mask_spk).unsqueeze(-1).float()
                mu_pooled = (mu_raw * valid_spk).sum(dim=1) / (valid_spk.sum(dim=1) + 1e-8)
            else:
                mu_pooled = mu_raw.mean(dim=1)
            spk_logits = self.speaker_classifier(mu_pooled)
            speaker_loss = F.cross_entropy(spk_logits, speaker_id)

        total_loss = (self.pred_weight * prediction_loss
                      + self.kl_weight * kl_loss
                      + self.recon_weight * recon_loss
                      + self.spectral_weight * spectral_loss
                      + self.speaker_weight * speaker_loss)

        return {
            "total_loss": total_loss,
            "prediction_loss": prediction_loss,
            "mse_loss": mse_loss,
            "cosine_loss": cosine_loss,
            "kl_loss": kl_loss,
            "recon_loss": recon_loss,
            "spectral_loss": spectral_loss,
            "speaker_loss": speaker_loss,
        }

    def encode_audio(self, mel):
        """Encode mel to latent embeddings (for inference)."""
        z, mu, _ = self.audio_encoder(mel)
        return mu  # Use mean at inference

    def predict_next(self, audio_emb, text_emb, text_mask=None):
        """Predict next audio embedding autoregressively."""
        predicted = self.predictor(audio_emb, text_emb, text_mask=text_mask)
        return predicted[:, -1:]  # Only the last prediction

    def init_ar_cache(self):
        """Initialize KV cache for autoregressive inference."""
        num_layers = len(self.predictor.transformer.layers)
        return self.predictor.init_cache(num_layers)

    def predict_next_cached(self, new_emb, text_emb, step_idx, cache, text_mask=None):
        """Predict next embedding using KV cache. O(1) per step instead of O(n)."""
        return self.predictor.inference_step(new_emb, text_emb, step_idx, cache, text_mask=text_mask)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def build_model(config=None):
    if config is None:
        config = {
            "d_model": 256,
            "nhead": 4,
            "n_mels": 100,
            "text_vocab_size": 256,
            "text_encoder_layers": 4,
            "audio_encoder_layers": 4,
            "predictor_layers": 6,
            "dropout": 0.1,
            "kl_weight": 0.05,
            "pred_weight": 10.0,
            "cosine_weight": 1.0,
            "ema_decay": 0.998,
            "free_bits": 2.0,
            "input_noise": 0.0,
        }
    model = LeWMTTS(config)
    print(f"LeWM TTS model: {count_parameters(model)/1e6:.2f}M parameters")
    return model, config
