"""Dataset for v7 — 8-level EnCodec tokens + text."""
import json, torch
from torch.utils.data import Dataset, DataLoader


class TokenDataset(Dataset):
    def __init__(self, manifest_path, max_tokens=750, max_text=512):
        with open(manifest_path, "r", encoding="utf-8") as f:
            self.data = [e for e in json.load(f) if e["n_tokens"] <= max_tokens]
        print(f"Dataset: {len(self.data)} samples")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        e = self.data[idx]
        tokens = torch.load(e["token_path"], weights_only=True).long()  # [n_rvq, T]
        text = torch.tensor(list(e["text"].encode("utf-8"))[:512], dtype=torch.long)
        result = {"tokens": tokens, "text_tokens": text,
                  "n_tokens": tokens.shape[1], "text_len": text.shape[0]}
        if "speaker_id_num" in e:
            result["speaker_id"] = e["speaker_id_num"]
        elif "speaker_id" in e and isinstance(e["speaker_id"], int):
            result["speaker_id"] = e["speaker_id"]
        return result


def collate_fn(batch):
    max_t = max(s["n_tokens"] for s in batch)
    max_tx = max(s["text_len"] for s in batch)
    B = len(batch)
    n_rvq = batch[0]["tokens"].shape[0]

    tokens = torch.zeros(B, n_rvq, max_t, dtype=torch.long)
    text = torch.zeros(B, max_tx, dtype=torch.long)
    tok_mask = torch.ones(B, max_t, dtype=torch.bool)
    text_mask = torch.ones(B, max_tx, dtype=torch.bool)

    for i, s in enumerate(batch):
        t, tx = s["n_tokens"], s["text_len"]
        tokens[i, :, :t] = s["tokens"]
        text[i, :tx] = s["text_tokens"]
        tok_mask[i, :t] = False
        text_mask[i, :tx] = False

    result = {"tokens": tokens, "text_tokens": text,
              "token_mask": tok_mask, "text_mask": text_mask}
    if "speaker_id" in batch[0]:
        result["speaker_id"] = torch.tensor([s["speaker_id"] for s in batch], dtype=torch.long)
    return result


def build_dataloader(manifest_path, batch_size=16, num_workers=0, shuffle=True):
    ds = TokenDataset(manifest_path)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
                    collate_fn=collate_fn, pin_memory=True, drop_last=True)
    return dl, ds
