#!/usr/bin/env python3
"""
Build NeMo-compatible tarred dataset manifests from production parquet files.

Outputs:
  data/manifests/tarred/shard_NNNNN.jsonl  - per-shard JSONL (one per tar_path)
  configs/data/stage1_prod_input_cfg_v2.yaml - input_cfg YAML pairing manifests with tars
  data/manifests/stage1_prod_val_v2.jsonl   - flat JSONL for val (extracted audio)
  data/audio_cache/val/                     - extracted val audio files
Also updates configs/train/stage1_prod_8xh200.yaml.

Design:
  - Reads parquet with PyArrow (never pandas)
  - Groups by tar_path (each = one shard)
  - Writes shard JSONLs in parallel with ProcessPoolExecutor (16+ workers)
  - Val: extracts 300 audio files via os.pread + soundfile
"""

import json
import os
import sys
import time
import io
import concurrent.futures
from pathlib import Path

import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.compute as pc

# ---------------------------------------------------------------------------
# Paths (relative to repo root; script resolves them against its own location)
# ---------------------------------------------------------------------------
SCRIPT_DIR = Path(__file__).resolve().parent
REPO_ROOT = SCRIPT_DIR.parent

TRAIN_PARQUET = REPO_ROOT / "artifacts/phase3/production_train_final.parquet"
VAL_PARQUET   = REPO_ROOT / "artifacts/phase3/production_val_final.parquet"

SHARD_MANIFEST_DIR = REPO_ROOT / "data/manifests/tarred"
VAL_MANIFEST       = REPO_ROOT / "data/manifests/stage1_prod_val_v2.jsonl"
VAL_AUDIO_DIR      = REPO_ROOT / "data/audio_cache/val"
INPUT_CFG_YAML     = REPO_ROOT / "configs/data/stage1_prod_input_cfg_v2.yaml"
TRAIN_YAML         = REPO_ROOT / "configs/train/stage1_prod_8xh200.yaml"

NUM_WORKERS = 32  # well above requested 16; 128 CPUs available


# ---------------------------------------------------------------------------
# Worker function: write one shard JSONL
# Called in a subprocess via ProcessPoolExecutor
# ---------------------------------------------------------------------------
def write_shard(args):
    """
    args: (shard_idx, tar_path, rows_bytes)
      rows_bytes: bytes of a compact JSON list: [[member_name, transcript, duration, lang], ...]
    Returns (shard_idx, tar_path, n_rows, out_path_str)
    """
    shard_idx, tar_path, rows_bytes = args
    rows = json.loads(rows_bytes)

    out_path = SHARD_MANIFEST_DIR / f"shard_{shard_idx:05d}.jsonl"
    with open(out_path, "w", encoding="utf-8") as f:
        for member_name, transcript, duration, lang in rows:
            rec = {
                "audio_filepath": member_name,
                "text": transcript,
                "duration": duration,
                "lang": lang,
            }
            f.write(json.dumps(rec, ensure_ascii=False))
            f.write("\n")

    return (shard_idx, tar_path, len(rows), str(out_path))


# ---------------------------------------------------------------------------
# Val extraction: os.pread + soundfile  → write FLAC to disk
# ---------------------------------------------------------------------------
def extract_val_audio(row):
    """
    row: dict with tar_path, tar_member_name, tar_offset_data, tar_nbytes,
         duration_s, transcript, language
    Returns (out_path, duration_s, transcript, language) or raises.
    """
    import soundfile as sf

    tar_path      = row["tar_path"]
    member_name   = row["tar_member_name"]
    offset        = row["tar_offset_data"]
    nbytes        = row["tar_nbytes"]
    duration_s    = row["duration_s"]
    transcript    = row["transcript"]
    language      = row["language"]

    # Read raw audio bytes from the tar file at the stored offset
    fd = os.open(tar_path, os.O_RDONLY)
    try:
        raw = os.pread(fd, nbytes, offset)
    finally:
        os.close(fd)

    # Parse with soundfile (handles FLAC, WAV, OGG, etc.)
    buf = io.BytesIO(raw)
    audio_data, sample_rate = sf.read(buf)

    # Write to val audio cache dir; use member_name as filename (flatten any /path/to/ prefix)
    safe_name = member_name.replace("/", "_").replace("\\", "_")
    out_path = VAL_AUDIO_DIR / safe_name
    sf.write(str(out_path), audio_data, sample_rate)

    return (str(out_path), duration_s, transcript, language)


