"""
Dataset and DataLoader for LeWM TTS training.
"""

import json
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path


class TTSDataset(Dataset):
    """Loads preprocessed mel spectrograms and text for TTS training."""

    def __init__(self, manifest_path, max_mel_frames=1200, max_text_len=512):
        with open(manifest_path, "r", encoding="utf-8") as f:
            self.manifest = json.load(f)

        self.max_mel_frames = max_mel_frames
        self.max_text_len = max_text_len

        # Filter out entries that are too long
        self.manifest = [
            e for e in self.manifest
            if e["mel_frames"] <= max_mel_frames
        ]
        print(f"Dataset: {len(self.manifest)} samples loaded")

    def __len__(self):
        return len(self.manifest)

    def text_to_tokens(self, text):
        """Convert text to byte-level tokens (handles any Unicode including Devanagari)."""
        # UTF-8 byte encoding — simple, universal, no tokenizer needed
        tokens = list(text.encode("utf-8"))
        return tokens

    def __getitem__(self, idx):
        entry = self.manifest[idx]

        # Load mel
        mel = torch.load(entry["mel_path"], weights_only=True)  # [n_mels, T]

        # Tokenize text
        tokens = self.text_to_tokens(entry["text"])
        tokens = tokens[:self.max_text_len]
        text_tokens = torch.tensor(tokens, dtype=torch.long)

        result = {
            "mel": mel,
            "text_tokens": text_tokens,
            "mel_frames": mel.shape[1],
            "text_len": len(tokens),
        }

        # Multi-speaker: pass speaker_id if present in manifest
        if "speaker_id" in entry:
            result["speaker_id"] = entry["speaker_id"]

        return result


def collate_fn(batch):
    """Pad and batch samples."""
    # Find max lengths
    max_mel = max(s["mel_frames"] for s in batch)
    max_text = max(s["text_len"] for s in batch)
    n_mels = batch[0]["mel"].shape[0]
    B = len(batch)

    # Allocate padded tensors
    mel_padded = torch.zeros(B, n_mels, max_mel)
    text_padded = torch.zeros(B, max_text, dtype=torch.long)
    mel_mask = torch.ones(B, max_mel, dtype=torch.bool)   # True = padding
    text_mask = torch.ones(B, max_text, dtype=torch.bool)

    for i, s in enumerate(batch):
        t_mel = s["mel_frames"]
        t_text = s["text_len"]
        mel_padded[i, :, :t_mel] = s["mel"]
        text_padded[i, :t_text] = s["text_tokens"]
        mel_mask[i, :t_mel] = False
        text_mask[i, :t_text] = False

    result = {
        "mel": mel_padded,
        "text_tokens": text_padded,
        "mel_mask": mel_mask,
        "text_mask": text_mask,
    }

    # Multi-speaker: batch speaker_ids if present
    if "speaker_id" in batch[0]:
        result["speaker_id"] = torch.tensor([s["speaker_id"] for s in batch], dtype=torch.long)

    return result


def build_dataloader(manifest_path, batch_size=16, num_workers=4, shuffle=True):
    dataset = TTSDataset(manifest_path)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True,
        drop_last=True,
    )
    return loader, dataset
