"""
Dataset and DataLoader for Cohere Transcribe finetuning.
Reads pre-computed mel shards from /workspace/maya-asr/mel_shards/per_shard/.

Shard format (per utterance triplet):
  {segment_id}.mel.npy  — float16 [80, T] mel spectrogram
  {segment_id}.txt      — UTF-8 transcript text
  {segment_id}.json     — metadata (segment_id, language, duration_s, n_mel_frames, prefix, bucket)

Implements:
  - On-the-fly tokenization (transcripts stored as text, not pre-tokenized)
  - Temperature-based multilingual sampling
  - Dynamic bucketed batching by mel frame count
  - Shuffle buffer for within-epoch randomization
  - Automatic filtering of empty transcripts
  - PRE-BATCHING IN WORKERS to eliminate main-process bottleneck
"""

import os
import torch
import numpy as np
import json
import io
import tarfile
import random
import pandas as pd
from pathlib import Path
from torch.utils.data import IterableDataset, DataLoader
from dataclasses import dataclass, field
from typing import List, Optional
import logging

logger = logging.getLogger(__name__)


@dataclass
class DataConfig:
    mel_shards_dir: str = "/workspace/maya-asr/mel_shards/per_shard"
    manifest_path: str = "/workspace/maya-asr/mel_shards/training_manifest.parquet"
    max_mel_frames: int = 3500          # ~35s at 100fps
    max_tokens: int = 448               # max transcript tokens (leave room for special tokens)
    max_batch_mel_frames: int = 1_200_000  # total mel frames per batch — H200 can handle more
    max_batch_utterances: int = 0       # optional hard cap; 0 disables the cap
    temperature: float = 5.0            # language sampling temperature
    shuffle_buffer_size: int = 8000     # utterances in shuffle buffer per worker
    sort_buffer_for_batching: bool = True  # sort shuffle buffer by duration before emitting
    seed: int = 42


