"""
LeWM TTS v5 — JEPA with EnCodec backend.

Architecture:
  - EnCodec encoder/decoder (frozen): audio ↔ 128d embeddings @ 75Hz
  - Simple linear projections: 128d ↔ d_model (no lossy compression)
  - TextEncoder: byte-level transformer for Hindi text
  - DurationPredictor: predicts per-character duration
  - JEPAPredictor: causal transformer with frame-level text conditioning
"""

import math
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F


# ─── Duration Predictor ─────────────────────────────────────────────────────

class DurationPredictor(nn.Module):
    """Predicts log-duration per text token. Similar to FastSpeech."""
    def __init__(self, d_model=256, kernel_size=3, num_layers=2, dropout=0.1):
        super().__init__()
        layers = []
        for _ in range(num_layers):
            layers.extend([
                nn.Conv1d(d_model, d_model, kernel_size, padding=kernel_size // 2),
                nn.ReLU(),
                nn.LayerNorm(d_model),
                nn.Dropout(dropout),
            ])
        self.convs = nn.ModuleList(layers)
        self.proj = nn.Linear(d_model, 1)

    def forward(self, text_emb, text_mask=None):
        """
        text_emb: [B, T_text, d_model]
        Returns: log_durations [B, T_text]
        """
        x = text_emb.transpose(1, 2)  # [B, d_model, T]
        for i in range(0, len(self.convs), 4):
            x = self.convs[i](x)      # Conv1d
            x = self.convs[i+1](x)    # ReLU
            x = x.transpose(1, 2)     # [B, T, d_model] for LayerNorm
            x = self.convs[i+2](x)    # LayerNorm
            x = self.convs[i+3](x)    # Dropout
            x = x.transpose(1, 2)     # back to [B, d_model, T]
        x = x.transpose(1, 2)  # [B, T, d_model]
        log_dur = self.proj(x).squeeze(-1)  # [B, T_text]
        if text_mask is not None:
            log_dur = log_dur.masked_fill(text_mask, 0.0)
        return log_dur


def length_regulate(text_emb, durations, text_mask=None):
    """
    Expand text embeddings by repeating each token according to its duration.
    text_emb: [B, T_text, d_model]
    durations: [B, T_text] integer durations
    Returns: expanded [B, T_audio, d_model], where T_audio = max(sum(durations))
    """
    B, T_text, D = text_emb.shape
    if text_mask is not None:
        durations = durations.masked_fill(text_mask, 0)
    max_audio_len = durations.sum(dim=1).max().item()
    expanded = torch.zeros(B, max_audio_len, D, device=text_emb.device)
    for b in range(B):
        durs_b = durations[b]  # [T_text]
        total = durs_b.sum().item()
        if total == 0:
            continue
        # Use repeat_interleave — vectorized per sample
        exp_b = torch.repeat_interleave(text_emb[b], durs_b, dim=0)  # [total, D]
        expanded[b, :exp_b.shape[0]] = exp_b
    return expanded


def compute_uniform_durations(text_lengths, audio_lengths):
    """
    Compute uniform durations: spread audio frames evenly across text characters.
    text_lengths: [B] number of valid text tokens
    audio_lengths: [B] number of valid audio frames
    Returns: durations [B, max_text_len] integer
    """
    B = text_lengths.shape[0]
    max_text_len = text_lengths.max().item()
    durations = torch.zeros(B, max_text_len, dtype=torch.long, device=text_lengths.device)
    # Vectorized per sample (text lengths vary)
    for b in range(B):
        tl = text_lengths[b].item()
        al = audio_lengths[b].item()
        if tl == 0:
            continue
        base = al // tl
        remainder = al % tl
        durations[b, :tl] = base
        durations[b, :remainder] += 1
    return durations


# ─── JEPA Predictor ─────────────────────────────────────────────────────────

class JEPAPredictor(nn.Module):
    """Causal predictor with frame-level text added to audio input."""
    def __init__(self, d_model=256, nhead=4, num_layers=6, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.audio_input_proj = nn.Linear(d_model, d_model)
        self.text_input_proj = nn.Linear(d_model, d_model)
        self.pos_embed = SinusoidalPositionalEncoding(d_model, max_len=4096)
        self.dropout = nn.Dropout(dropout)

        # Use encoder (self-attention only) since text is now part of input
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
            dropout=dropout, batch_first=True, activation="gelu",
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.output_proj = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model),
        )

    def forward(self, audio_emb, text_emb_expanded, audio_mask=None):
        """
        audio_emb: [B, T, d_model] — shifted audio embeddings
        text_emb_expanded: [B, T, d_model] — frame-level text (already expanded)
        """
        x = self.audio_input_proj(audio_emb) + self.text_input_proj(text_emb_expanded)
        x = self.pos_embed(x)
        x = self.dropout(x)
        T = x.shape[1]
        causal_mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
        predicted = self.transformer(x, mask=causal_mask, src_key_padding_mask=audio_mask)
        return self.output_proj(predicted)

    # Keep old cross-attention interface for backward compat loading
    def forward_cross_attn(self, audio_emb, text_emb, audio_mask=None, text_mask=None):
        """Legacy cross-attention forward (not used in duration-predictor mode)."""
        x = self.audio_input_proj(audio_emb)
        x = self.pos_embed(x)
        x = self.dropout(x)
        return self.output_proj(x)

    def _cached_mha(self, mha, q_input, kv_input, cache_k, cache_v):
        d = mha.embed_dim
        nhead = mha.num_heads
        head_dim = d // nhead
        B = q_input.shape[0]
        Wq, Wk, Wv = mha.in_proj_weight.chunk(3, dim=0)
        bq, bk, bv = mha.in_proj_bias.chunk(3, dim=0)
        q = F.linear(q_input, Wq, bq)
        k_new = F.linear(kv_input, Wk, bk)
        v_new = F.linear(kv_input, Wv, bv)
        if cache_k is not None and cache_k.shape[1] > 0:
            k = torch.cat([cache_k, k_new], dim=1)
            v = torch.cat([cache_v, v_new], dim=1)
        else:
            k, v = k_new, v_new
        q = q.view(B, 1, nhead, head_dim).transpose(1, 2)
        k_mh = k.view(B, -1, nhead, head_dim).transpose(1, 2)
        v_mh = v.view(B, -1, nhead, head_dim).transpose(1, 2)
        attn = torch.matmul(q, k_mh.transpose(-2, -1)) / (head_dim ** 0.5)
        attn = F.softmax(attn, dim=-1)
        out = torch.matmul(attn, v_mh)
        out = out.transpose(1, 2).contiguous().view(B, 1, d)
        out = F.linear(out, mha.out_proj.weight, mha.out_proj.bias)
        return out, k, v

    def init_cache(self, num_layers):
        return {
            'self_k': [None] * num_layers,
            'self_v': [None] * num_layers,
        }

    def inference_step(self, new_emb, text_emb_frame, step_idx, cache):
        """
        new_emb: [B, 1, d_model] — audio embedding for this step
        text_emb_frame: [B, 1, d_model] — expanded text embedding for this step
        """
        x = self.audio_input_proj(new_emb) + self.text_input_proj(text_emb_frame)
        x = x + self.pos_embed.pe[:, step_idx:step_idx + 1]
        for i, layer in enumerate(self.transformer.layers):
            sa_out, cache['self_k'][i], cache['self_v'][i] = self._cached_mha(
                layer.self_attn, x, x, cache['self_k'][i], cache['self_v'][i]
            )
            x = layer.norm1(x + layer.dropout1(sa_out))
            ff_out = layer.linear2(layer.dropout(layer.activation(layer.linear1(x))))
            x = layer.norm2(x + layer.dropout2(ff_out))
        return self.output_proj(x), cache


