"""
LeWM TTS v2 — DAC-based JEPA Text-to-Speech

Architecture:
  - DAC encoder (frozen): audio → 1024-dim continuous latents @ 75Hz
  - Linear proj: 1024 → 256 (d_model)
  - TextEncoder: byte-level transformer
  - JEPAPredictor: causal transformer decoder with cross-attention to text
  - Linear proj: 256 → 1024 (back to DAC space)
  - DAC decoder (frozen): latents → waveform

  Two losses: next-embedding prediction (MSE) + Gaussian regularizer (KL)
  NO mel decoder needed — DAC handles audio reconstruction.
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=8192):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.shape[1]]


class TextEncoder(nn.Module):
    def __init__(self, vocab_size=256, d_model=256, nhead=4, num_layers=4, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.char_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = SinusoidalPositionalEncoding(d_model, max_len=2048)
        self.dropout = nn.Dropout(dropout)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead,
            dim_feedforward=d_model * 4,
            dropout=dropout, batch_first=True, activation="gelu",
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.proj = nn.Linear(d_model, d_model)

    def forward(self, text_tokens, text_mask=None):
        x = self.char_embed(text_tokens) * math.sqrt(self.d_model)
        x = self.pos_embed(x)
        x = self.dropout(x)
        x = self.transformer(x, src_key_padding_mask=text_mask)
        return self.proj(x)


class JEPAPredictor(nn.Module):
    def __init__(self, d_model=256, nhead=4, num_layers=6, dropout=0.1):
        super().__init__()
        self.audio_input_proj = nn.Linear(d_model, d_model)
        self.pos_embed = SinusoidalPositionalEncoding(d_model, max_len=4096)
        self.dropout = nn.Dropout(dropout)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=nhead,
            dim_feedforward=d_model * 4,
            dropout=dropout, batch_first=True, activation="gelu",
        )
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.output_proj = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model),
        )

    def forward(self, audio_emb, text_emb, audio_mask=None, text_mask=None):
        x = self.audio_input_proj(audio_emb)
        x = self.pos_embed(x)
        x = self.dropout(x)

        T = x.shape[1]
        causal_mask = torch.triu(
            torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1
        )

        predicted = self.transformer(
            x, text_emb,
            tgt_mask=causal_mask,
            tgt_key_padding_mask=audio_mask,
            memory_key_padding_mask=text_mask,
        )
        return self.output_proj(predicted)


class LeWMTTS(nn.Module):
    """
    LeWM TTS v2 with DAC latents.

    Training:
      - DAC encodes audio → 1024-dim latents (frozen)
      - Project to d_model, predict next embedding, KL regularize
    Inference:
      - Text → predict DAC latents autoregressively → DAC decode → waveform
    """

    def __init__(self, config):
        super().__init__()
        d_model = config.get("d_model", 256)
        nhead = config.get("nhead", 4)
        dac_dim = config.get("dac_dim", 1024)
        text_vocab = config.get("text_vocab_size", 256)
        text_layers = config.get("text_encoder_layers", 4)
        predictor_layers = config.get("predictor_layers", 6)
        dropout = config.get("dropout", 0.1)
        self.kl_weight = config.get("kl_weight", 0.1)
        self.recon_weight = config.get("recon_weight", 10.0)

        # DAC space ↔ model space projections
        self.dac_in_proj = nn.Linear(dac_dim, d_model)
        self.dac_out_proj = nn.Linear(d_model, dac_dim)

        # Gaussian regularizer projections
        self.proj_mu = nn.Linear(d_model, d_model)
        self.proj_logvar = nn.Linear(d_model, d_model)

        self.text_encoder = TextEncoder(
            vocab_size=text_vocab, d_model=d_model, nhead=nhead,
            num_layers=text_layers, dropout=dropout,
        )
        self.predictor = JEPAPredictor(
            d_model=d_model, nhead=nhead,
            num_layers=predictor_layers, dropout=dropout,
        )

        self.config = config

    def forward(self, dac_latents, text_tokens, latent_mask=None, text_mask=None):
        """
        Args:
            dac_latents: [B, T, 1024] — continuous DAC encoder output
            text_tokens: [B, T_text] — byte-level text tokens
            latent_mask: [B, T] bool, True = padding
            text_mask: [B, T_text] bool, True = padding
        """
        # Project DAC latents to model space
        h = self.dac_in_proj(dac_latents)  # [B, T, d_model]

        # Gaussian regularizer: project to mu/logvar
        mu = self.proj_mu(h)
        logvar = self.proj_logvar(h)

        # Reparameterize
        if self.training:
            std = torch.exp(0.5 * logvar)
            z = mu + torch.randn_like(std) * std
        else:
            z = mu

        # Encode text
        text_emb = self.text_encoder(text_tokens, text_mask)

        # Predict next embedding
        input_emb = z[:, :-1]
        target_emb = z[:, 1:]

        if latent_mask is not None:
            pred_mask = latent_mask[:, :-1]
        else:
            pred_mask = None

        predicted = self.predictor(input_emb, text_emb, pred_mask, text_mask)

        # Loss 1: Next-embedding prediction (MSE)
        if pred_mask is not None:
            valid = (~pred_mask).unsqueeze(-1)
            prediction_loss = F.mse_loss(
                predicted * valid, target_emb * valid, reduction="sum"
            ) / (valid.sum() * predicted.shape[-1] + 1e-8)
        else:
            prediction_loss = F.mse_loss(predicted, target_emb)

        # Loss 2: KL(N(mu, sigma²) || N(0, I))
        if latent_mask is not None:
            valid = (~latent_mask).unsqueeze(-1).float()
            kl = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
            kl_loss = (kl * valid).sum() / (valid.sum() * mu.shape[-1] + 1e-8)
        else:
            kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

        # Loss 3: Projection round-trip reconstruction
        # Forces dac_out_proj(dac_in_proj(x)) ≈ x so DAC can decode our outputs
        dac_recon = self.dac_out_proj(mu)  # Use mu (not noisy z) for cleaner target
        if latent_mask is not None:
            valid = (~latent_mask).unsqueeze(-1).float()
            recon_loss = F.mse_loss(
                dac_recon * valid, dac_latents * valid, reduction="sum"
            ) / (valid.sum() * dac_latents.shape[-1] + 1e-8)
        else:
            recon_loss = F.mse_loss(dac_recon, dac_latents)

        total_loss = prediction_loss + self.kl_weight * kl_loss + self.recon_weight * recon_loss

        return {
            "total_loss": total_loss,
            "prediction_loss": prediction_loss,
            "kl_loss": kl_loss,
            "recon_loss": recon_loss,
        }

    def predict_next(self, audio_emb, text_emb, text_mask=None):
        """AR inference: predict next embedding from context."""
        predicted = self.predictor(audio_emb, text_emb, text_mask=text_mask)
        return predicted[:, -1:]

    def latents_to_dac(self, z):
        """Project model-space embeddings back to DAC latent space."""
        return self.dac_out_proj(z)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def build_model(config=None):
    if config is None:
        config = {
            "d_model": 256,
            "nhead": 4,
            "dac_dim": 1024,
            "text_vocab_size": 256,
            "text_encoder_layers": 4,
            "predictor_layers": 6,
            "dropout": 0.1,
            "kl_weight": 0.1,
        }
    model = LeWMTTS(config)
    print(f"LeWM TTS v2 (DAC): {count_parameters(model)/1e6:.2f}M trainable parameters")
    return model, config
