"""
Dataset for codec-based JEPA TTS v5.
Loads precomputed EnCodec embeddings + text.
"""

import json
import torch
from torch.utils.data import Dataset, DataLoader


class CodecDataset(Dataset):
    def __init__(self, manifest_path, max_codec_frames=900):
        with open(manifest_path, "r") as f:
            self.manifest = json.load(f)
        self.max_codec_frames = max_codec_frames
        # Filter out too-long samples
        self.manifest = [e for e in self.manifest if e["emb_frames"] <= max_codec_frames]
        print(f"CodecDataset: {len(self.manifest)} samples (max {max_codec_frames} frames)")

    def __len__(self):
        return len(self.manifest)

    def __getitem__(self, idx):
        entry = self.manifest[idx]
        emb = torch.load(entry["emb_path"], weights_only=True)  # [128, T]
        text_tokens = torch.tensor(list(entry["text"].encode("utf-8")), dtype=torch.long)
        return {"codec_emb": emb, "text_tokens": text_tokens}


def collate_fn(batch):
    codec_embs = [b["codec_emb"] for b in batch]
    text_tokens = [b["text_tokens"] for b in batch]

    # Pad codec embeddings
    codec_dim = codec_embs[0].shape[0]
    max_codec_len = max(e.shape[1] for e in codec_embs)
    padded_codec = torch.zeros(len(batch), codec_dim, max_codec_len)
    codec_mask = torch.ones(len(batch), max_codec_len, dtype=torch.bool)
    for i, e in enumerate(codec_embs):
        T = e.shape[1]
        padded_codec[i, :, :T] = e
        codec_mask[i, :T] = False

    # Pad text
    max_text_len = max(t.shape[0] for t in text_tokens)
    padded_text = torch.zeros(len(batch), max_text_len, dtype=torch.long)
    text_mask = torch.ones(len(batch), max_text_len, dtype=torch.bool)
    for i, t in enumerate(text_tokens):
        padded_text[i, :t.shape[0]] = t
        text_mask[i, :t.shape[0]] = False

    return {
        "codec_emb": padded_codec,
        "text_tokens": padded_text,
        "codec_mask": codec_mask,
        "text_mask": text_mask,
    }


def build_dataloader(manifest_path, batch_size=64, num_workers=4, max_codec_frames=900):
    dataset = CodecDataset(manifest_path, max_codec_frames=max_codec_frames)
    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, collate_fn=collate_fn,
        pin_memory=True, drop_last=True,
    )
    return loader, dataset
