#!/usr/bin/env python3
"""Phase 3: Build tar byte-offset index for every shard.

Scans tar headers to extract (member_name → data_offset, data_size) without
reading audio data. This is fast — pure header parsing.

Usage:
  python3 tools/phase3_build_tar_offsets.py --workers 32
  python3 tools/phase3_build_tar_offsets.py --workers 32 --max-shards 10
"""

import argparse
import json
import os
import sys
import tarfile
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path

import pyarrow as pa
import pyarrow.parquet as pq

os.environ["OMP_NUM_THREADS"] = "1"

ARTIFACTS_DIR = Path("/workspace/maya-asr/artifacts/phase2")
INVENTORY_PATH = ARTIFACTS_DIR / "shard_inventory.parquet"
OUTPUT_DIR = Path("/workspace/maya-asr/artifacts/phase3")


def build_offset_index(shard_dir_str: str) -> dict:
    """Scan tar headers and build offset index for one shard."""
    shard_dir = Path(shard_dir_str)
    tar_path = shard_dir / "audio.tar"
    output_path = shard_dir / "tar_offset_index.parquet"

    # Skip if already built
    if output_path.exists():
        return {"status": "skipped", "shard": shard_dir_str}

    if not tar_path.exists():
        return {"status": "error", "shard": shard_dir_str, "error": "no audio.tar"}

    rows = []
    try:
        with tarfile.open(tar_path, "r") as tf:
            for member in tf:
                if not member.isfile():
                    continue
                # member.offset_data is the byte offset of the file DATA in the tar
                rows.append({
                    "member_name": member.name,
                    "tar_offset_data": member.offset_data,
                    "tar_nbytes": member.size,
                })

        if not rows:
            return {"status": "error", "shard": shard_dir_str, "error": "no files in tar"}

        table = pa.table({
            "member_name": [r["member_name"] for r in rows],
            "tar_offset_data": [r["tar_offset_data"] for r in rows],
            "tar_nbytes": [r["tar_nbytes"] for r in rows],
        })
        pq.write_table(table, output_path)

        return {
            "status": "success",
            "shard": shard_dir_str,
            "count": len(rows),
        }
    except Exception as e:
        return {"status": "error", "shard": shard_dir_str, "error": str(e)}


def main():
    parser = argparse.ArgumentParser(description="Build tar offset indices")
    parser.add_argument("--workers", type=int, default=32)
    parser.add_argument("--max-shards", type=int, default=0)
    parser.add_argument("--force", action="store_true")
    args = parser.parse_args()

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    inv_df = pq.read_table(INVENTORY_PATH).to_pandas()
    shard_dirs = list(inv_df["shard_dir"])

    if args.max_shards > 0:
        shard_dirs = shard_dirs[:args.max_shards]

    if args.force:
        # Remove existing offset indices
        for sd in shard_dirs:
            p = Path(sd) / "tar_offset_index.parquet"
            if p.exists():
                p.unlink()

    print(f"Building tar offset indices for {len(shard_dirs)} shards ({args.workers} workers)")
    t0 = time.time()
    done = 0
    skipped = 0
    errors = 0
    total_files = 0

    with ProcessPoolExecutor(max_workers=args.workers) as executor:
        futures = {executor.submit(build_offset_index, sd): sd for sd in shard_dirs}
        last_report = time.time()

        for future in as_completed(futures):
            result = future.result()
            if result["status"] == "success":
                done += 1
                total_files += result.get("count", 0)
            elif result["status"] == "skipped":
                skipped += 1
            else:
                errors += 1
                print(f"  ERROR: {result['shard']}: {result.get('error')}", file=sys.stderr)

            now = time.time()
            if now - last_report >= 30:
                elapsed = now - t0
                print(f"  [{elapsed:.0f}s] done={done} skipped={skipped} errors={errors} files={total_files:,}")
                last_report = now

    elapsed = time.time() - t0
    print(f"\nComplete in {elapsed:.0f}s: {done} built, {skipped} skipped, {errors} errors")
    print(f"Total file entries: {total_files:,}")


if __name__ == "__main__":
    main()
