#!/usr/bin/env python3
"""Split a NeMo JSONL manifest into deterministic train/val sets.

Stratifies by language so both splits preserve the language distribution.

Usage:
  python3 scripts/split_manifest.py \
    --input data/manifests/smoke.jsonl \
    --train-output data/manifests/smoke_train.jsonl \
    --val-output data/manifests/smoke_val.jsonl \
    --val-ratio 0.01 --seed 42
"""

import argparse
import json
import random
import sys
from collections import defaultdict
from datetime import datetime, timezone
from pathlib import Path

# Add src to path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src"))
from maya_asr.config import file_sha256


def main():
    parser = argparse.ArgumentParser(description="Split manifest into train/val")
    parser.add_argument("--input", type=Path, required=True)
    parser.add_argument("--train-output", type=Path, required=True)
    parser.add_argument("--val-output", type=Path, required=True)
    parser.add_argument("--val-ratio", type=float, default=0.01)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument(
        "--stratify-by-lang",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="Language-stratified splitting (default: on)",
    )
    # Backward-compatible alias
    parser.add_argument(
        "--no-stratify",
        action="store_true",
        default=False,
        help=argparse.SUPPRESS,  # hidden, use --no-stratify-by-lang instead
    )
    args = parser.parse_args()

    # --no-stratify is a legacy alias for --no-stratify-by-lang
    if args.no_stratify:
        args.stratify_by_lang = False

    # Validate val-ratio
    if args.val_ratio <= 0 or args.val_ratio >= 1:
        print(
            f"ERROR: --val-ratio must be >0 and <1, got {args.val_ratio}",
            file=sys.stderr,
        )
        sys.exit(1)

    if not args.input.exists():
        print(f"ERROR: Input not found: {args.input}", file=sys.stderr)
        sys.exit(1)

    # Read and group by language
    by_lang: dict[str, list[str]] = defaultdict(list)
    total = 0
    with open(args.input) as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            row = json.loads(line)
            lang = row.get("lang", "unknown")
            by_lang[lang].append(line)
            total += 1

    if total == 0:
        print("ERROR: Empty manifest", file=sys.stderr)
        sys.exit(1)

    # Split each language group
    rng = random.Random(args.seed)
    train_lines: list[str] = []
    val_lines: list[str] = []

    if not args.stratify_by_lang:
        # Simple random split
        all_lines = []
        for lines in by_lang.values():
            all_lines.extend(lines)
        rng.shuffle(all_lines)
        n_val = max(1, int(len(all_lines) * args.val_ratio))
        val_lines = all_lines[:n_val]
        train_lines = all_lines[n_val:]
    else:
        # Stratified split
        for lang in sorted(by_lang.keys()):
            lines = by_lang[lang]
            rng.shuffle(lines)
            n_val = max(1, int(len(lines) * args.val_ratio))
            val_lines.extend(lines[:n_val])
            train_lines.extend(lines[n_val:])

    # Write outputs
    args.train_output.parent.mkdir(parents=True, exist_ok=True)
    args.val_output.parent.mkdir(parents=True, exist_ok=True)

    with open(args.train_output, "w") as f:
        for line in train_lines:
            f.write(line + "\n")

    with open(args.val_output, "w") as f:
        for line in val_lines:
            f.write(line + "\n")

    # Compute stats
    train_stats: dict[str, dict] = defaultdict(lambda: {"n": 0, "h": 0.0})
    val_stats: dict[str, dict] = defaultdict(lambda: {"n": 0, "h": 0.0})

    for line in train_lines:
        row = json.loads(line)
        lang = row.get("lang", "unknown")
        train_stats[lang]["n"] += 1
        train_stats[lang]["h"] += row.get("duration", 0) / 3600.0

    for line in val_lines:
        row = json.loads(line)
        lang = row.get("lang", "unknown")
        val_stats[lang]["n"] += 1
        val_stats[lang]["h"] += row.get("duration", 0) / 3600.0

    # Print summary
    all_langs = sorted(set(list(train_stats.keys()) + list(val_stats.keys())))
    print(f"Input:  {args.input} ({total:,} rows)")
    print(f"Train:  {args.train_output} ({len(train_lines):,} rows)")
    print(f"Val:    {args.val_output} ({len(val_lines):,} rows)")
    print(f"Seed:   {args.seed}")
    print(f"Ratio:  {args.val_ratio}")
    print()
    hdr = f"  {'Lang':<6} {'Train':>8} {'Tr hrs':>8} {'Val':>8} {'Val hrs':>8}"
    print(hdr)
    print(f"  {'-' * 6} {'-' * 8} {'-' * 8} {'-' * 8} {'-' * 8}")
    for lang in all_langs:
        ts = train_stats[lang]
        vs = val_stats[lang]
        print(f"  {lang:<6} {ts['n']:>8,} {ts['h']:>8.1f} {vs['n']:>8,} {vs['h']:>8.1f}")
    print(f"  {'-' * 6} {'-' * 8} {'-' * 8} {'-' * 8} {'-' * 8}")
    t_n = sum(s["n"] for s in train_stats.values())
    t_h = sum(s["h"] for s in train_stats.values())
    v_n = sum(s["n"] for s in val_stats.values())
    v_h = sum(s["h"] for s in val_stats.values())
    print(f"  {'TOTAL':<6} {t_n:>8,} {t_h:>8.1f} {v_n:>8,} {v_h:>8.1f}")

    # Write split metadata
    metadata = {
        "source_manifest": str(args.input.resolve()),
        "source_manifest_sha256": file_sha256(args.input),
        "seed": args.seed,
        "val_ratio": args.val_ratio,
        "stratify_by_lang": args.stratify_by_lang,
        "train_rows": len(train_lines),
        "val_rows": len(val_lines),
        "created_at": datetime.now(timezone.utc).isoformat(),
    }
    metadata_file = args.train_output.parent / "split_metadata.json"
    with open(metadata_file, "w") as f:
        json.dump(metadata, f, indent=2)
    print(f"\nMetadata: {metadata_file}")


if __name__ == "__main__":
    main()
