"""Build byte-offset index files for all audio_16k.tar files."""
import os, json, tarfile, sys
from pathlib import Path
from multiprocessing import Pool
import time

os.environ["OMP_NUM_THREADS"] = "1"

def build_index_for_tar(tar_path: str) -> dict:
    """Open tar, extract member offsets, write index JSON alongside tar."""
    index_path = tar_path + ".index.json"
    if os.path.exists(index_path):
        return {"tar": tar_path, "status": "skipped", "members": 0}
    try:
        with tarfile.open(tar_path, "r:") as tf:
            members = tf.getmembers()
            index = {}
            for m in members:
                if m.isfile():
                    index[m.name] = {"offset": m.offset_data, "size": m.size}
                    bare = m.name.lstrip("./")
                    if bare != m.name:
                        index[bare] = {"offset": m.offset_data, "size": m.size}
        with open(index_path, "w") as f:
            json.dump(index, f)
        return {"tar": tar_path, "status": "ok", "members": len(members)}
    except Exception as e:
        return {"tar": tar_path, "status": f"error: {e}", "members": 0}

def main():
    import pyarrow.parquet as pq
    print("Loading manifest to find all tar paths...", flush=True)
    df = pq.read_table("/root/gemini-asr/lf_asr/artifacts/phase2/train.parquet",
                        columns=["tar_path"]).to_pandas()
    tar_paths = sorted(df["tar_path"].unique())
    print(f"Found {len(tar_paths)} unique tar files", flush=True)

    for split in ["dev", "test"]:
        sp = pq.read_table(f"/root/gemini-asr/lf_asr/artifacts/phase2/{split}.parquet",
                           columns=["tar_path"]).to_pandas()
        tar_paths = sorted(set(tar_paths) | set(sp["tar_path"].unique()))
    print(f"Total unique tars (train+dev+test): {len(tar_paths)}", flush=True)

    existing = sum(1 for p in tar_paths if os.path.exists(p + ".index.json"))
    print(f"Already indexed: {existing}, need to build: {len(tar_paths) - existing}", flush=True)

    workers = 64
    done = 0
    errors = 0
    t0 = time.time()

    with Pool(processes=workers) as pool:
        for result in pool.imap_unordered(build_index_for_tar, tar_paths, chunksize=4):
            done += 1
            if "error" in result["status"]:
                errors += 1
                print(f"  ERROR: {result['tar']}: {result['status']}", flush=True)
            if done % 100 == 0:
                elapsed = time.time() - t0
                rate = done / elapsed
                eta = (len(tar_paths) - done) / rate
                print(f"  {done}/{len(tar_paths)} ({rate:.1f}/s, ETA {eta:.0f}s) errors={errors}", flush=True)

    elapsed = time.time() - t0
    print(f"\nDone: {done} tars indexed in {elapsed:.1f}s, {errors} errors", flush=True)

    missing = [p for p in tar_paths if not os.path.exists(p + ".index.json")]
    if missing:
        print(f"WARNING: {len(missing)} tars missing indices!")
        for p in missing[:10]:
            print(f"  {p}")
    else:
        print("All tars have index files.")

if __name__ == "__main__":
    main()
