"""Builds pretraining data from merged vectorized datasets and/or text corpora.

Produces the flat memmap files expected by TtsPretrainingDataset and
TextPretrainingDataset:
  - {split}_pretraining_codes.npy   (flat int32 stream of codec token ids)
  - {split}_pretraining_tokens.npy  (flat int32 stream of text token ids)
  - manifest.json                   (reproducibility manifest)

Audio PT:
  Reads merged SFT-style codes (train_codes.npy + train_codes_index.npy) and
  concatenates them into a single flat stream, optionally shuffling at the
  sample level first.

Text PT:
  Reads a directory of .jsonl files (each line has a "text" field), tokenizes
  with the target LLM tokenizer, and concatenates into a flat stream.

Usage examples:

  # Build audio pretraining codes from an existing vectorized dataset
  python tools/data/build_pretraining_data.py \
      --mode=audio \
      --input_dir=/path/to/merged_vectorized_dataset \
      --output_dir=/path/to/pt_output \
      --split=train

  # Build text pretraining tokens from a jsonl corpus
  python tools/data/build_pretraining_data.py \
      --mode=text \
      --input_dir=/path/to/text_jsonl_dir \
      --output_dir=/path/to/pt_output \
      --tokenizer_name=meta-llama/Llama-3.2-1B-Instruct \
      --max_seq_len=2048 \
      --split=train
"""

import hashlib
import json
import os
import time
from typing import Sequence

import numpy as np
from absl import app, flags, logging

FLAGS = flags.FLAGS

_MODE = flags.DEFINE_enum(
    "mode", None, ["audio", "text"],
    "Build mode: 'audio' for codec pretraining codes, 'text' for text tokens.")
_INPUT_DIR = flags.DEFINE_string(
    "input_dir", None, "Input directory with merged codes or text jsonl files.")
_OUTPUT_DIR = flags.DEFINE_string(
    "output_dir", None, "Output directory for pretraining data.")
_SPLIT = flags.DEFINE_string("split", "train", "Split name (train or val).")
_SHUFFLE = flags.DEFINE_bool(
    "shuffle", True, "Shuffle samples before concatenation (audio mode).")
_SEED = flags.DEFINE_integer("seed", 42, "Random seed for shuffling.")
_TOKENIZER_NAME = flags.DEFINE_string(
    "tokenizer_name", "meta-llama/Llama-3.2-1B-Instruct",
    "HuggingFace tokenizer name (text mode).")
_MAX_SEQ_LEN = flags.DEFINE_integer(
    "max_seq_len", 2048, "Maximum sequence length for text tokenization.")
_MAX_SAMPLES = flags.DEFINE_integer(
    "max_samples", -1, "Cap on number of samples to process (-1 = all).")


def _md5_file(path: str) -> str:
    """Returns hex MD5 digest for a file (reads in chunks for large files)."""
    h = hashlib.md5()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1 << 20), b""):
            h.update(chunk)
    return h.hexdigest()


def _build_audio_pretraining(
    input_dir: str,
    output_dir: str,
    split: str,
    shuffle: bool,
    seed: int,
    max_samples: int,
) -> dict:
    """Converts merged SFT codes into a flat pretraining stream."""
    codes_path = os.path.join(input_dir, f"{split}_codes.npy")
    index_path = os.path.join(input_dir, f"{split}_codes_index.npy")
    samples_path = os.path.join(input_dir, f"{split}_samples.jsonl")

    for p in (codes_path, index_path):
        if not os.path.exists(p):
            raise FileNotFoundError(f"Required file not found: {p}")

    codes = np.memmap(codes_path, dtype=np.int32, mode="r")
    index = np.load(index_path)
    num_samples = len(index)

    sample_count = 0
    lang_counts: dict[str, int] = {}
    if os.path.exists(samples_path):
        with open(samples_path, encoding="utf-8") as f:
            for line in f:
                sample_count += 1
                try:
                    lang = json.loads(line).get("language", "unknown")
                    lang_counts[lang] = lang_counts.get(lang, 0) + 1
                except json.JSONDecodeError:
                    pass

    order = np.arange(num_samples)
    if shuffle:
        rng = np.random.default_rng(seed)
        rng.shuffle(order)

    if max_samples > 0:
        order = order[:max_samples]

    total_codes = codes.shape[0]
    chunks = []
    for i in order:
        left = index[i]
        right = index[i + 1] if i < num_samples - 1 else total_codes
        chunks.append(codes[left:right].copy())

    flat = np.concatenate(chunks)
    out_path = os.path.join(output_dir, f"{split}_pretraining_codes.npy")
    out = np.memmap(out_path, dtype=np.int32, mode="w+", shape=flat.shape)
    out[:] = flat
    out.flush()

    logging.info(
        "Wrote %d pretraining codes (%d samples) to %s",
        flat.shape[0], len(order), out_path,
    )

    return {
        "mode": "audio",
        "source_codes_path": codes_path,
        "source_codes_md5": _md5_file(codes_path),
        "num_source_samples": num_samples,
        "num_selected_samples": len(order),
        "total_codes": int(flat.shape[0]),
        "total_hours": float(flat.shape[0]) / (50 * 3600),
        "shuffle": shuffle,
        "seed": seed,
        "language_counts": lang_counts,
        "output_path": out_path,
    }


