#!/usr/bin/env python3
"""
Phase 3: Build final train/val split from all phase3 manifests.

Pools train + dev + test manifests, picks 200-300 MEDIUM/LONG validation
samples (duration_s > 5s, ~25 per language, diverse durations), and puts
everything else into training.

Outputs:
  artifacts/phase3/final_train.parquet
  artifacts/phase3/final_val.parquet
"""

import time
import numpy as np
import pyarrow.parquet as pq
import pyarrow as pa

ARTIFACTS = "artifacts/phase3"
SOURCES = ["train_manifest.parquet", "dev_manifest.parquet", "test_manifest.parquet"]
VAL_PER_LANG = 25          # target per language
MIN_DURATION = 5.0         # minimum duration for val candidates
MAX_DURATION = 20.0        # prefer up to this for diversity
SEED = 42

def main():
    t0 = time.time()

    # ── 1. Read and pool all three manifests ──────────────────────────
    tables = []
    for src in SOURCES:
        path = f"{ARTIFACTS}/{src}"
        print(f"Reading {path} ...")
        t = pq.read_table(path)
        print(f"  -> {t.num_rows:,} rows")
        tables.append(t)

    pool = pa.concat_tables(tables, promote_options="default")
    print(f"\nPooled total: {pool.num_rows:,} rows")

    # Convert to pandas for selection logic (only cols we need)
    print("Converting to pandas (selected columns) ...")
    df = pool.select(["sample_id", "transcript", "language", "duration_s"]).to_pandas()
    del pool  # free arrow memory

    # Normalize language codes: strip "lang=" prefix so we have 12 canonical languages
    raw_langs = sorted(df["language"].unique())
    print(f"Raw language values ({len(raw_langs)}): {raw_langs}")
    df["language_canonical"] = df["language"].str.replace(r"^lang=", "", regex=True)
    languages = sorted(df["language_canonical"].unique())
    print(f"Canonical languages ({len(languages)}): {languages}")

    # ── 2. Select validation samples ──────────────────────────────────
    rng = np.random.RandomState(SEED)
    val_indices = []

    for lang in languages:
        # Filter: duration > 5s, non-empty transcript
        mask = (
            (df["language_canonical"] == lang)
            & (df["duration_s"] > MIN_DURATION)
            & (df["transcript"].str.len() > 0)
        )
        candidates = df.loc[mask]
        print(f"  {lang}: {len(candidates):,} candidates (dur > {MIN_DURATION}s, non-empty transcript)")

        if len(candidates) == 0:
            print(f"    WARNING: no candidates for {lang}")
            continue

        # Stratify by duration buckets for diversity: 5-8, 8-12, 12-16, 16-20, 20+
        edges = [5, 8, 12, 16, 20, float("inf")]
        bucket_indices = []
        for lo, hi in zip(edges[:-1], edges[1:]):
            bmask = (candidates["duration_s"] >= lo) & (candidates["duration_s"] < hi)
            bucket_indices.append(candidates.index[bmask].tolist())

        # Distribute ~VAL_PER_LANG across non-empty buckets
        non_empty = [(i, b) for i, b in enumerate(bucket_indices) if len(b) > 0]
        per_bucket = max(1, VAL_PER_LANG // len(non_empty))
        selected = []
        for _, b in non_empty:
            n = min(per_bucket, len(b))
            selected.extend(rng.choice(b, size=n, replace=False).tolist())

        # Top up if under target
        remaining_target = VAL_PER_LANG - len(selected)
        if remaining_target > 0:
            all_cand = candidates.index.tolist()
            leftovers = [i for i in all_cand if i not in set(selected)]
            if leftovers:
                extra = min(remaining_target, len(leftovers))
                selected.extend(rng.choice(leftovers, size=extra, replace=False).tolist())

        # Cap at VAL_PER_LANG
        selected = selected[:VAL_PER_LANG]
        val_indices.extend(selected)

    val_set = set(val_indices)
    print(f"\nTotal val samples selected: {len(val_set)}")

    # ── 3. Re-read full tables and split ──────────────────────────────
    # Re-read pooled data (full schema) in arrow for efficient write
    print("\nRe-reading full tables for final write ...")
    tables = []
    for src in SOURCES:
        tables.append(pq.read_table(f"{ARTIFACTS}/{src}"))
    pool = pa.concat_tables(tables, promote_options="default")
    del tables

    # Normalize language column in the output (strip "lang=" prefix)
    lang_series = pool.column("language").to_pandas().str.replace(r"^lang=", "", regex=True)
    lang_arr = pa.array(lang_series, type=pa.large_string())
    col_idx = pool.schema.get_field_index("language")
    pool = pool.set_column(col_idx, "language", lang_arr)

    # Build boolean mask
    n = pool.num_rows
    val_mask = np.zeros(n, dtype=bool)
    val_mask[list(val_set)] = True

    val_table = pool.filter(pa.array(val_mask))
    train_table = pool.filter(pa.array(~val_mask))
    del pool

    # ── 4. Write outputs ──────────────────────────────────────────────
    out_train = f"{ARTIFACTS}/final_train.parquet"
    out_val = f"{ARTIFACTS}/final_val.parquet"

    print(f"\nWriting {out_val} ({val_table.num_rows:,} rows) ...")
    pq.write_table(val_table, out_val)

    print(f"Writing {out_train} ({train_table.num_rows:,} rows) ...")
    pq.write_table(train_table, out_train, row_group_size=1_000_000)

    # ── 5. Summary ────────────────────────────────────────────────────
    val_df = val_table.select(["language", "duration_s"]).to_pandas()
    print("\n" + "=" * 60)
    print("VALIDATION SET -- per-language breakdown")
    print("=" * 60)
    for lang in languages:
        sub = val_df[val_df["language"] == lang]
        if len(sub) > 0:
            print(f"  {lang}: {len(sub):>4} samples | "
                  f"dur range [{sub['duration_s'].min():.1f}s – {sub['duration_s'].max():.1f}s] | "
                  f"mean {sub['duration_s'].mean():.1f}s")
        else:
            print(f"  {lang}:    0 samples")
    print(f"  {'TOTAL':>2}: {len(val_df):>4} samples")

    print(f"\nTRAINING SET: {train_table.num_rows:,} rows")
    print(f"\nDone in {time.time() - t0:.1f}s")


if __name__ == "__main__":
    main()
