#!/usr/bin/env python3
"""Compute duration-bucket statistics and batch-size calibration for Phase 3 training data."""

import json
import math
from pathlib import Path

import pandas as pd

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
MANIFEST = Path(__file__).resolve().parents[1] / "artifacts" / "phase3" / "train_manifest.parquet"
OUTPUT_JSON = MANIFEST.parent / "bucket_batchsize_table.json"

GPU_MEM_GB = 80  # A100 / H100

# Bucket edges (seconds). Last bucket captures everything >= 16 up to 32.
BUCKET_EDGES = [
    (0, 4),
    (4, 8),
    (8, 16),
    (16, 32),
]

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def suggest_batch_size(max_dur_s: float, gpu_mem_gb: float = GPU_MEM_GB) -> int:
    """Rule-of-thumb micro-batch size for a given max duration and GPU memory.

    Formula: batch_size = floor(gpu_mem_gb * 0.7 / (max_dur_s * 16000 * 2 / 1e9))
    Clamped to [2, 64].
    """
    mem_per_sample_gb = max_dur_s * 16000 * 2 / 1e9
    bs = math.floor(gpu_mem_gb * 0.7 / mem_per_sample_gb)
    return max(2, min(64, bs))


def assign_bucket(dur: float) -> str | None:
    for lo, hi in BUCKET_EDGES:
        if lo <= dur < hi:
            return f"{lo}-{hi}s"
    return None


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main() -> None:
    print(f"Reading {MANIFEST} ...")
    df = pd.read_parquet(MANIFEST, columns=["duration_s"])
    total_rows = len(df)
    print(f"  Total rows: {total_rows:,}")

    # Assign buckets
    df["bucket"] = df["duration_s"].apply(assign_bucket)

    # Drop rows outside all buckets (e.g. >= 32s)
    outside = df["bucket"].isna().sum()
    if outside:
        print(f"  Rows outside 0-32s range (excluded): {outside:,}")
    df = df.dropna(subset=["bucket"])

    # Compute per-bucket stats
    results = []
    for lo, hi in BUCKET_EDGES:
        label = f"{lo}-{hi}s"
        mask = df["bucket"] == label
        subset = df.loc[mask, "duration_s"]
        count = int(len(subset))
        pct = count / total_rows * 100
        total_hours = subset.sum() / 3600.0
        max_dur = float(hi)  # use bucket ceiling for batch-size calc
        bs = suggest_batch_size(max_dur)
        results.append({
            "bucket": label,
            "count": count,
            "percentage": round(pct, 2),
            "total_hours": round(total_hours, 2),
            "max_duration_s": max_dur,
            "suggested_micro_batch_size": bs,
        })

    # -----------------------------------------------------------------------
    # Pretty-print summary table
    # -----------------------------------------------------------------------
    hdr = f"{'Bucket':>10s} | {'Count':>12s} | {'%':>7s} | {'Hours':>10s} | {'MaxDur':>7s} | {'uBatch':>6s}"
    sep = "-" * len(hdr)
    print()
    print(sep)
    print(hdr)
    print(sep)
    for r in results:
        print(
            f"{r['bucket']:>10s} | "
            f"{r['count']:>12,d} | "
            f"{r['percentage']:>6.2f}% | "
            f"{r['total_hours']:>10,.1f} | "
            f"{r['max_duration_s']:>6.0f}s | "
            f"{r['suggested_micro_batch_size']:>6d}"
        )
    print(sep)

    # Total row
    total_count = sum(r["count"] for r in results)
    total_pct = sum(r["percentage"] for r in results)
    total_hrs = sum(r["total_hours"] for r in results)
    print(
        f"{'TOTAL':>10s} | "
        f"{total_count:>12,d} | "
        f"{total_pct:>6.2f}% | "
        f"{total_hrs:>10,.1f} |        |       "
    )
    print(sep)
    print()

    # -----------------------------------------------------------------------
    # Write JSON
    # -----------------------------------------------------------------------
    OUTPUT_JSON.parent.mkdir(parents=True, exist_ok=True)
    with open(OUTPUT_JSON, "w") as f:
        json.dump(results, f, indent=2)
    print(f"Wrote {OUTPUT_JSON}")


if __name__ == "__main__":
    main()