class MelShardDataset(IterableDataset):
    """
    Streams pre-computed mel spectrograms from tar shards.

    Design:
    - Each DataLoader worker gets a disjoint set of shards (no duplication across workers or DDP ranks)
    - Shard-level shuffling for inter-shard randomization
    - STREAMING tar reads — yields utterances as they are read, not after full shard load
    - In-memory shuffle buffer for within-shard randomization
    - Sorts buffer by frame count before emitting → creates low-padding batches
    - PRE-BATCHES in workers so the main process gets ready-to-go collated batches
    - Skips utterances with empty transcripts
    """

    def __init__(self, config: DataConfig, tokenizer=None, split: str = "train",
                 prebatch: bool = False, max_batch_mel_frames: int = 0,
                 max_batch_utterances: int = 0):
        self.config = config
        self.split = split
        self.tokenizer = tokenizer
        self.prebatch = prebatch
        self.max_batch_mel_frames = max_batch_mel_frames
        self.max_batch_utterances = max_batch_utterances

        # Load manifest with caching — parsing 73M rows + groupby is slow (~5min),
        # so we cache the derived lookups to a pickle file for instant reloads.
        import hashlib, pickle
        cache_key = hashlib.md5(config.manifest_path.encode()).hexdigest()[:12]
        cache_path = config.manifest_path + f".cache_{cache_key}.pkl"

        if os.path.exists(cache_path) and os.path.getmtime(cache_path) >= os.path.getmtime(config.manifest_path):
            with open(cache_path, "rb") as f:
                cached = pickle.load(f)
            self.all_shards = cached["all_shards"]
            self._transcript_lookup = cached["transcript_lookup"]
            self._valid_ids_by_shard = cached["valid_ids_by_shard"]
            self._lang_hours = cached["lang_hours"]
            self._shard_weights = cached["shard_weights"]
        else:
            manifest = pd.read_parquet(
                config.manifest_path,
                columns=['segment_id', 'language', 'transcript', 'duration_s', 'shard_path']
            )
            has_transcript = manifest['transcript'].notna() & (manifest['transcript'] != '')
            manifest = manifest[has_transcript]

            shard_paths = manifest['shard_path'].unique()
            self.all_shards = sorted(shard_paths.tolist())

            self._transcript_lookup = dict(zip(
                manifest['segment_id'].values,
                zip(manifest['transcript'].values, manifest['language'].values)
            ))

            self._valid_ids_by_shard = {}
            for shard_path, group in manifest.groupby('shard_path'):
                self._valid_ids_by_shard[shard_path] = set(group['segment_id'].values)

            lang_counts = manifest.groupby('language')['segment_id'].count()
            self._lang_hours = manifest.groupby('language')['duration_s'].sum() / 3600

            if config.temperature > 0 and split == "train":
                lang_probs = (lang_counts / lang_counts.sum()) ** (1.0 / config.temperature)
                lang_probs = lang_probs / lang_probs.sum()
                shard_langs = manifest.groupby('shard_path')['language'].agg(lambda x: x.mode().iloc[0])
                self._shard_weights = {sp: float(lang_probs.get(lang, 1.0 / len(lang_probs)))
                                       for sp, lang in shard_langs.items()}
            else:
                self._shard_weights = {sp: 1.0 for sp in self.all_shards}

            del manifest

            # Save cache for instant reload
            try:
                with open(cache_path, "wb") as f:
                    pickle.dump({
                        "all_shards": self.all_shards,
                        "transcript_lookup": self._transcript_lookup,
                        "valid_ids_by_shard": self._valid_ids_by_shard,
                        "lang_hours": self._lang_hours,
                        "shard_weights": self._shard_weights,
                    }, f, protocol=pickle.HIGHEST_PROTOCOL)
            except Exception:
                pass  # cache write failure is not fatal

        # Tracking counters for diagnostics
        self._truncated_count = 0
        self._skipped_count = 0

    def _get_worker_shards(self):
        """Partition shards across DataLoader workers and DDP ranks."""
        worker_info = torch.utils.data.get_worker_info()

        if torch.distributed.is_initialized():
            rank = torch.distributed.get_rank()
            world_size = torch.distributed.get_world_size()
        else:
            rank = 0
            world_size = 1

        if worker_info is not None:
            worker_id = worker_info.id
            num_workers = worker_info.num_workers
        else:
            worker_id = 0
            num_workers = 1

        global_worker_id = rank * num_workers + worker_id
        total_workers = world_size * num_workers

        # Round-robin shard assignment
        my_shards = [
            self.all_shards[i]
            for i in range(global_worker_id, len(self.all_shards), total_workers)
        ]
        return my_shards

    def _iter_shard_streaming(self, shard_path):
        """Read utterances from a shard — extracted dir (fast) or tar (fallback).

        If /workspace/maya-asr/mel_extracted/{shard_name}/ exists, reads individual
        .npy files directly (~0.1ms each warm). Otherwise falls back to tar streaming.
        """
        from pathlib import Path
        shard_name = Path(shard_path).stem
        extracted_dir = os.path.join("/workspace/maya-asr/mel_extracted", shard_name)

        if os.path.isdir(extracted_dir):
            yield from self._iter_extracted_dir(extracted_dir, shard_path)
        else:
            yield from self._iter_tar_streaming(shard_path)

    def _iter_extracted_dir(self, extracted_dir, shard_path):
        """Fast path: read per-shard parquet index + direct .npy files.

        Each shard has a small parquet (~1-3MB) with segment_id, language,
        transcript, n_mel_frames, mel_path. Load it, iterate, np.load each mel.
        No global state needed — bounded memory per shard.
        """
        import pandas as pd
        from pathlib import Path

        shard_name = Path(shard_path).stem
        index_path = os.path.join("/workspace/maya-asr/mel_extracted_index", f"{shard_name}.parquet")

        if not os.path.exists(index_path):
            # Fallback: use global transcript lookup
            yield from self._iter_extracted_dir_fallback(extracted_dir, shard_path)
            return

        try:
            shard_df = pd.read_parquet(index_path)
        except Exception:
            return

        for _, row in shard_df.iterrows():
            transcript = row['transcript']
            if not transcript or not str(transcript).strip():
                continue

            mel_path = row['mel_path']
            if not os.path.exists(mel_path):
                continue

            mel = np.load(mel_path)
            yield from self._process_mel(mel, str(transcript), row['language'])

    def _iter_extracted_dir_fallback(self, extracted_dir, shard_path):
        """Fallback: iterate files in extracted dir using global transcript lookup."""
        valid_ids = self._valid_ids_by_shard.get(shard_path)
        if not valid_ids:
            return

        for fname in os.listdir(extracted_dir):
            if not fname.endswith('.mel.npy'):
                continue
            seg_id = fname[:-8]
            info = self._transcript_lookup.get(seg_id)
            if info is None:
                info = self._transcript_lookup.get(seg_id + '.flac')
            if info is None:
                continue
            transcript, language = info
            if not transcript or not transcript.strip():
                continue
            mel = np.load(os.path.join(extracted_dir, fname))
            yield from self._process_mel(mel, transcript, language)

    def _iter_tar_streaming(self, shard_path):
        """Fallback: stream from tar file."""
        valid_ids = self._valid_ids_by_shard.get(shard_path)
        if not valid_ids:
            return

        try:
            with tarfile.open(shard_path, 'r') as tar:
                for member in tar:
                    if not member.name.endswith('.mel.npy'):
                        continue

                    seg_id = member.name[:-8]
                    if seg_id not in valid_ids:
                        continue

                    info = self._transcript_lookup.get(seg_id)
                    if info is None:
                        continue
                    transcript, language = info
                    if not transcript or not transcript.strip():
                        continue

                    f = tar.extractfile(member)
                    if f is None:
                        continue
                    mel = np.load(io.BytesIO(f.read()))

                    yield from self._process_mel(mel, transcript, language)
        except Exception:
            self._skipped_count += 1

    def _process_mel(self, mel, transcript, language):
        """Normalize mel, tokenize transcript, yield item dict."""
        mel_f32 = mel.astype(np.float32)
        mean = mel_f32.mean(axis=-1, keepdims=True)
        std = np.maximum(mel_f32.std(axis=-1, keepdims=True), 1e-5)
        mel = ((mel_f32 - mean) / std).astype(np.float16)

        n_frames = mel.shape[1]
        if n_frames > self.config.max_mel_frames:
            mel = mel[:, :self.config.max_mel_frames]
            n_frames = self.config.max_mel_frames

        if self.tokenizer is not None:
            token_ids = self.tokenizer.encode(transcript)
            if len(token_ids) > self.config.max_tokens:
                self._truncated_count += 1
                token_ids = token_ids[:self.config.max_tokens]
            tokens = np.array(token_ids, dtype=np.int32)
        else:
            tokens = np.array([], dtype=np.int32)

        yield {
            'mel': mel,
            'tokens': tokens,
            'transcript': transcript,
            'language': language,
            'n_frames': n_frames,
            'n_tokens': len(tokens),
        }

    def _emit_from_buffer(self, buffer, fraction=0.5):
        """Sort and emit a fraction of the buffer as individual items."""
        n_emit = max(1, int(len(buffer) * fraction))
        if self.config.sort_buffer_for_batching:
            # Sort entire buffer, emit the first half (shorter utterances)
            # This gives the dynamic batcher similar-length sequences
            buffer.sort(key=lambda x: x['n_frames'])
        emit_items = buffer[:n_emit]
        remaining = buffer[n_emit:]
        return emit_items, remaining

    def _make_batch_from_items(self, items):
        """Create a single batch from a list of items using dynamic batching logic."""
        batches = []
        current_batch = []
        current_max_frames = 0

        for item in items:
            item_frames = item['n_frames']
            new_max = max(current_max_frames, item_frames)
            new_size = len(current_batch) + 1
            projected = new_size * new_max
            over_utt_cap = self.max_batch_utterances > 0 and new_size > self.max_batch_utterances

            if current_batch and (projected > self.max_batch_mel_frames or over_utt_cap):
                batches.append(dynamic_batch_collate(current_batch))
                current_batch = [item]
                current_max_frames = item_frames
            else:
                current_batch.append(item)
                current_max_frames = new_max

        if current_batch:
            batches.append(dynamic_batch_collate(current_batch))
        return batches

    def __iter__(self):
        my_shards = self._get_worker_shards()

        if self.split == "train":
            # Temperature-weighted shard sampling: oversample shards from low-resource languages
            weights = [self._shard_weights.get(sp, 1.0) for sp in my_shards]
            total_w = sum(weights)
            weights = [w / total_w for w in weights]
            # Weighted shuffle: sample without replacement, weighted by language probability
            indices = list(range(len(my_shards)))
            random.shuffle(indices)  # base shuffle
            # Sort by weight descending with random tiebreak for weighted ordering
            indices.sort(key=lambda i: -weights[i] + random.random() * 0.001)
            my_shards = [my_shards[i] for i in indices]

        buffer = []

        for shard_path in my_shards:
            # STREAMING: yield items as they are read, not after full shard load
            for utterance in self._iter_shard_streaming(shard_path):
                buffer.append(utterance)

                # Emit from buffer when full
                if len(buffer) >= self.config.shuffle_buffer_size:
                    if self.split == "train":
                        random.shuffle(buffer)

                    emit_items, buffer = self._emit_from_buffer(buffer)

                    if self.prebatch:
                        for batch in self._make_batch_from_items(emit_items):
                            yield batch
                    else:
                        for item in emit_items:
                            yield item

        # Flush remaining
        if self.split == "train":
            random.shuffle(buffer)
        if self.config.sort_buffer_for_batching:
            buffer.sort(key=lambda x: x['n_frames'])

        if self.prebatch:
            for batch in self._make_batch_from_items(buffer):
                yield batch
        else:
            for item in buffer:
                yield item


