"""
LeWM TTS v7 — AR (level 1) + NAR (levels 2-8) on frozen EnCodec tokens.
Multi-speaker ready via speaker embeddings.

Inference: text + speaker_id → AR tokens → NAR tokens → EnCodec decode → waveform
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class SinusoidalPE(nn.Module):
    def __init__(self, d, max_len=8192):
        super().__init__()
        pe = torch.zeros(max_len, d)
        pos = torch.arange(0, max_len).float().unsqueeze(1)
        div = torch.exp(torch.arange(0, d, 2).float() * (-math.log(10000.0) / d))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.shape[1]]


class TextEncoder(nn.Module):
    def __init__(self, vocab=256, d=256, nhead=4, layers=4, dropout=0.1):
        super().__init__()
        self.d = d
        self.embed = nn.Embedding(vocab, d)
        self.pe = SinusoidalPE(d)
        self.drop = nn.Dropout(dropout)
        layer = nn.TransformerEncoderLayer(d, nhead, d * 4, dropout, batch_first=True, activation="gelu")
        self.tf = nn.TransformerEncoder(layer, layers)
        self.proj = nn.Linear(d, d)

    def forward(self, tokens, mask=None):
        x = self.embed(tokens) * math.sqrt(self.d)
        x = self.drop(self.pe(x))
        return self.proj(self.tf(x, src_key_padding_mask=mask))


# ─── AR Model (Level 1) ──────────────────────────────────────────────────────

class ARModel(nn.Module):
    """Causal transformer predicting level-1 EnCodec tokens."""

    def __init__(self, d=256, nhead=4, layers=8, n_codes=1024, dropout=0.1):
        super().__init__()
        self.d = d
        self.n_codes = n_codes
        self.bos_id = n_codes

        self.tok_embed = nn.Embedding(n_codes + 1, d)  # +1 for BOS
        self.pe = SinusoidalPE(d)
        self.drop = nn.Dropout(dropout)

        layer = nn.TransformerDecoderLayer(d, nhead, d * 4, dropout, batch_first=True, activation="gelu")
        self.tf = nn.TransformerDecoder(layer, layers)
        self.head = nn.Sequential(nn.Linear(d, d), nn.GELU(), nn.Linear(d, n_codes))

    def forward(self, token_ids, text_emb, tok_mask=None, text_mask=None):
        """token_ids: [B, T] with BOS prepended. Returns logits [B, T, n_codes]."""
        x = self.drop(self.pe(self.tok_embed(token_ids)))
        T = x.shape[1]
        causal = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
        x = self.tf(x, text_emb, tgt_mask=causal, tgt_key_padding_mask=tok_mask,
                     memory_key_padding_mask=text_mask)
        return self.head(x)

    # ─── Cached inference ─────────────────────────────────────────────

    def init_cache(self):
        n = len(self.tf.layers)
        return {'sk': [None]*n, 'sv': [None]*n, 'ck': [None]*n, 'cv': [None]*n}

    def step(self, tok_id, text_emb, step_idx, cache, text_mask=None):
        """Single token step. tok_id: [B,1]. Returns logits [B, n_codes], cache."""
        x = self.tok_embed(tok_id) + self.pe.pe[:, step_idx:step_idx+1]
        for i, layer in enumerate(self.tf.layers):
            sa, cache['sk'][i], cache['sv'][i] = self._attn(
                layer.self_attn, x, x, cache['sk'][i], cache['sv'][i])
            x = layer.norm1(x + layer.dropout1(sa))
            ca, cache['ck'][i], cache['cv'][i] = self._attn(
                layer.multihead_attn, x, text_emb, cache['ck'][i], cache['cv'][i], text_mask)
            x = layer.norm2(x + layer.dropout2(ca))
            ff = layer.linear2(layer.dropout(layer.activation(layer.linear1(x))))
            x = layer.norm3(x + layer.dropout3(ff))
        return self.head(x).squeeze(1), cache

    def _attn(self, mha, q, kv, ck, cv, mask=None):
        d, nh, hd = mha.embed_dim, mha.num_heads, mha.embed_dim // mha.num_heads
        B = q.shape[0]
        Wq, Wk, Wv = mha.in_proj_weight.chunk(3, 0)
        bq, bk, bv = mha.in_proj_bias.chunk(3, 0)
        qp = F.linear(q, Wq, bq)
        if ck is None:
            k, v = F.linear(kv, Wk, bk), F.linear(kv, Wv, bv)
        else:
            kn, vn = F.linear(kv, Wk, bk), F.linear(kv, Wv, bv)
            if kv.shape[1] == 1 and ck.shape[1] > 0:
                k, v = torch.cat([ck, kn], 1), torch.cat([cv, vn], 1)
            else:
                k, v = ck, cv
        qp = qp.view(B, -1, nh, hd).transpose(1, 2)
        km = k.view(B, -1, nh, hd).transpose(1, 2)
        vm = v.view(B, -1, nh, hd).transpose(1, 2)
        a = torch.matmul(qp, km.transpose(-2, -1)) / (hd ** 0.5)
        if mask is not None:
            a = a.masked_fill(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
        o = torch.matmul(F.softmax(a, -1), vm).transpose(1, 2).contiguous().view(B, -1, d)
        return F.linear(o, mha.out_proj.weight, mha.out_proj.bias), k, v


# ─── NAR Model (Levels 2-8) ──────────────────────────────────────────────────

class NARModel(nn.Module):
    """Bidirectional transformer: given level-1 tokens + text, predict levels 2-8 in parallel."""

    def __init__(self, d=256, nhead=4, layers=6, n_codes=1024, n_rvq=8, dropout=0.1):
        super().__init__()
        self.n_codes = n_codes
        self.n_predict = n_rvq - 1  # levels 2-8 = 7 levels

        self.tok_embed = nn.Embedding(n_codes, d)
        self.pe = SinusoidalPE(d)
        self.drop = nn.Dropout(dropout)

        # Level embedding — tells the model which RVQ level to predict
        self.level_embed = nn.Embedding(n_rvq, d)

        layer = nn.TransformerDecoderLayer(d, nhead, d * 4, dropout, batch_first=True, activation="gelu")
        self.tf = nn.TransformerDecoder(layer, layers)

        # Shared output head (level embedding differentiates levels)
        self.head = nn.Sequential(nn.Linear(d, d), nn.GELU(), nn.Linear(d, n_codes))

    def forward(self, level1_tokens, text_emb, target_level, text_mask=None, tok_mask=None):
        """
        level1_tokens: [B, T] — level 1 codebook indices
        text_emb: [B, T_text, d]
        target_level: int 1-7 (which level to predict, 0-indexed from level 2)
        Returns: logits [B, T, n_codes]
        """
        x = self.tok_embed(level1_tokens)
        x = x + self.level_embed.weight[target_level + 1].unsqueeze(0).unsqueeze(0)
        x = self.drop(self.pe(x))
        x = self.tf(x, text_emb, tgt_key_padding_mask=tok_mask,
                     memory_key_padding_mask=text_mask)
        return self.head(x)

    @torch.no_grad()
    def generate_all_levels(self, level1_tokens, text_emb, text_mask=None, tok_mask=None):
        """Predict all 7 levels in 7 forward passes. Returns [B, 8, T]."""
        B, T = level1_tokens.shape
        all_codes = torch.zeros(B, 8, T, dtype=torch.long, device=level1_tokens.device)
        all_codes[:, 0] = level1_tokens

        # Accumulate: each level conditions on sum of previous level embeddings
        accum = self.tok_embed(level1_tokens)

        for lvl in range(7):  # predict levels 2-8 (index 1-7)
            x = accum + self.level_embed.weight[lvl + 1].unsqueeze(0).unsqueeze(0)
            x = self.pe(x)
            x = self.tf(x, text_emb, tgt_key_padding_mask=tok_mask,
                         memory_key_padding_mask=text_mask)
            logits = self.head(x)
            tokens = logits.argmax(dim=-1)  # greedy
            all_codes[:, lvl + 1] = tokens
            # Add this level's embeddings for next level's conditioning
            accum = accum + self.tok_embed(tokens)

        return all_codes


# ─── Full Model ──────────────────────────────────────────────────────────────

class LeWMTTSv7(nn.Module):
    def __init__(self, config):
        super().__init__()
        d = config.get("d_model", 256)
        nhead = config.get("nhead", 4)
        n_codes = config.get("n_codes", 1024)
        n_rvq = config.get("n_rvq", 8)
        n_speakers = config.get("n_speakers", 1)
        dropout = config.get("dropout", 0.1)

        self.text_encoder = TextEncoder(
            vocab=config.get("text_vocab_size", 256), d=d, nhead=nhead,
            layers=config.get("text_encoder_layers", 4), dropout=dropout,
        )
        self.ar = ARModel(d=d, nhead=nhead, layers=config.get("ar_layers", 8),
                          n_codes=n_codes, dropout=dropout)
        self.nar = NARModel(d=d, nhead=nhead, layers=config.get("nar_layers", 6),
                            n_codes=n_codes, n_rvq=n_rvq, dropout=dropout)

        # Speaker conditioning
        self.n_speakers = n_speakers
        if n_speakers > 1:
            self.speaker_embed = nn.Embedding(n_speakers, d)
        else:
            self.speaker_embed = None

        self.n_codes = n_codes
        self.n_rvq = n_rvq
        self.label_smoothing = config.get("label_smoothing", 0.1)
        self.config = config

    def _add_speaker(self, text_emb, speaker_id):
        if self.speaker_embed is not None and speaker_id is not None:
            spk = self.speaker_embed(speaker_id).unsqueeze(1)
            text_emb = text_emb + spk
        return text_emb

    def forward(self, all_tokens, text_tokens, token_mask=None, text_mask=None,
                speaker_id=None):
        """
        all_tokens: [B, n_rvq, T] — all 8 RVQ levels
        text_tokens: [B, T_text]
        """
        B, n_rvq, T = all_tokens.shape
        level1 = all_tokens[:, 0]  # [B, T]

        text_emb = self.text_encoder(text_tokens, text_mask)
        text_emb = self._add_speaker(text_emb, speaker_id)

        # ─── AR loss (level 1) ───
        bos = torch.full((B, 1), self.ar.bos_id, dtype=torch.long, device=level1.device)
        ar_input = torch.cat([bos, level1[:, :-1]], dim=1)
        if token_mask is not None:
            bos_mask = torch.zeros(B, 1, dtype=torch.bool, device=level1.device)
            ar_mask = torch.cat([bos_mask, token_mask[:, :-1]], dim=1)
        else:
            ar_mask = None

        ar_logits = self.ar(ar_input, text_emb, ar_mask, text_mask)

        if token_mask is not None:
            valid = ~token_mask
            ar_loss = F.cross_entropy(ar_logits[valid], level1[valid],
                                       label_smoothing=self.label_smoothing)
        else:
            ar_loss = F.cross_entropy(ar_logits.reshape(-1, self.n_codes), level1.reshape(-1),
                                       label_smoothing=self.label_smoothing)

        # ─── NAR loss (levels 2-8, random level per batch) ───
        lvl = torch.randint(0, n_rvq - 1, (1,)).item()  # 0-6 → predicts level 2-8
        target = all_tokens[:, lvl + 1]  # [B, T]

        # NAR input: sum of embeddings from levels 0..lvl
        nar_input_tokens = level1  # always start from level 1
        nar_logits = self.nar(nar_input_tokens, text_emb, lvl, text_mask, token_mask)

        if token_mask is not None:
            valid = ~token_mask
            nar_loss = F.cross_entropy(nar_logits[valid], target[valid],
                                        label_smoothing=self.label_smoothing)
        else:
            nar_loss = F.cross_entropy(nar_logits.reshape(-1, self.n_codes), target.reshape(-1),
                                        label_smoothing=self.label_smoothing)

        with torch.no_grad():
            if token_mask is not None:
                ar_acc = (ar_logits[valid].argmax(-1) == level1[valid]).float().mean()
                nar_acc = (nar_logits[valid].argmax(-1) == target[valid]).float().mean()
            else:
                ar_acc = (ar_logits.argmax(-1) == level1).float().mean()
                nar_acc = (nar_logits.argmax(-1) == target).float().mean()

        total = ar_loss + nar_loss
        return {
            "total_loss": total, "ar_loss": ar_loss, "nar_loss": nar_loss,
            "ar_acc": ar_acc, "nar_acc": nar_acc,
        }

    # ─── Inference ────────────────────────────────────────────────────

    @torch.no_grad()
    def generate(self, text_tokens, max_steps=750, temperature=0.8, top_k=50,
                 text_mask=None, speaker_id=None):
        """Full pipeline: text → 8-level tokens."""
        text_emb = self.text_encoder(text_tokens, text_mask)
        text_emb = self._add_speaker(text_emb, speaker_id)
        B = text_tokens.shape[0]

        # AR: generate level 1
        cache = self.ar.init_cache()
        tok = torch.full((B, 1), self.ar.bos_id, dtype=torch.long, device=text_tokens.device)
        level1 = []

        for step in range(max_steps):
            logits, cache = self.ar.step(tok, text_emb, step, cache, text_mask)
            logits = logits / max(temperature, 1e-8)
            if top_k > 0:
                topk_v, _ = logits.topk(top_k, dim=-1)
                logits[logits < topk_v[:, -1:]] = float('-inf')
            tok = torch.multinomial(F.softmax(logits, -1), 1)
            level1.append(tok)

        level1 = torch.cat(level1, dim=1)  # [B, max_steps]

        # NAR: generate levels 2-8
        all_codes = self.nar.generate_all_levels(level1, text_emb, text_mask)
        return all_codes  # [B, 8, T]


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def build_model_v7(config=None):
    if config is None:
        config = {
            "d_model": 256, "nhead": 4, "n_codes": 1024, "n_rvq": 8,
            "text_vocab_size": 256, "text_encoder_layers": 4,
            "ar_layers": 8, "nar_layers": 6,
            "dropout": 0.1, "label_smoothing": 0.1,
            "n_speakers": 1,
        }
    model = LeWMTTSv7(config)
    print(f"LeWM TTS v7: {count_parameters(model)/1e6:.2f}M params")
    return model, config
