#!/usr/bin/env python3
"""Build NeMo JSONL manifests from local sft_data shard metadata.

Reads metadata.parquet + audio_index.parquet from each shard and produces
NeMo-compatible JSONL manifest lines:
  {"audio_filepath", "text", "duration", "lang", "taskname", "source_lang", "target_lang"}

Audio filepaths point into the shard's audio.tar using pipe: notation
(tar_member_name within the tar), suitable for Lhotse/NeMo tarred datasets.

Usage:
  # Smoke (default): 1 shard per language, en+hi only
  python scripts/build_manifest.py --output data/manifests/smoke.jsonl

  # Custom smoke
  python scripts/build_manifest.py --languages en hi ta --max-shards 2 \
    --output data/manifests/custom.jsonl

  # Full build (all languages, all shards)
  python scripts/build_manifest.py --max-shards 0 --output data/manifests/full.jsonl
"""

import argparse
import json
import sys
from pathlib import Path

import pyarrow.parquet as pq
from tqdm import tqdm

# Add src to path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src"))
from maya_asr.config import FINAL_EXPORT_ROOT, LANGUAGES

# All data source roots that follow the same shard format
DATA_SOURCES = {
    "final-export": FINAL_EXPORT_ROOT,
}

# Additional sources with lang=XX directory structure at top level
LANG_DIR_SOURCES = {
    "indicvoices": Path("/root/sft_data/indicvoices"),
    "indicvoices-r": Path("/root/sft_data/indicvoices-r"),
    "josh": Path("/root/sft_data/josh"),
    "joshdelivery": Path("/root/sft_data/joshdelivery"),
}

# Sources with lang=XX subdirectory structure
HIVE_SOURCES = {
    "ears": Path("/root/sft_data/ears"),
    "expresso": Path("/root/sft_data/expresso"),
    "globe": Path("/root/sft_data/globe"),
    "librittsr": Path("/root/sft_data/librittsr"),
    "ljspeech": Path("/root/sft_data/ljspeech"),
    "vctk": Path("/root/sft_data/vctk"),
}


def discover_shards(source_root: Path, languages: list[str]) -> list[tuple[str, Path]]:
    """Discover shard directories for given languages under a source root.

    Handles two directory layouts:
    1. Hive-partitioned: source_root/lang={code}/{shard_id}/
    2. Plain language dirs: source_root/{code}/{shard_id}/
    """
    shards = []
    for lang in languages:
        # Try hive-partitioned first
        lang_dir = source_root / f"lang={lang}"
        if not lang_dir.exists():
            # Try plain language directory
            lang_dir = source_root / lang
        if not lang_dir.exists():
            continue
        for shard_dir in sorted(lang_dir.iterdir()):
            if shard_dir.is_dir() and (shard_dir / "metadata.parquet").exists():
                shards.append((lang, shard_dir))
    return shards


def process_shard(lang: str, shard_dir: Path, min_quality: float = 0.0) -> tuple[list[dict], int]:
    """Convert one shard's metadata into NeMo manifest rows.

    Returns (rows, dropped_low_quality_count).
    """
    meta_path = shard_dir / "metadata.parquet"
    audio_index_path = shard_dir / "audio_index.parquet"
    audio_tar_path = shard_dir / "audio.tar"

    # Read metadata
    meta_cols = [
        "segment_id",
        "transcription_native",
        "duration_s",
        "lang",
        "final_bucket",
        "tx_quality_score",
    ]
    try:
        meta_table = pq.read_table(meta_path, columns=meta_cols)
    except Exception as e:
        print(f"  WARN: Cannot read {meta_path}: {e}", file=sys.stderr)
        return [], 0

    meta_df = meta_table.to_pandas()

    # Read audio index for tar member mapping
    audio_lookup = {}
    if audio_index_path.exists():
        try:
            ai_table = pq.read_table(audio_index_path, columns=["segment_id", "tar_member_name"])
            ai_df = ai_table.to_pandas()
            audio_lookup = dict(zip(ai_df["segment_id"], ai_df["tar_member_name"]))
        except Exception as e:
            print(f"  WARN: Cannot read {audio_index_path}: {e}", file=sys.stderr)

    rows = []
    dropped_quality = 0
    for _, row in meta_df.iterrows():
        text = row.get("transcription_native", "")
        if not isinstance(text, str) or not text.strip():
            continue

        duration = row.get("duration_s")
        if duration is None or duration <= 0 or duration > 60:
            continue

        # Apply quality filter
        quality = row.get("tx_quality_score")
        if min_quality > 0 and (quality is None or float(quality) < min_quality):
            dropped_quality += 1
            continue

        segment_id = row["segment_id"]
        tar_member = audio_lookup.get(segment_id, segment_id)
        row_lang = str(row.get("lang", lang))

        rows.append(
            {
                "audio_filepath": str(audio_tar_path),
                "tar_member": tar_member,
                "text": text.strip(),
                "duration": round(float(duration), 3),
                "lang": row_lang,
                "taskname": "asr",
                "source_lang": row_lang,
                "target_lang": row_lang,
            }
        )

    return rows, dropped_quality


