#!/usr/bin/env python3
"""Create a 1000-sample validation holdout from medium-to-long duration buckets.

Samples 334 + 333 + 333 rows from b_5_7, b_7_10, b_10_15 respectively,
removes them from the training parquets, and writes a combined val_holdout.parquet.
Updates bucket_config.json with the new counts.
"""

import json
import pandas as pd
from pathlib import Path

BUCKET_DIR = Path("/root/gemini-asr/lf_asr/artifacts/phase2/buckets")
SEED = 42

# Buckets to sample from and how many to take from each
SAMPLE_SPEC = [
    ("b_5_7", 334),
    ("b_7_10", 333),
    ("b_10_15", 333),
]


def main():
    config_path = BUCKET_DIR / "bucket_config.json"
    with open(config_path) as f:
        config = json.load(f)

    # Build a lookup from bucket_id -> config entry
    bucket_cfg = {b["bucket_id"]: b for b in config["buckets"]}

    holdout_parts = []

    for bucket_id, n_sample in SAMPLE_SPEC:
        pq_path = BUCKET_DIR / f"{bucket_id}.parquet"
        print(f"\nLoading {pq_path.name} ...")
        df = pd.read_parquet(pq_path)
        orig_len = len(df)

        # Sample
        sampled = df.sample(n=n_sample, random_state=SEED)
        holdout_parts.append(sampled)

        # Remove sampled rows from training set
        remaining = df.drop(sampled.index)
        assert len(remaining) == orig_len - n_sample

        # Rewrite the bucket parquet
        print(f"  {bucket_id}: {orig_len} -> {len(remaining)} (held out {n_sample})")
        remaining.to_parquet(pq_path, index=False)

        # Update config
        entry = bucket_cfg[bucket_id]
        entry["samples"] = len(remaining)
        held_hours = sampled["duration_s"].sum() / 3600.0
        entry["hours"] = round(entry["hours"] - held_hours, 6)

    # Combine holdout
    val_holdout = pd.concat(holdout_parts, ignore_index=True)
    val_path = BUCKET_DIR / "val_holdout.parquet"
    val_holdout.to_parquet(val_path, index=False)
    print(f"\nWrote {val_path}  ({len(val_holdout)} samples)")

    # Update totals in config
    total_held = sum(n for _, n in SAMPLE_SPEC)
    config["total_samples"] -= total_held
    held_hours_total = val_holdout["duration_s"].sum() / 3600.0
    config["total_hours"] = round(config["total_hours"] - held_hours_total, 6)

    # Write updated config
    with open(config_path, "w") as f:
        json.dump(config, f, indent=2)
    print(f"Updated {config_path.name}")

    # Print stats
    print("\n" + "=" * 60)
    print("Validation holdout stats")
    print("=" * 60)
    print(f"Total samples:  {len(val_holdout)}")
    print(f"Total duration: {val_holdout['duration_s'].sum():.1f}s  "
          f"({val_holdout['duration_s'].sum()/3600:.2f}h)")
    print(f"Duration range: {val_holdout['duration_s'].min():.2f}s - "
          f"{val_holdout['duration_s'].max():.2f}s")
    print(f"Mean duration:  {val_holdout['duration_s'].mean():.2f}s")

    print("\nPer-bucket breakdown:")
    for bucket_id, n_sample in SAMPLE_SPEC:
        subset = val_holdout[val_holdout["bucket_id"] == bucket_id]
        langs = subset["language"].nunique()
        print(f"  {bucket_id}: {len(subset)} samples, "
              f"{subset['duration_s'].sum():.1f}s, "
              f"{langs} languages")

    print("\nLanguage distribution in holdout:")
    lang_counts = val_holdout["language"].value_counts()
    for lang, count in lang_counts.head(15).items():
        print(f"  {lang}: {count}")
    if len(lang_counts) > 15:
        print(f"  ... and {len(lang_counts) - 15} more languages")


if __name__ == "__main__":
    main()
