#!/usr/bin/env python3
"""Repack tar archives into webdataset format with sequential naming.

Reads audio bytes via pread (no tar extraction) and writes new tars with
sequential member names: 000000.flac, 000001.flac, etc.

Also writes aligned JSONL manifests for NeMo's TarredAudioToBPEDataset.

Usage:
  python3 tools/repack_tars_webdataset.py --workers 16
"""

import io
import json
import os
import sys
import tarfile
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path

import yaml


def repack_shard(args):
    """Repack one shard's tar into webdataset format."""
    shard_idx, manifest_path, tar_path, out_tar_path, out_manifest_path = args

    # Read manifest to get member list and metadata
    entries = []
    with open(manifest_path) as f:
        for line in f:
            if line.strip():
                entries.append(json.loads(line))

    if not entries:
        return shard_idx, 0, "empty"

    # Open source tar for reading
    try:
        # Build member name -> offset mapping by scanning tar headers
        member_map = {}
        with tarfile.open(tar_path, 'r') as tf:
            for member in tf:
                if member.isfile():
                    member_map[member.name] = (member.offset_data, member.size)
    except Exception as e:
        return shard_idx, 0, f"tar_scan_error: {e}"

    # Write new tar with sequential naming + aligned manifest
    manifest_lines = []
    written = 0

    try:
        fd = os.open(tar_path, os.O_RDONLY)
        with tarfile.open(out_tar_path, 'w') as out_tf:
            for i, entry in enumerate(entries):
                member_name = entry['audio_filepath']
                if member_name not in member_map:
                    continue

                offset, size = member_map[member_name]
                raw = os.pread(fd, size, offset)
                if len(raw) != size:
                    continue

                # Write with sequential name
                seq_name = f"{i:06d}.flac"
                info = tarfile.TarInfo(name=seq_name)
                info.size = size
                out_tf.addfile(info, io.BytesIO(raw))

                # Write aligned manifest entry
                manifest_lines.append(json.dumps({
                    "audio_filepath": seq_name,
                    "text": entry["text"],
                    "duration": entry["duration"],
                    "lang": entry.get("lang", "en"),
                }, ensure_ascii=False))
                written += 1

        os.close(fd)

        # Write manifest
        with open(out_manifest_path, 'w') as f:
            f.write('\n'.join(manifest_lines) + '\n')

        return shard_idx, written, "ok"
    except Exception as e:
        try:
            os.close(fd)
        except:
            pass
        return shard_idx, written, f"error: {e}"


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--workers", type=int, default=16)
    parser.add_argument("--max-shards", type=int, default=0, help="0=all")
    args = parser.parse_args()

    # Read shard mapping from input_cfg
    with open('configs/data/stage1_prod_input_cfg_v2.yaml') as f:
        cfg = yaml.safe_load(f)

    entries = cfg['input_cfg']
    if args.max_shards > 0:
        entries = entries[:args.max_shards]

    out_tar_dir = Path('data/wds_tars')
    out_manifest_dir = Path('data/manifests/wds')
    out_tar_dir.mkdir(parents=True, exist_ok=True)
    out_manifest_dir.mkdir(parents=True, exist_ok=True)

    # Build task list
    tasks = []
    for i, entry in enumerate(entries):
        manifest_path = entry['manifest_filepath']
        tar_path = entry['tarred_audio_filepaths']
        out_tar = str(out_tar_dir / f'shard_{i:05d}.tar')
        out_manifest = str(out_manifest_dir / f'shard_{i:05d}.jsonl')
        tasks.append((i, manifest_path, tar_path, out_tar, out_manifest))

    print(f"Repacking {len(tasks)} shards with {args.workers} workers...")
    print(f"  Output tars: {out_tar_dir}")
    print(f"  Output manifests: {out_manifest_dir}")

    t0 = time.time()
    total_written = 0
    errors = []

    with ProcessPoolExecutor(max_workers=args.workers) as ex:
        futures = {ex.submit(repack_shard, t): t[0] for t in tasks}
        done = 0
        for future in as_completed(futures):
            shard_idx, written, status = future.result()
            total_written += written
            done += 1
            if status != "ok":
                errors.append((shard_idx, status))
            if done % 100 == 0 or done == len(tasks):
                elapsed = time.time() - t0
                rate = done / elapsed
                eta = (len(tasks) - done) / max(rate, 0.01)
                print(f"  {done}/{len(tasks)} shards ({total_written:,} samples), "
                      f"{elapsed:.0f}s elapsed, ~{eta:.0f}s remaining")

    elapsed = time.time() - t0
    print(f"\nDone in {elapsed:.0f}s")
    print(f"  Shards: {len(tasks)}")
    print(f"  Samples: {total_written:,}")
    print(f"  Errors: {len(errors)}")
    if errors:
        for idx, err in errors[:10]:
            print(f"    shard_{idx:05d}: {err}")

    # Write NeMo config for the repacked tars
    n = len(tasks)
    nemo_cfg = {
        'is_tarred': True,
        'tarred_audio_filepaths': f'data/wds_tars/shard_{{0..{n-1}}}.tar',
        'manifest_filepath': f'data/manifests/wds/shard_{{0..{n-1}}}.jsonl',
        'note': f'Generated by repack_tars_webdataset.py, {total_written:,} samples across {n} shards',
    }
    with open('data/wds_tars/config_snippet.yaml', 'w') as f:
        yaml.dump(nemo_cfg, f, default_flow_style=False)
    print(f"\n  Config snippet: data/wds_tars/config_snippet.yaml")


if __name__ == '__main__':
    main()
