"""Dataset for EnCodec token-based TTS training."""

import json
import torch
from torch.utils.data import Dataset, DataLoader


class TokenTTSDataset(Dataset):
    def __init__(self, manifest_path, max_tokens=900, max_text_len=512):
        with open(manifest_path, "r", encoding="utf-8") as f:
            self.manifest = json.load(f)
        self.manifest = [e for e in self.manifest if e["n_tokens"] <= max_tokens]
        print(f"Dataset: {len(self.manifest)} samples loaded")

    def __len__(self):
        return len(self.manifest)

    def __getitem__(self, idx):
        entry = self.manifest[idx]
        tokens = torch.load(entry["token_path"], weights_only=True).long()
        text = list(entry["text"].encode("utf-8"))[:512]
        text_tokens = torch.tensor(text, dtype=torch.long)
        return {
            "tokens": tokens,
            "text_tokens": text_tokens,
            "n_tokens": tokens.shape[0],
            "text_len": len(text),
        }


def collate_fn(batch):
    max_tok = max(s["n_tokens"] for s in batch)
    max_text = max(s["text_len"] for s in batch)
    B = len(batch)

    tok_padded = torch.zeros(B, max_tok, dtype=torch.long)
    text_padded = torch.zeros(B, max_text, dtype=torch.long)
    tok_mask = torch.ones(B, max_tok, dtype=torch.bool)
    text_mask = torch.ones(B, max_text, dtype=torch.bool)

    for i, s in enumerate(batch):
        t = s["n_tokens"]
        tx = s["text_len"]
        tok_padded[i, :t] = s["tokens"]
        text_padded[i, :tx] = s["text_tokens"]
        tok_mask[i, :t] = False
        text_mask[i, :tx] = False

    return {
        "tokens": tok_padded,
        "text_tokens": text_padded,
        "token_mask": tok_mask,
        "text_mask": text_mask,
    }


def build_dataloader(manifest_path, batch_size=16, num_workers=4, shuffle=True):
    dataset = TokenTTSDataset(manifest_path)
    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle,
        num_workers=num_workers, collate_fn=collate_fn,
        pin_memory=True, drop_last=True,
    )
    return loader, dataset