# ---------------------------------------------------------------------------
# Step 1: Build train shard manifests
# ---------------------------------------------------------------------------
def build_train_manifests():
    t0 = time.time()
    print(f"[train] Reading parquet: {TRAIN_PARQUET}")

    pf = pq.ParquetFile(str(TRAIN_PARQUET))
    total_rows = pf.metadata.num_rows
    total_rgs  = pf.metadata.num_row_groups
    print(f"[train] {total_rows:,} rows  |  {total_rgs} row-groups")

    # -----------------------------------------------------------------------
    # Pass 1: stream through all row groups, accumulate per-tar_path rows.
    # We store rows as compact lists to minimise memory: [member, text, dur, lang]
    # -----------------------------------------------------------------------
    print("[train] Pass 1/2 – grouping rows by tar_path …")
    shard_rows: dict[str, list] = {}  # tar_path -> list of [member, text, dur, lang]

    cols = ["tar_path", "tar_member_name", "duration_s", "transcript", "language"]
    rows_processed = 0
    rg_t0 = time.time()

    for rg_idx in range(total_rgs):
        batch = pf.read_row_group(rg_idx, columns=cols)
        n = len(batch)

        tar_paths    = batch.column("tar_path").to_pylist()
        member_names = batch.column("tar_member_name").to_pylist()
        durations    = batch.column("duration_s").to_pylist()
        transcripts  = batch.column("transcript").to_pylist()
        languages    = batch.column("language").to_pylist()

        for i in range(n):
            tp = tar_paths[i]
            if tp not in shard_rows:
                shard_rows[tp] = []
            shard_rows[tp].append([member_names[i], transcripts[i], durations[i], languages[i]])

        rows_processed += n
        if (rg_idx + 1) % 20 == 0 or rg_idx == total_rgs - 1:
            elapsed = time.time() - rg_t0
            rate = rows_processed / elapsed if elapsed > 0 else 0
            print(
                f"  rg {rg_idx+1:3d}/{total_rgs}  |  {rows_processed:>11,} rows  |  "
                f"{len(shard_rows):,} shards seen  |  {rate/1e6:.2f}M rows/s"
            )

    n_shards = len(shard_rows)
    print(f"[train] Grouped into {n_shards:,} shards in {time.time()-t0:.1f}s")

    # -----------------------------------------------------------------------
    # Pass 2: assign sequential shard index, write JSONLs in parallel
    # -----------------------------------------------------------------------
    print(f"[train] Pass 2/2 – writing {n_shards:,} JSONL shards with {NUM_WORKERS} workers …")
    SHARD_MANIFEST_DIR.mkdir(parents=True, exist_ok=True)

    # Build work items; serialise rows as JSON bytes to pass across processes
    work_items = []
    tar_path_to_shard_idx: dict[str, int] = {}
    for shard_idx, (tar_path, rows) in enumerate(shard_rows.items()):
        rows_bytes = json.dumps(rows, ensure_ascii=False).encode("utf-8")
        work_items.append((shard_idx, tar_path, rows_bytes))
        tar_path_to_shard_idx[tar_path] = shard_idx

    shard_meta: list[tuple[int, str, int, str]] = []  # (idx, tar_path, n_rows, out_path)
    completed = 0
    write_t0 = time.time()
    report_every = max(1, n_shards // 20)

    with concurrent.futures.ProcessPoolExecutor(max_workers=NUM_WORKERS) as executor:
        futures = {executor.submit(write_shard, item): item[0] for item in work_items}
        for fut in concurrent.futures.as_completed(futures):
            result = fut.result()  # raises on worker error
            shard_meta.append(result)
            completed += 1
            if completed % report_every == 0 or completed == n_shards:
                elapsed = time.time() - write_t0
                rate = completed / elapsed if elapsed > 0 else 0
                print(f"  written {completed:>5,}/{n_shards:,} shards  |  {rate:.0f} shards/s")

    shard_meta.sort(key=lambda x: x[0])  # sort by shard_idx for deterministic YAML
    total_written = sum(r[2] for r in shard_meta)
    print(
        f"[train] All shards written: {completed:,} JSONLs, {total_written:,} rows  "
        f"in {time.time()-t0:.1f}s total"
    )
    return shard_meta


# ---------------------------------------------------------------------------
# Step 2: Build input_cfg YAML
# ---------------------------------------------------------------------------
def build_input_cfg(shard_meta: list):
    print(f"[input_cfg] Writing {INPUT_CFG_YAML}")
    INPUT_CFG_YAML.parent.mkdir(parents=True, exist_ok=True)

    lines = ["input_cfg:\n"]
    for shard_idx, tar_path, n_rows, out_path in shard_meta:
        # Store manifest path relative to repo root for portability
        rel_manifest = os.path.relpath(out_path, str(REPO_ROOT))
        lines.append(f"- manifest_filepath: {rel_manifest}\n")
        lines.append(f"  tarred_audio_filepaths: {tar_path}\n")
        lines.append(f"  type: nemo_tarred\n")
        lines.append(f"  weight: 1\n")

    with open(INPUT_CFG_YAML, "w", encoding="utf-8") as f:
        f.writelines(lines)

    print(f"[input_cfg] Done – {len(shard_meta):,} entries written")


# ---------------------------------------------------------------------------
# Step 3: Build val manifest (extract audio from tars)
# ---------------------------------------------------------------------------
def build_val_manifest():
    t0 = time.time()
    print(f"[val] Reading {VAL_PARQUET}")
    pf = pq.ParquetFile(str(VAL_PARQUET))
    tbl = pf.read()
    n_rows = len(tbl)
    print(f"[val] {n_rows} rows to extract")

    VAL_AUDIO_DIR.mkdir(parents=True, exist_ok=True)

    # Build row dicts
    cols = ["tar_path", "tar_member_name", "tar_offset_data", "tar_nbytes",
            "duration_s", "transcript", "language"]
    row_dicts = []
    for i in range(n_rows):
        row_dicts.append({col: tbl.column(col)[i].as_py() for col in cols})

    print(f"[val] Extracting {n_rows} audio files with {NUM_WORKERS} workers …")
    results = []
    completed = 0

    with concurrent.futures.ProcessPoolExecutor(max_workers=NUM_WORKERS) as executor:
        futures = {executor.submit(extract_val_audio, row): i for i, row in enumerate(row_dicts)}
        for fut in concurrent.futures.as_completed(futures):
            try:
                result = fut.result()
                results.append(result)
            except Exception as e:
                print(f"  [val] WARNING: extraction failed for row {futures[fut]}: {e}")
                results.append(None)
            completed += 1
            if completed % 50 == 0 or completed == n_rows:
                print(f"  extracted {completed}/{n_rows}")

    # Filter successes and write JSONL
    VAL_MANIFEST.parent.mkdir(parents=True, exist_ok=True)
    n_ok = 0
    with open(VAL_MANIFEST, "w", encoding="utf-8") as f:
        for res in results:
            if res is None:
                continue
            out_path, duration_s, transcript, language = res
            rec = {
                "audio_filepath": out_path,
                "text": transcript,
                "duration": duration_s,
                "lang": language,
            }
            f.write(json.dumps(rec, ensure_ascii=False))
            f.write("\n")
            n_ok += 1

    n_failed = n_rows - n_ok
    print(
        f"[val] Done: {n_ok} extracted, {n_failed} failed  "
        f"in {time.time()-t0:.1f}s  →  {VAL_MANIFEST}"
    )


# ---------------------------------------------------------------------------
# Step 4: Update stage1_prod_8xh200.yaml
# ---------------------------------------------------------------------------
def update_train_yaml():
    print(f"[yaml] Updating {TRAIN_YAML}")
    with open(TRAIN_YAML, "r", encoding="utf-8") as f:
        content = f.read()

    # 4a. Change train_ds.manifest_filepath to null
    import re
    content = re.sub(
        r"(train_ds:\s*\n(?:[ \t]+.*\n)*?[ \t]+manifest_filepath:)\s*\S+",
        r"\1 null",
        content,
    )

    # 4b. Add train_ds.input_cfg after manifest_filepath: null
    #     Only add if not already present
    if "input_cfg:" not in content.split("train_ds:")[1].split("validation_ds:")[0]:
        rel_input_cfg = os.path.relpath(str(INPUT_CFG_YAML), str(REPO_ROOT))
        content = re.sub(
            r"(train_ds:.*?manifest_filepath: null\n)",
            rf"\1    input_cfg: {rel_input_cfg}\n",
            content,
            count=1,
            flags=re.DOTALL,
        )

    # 4c. Change validation_ds.manifest_filepath
    rel_val = os.path.relpath(str(VAL_MANIFEST), str(REPO_ROOT))
    content = re.sub(
        r"(validation_ds:\s*\n(?:[ \t]+.*\n)*?[ \t]+manifest_filepath:)\s*\S+",
        rf"\1 {rel_val}",
        content,
    )

    with open(TRAIN_YAML, "w", encoding="utf-8") as f:
        f.write(content)
    print(f"[yaml] Done")


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
    overall_t0 = time.time()
    print("=" * 70)
    print("build_nemo_tarred_manifests.py")
    print("=" * 70)

    # Step 1
    shard_meta = build_train_manifests()

    # Step 2
    build_input_cfg(shard_meta)

    # Step 3
    build_val_manifest()

    # Step 4
    update_train_yaml()

    elapsed = time.time() - overall_t0
    print()
    print("=" * 70)
    print(f"ALL DONE in {elapsed:.1f}s  ({elapsed/60:.1f} min)")
    print(f"  Train shards : {len(shard_meta):,}  →  {SHARD_MANIFEST_DIR}")
    print(f"  input_cfg    : {INPUT_CFG_YAML}")
    print(f"  Val manifest : {VAL_MANIFEST}")
    print(f"  Val audio    : {VAL_AUDIO_DIR}")
    print(f"  Updated YAML : {TRAIN_YAML}")
    print("=" * 70)


if __name__ == "__main__":
    main()