def main():
    parser = argparse.ArgumentParser(description="Build NeMo JSONL manifests from sft_data shards")
    parser.add_argument(
        "--output",
        type=Path,
        default=Path("data/manifests/smoke.jsonl"),
        help="Output JSONL manifest path",
    )
    parser.add_argument(
        "--languages",
        nargs="+",
        default=["en", "hi"],
        choices=LANGUAGES,
        help="Languages to process (default: en hi)",
    )
    parser.add_argument(
        "--max-shards",
        type=int,
        default=1,
        help="Max shards per language per source (0=all, default=1 for smoke)",
    )
    parser.add_argument(
        "--sources",
        nargs="+",
        default=["final-export"],
        help="Data sources to include (default: final-export)",
    )
    parser.add_argument(
        "--min-quality",
        type=float,
        default=0.0,
        help="Minimum tx_quality_score threshold (default: 0.0, no filter)",
    )
    parser.add_argument(
        "--no-dedupe",
        action="store_true",
        default=False,
        help="Disable deduplication of (audio_filepath, tar_member) pairs",
    )
    args = parser.parse_args()

    # Ensure output directory exists
    args.output.parent.mkdir(parents=True, exist_ok=True)

    # Collect all source roots
    source_roots = {}
    for src in args.sources:
        if src in DATA_SOURCES:
            source_roots[src] = DATA_SOURCES[src]
        elif src in LANG_DIR_SOURCES:
            source_roots[src] = LANG_DIR_SOURCES[src]
        elif src in HIVE_SOURCES:
            source_roots[src] = HIVE_SOURCES[src]
        else:
            print(f"WARN: Unknown source '{src}', skipping", file=sys.stderr)

    # Discover and process shards
    stats = {lang: {"segments": 0, "duration_h": 0.0} for lang in args.languages}
    total_rows = 0
    total_dropped_quality = 0
    total_dropped_dupes = 0
    seen_keys: set[tuple[str, str]] = set()
    dedupe = not args.no_dedupe

    with open(args.output, "w") as fout:
        for source_name, source_root in source_roots.items():
            print(f"\n=== Source: {source_name} ({source_root}) ===")
            shards = discover_shards(source_root, args.languages)

            if args.max_shards > 0:
                # Limit per language
                limited = []
                lang_counts = {}
                for lang, shard_dir in shards:
                    lang_counts.setdefault(lang, 0)
                    if lang_counts[lang] < args.max_shards:
                        limited.append((lang, shard_dir))
                        lang_counts[lang] += 1
                shards = limited

            print(f"  Processing {len(shards)} shards...")

            for lang, shard_dir in tqdm(shards, desc=f"  {source_name}", unit="shard"):
                rows, dropped_q = process_shard(lang, shard_dir, min_quality=args.min_quality)
                total_dropped_quality += dropped_q
                for row in rows:
                    # Deduplicate
                    if dedupe:
                        key = (row["audio_filepath"], row["tar_member"])
                        if key in seen_keys:
                            total_dropped_dupes += 1
                            continue
                        seen_keys.add(key)

                    fout.write(json.dumps(row, ensure_ascii=False) + "\n")
                    stats[lang]["segments"] += 1
                    stats[lang]["duration_h"] += row["duration"] / 3600.0
                    total_rows += 1

    # Print summary
    print(f"\n{'=' * 60}")
    print(f"Manifest: {args.output}")
    print(f"Total segments: {total_rows:,}")
    if total_dropped_quality > 0:
        print(f"Dropped (low quality): {total_dropped_quality:,}")
    if total_dropped_dupes > 0:
        print(f"Dropped (duplicates):  {total_dropped_dupes:,}")
    print("\nPer-language breakdown:")
    print(f"  {'Lang':<6} {'Segments':>10} {'Hours':>10}")
    print(f"  {'-' * 6} {'-' * 10} {'-' * 10}")
    total_hours = 0.0
    for lang in sorted(stats.keys()):
        s = stats[lang]
        if s["segments"] > 0:
            print(f"  {lang:<6} {s['segments']:>10,} {s['duration_h']:>10.1f}")
            total_hours += s["duration_h"]
    print(f"  {'-' * 6} {'-' * 10} {'-' * 10}")
    print(f"  {'TOTAL':<6} {total_rows:>10,} {total_hours:>10.1f}")


if __name__ == "__main__":
    main()
