"""
Ultra-fast dataset for extracted mel files.

Design for maximum GPU saturation:
  - Pre-built index (built offline, ~500MB on disk)
  - Text files loaded via linecache (OS-level caching, zero-copy for workers)
  - Frame counts as memory-mapped numpy (shared across all workers, zero RAM per fork)
  - BucketBatchSampler does all shuffling/batching logic in the main process
  - Workers ONLY do np.load() + normalize — minimal per-worker memory
  - Each np.load() takes ~0.15ms (direct file read, no tar scanning)

Expected: GPU util 80-95%, step time limited by compute not I/O.
"""

import os
import json
import random
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
from dataclasses import dataclass
from typing import List


INDEX_DIR = "/workspace/maya-asr/mel_extracted/training_index"


@dataclass
class FastDataConfig:
    index_dir: str = INDEX_DIR
    max_mel_frames: int = 3500
    max_tokens: int = 384
    max_batch_mel_frames: int = 800000
    max_batch_utterances: int = 64
    temperature: float = 5.0
    seed: int = 42


class FastMelDataset(Dataset):
    """Minimal-memory dataset using pre-built index + direct file reads."""

    def __init__(self, config: FastDataConfig, tokenizer=None):
        self.config = config
        self.tokenizer = tokenizer

        idx_dir = config.index_dir

        # Memory-mapped numpy arrays — shared across forked workers (copy-on-write)
        self.n_frames = np.load(os.path.join(idx_dir, "n_frames.npy"), mmap_mode='r')
        self.total = len(self.n_frames)

        # Load text arrays into RAM once — workers share via COW after fork.
        # This costs ~25s startup + ~25GB RAM (shared, not per-worker).
        with open(os.path.join(idx_dir, "mel_paths.txt")) as f:
            self.mel_paths = f.read().split('\n')
        with open(os.path.join(idx_dir, "transcripts.txt")) as f:
            self.transcripts = f.read().split('\n')
        with open(os.path.join(idx_dir, "languages.txt")) as f:
            self.languages = f.read().split('\n')

        # Pre-load bucket order (sorted indices by duration)
        self.bucket_order = np.load(os.path.join(idx_dir, "bucket_order.npy"), mmap_mode='r')

    def __len__(self):
        return self.total

    def __getitem__(self, idx):
        mel_path = self.mel_paths[idx]
        transcript = self.transcripts[idx]
        language = self.languages[idx]
        n_frames = int(self.n_frames[idx])

        # Load mel — single np.load, ~0.15ms per file
        mel = np.load(mel_path)  # [128, T] float16

        # Per-feature normalization (vectorized, fast)
        mel = mel.astype(np.float32)
        mean = mel.mean(axis=-1, keepdims=True)
        std = np.maximum(mel.std(axis=-1, keepdims=True), 1e-5)
        mel = ((mel - mean) / std).astype(np.float16)

        if n_frames > self.config.max_mel_frames:
            mel = mel[:, :self.config.max_mel_frames]
            n_frames = self.config.max_mel_frames

        # Tokenize
        if self.tokenizer is not None:
            token_ids = self.tokenizer.encode(transcript)
            if len(token_ids) > self.config.max_tokens:
                token_ids = token_ids[:self.config.max_tokens]
            tokens = np.array(token_ids, dtype=np.int32)
        else:
            tokens = np.array([], dtype=np.int32)

        return {
            'mel': mel,
            'tokens': tokens,
            'transcript': transcript,
            'language': language,
            'n_frames': n_frames,
            'n_tokens': len(tokens),
        }