def dynamic_batch_collate(batch_items: List[dict]) -> dict:
    """Collate a dynamic batch: pad mels and tokens to batch max lengths."""
    max_frames = max(item['n_frames'] for item in batch_items)
    max_tokens = max(item['n_tokens'] for item in batch_items)
    batch_size = len(batch_items)

    # Pad mel spectrograms [B, 128, max_frames]
    mel_padded = torch.zeros(batch_size, 128, max_frames, dtype=torch.float16)
    mel_lengths = torch.zeros(batch_size, dtype=torch.long)

    # Pad token sequences [B, max_tokens]
    tokens_padded = torch.full((batch_size, max_tokens), -100, dtype=torch.long)  # -100 = ignore in loss
    token_lengths = torch.zeros(batch_size, dtype=torch.long)

    languages = []

    for i, item in enumerate(batch_items):
        n_f = item['n_frames']
        n_t = item['n_tokens']
        mel = item['mel'] if isinstance(item['mel'], torch.Tensor) else torch.from_numpy(item['mel'])
        mel_padded[i, :, :n_f] = mel
        mel_lengths[i] = n_f
        if n_t > 0:
            tok = item['tokens'] if isinstance(item['tokens'], torch.Tensor) else torch.from_numpy(item['tokens'])
            tokens_padded[i, :n_t] = tok.long()
        token_lengths[i] = n_t
        languages.append(item['language'])

    return {
        'mel': mel_padded,
        'mel_lengths': mel_lengths,
        'labels': tokens_padded,        # named 'labels' for HF model API
        'token_lengths': token_lengths,
        'languages': languages,
    }


