"""
LeWM TTS v6 — Predict pre-trained EnCodec tokens.

The key insight from failed experiments:
- Continuous AR prediction drifts (noise)
- VQ learned from scratch + prediction = too many moving targets
- Solution: use a FROZEN pre-trained codec (EnCodec). The codebook is fixed and good.
  The model only needs to learn: text → token sequence. Standard next-token prediction.

Architecture:
  - TextEncoder: byte-level transformer
  - TokenPredictor: causal transformer, cross-attends to text, outputs logits over 1024 EnCodec codes
  - EnCodec decoder (frozen): tokens → waveform (no mel decoder needed!)

This is essentially a small language model that speaks EnCodec.
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=8192):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.shape[1]]


class TextEncoder(nn.Module):
    def __init__(self, vocab_size=256, d_model=256, nhead=4, num_layers=4, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.char_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = SinusoidalPositionalEncoding(d_model, max_len=2048)
        self.dropout = nn.Dropout(dropout)
        layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
            dropout=dropout, batch_first=True, activation="gelu",
        )
        self.transformer = nn.TransformerEncoder(layer, num_layers=num_layers)
        self.proj = nn.Linear(d_model, d_model)

    def forward(self, tokens, mask=None):
        x = self.char_embed(tokens) * math.sqrt(self.d_model)
        x = self.pos_embed(x)
        x = self.dropout(x)
        x = self.transformer(x, src_key_padding_mask=mask)
        return self.proj(x)


class TokenPredictor(nn.Module):
    """Causal transformer: predicts next EnCodec token given previous tokens + text."""

    def __init__(self, d_model=256, nhead=4, num_layers=8, n_codes=1024, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_codes = n_codes

        # Token embedding (input side — embed EnCodec codes)
        self.token_embed = nn.Embedding(n_codes + 1, d_model)  # +1 for BOS token
        self.bos_id = n_codes  # BOS = last index

        self.pos_embed = SinusoidalPositionalEncoding(d_model, max_len=4096)
        self.dropout = nn.Dropout(dropout)

        layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
            dropout=dropout, batch_first=True, activation="gelu",
        )
        self.transformer = nn.TransformerDecoder(layer, num_layers=num_layers)

        # Output head
        self.output_head = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, n_codes),  # logits over EnCodec codebook
        )

    def forward(self, token_ids, text_emb, token_mask=None, text_mask=None):
        """
        token_ids: [B, T] — EnCodec token IDs (with BOS prepended)
        text_emb: [B, T_text, d] — from TextEncoder
        Returns: logits [B, T, n_codes]
        """
        x = self.token_embed(token_ids)
        x = self.pos_embed(x)
        x = self.dropout(x)

        T = x.shape[1]
        causal = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)

        x = self.transformer(
            x, text_emb,
            tgt_mask=causal,
            tgt_key_padding_mask=token_mask,
            memory_key_padding_mask=text_mask,
        )
        return self.output_head(x)

    # ─── Cached inference ─────────────────────────────────────────────────

    def init_cache(self):
        n = len(self.transformer.layers)
        return {
            'self_k': [None] * n, 'self_v': [None] * n,
            'cross_k': [None] * n, 'cross_v': [None] * n,
        }

    def inference_step(self, token_id, text_emb, step, cache, text_mask=None):
        """
        token_id: [B, 1] long — single token
        Returns: logits [B, n_codes], updated cache
        """
        x = self.token_embed(token_id)  # [B, 1, d]
        x = x + self.pos_embed.pe[:, step:step + 1]

        for i, layer in enumerate(self.transformer.layers):
            sa_out, cache['self_k'][i], cache['self_v'][i] = self._cached_attn(
                layer.self_attn, x, x, cache['self_k'][i], cache['self_v'][i]
            )
            x = layer.norm1(x + layer.dropout1(sa_out))

            ca_out, cache['cross_k'][i], cache['cross_v'][i] = self._cached_attn(
                layer.multihead_attn, x, text_emb,
                cache['cross_k'][i], cache['cross_v'][i], text_mask
            )
            x = layer.norm2(x + layer.dropout2(ca_out))

            ff = layer.linear2(layer.dropout(layer.activation(layer.linear1(x))))
            x = layer.norm3(x + layer.dropout3(ff))

        logits = self.output_head(x).squeeze(1)  # [B, n_codes]
        return logits, cache

    def _cached_attn(self, mha, q, kv, ck, cv, mask=None):
        d, nh = mha.embed_dim, mha.num_heads
        hd = d // nh
        B = q.shape[0]
        Wq, Wk, Wv = mha.in_proj_weight.chunk(3, dim=0)
        bq, bk, bv = mha.in_proj_bias.chunk(3, dim=0)

        qp = F.linear(q, Wq, bq)
        if ck is None:
            k = F.linear(kv, Wk, bk)
            v = F.linear(kv, Wv, bv)
        else:
            kn = F.linear(kv, Wk, bk)
            vn = F.linear(kv, Wv, bv)
            # For self-attn: append new KV. For cross-attn: reuse cached.
            if kv.shape[1] == 1 and ck.shape[1] > 0:
                k = torch.cat([ck, kn], dim=1)
                v = torch.cat([cv, vn], dim=1)
            else:
                k, v = ck, cv

        qp = qp.view(B, -1, nh, hd).transpose(1, 2)
        km = k.view(B, -1, nh, hd).transpose(1, 2)
        vm = v.view(B, -1, nh, hd).transpose(1, 2)

        a = torch.matmul(qp, km.transpose(-2, -1)) / (hd ** 0.5)
        if mask is not None:
            a = a.masked_fill(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
        o = torch.matmul(F.softmax(a, dim=-1), vm)
        o = o.transpose(1, 2).contiguous().view(B, -1, d)
        return F.linear(o, mha.out_proj.weight, mha.out_proj.bias), k, v


# ─── Full Model ──────────────────────────────────────────────────────────────

class LeWMTTSv6(nn.Module):
    """
    Text → EnCodec tokens. Simple next-token prediction.
    No encoder, no VQ, no mel decoder. Just predict the right token.
    EnCodec decoder handles token→waveform at inference.
    """

    def __init__(self, config):
        super().__init__()
        d_model = config.get("d_model", 256)
        nhead = config.get("nhead", 4)
        n_codes = config.get("n_codes", 1024)

        self.text_encoder = TextEncoder(
            vocab_size=config.get("text_vocab_size", 256),
            d_model=d_model, nhead=nhead,
            num_layers=config.get("text_encoder_layers", 4),
            dropout=config.get("dropout", 0.1),
        )
        self.predictor = TokenPredictor(
            d_model=d_model, nhead=nhead,
            num_layers=config.get("predictor_layers", 8),
            n_codes=n_codes,
            dropout=config.get("dropout", 0.1),
        )

        self.n_codes = n_codes
        self.bos_id = n_codes  # BOS token
        self.label_smoothing = config.get("label_smoothing", 0.1)
        self.config = config

    def forward(self, token_ids, text_tokens, token_mask=None, text_mask=None, **kwargs):
        """
        token_ids: [B, T] — EnCodec token indices (ground truth)
        text_tokens: [B, T_text] — byte-level text
        """
        B, T = token_ids.shape

        # Encode text
        text_emb = self.text_encoder(text_tokens, text_mask)

        # Prepend BOS, shift right
        bos = torch.full((B, 1), self.bos_id, dtype=torch.long, device=token_ids.device)
        input_ids = torch.cat([bos, token_ids[:, :-1]], dim=1)  # [B, T]

        # Input mask (BOS is never masked)
        if token_mask is not None:
            bos_mask = torch.zeros(B, 1, dtype=torch.bool, device=token_ids.device)
            input_mask = torch.cat([bos_mask, token_mask[:, :-1]], dim=1)
        else:
            input_mask = None

        # Predict
        logits = self.predictor(input_ids, text_emb, input_mask, text_mask)  # [B, T, n_codes]

        # Loss: cross-entropy on valid positions
        if token_mask is not None:
            valid = ~token_mask
            logits_flat = logits[valid]
            targets_flat = token_ids[valid]
        else:
            logits_flat = logits.reshape(-1, self.n_codes)
            targets_flat = token_ids.reshape(-1)

        token_loss = F.cross_entropy(
            logits_flat, targets_flat, label_smoothing=self.label_smoothing
        )

        with torch.no_grad():
            token_acc = (logits_flat.argmax(dim=-1) == targets_flat).float().mean()

        return {
            "total_loss": token_loss,
            "token_loss": token_loss,
            "token_accuracy": token_acc,
        }

    @torch.no_grad()
    def generate(self, text_tokens, max_steps=750, temperature=0.8,
                 top_k=50, text_mask=None):
        """AR generation: text → EnCodec token sequence."""
        text_emb = self.text_encoder(text_tokens, text_mask)
        cache = self.predictor.init_cache()

        B = text_tokens.shape[0]
        token = torch.full((B, 1), self.bos_id, dtype=torch.long, device=text_tokens.device)
        all_tokens = []

        for step in range(max_steps):
            logits, cache = self.predictor.inference_step(token, text_emb, step, cache, text_mask)

            # Temperature + top-k
            logits = logits / max(temperature, 1e-8)
            if top_k > 0:
                topk_vals, _ = logits.topk(top_k, dim=-1)
                logits[logits < topk_vals[:, -1:]] = float('-inf')

            probs = F.softmax(logits, dim=-1)
            token = torch.multinomial(probs, 1)  # [B, 1]
            all_tokens.append(token)

        return torch.cat(all_tokens, dim=1)  # [B, max_steps]


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def build_model_v6(config=None):
    if config is None:
        config = {
            "d_model": 256,
            "nhead": 4,
            "n_codes": 1024,
            "text_vocab_size": 256,
            "text_encoder_layers": 4,
            "predictor_layers": 8,
            "dropout": 0.1,
            "label_smoothing": 0.1,
        }
    model = LeWMTTSv6(config)
    print(f"LeWM TTS v6: {count_parameters(model)/1e6:.2f}M parameters")
    return model, config