def _build_text_pretraining(
    input_dir: str,
    output_dir: str,
    split: str,
    tokenizer_name: str,
    max_seq_len: int,
    max_samples: int,
) -> dict:
    """Tokenizes text corpus into a flat pretraining token stream."""
    import transformers  # deferred to avoid import cost when unused

    tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    jsonl_files = sorted(
        f for f in os.listdir(input_dir) if f.endswith(".jsonl")
    )
    if not jsonl_files:
        raise FileNotFoundError(f"No .jsonl files found in {input_dir}")

    all_tokens: list[int] = []
    docs_processed = 0
    docs_skipped = 0

    for jf in jsonl_files:
        jf_path = os.path.join(input_dir, jf)
        with open(jf_path, encoding="utf-8") as f:
            for line in f:
                if max_samples > 0 and docs_processed >= max_samples:
                    break
                try:
                    text = json.loads(line).get("text", "").strip()
                except json.JSONDecodeError:
                    docs_skipped += 1
                    continue

                if len(text) < 128:
                    docs_skipped += 1
                    continue

                tokens = tokenizer.encode(text, add_special_tokens=False)
                if len(tokens) > max_seq_len:
                    tokens = tokens[:max_seq_len]
                all_tokens.extend(tokens)
                docs_processed += 1

                if docs_processed % 100000 == 0:
                    logging.info(
                        "Tokenized %d documents (%d tokens so far).",
                        docs_processed, len(all_tokens),
                    )

    flat = np.array(all_tokens, dtype=np.int32)
    out_path = os.path.join(output_dir, f"{split}_pretraining_tokens.npy")
    out = np.memmap(out_path, dtype=np.int32, mode="w+", shape=flat.shape)
    out[:] = flat
    out.flush()

    logging.info(
        "Wrote %d pretraining tokens (%d docs) to %s",
        flat.shape[0], docs_processed, out_path,
    )

    return {
        "mode": "text",
        "source_dir": input_dir,
        "source_files": jsonl_files,
        "tokenizer_name": tokenizer_name,
        "max_seq_len": max_seq_len,
        "docs_processed": docs_processed,
        "docs_skipped": docs_skipped,
        "total_tokens": int(flat.shape[0]),
        "output_path": out_path,
    }


def main(argv: Sequence[str]) -> None:
    del argv

    mode = _MODE.value
    input_dir = _INPUT_DIR.value
    output_dir = _OUTPUT_DIR.value
    split = _SPLIT.value

    os.makedirs(output_dir, exist_ok=True)
    start = time.time()

    if mode == "audio":
        manifest = _build_audio_pretraining(
            input_dir=input_dir,
            output_dir=output_dir,
            split=split,
            shuffle=_SHUFFLE.value,
            seed=_SEED.value,
            max_samples=_MAX_SAMPLES.value,
        )
    else:
        manifest = _build_text_pretraining(
            input_dir=input_dir,
            output_dir=output_dir,
            split=split,
            tokenizer_name=_TOKENIZER_NAME.value,
            max_seq_len=_MAX_SEQ_LEN.value,
            max_samples=_MAX_SAMPLES.value,
        )

    manifest["split"] = split
    manifest["build_time_sec"] = round(time.time() - start, 2)
    manifest["timestamp"] = time.strftime("%Y-%m-%dT%H:%M:%S%z")

    manifest_path = os.path.join(output_dir, "manifest.json")
    existing = []
    if os.path.exists(manifest_path):
        with open(manifest_path, encoding="utf-8") as f:
            try:
                existing = json.load(f)
                if not isinstance(existing, list):
                    existing = [existing]
            except json.JSONDecodeError:
                existing = []

    existing.append(manifest)
    with open(manifest_path, "w", encoding="utf-8") as f:
        json.dump(existing, f, indent=2, ensure_ascii=False)

    logging.info("Manifest appended to %s", manifest_path)
    logging.info("Build completed in %.1f seconds.", manifest["build_time_sec"])


if __name__ == "__main__":
    flags.mark_flags_as_required(["mode", "input_dir", "output_dir"])
    app.run(main)