class BucketBatchSampler(Sampler):
    """Duration-bucketed batch sampler.

    All heavy logic (shuffling, batching, sharding) runs in the main process.
    Workers just get lists of indices to fetch.

    Algorithm:
      1. Take pre-sorted indices (by duration)
      2. Split into buckets of ~2000 items (similar durations)
      3. Shuffle within each bucket (randomization)
      4. Shuffle bucket order
      5. Shard across DDP ranks
      6. Form batches respecting frame budget + utterance cap
    """

    def __init__(self, n_frames, bucket_order, max_batch_mel_frames,
                 max_batch_utterances, bucket_size=2000,
                 shuffle=True, seed=42, epoch=0,
                 rank=0, world_size=1):
        self.n_frames = n_frames
        self.bucket_order = bucket_order
        self.max_batch_mel_frames = max_batch_mel_frames
        self.max_batch_utterances = max_batch_utterances
        self.bucket_size = bucket_size
        self.shuffle = shuffle
        self.seed = seed
        self.epoch = epoch
        self.rank = rank
        self.world_size = world_size

    def set_epoch(self, epoch):
        self.epoch = epoch

    def __iter__(self):
        rng = random.Random(self.seed + self.epoch)

        # Use pre-sorted order
        indices = list(self.bucket_order)

        # Split into buckets
        buckets = [indices[i:i + self.bucket_size]
                    for i in range(0, len(indices), self.bucket_size)]

        if self.shuffle:
            for b in buckets:
                rng.shuffle(b)
            rng.shuffle(buckets)

        # Flatten
        flat = [idx for b in buckets for idx in b]

        # Shard for DDP
        per_rank = len(flat) // self.world_size
        start = self.rank * per_rank
        my_indices = flat[start:start + per_rank]

        # Form batches
        batch = []
        batch_max_frames = 0
        for idx in my_indices:
            frames = int(self.n_frames[idx])

            # Would adding this item exceed limits?
            new_max = max(batch_max_frames, frames)
            projected_total = (len(batch) + 1) * new_max  # padded frame count

            if batch and (len(batch) >= self.max_batch_utterances or
                          projected_total > self.max_batch_mel_frames):
                yield batch
                batch = [idx]
                batch_max_frames = frames
            else:
                batch.append(idx)
                batch_max_frames = new_max

        if batch:
            yield batch

    def __len__(self):
        avg_frames = int(np.mean(self.n_frames))
        utts_per_batch = min(
            self.max_batch_utterances,
            self.max_batch_mel_frames // max(avg_frames, 1)
        )
        return len(self.n_frames) // (self.world_size * max(utts_per_batch, 1))


def fast_collate(batch_items: List[dict]) -> dict:
    """Collate batch — pad mels and tokens to batch max."""
    max_frames = max(item['n_frames'] for item in batch_items)
    max_tokens = max((item['n_tokens'] for item in batch_items), default=1)
    B = len(batch_items)

    mel_padded = torch.zeros(B, 128, max_frames, dtype=torch.float16)
    mel_lengths = torch.zeros(B, dtype=torch.long)
    tokens_padded = torch.full((B, max_tokens), -100, dtype=torch.long)
    token_lengths = torch.zeros(B, dtype=torch.long)
    languages = []

    for i, item in enumerate(batch_items):
        n_f, n_t = item['n_frames'], item['n_tokens']
        mel = torch.from_numpy(item['mel']) if isinstance(item['mel'], np.ndarray) else item['mel']
        mel_padded[i, :, :n_f] = mel
        mel_lengths[i] = n_f
        if n_t > 0:
            tok = torch.from_numpy(item['tokens']) if isinstance(item['tokens'], np.ndarray) else item['tokens']
            tokens_padded[i, :n_t] = tok.long()
        token_lengths[i] = n_t
        languages.append(item['language'])

    return {
        'mel': mel_padded,
        'mel_lengths': mel_lengths,
        'labels': tokens_padded,
        'token_lengths': token_lengths,
        'languages': languages,
    }


class FlatIndexSampler(Sampler):
    """Yields individual indices from pre-computed batch lists.

    This lets PyTorch distribute individual items across workers,
    so multiple workers fetch items in parallel (vs batch_sampler
    which gives all indices to one worker).
    """
    def __init__(self, batches):
        self.batches = batches
        self._flat = [idx for batch in batches for idx in batch]
        self._batch_boundaries = []
        pos = 0
        for b in batches:
            pos += len(b)
            self._batch_boundaries.append(pos)

    def __iter__(self):
        return iter(self._flat)

    def __len__(self):
        return len(self._flat)


def create_fast_dataloader(config: FastDataConfig, tokenizer=None,
                            num_workers=4, rank=0, world_size=1, epoch=0):
    """Create DataLoader with bucketed batching on extracted mel files.

    Uses batch_sampler for correct dynamic batching. With prefetch_factor=4
    and multiple workers, PyTorch will prefetch multiple batches in parallel.
    """
    dataset = FastMelDataset(config, tokenizer=tokenizer)

    sampler = BucketBatchSampler(
        n_frames=dataset.n_frames,
        bucket_order=dataset.bucket_order,
        max_batch_mel_frames=config.max_batch_mel_frames,
        max_batch_utterances=config.max_batch_utterances,
        shuffle=True,
        seed=config.seed,
        epoch=epoch,
        rank=rank,
        world_size=world_size,
    )

    loader = DataLoader(
        dataset,
        batch_sampler=sampler,
        collate_fn=fast_collate,
        num_workers=num_workers,
        persistent_workers=True if num_workers > 0 else False,
        prefetch_factor=8 if num_workers > 0 else None,  # prefetch more batches
        pin_memory=True,
    )

    return loader, sampler
