#!/usr/bin/env python3
"""Phase 3: Build training manifests with tar offsets from Phase 2 splits.

Joins Phase 2 global manifest (has transcript, language, split) with
per-shard tar_offset_index (has offset, nbytes) and 16k_index (has duration).

Usage:
  python3 tools/phase3_build_training_manifest.py
"""

import sys
import time
from pathlib import Path

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq

PHASE2_DIR = Path("/workspace/maya-asr/artifacts/phase2")
PHASE3_DIR = Path("/workspace/maya-asr/artifacts/phase3")
INVENTORY_PATH = PHASE2_DIR / "shard_inventory.parquet"


def load_shard_data(shard_dir: Path, tar_path: str) -> pd.DataFrame:
    """Load offset index + duration data for one shard."""
    offset_idx = shard_dir / "tar_offset_index.parquet"
    idx_16k = shard_dir / "audio_16k_index.parquet"
    idx_orig = shard_dir / "audio_index.parquet"

    if not offset_idx.exists():
        return pd.DataFrame()

    # Load offsets
    offsets = pq.read_table(offset_idx).to_pandas()

    # Load duration from best available source
    if idx_16k.exists():
        dur_df = pq.read_table(
            idx_16k, columns=["member_name", "duration_s", "sample_rate_hz", "channels"]
        ).to_pandas()
        merged = offsets.merge(dur_df, on="member_name", how="left")
    elif idx_orig.exists():
        try:
            orig_cols = pq.read_schema(idx_orig).names
            read_cols = []
            name_col = None
            dur_col = None
            for c in orig_cols:
                if c in ("tar_member_name", "segment_id"):
                    name_col = c
                    read_cols.append(c)
                if c in ("audio_duration_s", "duration_s"):
                    dur_col = c
                    read_cols.append(c)
                if c in ("flac_size_bytes",):
                    read_cols.append(c)

            if name_col and dur_col:
                dur_df = pq.read_table(idx_orig, columns=read_cols).to_pandas()
                dur_df = dur_df.rename(columns={name_col: "member_name", dur_col: "duration_s"})
                merged = offsets.merge(dur_df[["member_name", "duration_s"]], on="member_name", how="left")
            else:
                merged = offsets.copy()
                merged["duration_s"] = 0.0
        except Exception:
            merged = offsets.copy()
            merged["duration_s"] = 0.0
    else:
        merged = offsets.copy()
        merged["duration_s"] = 0.0

    merged["tar_path"] = tar_path

    # Fill missing
    if "sample_rate_hz" not in merged.columns:
        merged["sample_rate_hz"] = 16000
    if "channels" not in merged.columns:
        merged["channels"] = 1

    return merged


def main():
    PHASE3_DIR.mkdir(parents=True, exist_ok=True)

    inv_df = pq.read_table(INVENTORY_PATH).to_pandas()
    print(f"Loading data from {len(inv_df)} shards...")

    # Load Phase 2 global manifest for transcript + metadata
    global_df = pq.read_table(
        PHASE2_DIR / "global_manifest.parquet",
        columns=["sample_id", "prefix", "language", "shard_id", "transcript",
                 "tar_path", "tar_member_name", "video_id", "segment_id",
                 "quality_bucket", "quality_score", "split"],
    ).to_pandas()
    print(f"Global manifest: {len(global_df):,} rows")

    # Build offset+duration index across all shards
    t0 = time.time()
    all_offset_rows = []
    for _, row in inv_df.iterrows():
        shard_dir = Path(row["shard_dir"])
        tar_path = str(shard_dir / "audio.tar")
        shard_data = load_shard_data(shard_dir, tar_path)
        if len(shard_data) > 0:
            shard_data["shard_dir"] = str(shard_dir)
            all_offset_rows.append(shard_data)

    offset_df = pd.concat(all_offset_rows, ignore_index=True)
    print(f"Offset index: {len(offset_df):,} rows in {time.time()-t0:.0f}s")

    # Build join key: tar_path + member_name (unique across shards)
    merged = global_df.merge(
        offset_df[["tar_path", "member_name", "tar_offset_data", "tar_nbytes", "duration_s",
                    "sample_rate_hz", "channels"]].rename(columns={"duration_s": "duration_s_offset"}),
        left_on=["tar_path", "tar_member_name"],
        right_on=["tar_path", "member_name"],
        how="left",
        suffixes=("", "_off"),
    )

    # Use offset duration (more accurate than Phase 2 join)
    merged["duration_s"] = merged["duration_s_offset"]
    merged = merged.drop(columns=["member_name", "duration_s_offset"], errors="ignore")

    # Clean transcripts (convert to native python str to avoid Arrow issues)
    merged["transcript"] = merged["transcript"].astype("object").fillna("").astype(str)
    merged["transcript"] = merged["transcript"].replace("nan", "")

    # Drop rows with no offset (failed to join)
    before = len(merged)
    merged = merged.dropna(subset=["tar_offset_data"])
    after = len(merged)
    dropped = before - after
    print(f"Dropped {dropped:,} rows with no offset data ({dropped/max(before,1)*100:.2f}%)")
    print(f"Final manifest: {len(merged):,} rows")

    # Stats
    total_hours = merged["duration_s"].sum() / 3600
    print(f"Total audio: {total_hours:,.1f} hours")

    # Write per-split manifests
    for split in ["train", "dev", "test"]:
        split_df = merged[merged["split"] == split].copy()
        # Select output columns
        out_cols = [
            "sample_id", "transcript", "language", "tar_path", "tar_member_name",
            "tar_offset_data", "tar_nbytes", "duration_s", "split",
            "quality_bucket", "quality_score", "video_id", "prefix",
        ]
        out_df = split_df[[c for c in out_cols if c in split_df.columns]]
        out_path = PHASE3_DIR / f"{split}_manifest.parquet"
        pq.write_table(pa.Table.from_pandas(out_df), out_path)
        hours = out_df["duration_s"].sum() / 3600
        print(f"  {split}: {len(out_df):,} rows, {hours:,.1f}h -> {out_path}")

    # Summary
    summary = {
        "total_rows": len(merged),
        "total_hours": round(total_hours, 1),
        "dropped_no_offset": dropped,
        "splits": {
            split: {
                "rows": int((merged["split"] == split).sum()),
                "hours": round(merged[merged["split"] == split]["duration_s"].sum() / 3600, 1),
            }
            for split in ["train", "dev", "test"]
        },
        "per_language": {
            lang: {
                "rows": int((merged["language"] == lang).sum()),
                "hours": round(merged[merged["language"] == lang]["duration_s"].sum() / 3600, 1),
            }
            for lang in sorted(merged["language"].unique())
        },
    }
    import json
    with open(PHASE3_DIR / "manifest_summary.json", "w") as f:
        json.dump(summary, f, indent=2)

    print(f"\nManifest summary: {PHASE3_DIR / 'manifest_summary.json'}")


if __name__ == "__main__":
    main()
