"""Split train.parquet into per-bucket parquets, shuffled within each bucket."""
import os, sys
os.environ["OMP_NUM_THREADS"] = "1"

import pyarrow.parquet as pq
import pandas as pd
import numpy as np
from pathlib import Path
import json

BUCKETS = [
    ("b_0_3",   0.0,  3.0),
    ("b_3_5",   3.0,  5.0),
    ("b_5_7",   5.0,  7.0),
    ("b_7_10",  7.0, 10.0),
    ("b_10_15", 10.0, 15.0),
    ("b_15_20", 15.0, 20.0),
    ("b_20_30", 20.0, 30.0),
]

SEED = 42
ARTIFACTS = Path("/root/gemini-asr/lf_asr/artifacts/phase2")
BUCKET_DIR = ARTIFACTS / "buckets"

def main():
    BUCKET_DIR.mkdir(parents=True, exist_ok=True)

    print("Loading train.parquet...")
    df = pd.read_parquet(ARTIFACTS / "train.parquet")
    print(f"  {len(df):,} rows")

    # Add tar_index_path column
    df["tar_index_path"] = df["tar_path"] + ".index.json"

    # Filter out empty transcripts
    non_empty = df["transcript"].str.strip().ne("") & df["transcript"].notna()
    print(f"  Non-empty transcripts: {non_empty.sum():,} ({non_empty.sum()/len(df)*100:.1f}%)")
    df = df[non_empty].copy()

    stats = []
    rng = np.random.RandomState(SEED)

    for bucket_id, lo, hi in BUCKETS:
        mask = (df["duration_s"] >= lo) & (df["duration_s"] < hi)
        bdf = df[mask].copy()
        bdf = bdf.sample(frac=1.0, random_state=rng).reset_index(drop=True)
        bdf["bucket_id"] = bucket_id
        bdf["bucket_max_duration"] = hi

        out_path = BUCKET_DIR / f"{bucket_id}.parquet"
        bdf.to_parquet(out_path, index=False)

        hours = bdf["duration_s"].sum() / 3600
        stats.append({
            "bucket_id": bucket_id,
            "range": f"{lo}-{hi}s",
            "samples": len(bdf),
            "hours": float(hours),
            "pct_samples": float(len(bdf) / len(df) * 100),
            "mean_duration": float(bdf["duration_s"].mean()),
            "max_duration": float(bdf["duration_s"].max()),
        })
        print(f"  {bucket_id}: {len(bdf):>10,} samples, {hours:>8,.0f}h, "
              f"mean={bdf['duration_s'].mean():.1f}s, max={bdf['duration_s'].max():.1f}s")

    total_hours = sum(s["hours"] for s in stats)
    for s in stats:
        s["pct_hours"] = float(s["hours"] / total_hours * 100)

    # Also count samples >30s that were excluded
    over_30 = (df["duration_s"] >= 30.0).sum()
    under_0 = (df["duration_s"] < 0.0).sum()
    total_bucketed = sum(s["samples"] for s in stats)
    print(f"\n  Samples >30s (excluded): {over_30:,}")
    print(f"  Samples <0s (excluded): {under_0:,}")
    print(f"  Total bucketed: {total_bucketed:,} / {len(df):,}")

    config = {
        "buckets": stats,
        "total_samples": total_bucketed,
        "total_hours": float(total_hours),
        "seed": SEED,
    }
    with open(BUCKET_DIR / "bucket_config.json", "w") as f:
        json.dump(config, f, indent=2)

    print(f"\nTotal: {total_bucketed:,} samples, {total_hours:,.0f}h across {len(BUCKETS)} buckets")
    print(f"Bucket parquets written to: {BUCKET_DIR}")

if __name__ == "__main__":
    main()
