#!/usr/bin/env python3
"""Phase 2: Scan all shard directories and build inventory."""

import json
import sys
import time
from pathlib import Path

import pyarrow as pa
import pyarrow.parquet as pq

DATA_ROOT = Path("/root/sft_data")
OUTPUT_DIR = Path("/workspace/maya-asr/artifacts/phase2")

# Source configs: (prefix, base_path, shard_layout)
# shard_layout: "hive" = lang={code}/{shard_id}/  |  "plain" = {code}/{shard_id}/
SOURCES = [
    ("final-export", DATA_ROOT / "final-export/production/shards", "hive"),
    ("indicvoices", DATA_ROOT / "indicvoices", "plain"),
    ("indicvoices-r", DATA_ROOT / "indicvoices-r", "plain"),
    ("josh", DATA_ROOT / "josh", "hive"),
    ("joshdelivery", DATA_ROOT / "joshdelivery", "hive"),
    ("ears", DATA_ROOT / "ears", "hive"),
    ("expresso", DATA_ROOT / "expresso", "hive"),
    ("globe", DATA_ROOT / "globe", "hive"),
    ("librittsr", DATA_ROOT / "librittsr", "hive"),
    ("ljspeech", DATA_ROOT / "ljspeech", "hive"),
    ("vctk", DATA_ROOT / "vctk", "hive"),
    ("hifitts2", DATA_ROOT / "hifitts2", "hive"),
]


def scan_source(prefix: str, base_path: Path, layout: str) -> list[dict]:
    """Scan one data source for shard directories."""
    rows = []
    if not base_path.exists():
        return rows

    for lang_dir in sorted(base_path.iterdir()):
        if not lang_dir.is_dir():
            continue
        # Extract language code
        name = lang_dir.name
        if layout == "hive" and name.startswith("lang="):
            lang = name.split("=", 1)[1]
        elif layout == "plain":
            lang = name
        else:
            continue

        for shard_dir in sorted(lang_dir.iterdir()):
            if not shard_dir.is_dir():
                continue

            audio_tar = shard_dir / "audio.tar"
            audio_index = shard_dir / "audio_index.parquet"
            metadata = shard_dir / "metadata.parquet"

            if not audio_tar.exists():
                continue

            # Get segment count from audio_index if available
            segment_count = 0
            if audio_index.exists():
                try:
                    table = pq.read_metadata(str(audio_index))
                    segment_count = table.num_rows
                except Exception:
                    pass

            tar_size = audio_tar.stat().st_size

            rows.append(
                {
                    "prefix": prefix,
                    "language": lang,
                    "shard_id": shard_dir.name,
                    "shard_dir": str(shard_dir),
                    "audio_tar_path": str(audio_tar),
                    "audio_index_path": str(audio_index) if audio_index.exists() else "",
                    "metadata_path": str(metadata) if metadata.exists() else "",
                    "segment_count": segment_count,
                    "audio_tar_bytes": tar_size,
                    "needs_conversion": prefix == "final-export",  # Only 48kHz source
                }
            )

    return rows


def main():
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    print("Phase 2: Scanning shard inventory...")
    t0 = time.time()

    all_rows = []
    for prefix, base_path, layout in SOURCES:
        rows = scan_source(prefix, base_path, layout)
        print(f"  {prefix}: {len(rows)} shards")
        all_rows.extend(rows)

    elapsed = time.time() - t0
    print(f"\nTotal: {len(all_rows)} shards in {elapsed:.1f}s")

    # Write inventory parquet
    table = pa.table(
        {col: [r[col] for r in all_rows] for col in all_rows[0].keys()}
    )
    inv_path = OUTPUT_DIR / "shard_inventory.parquet"
    pq.write_table(table, inv_path)
    print(f"Inventory: {inv_path}")

    # Build conversion queue (only shards needing conversion, largest first)
    conv_rows = [r for r in all_rows if r["needs_conversion"]]
    conv_rows.sort(key=lambda r: r["audio_tar_bytes"], reverse=True)
    if conv_rows:
        conv_table = pa.table(
            {col: [r[col] for r in conv_rows] for col in conv_rows[0].keys()}
        )
        queue_path = OUTPUT_DIR / "conversion_queue.parquet"
        pq.write_table(conv_table, queue_path)
        print(f"Conversion queue: {queue_path} ({len(conv_rows)} shards)")

    # Summary
    summary = {
        "total_shards": len(all_rows),
        "needs_conversion": len(conv_rows),
        "already_16k": len(all_rows) - len(conv_rows),
        "total_bytes": sum(r["audio_tar_bytes"] for r in all_rows),
        "conversion_bytes": sum(r["audio_tar_bytes"] for r in conv_rows),
        "per_prefix": {},
        "per_language": {},
    }
    for r in all_rows:
        p = r["prefix"]
        lang = r["language"]
        summary["per_prefix"].setdefault(p, {"shards": 0, "bytes": 0})
        summary["per_prefix"][p]["shards"] += 1
        summary["per_prefix"][p]["bytes"] += r["audio_tar_bytes"]
        summary["per_language"].setdefault(lang, {"shards": 0, "bytes": 0})
        summary["per_language"][lang]["shards"] += 1
        summary["per_language"][lang]["bytes"] += r["audio_tar_bytes"]

    summary_path = OUTPUT_DIR / "inventory_summary.json"
    with open(summary_path, "w") as f:
        json.dump(summary, f, indent=2)
    print(f"Summary: {summary_path}")

    # Print overview
    print(f"\n{'Prefix':<20} {'Shards':>8} {'Size':>10} {'Convert':>8}")
    print("-" * 50)
    for p, info in sorted(summary["per_prefix"].items()):
        size_gb = info["bytes"] / (1024**3)
        needs = "YES" if p == "final-export" else "no"
        print(f"{p:<20} {info['shards']:>8} {size_gb:>9.1f}G {needs:>8}")


if __name__ == "__main__":
    main()