# ─── Positional Encoding ────────────────────────────────────────────────────

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=8192):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.shape[1]]


# ─── Text Encoder ───────────────────────────────────────────────────────────

class TextEncoder(nn.Module):
    def __init__(self, vocab_size=256, d_model=256, nhead=4, num_layers=4, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.char_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = SinusoidalPositionalEncoding(d_model, max_len=2048)
        self.dropout = nn.Dropout(dropout)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
            dropout=dropout, batch_first=True, activation="gelu",
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.proj = nn.Linear(d_model, d_model)

    def forward(self, text_tokens, text_mask=None):
        x = self.char_embed(text_tokens) * math.sqrt(self.d_model)
        x = self.pos_embed(x)
        x = self.dropout(x)
        x = self.transformer(x, src_key_padding_mask=text_mask)
        return self.proj(x)


# ─── Full v5 Model ──────────────────────────────────────────────────────────

class LeWMTTSv5(nn.Module):
    """
    JEPA TTS with EnCodec backend + duration predictor.
    Text is expanded to frame level and added to audio input — can't be ignored.

    Training: codec_emb → proj_in → predict(audio + aligned_text) → loss vs targets
              Duration predictor trained with uniform alignment as supervision.
    Inference: text → duration predict → expand text → AR predict → proj_out → decode
    """

    def __init__(self, config):
        super().__init__()
        d_model = config.get("d_model", 256)
        nhead = config.get("nhead", 4)
        codec_dim = config.get("codec_dim", 128)
        text_vocab = config.get("text_vocab_size", 256)
        text_layers = config.get("text_encoder_layers", 4)
        predictor_layers = config.get("predictor_layers", 6)
        dropout = config.get("dropout", 0.1)

        self.codec_dim = codec_dim

        self.text_encoder = TextEncoder(
            vocab_size=text_vocab, d_model=d_model, nhead=nhead,
            num_layers=text_layers, dropout=dropout,
        )

        # Duration predictor
        self.duration_predictor = DurationPredictor(
            d_model=d_model, kernel_size=3, num_layers=2, dropout=dropout,
        )

        # Simple linear projections — no lossy compression
        self.proj_in = nn.Sequential(
            nn.Linear(codec_dim, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model),
        )
        self.proj_out = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, codec_dim),
        )

        self.predictor = JEPAPredictor(
            d_model=d_model, nhead=nhead,
            num_layers=predictor_layers, dropout=dropout,
        )

        # EMA on proj_in (provides stable prediction targets)
        self.ema_proj_in = copy.deepcopy(self.proj_in)
        for p in self.ema_proj_in.parameters():
            p.requires_grad = False
        self.ema_decay = config.get("ema_decay", 0.998)

        # Learnable start embedding
        self.start_emb = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)

        # Loss weights
        self.pred_weight = config.get("pred_weight", 10.0)
        self.roundtrip_weight = config.get("roundtrip_weight", 1.0)
        self.dur_weight = config.get("dur_weight", 1.0)

        # Input noise for train/inference gap
        self.input_noise = config.get("input_noise", 0.0)

        self.config = config

    @torch.no_grad()
    def update_ema(self):
        for p_online, p_ema in zip(self.proj_in.parameters(),
                                   self.ema_proj_in.parameters()):
            p_ema.data.mul_(self.ema_decay).add_(p_online.data, alpha=1 - self.ema_decay)

    def forward(self, codec_emb, text_tokens, codec_mask=None, text_mask=None):
        """
        Args:
            codec_emb: [B, 128, T] — continuous EnCodec embeddings
            text_tokens: [B, T_text]
            codec_mask: [B, T] bool, True = padding
            text_mask: [B, T_text] bool
        """
        B = codec_emb.shape[0]
        T = codec_emb.shape[2]

        # codec_emb is [B, 128, T] → transpose to [B, T, 128]
        codec_seq = codec_emb.transpose(1, 2)

        # Text encoding
        text_emb = self.text_encoder(text_tokens, text_mask)

        # Duration prediction + ground truth durations (uniform alignment)
        log_dur_pred = self.duration_predictor(text_emb, text_mask)

        # Compute actual text/audio lengths from masks
        if text_mask is not None:
            text_lengths = (~text_mask).sum(dim=1)  # [B]
        else:
            text_lengths = torch.full((B,), text_tokens.shape[1], device=text_tokens.device)
        if codec_mask is not None:
            audio_lengths = (~codec_mask).sum(dim=1)  # [B]
        else:
            audio_lengths = torch.full((B,), T, device=codec_emb.device)

        # Ground truth durations (uniform distribution)
        gt_durations = compute_uniform_durations(text_lengths, audio_lengths)
        log_dur_gt = torch.log(gt_durations.float().clamp(min=1))

        # Duration loss (MSE on log-durations)
        if text_mask is not None:
            valid_text = (~text_mask).float()
            dur_loss = F.mse_loss(
                log_dur_pred * valid_text, log_dur_gt * valid_text, reduction="sum"
            ) / (valid_text.sum() + 1e-8)
        else:
            dur_loss = F.mse_loss(log_dur_pred, log_dur_gt)

        # Expand text to frame level using GT durations (training uses GT)
        text_emb_expanded = length_regulate(text_emb, gt_durations, text_mask)
        # Trim/pad to match audio length T
        if text_emb_expanded.shape[1] > T:
            text_emb_expanded = text_emb_expanded[:, :T]
        elif text_emb_expanded.shape[1] < T:
            pad = torch.zeros(B, T - text_emb_expanded.shape[1], text_emb.shape[-1],
                            device=text_emb.device)
            text_emb_expanded = torch.cat([text_emb_expanded, pad], dim=1)

        # Project codec embeddings to predictor space
        z = self.proj_in(codec_seq)  # [B, T, d_model]

        # EMA targets (deterministic)
        with torch.no_grad():
            self.ema_proj_in.eval()
            target_emb = self.ema_proj_in(codec_seq)  # [B, T, d_model]
            self.ema_proj_in.train()

        # Prepend start embedding, shift right
        start = self.start_emb.expand(B, -1, -1)
        input_emb = torch.cat([start, z[:, :-1]], dim=1)  # [B, T, d_model]

        # Input noise
        if self.training and self.input_noise > 0:
            noise = torch.zeros_like(input_emb)
            noise[:, 1:] = torch.randn(B, T - 1, input_emb.shape[-1],
                                        device=input_emb.device) * self.input_noise
            input_emb = input_emb + noise

        # Masks
        if codec_mask is not None:
            start_mask = torch.zeros(B, 1, dtype=torch.bool, device=codec_emb.device)
            pred_mask = torch.cat([start_mask, codec_mask[:, :-1]], dim=1)
            loss_mask = codec_mask
        else:
            pred_mask = None
            loss_mask = None

        # JEPA prediction with frame-level text
        predicted = self.predictor(input_emb, text_emb_expanded, pred_mask)

        # ─── Loss 1: Prediction (MSE vs EMA targets) ───
        if loss_mask is not None:
            valid = (~loss_mask).unsqueeze(-1)
            prediction_loss = F.mse_loss(
                predicted * valid, target_emb * valid, reduction="sum"
            ) / (valid.sum() * predicted.shape[-1] + 1e-8)
        else:
            prediction_loss = F.mse_loss(predicted, target_emb)

        # ─── Loss 2: Roundtrip (proj_in → proj_out → match original codec_emb) ───
        codec_recon = self.proj_out(z)  # [B, T, 128]

        if codec_mask is not None:
            valid_rt = (~codec_mask).unsqueeze(-1).float()
            roundtrip_loss = F.l1_loss(
                codec_recon * valid_rt, codec_seq * valid_rt, reduction="sum"
            ) / (valid_rt.sum() * self.codec_dim + 1e-8)
        else:
            roundtrip_loss = F.l1_loss(codec_recon, codec_seq)

        total_loss = (self.pred_weight * prediction_loss
                     + self.roundtrip_weight * roundtrip_loss
                     + self.dur_weight * dur_loss)

        return {
            "total_loss": total_loss,
            "prediction_loss": prediction_loss,
            "roundtrip_loss": roundtrip_loss,
            "dur_loss": dur_loss,
        }

    def predict_durations(self, text_emb, text_mask=None):
        """Predict durations from text embeddings (for inference)."""
        log_dur = self.duration_predictor(text_emb, text_mask)
        durations = torch.exp(log_dur).round().long().clamp(min=1)
        if text_mask is not None:
            durations = durations.masked_fill(text_mask, 0)
        return durations

    def init_ar_cache(self):
        num_layers = len(self.predictor.transformer.layers)
        return self.predictor.init_cache(num_layers)

    def predict_next_cached(self, new_emb, text_emb_frame, step_idx, cache):
        """
        new_emb: [B, 1, d_model]
        text_emb_frame: [B, 1, d_model] — expanded text for this frame
        """
        return self.predictor.inference_step(new_emb, text_emb_frame, step_idx, cache)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def build_model(config=None):
    if config is None:
        config = {
            "d_model": 256, "nhead": 4, "codec_dim": 128,
            "text_vocab_size": 256, "text_encoder_layers": 4,
            "predictor_layers": 6, "dropout": 0.1,
            "pred_weight": 10.0, "roundtrip_weight": 1.0,
            "ema_decay": 0.998, "input_noise": 0.0,
        }
    model = LeWMTTSv5(config)
    print(f"LeWM TTS v5: {count_parameters(model)/1e6:.2f}M trainable parameters")
    return model, config