class DynamicBatchSampler:
    """
    Collects items from a data iterator and groups them into batches
    that fit within a max total mel frame budget.

    Because the dataset sorts its shuffle buffer by duration before emitting,
    consecutive items have similar lengths → batches have minimal padding.
    """

    def __init__(self, max_batch_mel_frames: int = 1_200_000, max_batch_utterances: int = 0):
        self.max_batch_mel_frames = max_batch_mel_frames
        self.max_batch_utterances = max_batch_utterances

    def batch_iter(self, data_iter):
        """Takes an iterator of items, yields collated batches."""
        current_batch = []
        current_max_frames = 0

        for item in data_iter:
            item_frames = item['n_frames']
            new_max = max(current_max_frames, item_frames)
            new_size = len(current_batch) + 1
            # Projected total padded frames = batch_size * max_frame_in_batch
            projected = new_size * new_max
            over_utt_cap = self.max_batch_utterances > 0 and new_size > self.max_batch_utterances

            if current_batch and (projected > self.max_batch_mel_frames or over_utt_cap):
                yield dynamic_batch_collate(current_batch)
                current_batch = [item]
                current_max_frames = item_frames
            else:
                current_batch.append(item)
                current_max_frames = new_max

        if current_batch:
            yield dynamic_batch_collate(current_batch)


def create_dataloader(config: DataConfig, tokenizer=None, split: str = "train",
                      num_workers: int = 8, prebatch: bool = True,
                      max_batch_mel_frames: int = 0, max_batch_utterances: int = 0):
    """Create a DataLoader.

    When prebatch=True (default), batching happens INSIDE workers — each worker
    yields fully collated batch dicts. This eliminates the main-process
    DynamicBatchSampler bottleneck and lets workers do collation in parallel.

    When prebatch=False, yields individual items (legacy mode for DynamicBatchSampler).
    """
    dataset = MelShardDataset(
        config, tokenizer=tokenizer, split=split,
        prebatch=prebatch,
        max_batch_mel_frames=max_batch_mel_frames,
        max_batch_utterances=max_batch_utterances,
    )

    loader = DataLoader(
        dataset,
        batch_size=None,
        num_workers=num_workers,
        persistent_workers=True if num_workers > 0 else False,
        prefetch_factor=4 if num_workers > 0 else None,
        pin_memory=True,
    )

    return loader
