#!/usr/bin/env python3
"""Validate a NeMo JSONL manifest for common issues.

Checks:
- Required fields present
- No empty/missing text
- Duration within sane bounds
- Language is in expected set
- Audio tar files exist on disk
- No duplicate segment_ids

Usage:
  python scripts/validate_manifest.py data/manifests/smoke.jsonl
"""

import argparse
import json
import sys
from collections import Counter
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src"))
from maya_asr.config import LANGUAGES

REQUIRED_FIELDS = [
    "audio_filepath",
    "text",
    "duration",
    "lang",
    "taskname",
    "source_lang",
    "target_lang",
]
MIN_DURATION = 0.1
MAX_DURATION = 60.0


def validate(manifest_path: Path) -> bool:
    errors = []
    warnings = []
    lang_counts = Counter()
    lang_durations = Counter()
    tar_paths_checked = {}
    total = 0
    seen_keys = set()

    with open(manifest_path) as f:
        for lineno, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue

            try:
                row = json.loads(line)
            except json.JSONDecodeError as e:
                errors.append(f"Line {lineno}: Invalid JSON: {e}")
                continue

            total += 1

            # Check required fields
            for field in REQUIRED_FIELDS:
                if field not in row:
                    errors.append(f"Line {lineno}: Missing field '{field}'")

            # Check text
            text = row.get("text", "")
            if not isinstance(text, str) or not text.strip():
                errors.append(f"Line {lineno}: Empty or missing text")

            # Check duration
            duration = row.get("duration")
            if duration is None:
                errors.append(f"Line {lineno}: Missing duration")
            elif not isinstance(duration, (int, float)):
                errors.append(f"Line {lineno}: Duration not numeric: {duration}")
            elif duration < MIN_DURATION:
                warnings.append(f"Line {lineno}: Very short duration: {duration:.3f}s")
            elif duration > MAX_DURATION:
                warnings.append(f"Line {lineno}: Very long duration: {duration:.1f}s")

            # Check language
            lang = row.get("lang", "")
            if lang not in LANGUAGES:
                warnings.append(f"Line {lineno}: Unexpected language: {lang}")
            lang_counts[lang] += 1
            if isinstance(duration, (int, float)):
                lang_durations[lang] += duration / 3600.0

            # Check audio filepath exists (cache result per tar path)
            audio_path = row.get("audio_filepath", "")
            if audio_path not in tar_paths_checked:
                tar_paths_checked[audio_path] = Path(audio_path).exists()
            if not tar_paths_checked[audio_path]:
                errors.append(f"Line {lineno}: Audio file not found: {audio_path}")

            # Check taskname
            if row.get("taskname") != "asr":
                warnings.append(f"Line {lineno}: Unexpected taskname: {row.get('taskname')}")

            # Check for duplicates
            key = (audio_path, row.get("tar_member", ""))
            if key in seen_keys:
                warnings.append(f"Line {lineno}: Duplicate segment")
            seen_keys.add(key)

    # Report
    print(f"Manifest: {manifest_path}")
    print(f"Total rows: {total:,}")
    print(f"Unique audio tars: {len(tar_paths_checked)}")
    print()

    print("Per-language breakdown:")
    print(f"  {'Lang':<6} {'Segments':>10} {'Hours':>10}")
    print(f"  {'-' * 6} {'-' * 10} {'-' * 10}")
    for lang in sorted(lang_counts.keys()):
        print(f"  {lang:<6} {lang_counts[lang]:>10,} {lang_durations[lang]:>10.1f}")
    print()

    if errors:
        print(f"ERRORS ({len(errors)}):")
        for e in errors[:20]:
            print(f"  {e}")
        if len(errors) > 20:
            print(f"  ... and {len(errors) - 20} more")
        print()

    if warnings:
        print(f"WARNINGS ({len(warnings)}):")
        for w in warnings[:20]:
            print(f"  {w}")
        if len(warnings) > 20:
            print(f"  ... and {len(warnings) - 20} more")
        print()

    ok = len(errors) == 0
    print(f"Result: {'PASS' if ok else 'FAIL'}")
    return ok


def main():
    parser = argparse.ArgumentParser(description="Validate NeMo JSONL manifest")
    parser.add_argument("manifest", type=Path, help="Path to JSONL manifest file")
    args = parser.parse_args()

    if not args.manifest.exists():
        print(f"ERROR: File not found: {args.manifest}", file=sys.stderr)
        sys.exit(1)

    ok = validate(args.manifest)
    sys.exit(0 if ok else 1)


if __name__ == "__main__":
    main()
