"""Dataset for DAC-based LeWM TTS."""

import json
import torch
from torch.utils.data import Dataset, DataLoader


class TTSDataset(Dataset):
    def __init__(self, manifest_path, max_latent_frames=900):
        with open(manifest_path, "r", encoding="utf-8") as f:
            self.manifest = json.load(f)
        self.manifest = [e for e in self.manifest if e["latent_frames"] <= max_latent_frames]
        print(f"Dataset: {len(self.manifest)} samples")

    def __len__(self):
        return len(self.manifest)

    def __getitem__(self, idx):
        entry = self.manifest[idx]
        latent = torch.load(entry["latent_path"], weights_only=True).float()  # [1024, T]
        latent = latent.transpose(0, 1)  # [T, 1024]
        tokens = list(entry["text"].encode("utf-8"))
        text_tokens = torch.tensor(tokens, dtype=torch.long)
        return {
            "latent": latent,
            "text_tokens": text_tokens,
            "latent_frames": latent.shape[0],
            "text_len": len(tokens),
        }


def collate_fn(batch):
    max_lat = max(s["latent_frames"] for s in batch)
    max_text = max(s["text_len"] for s in batch)
    dac_dim = batch[0]["latent"].shape[1]
    B = len(batch)

    lat_padded = torch.zeros(B, max_lat, dac_dim)
    text_padded = torch.zeros(B, max_text, dtype=torch.long)
    lat_mask = torch.ones(B, max_lat, dtype=torch.bool)
    text_mask = torch.ones(B, max_text, dtype=torch.bool)

    for i, s in enumerate(batch):
        t_lat = s["latent_frames"]
        t_text = s["text_len"]
        lat_padded[i, :t_lat] = s["latent"]
        text_padded[i, :t_text] = s["text_tokens"]
        lat_mask[i, :t_lat] = False
        text_mask[i, :t_text] = False

    return {
        "latent": lat_padded,
        "text_tokens": text_padded,
        "latent_mask": lat_mask,
        "text_mask": text_mask,
    }


def build_dataloader(manifest_path, batch_size=16, num_workers=4, shuffle=True):
    dataset = TTSDataset(manifest_path)
    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle,
        num_workers=num_workers, collate_fn=collate_fn,
        pin_memory=True, drop_last=True,
    )
    return loader, dataset
