#!/usr/bin/env python3
"""Generate NeMo input_cfg YAML from one or more JSONL manifests.

Groups manifest entries by language and produces a YAML with input_cfg
entries suitable for NeMo's multi-corpus training.

Usage:
  python3 scripts/generate_input_cfg.py \
    --manifests data/manifests/smoke_train.jsonl \
    --output configs/data/stage1_smoke_input_cfg.yaml
"""

import argparse
import json
import sys
from collections import defaultdict
from datetime import datetime, timezone
from pathlib import Path

import yaml

# Add src to path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src"))
from maya_asr.config import file_sha256


def main():
    parser = argparse.ArgumentParser(description="Generate NeMo input_cfg YAML from manifests")
    parser.add_argument(
        "--manifests",
        nargs="+",
        type=Path,
        required=True,
        help="One or more JSONL manifest files",
    )
    parser.add_argument("--output", type=Path, required=True)
    parser.add_argument(
        "--uniform-weights",
        action="store_true",
        default=False,
        help="Use equal weights for all languages (default: proportional to hours)",
    )
    args = parser.parse_args()

    # Gather per-language stats
    lang_data: dict[str, dict] = defaultdict(
        lambda: {"tar_paths": set(), "hours": 0.0, "segments": 0, "manifests": set()}
    )

    for manifest_path in args.manifests:
        if not manifest_path.exists():
            print(f"ERROR: Manifest not found: {manifest_path}", file=sys.stderr)
            sys.exit(1)

        with open(manifest_path) as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                row = json.loads(line)
                lang = row.get("lang", "unknown")
                lang_data[lang]["tar_paths"].add(row["audio_filepath"])
                lang_data[lang]["hours"] += row.get("duration", 0) / 3600.0
                lang_data[lang]["segments"] += 1
                lang_data[lang]["manifests"].add(str(manifest_path.resolve()))

    if not lang_data:
        print("ERROR: No data found in manifests", file=sys.stderr)
        sys.exit(1)

    # Build input_cfg entries
    entries = []
    for lang in sorted(lang_data.keys()):
        info = lang_data[lang]
        tar_paths = sorted(info["tar_paths"])

        if args.uniform_weights:
            weight = 1
        else:
            # Weight proportional to hours (rounded to int)
            weight = max(1, round(info["hours"]))

        entry = {
            "corpus": f"maya_{lang}",
            "language": lang,
            "manifest_filepath": sorted(info["manifests"]),
            "tarred_audio_filepaths": tar_paths,
            "type": "nemo_tarred",
            "weight": weight,
        }
        # Simplify single-element lists
        if len(entry["manifest_filepath"]) == 1:
            entry["manifest_filepath"] = entry["manifest_filepath"][0]
        if len(entry["tarred_audio_filepaths"]) == 1:
            entry["tarred_audio_filepaths"] = entry["tarred_audio_filepaths"][0]

        entries.append(entry)

    cfg = {"input_cfg": entries}

    # Write YAML
    args.output.parent.mkdir(parents=True, exist_ok=True)
    with open(args.output, "w") as f:
        yaml.dump(cfg, f, default_flow_style=False, sort_keys=False)

    # Write metadata
    manifest_hashes = {str(m.resolve()): file_sha256(m) for m in args.manifests}
    metadata = {
        "output": str(args.output.resolve()),
        "source_manifests": sorted(manifest_hashes.keys()),
        "source_manifest_sha256": manifest_hashes,
        "weight_mode": "uniform" if args.uniform_weights else "proportional",
        "language_count": len(entries),
        "created_at": datetime.now(timezone.utc).isoformat(),
    }
    metadata_file = args.output.parent / "input_cfg_metadata.json"
    with open(metadata_file, "w") as f:
        json.dump(metadata, f, indent=2)

    # Summary
    print(f"Generated: {args.output}")
    print(f"Metadata:  {metadata_file}")
    print(f"Languages: {len(entries)}")
    print()
    print(f"  {'Lang':<6} {'Segments':>10} {'Hours':>8} {'Tars':>6} {'Weight':>8}")
    print(f"  {'-' * 6} {'-' * 10} {'-' * 8} {'-' * 6} {'-' * 8}")
    for entry in entries:
        lang = entry["language"]
        info = lang_data[lang]
        print(
            f"  {lang:<6} {info['segments']:>10,} {info['hours']:>8.1f}"
            f" {len(info['tar_paths']):>6} {entry['weight']:>8}"
        )


if __name__ == "__main__":
    main()
