#!/usr/bin/env python3
"""Phase 2: Build global training manifest from all converted shards.

Joins converted audio index + metadata into a unified parquet manifest.
Splits into train/dev/test with video_id-safe grouping.

Usage:
  python3 tools/phase2_build_global_manifest.py
  python3 tools/phase2_build_global_manifest.py --val-ratio 0.02 --test-ratio 0.02
"""

import argparse
import hashlib
import json
import sys
import time
from pathlib import Path

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq

ARTIFACTS_DIR = Path("/workspace/maya-asr/artifacts/phase2")
INVENTORY_PATH = ARTIFACTS_DIR / "shard_inventory.parquet"


def hash_group(video_id: str, seed: int = 42) -> int:
    """Deterministic hash for split assignment. Returns 0-99."""
    h = hashlib.md5(f"{seed}:{video_id}".encode()).hexdigest()
    return int(h[:8], 16) % 100


def main():
    parser = argparse.ArgumentParser(description="Build global manifest")
    parser.add_argument("--val-ratio", type=float, default=0.02)
    parser.add_argument("--test-ratio", type=float, default=0.02)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    if not INVENTORY_PATH.exists():
        print("ERROR: Run phase2_scan_inventory.py first", file=sys.stderr)
        sys.exit(1)

    inv_df = pq.read_table(INVENTORY_PATH).to_pandas()
    print(f"Inventory: {len(inv_df)} shards")

    t0 = time.time()
    all_rows = []
    skipped_shards = 0
    empty_transcript = 0

    for _, shard_row in inv_df.iterrows():
        shard_dir = Path(shard_row["shard_dir"])
        prefix = shard_row["prefix"]
        language = shard_row["language"]
        shard_id = shard_row["shard_id"]

        # Read audio index (prefer 16k index if converted, else original)
        idx_16k = shard_dir / "audio_16k_index.parquet"
        idx_orig = Path(shard_row["audio_index_path"]) if shard_row["audio_index_path"] else None

        if idx_16k.exists():
            idx_df = pq.read_table(idx_16k).to_pandas()
            audio_tar = str(shard_dir / "audio.tar")
        elif idx_orig and idx_orig.exists():
            idx_df = pq.read_table(idx_orig).to_pandas()
            audio_tar = str(shard_dir / "audio.tar")
        else:
            skipped_shards += 1
            continue

        # Read metadata
        meta_path = Path(shard_row["metadata_path"]) if shard_row["metadata_path"] else None
        if meta_path and meta_path.exists():
            meta_cols = [
                "segment_id", "video_id", "transcription_native",
                "duration_s", "tx_quality_score", "final_bucket",
                "segment_language",
            ]
            try:
                meta_df = pq.read_table(
                    meta_path,
                    columns=[c for c in meta_cols if c in pq.read_schema(meta_path).names],
                ).to_pandas()
            except Exception:
                meta_df = pd.DataFrame()
        else:
            meta_df = pd.DataFrame()

        # Join index with metadata on segment_id / member_name
        if not meta_df.empty and "segment_id" in meta_df.columns:
            # Determine join key from index
            if "member_name" in idx_df.columns:
                join_key_idx = "member_name"
            elif "segment_id" in idx_df.columns:
                join_key_idx = "segment_id"
            elif "tar_member_name" in idx_df.columns:
                join_key_idx = "tar_member_name"
            else:
                join_key_idx = idx_df.columns[0]

            merged = idx_df.merge(
                meta_df, left_on=join_key_idx, right_on="segment_id", how="left"
            )
        else:
            merged = idx_df.copy()

        for _, row in merged.iterrows():
            member = row.get("member_name", row.get("tar_member_name", row.get("segment_id", "")))
            transcript = str(row.get("transcription_native", ""))
            video_id = str(row.get("video_id", ""))
            segment_id = str(row.get("segment_id", member))
            duration = float(row.get("duration_s", row.get("duration_s_16k", row.get("audio_duration_s", 0))))
            quality = str(row.get("final_bucket", ""))
            qscore = float(row.get("tx_quality_score", 0))

            if not transcript.strip():
                empty_transcript += 1

            all_rows.append(
                {
                    "sample_id": f"{prefix}/{language}/{shard_id}/{member}",
                    "prefix": prefix,
                    "language": language,
                    "shard_id": shard_id,
                    "transcript": transcript,
                    "tar_path": audio_tar,
                    "tar_member_name": member,
                    "duration_s": round(duration, 4),
                    "quality_bucket": quality,
                    "quality_score": round(qscore, 3),
                    "video_id": video_id,
                    "segment_id": segment_id,
                }
            )

    elapsed = time.time() - t0
    print(f"Collected {len(all_rows):,} rows in {elapsed:.1f}s (skipped {skipped_shards} shards)")
    print(f"Empty transcripts: {empty_transcript:,} (kept, not dropped)")

    if not all_rows:
        print("ERROR: No rows collected", file=sys.stderr)
        sys.exit(1)

    # Build dataframe
    df = pd.DataFrame(all_rows)

    # Assign splits: deterministic by video_id hash
    val_pct = int(args.val_ratio * 100)
    test_pct = int(args.test_ratio * 100)
    train_pct = 100 - val_pct - test_pct

    def assign_split(vid):
        h = hash_group(vid, args.seed)
        if h < train_pct:
            return "train"
        elif h < train_pct + val_pct:
            return "dev"
        else:
            return "test"

    df["split"] = df["video_id"].apply(assign_split)

    # Write global manifest
    global_path = ARTIFACTS_DIR / "global_manifest.parquet"
    pq.write_table(pa.Table.from_pandas(df), global_path)
    print(f"\nGlobal manifest: {global_path} ({len(df):,} rows)")

    # Write split files
    for split_name in ["train", "dev", "test"]:
        split_df = df[df["split"] == split_name]
        split_path = ARTIFACTS_DIR / f"{split_name}.parquet"
        pq.write_table(pa.Table.from_pandas(split_df), split_path)
        hours = split_df["duration_s"].sum() / 3600
        langs = split_df["language"].nunique()
        print(f"  {split_name}: {len(split_df):,} rows, {hours:.1f}h, {langs} languages")

    # Verify no video_id leakage
    train_vids = set(df[df["split"] == "train"]["video_id"])
    dev_vids = set(df[df["split"] == "dev"]["video_id"])
    test_vids = set(df[df["split"] == "test"]["video_id"])
    leaks = (train_vids & dev_vids) | (train_vids & test_vids) | (dev_vids & test_vids)
    if leaks:
        print(f"\nWARNING: {len(leaks)} video_ids leak across splits!")
    else:
        print("\nSplit integrity: PASS (no video_id leakage)")

    # Summary stats
    summary = {
        "total_rows": len(df),
        "total_hours": round(df["duration_s"].sum() / 3600, 1),
        "empty_transcript_count": empty_transcript,
        "skipped_shards": skipped_shards,
        "split_counts": {s: int((df["split"] == s).sum()) for s in ["train", "dev", "test"]},
        "per_language": {
            lang: {
                "total": int((df["language"] == lang).sum()),
                "hours": round(df[df["language"] == lang]["duration_s"].sum() / 3600, 1),
            }
            for lang in sorted(df["language"].unique())
        },
        "per_prefix": {
            p: int((df["prefix"] == p).sum()) for p in sorted(df["prefix"].unique())
        },
        "video_id_leakage": len(leaks),
    }
    with open(ARTIFACTS_DIR / "manifest_summary.json", "w") as f:
        json.dump(summary, f, indent=2)

    print(f"\n{'Language':<8} {'Total':>10} {'Hours':>8}")
    print("-" * 30)
    for lang, info in sorted(summary["per_language"].items()):
        print(f"{lang:<8} {info['total']:>10,} {info['hours']:>8.1f}")


if __name__ == "__main__":
    main()
