#!/usr/bin/env python3
"""Extract audio files from tar shards for a smoke manifest.

Reads a JSONL manifest with (audio_filepath=tar, tar_member) fields,
extracts the referenced audio files into a flat output directory,
and writes an updated manifest with direct file paths.

Usage:
  python3 scripts/extract_smoke_audio.py \
    --input data/manifests/smoke_train.jsonl \
    --output data/manifests/smoke_train_extracted.jsonl \
    --audio-dir data/audio_cache \
    --max-rows 500
"""

import argparse
import json
import sys
import tarfile
from datetime import datetime, timezone
from pathlib import Path

from tqdm import tqdm

sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src"))
from maya_asr.config import file_sha256


def main():
    parser = argparse.ArgumentParser(description="Extract audio from tar for smoke runs")
    parser.add_argument("--input", type=Path, required=True)
    parser.add_argument("--output", type=Path, required=True)
    parser.add_argument("--audio-dir", type=Path, required=True)
    parser.add_argument(
        "--max-rows",
        type=int,
        default=0,
        help="Max rows to extract (0=all)",
    )
    args = parser.parse_args()

    if not args.input.exists():
        print(f"ERROR: Input not found: {args.input}", file=sys.stderr)
        sys.exit(1)

    args.audio_dir.mkdir(parents=True, exist_ok=True)
    args.output.parent.mkdir(parents=True, exist_ok=True)

    # Read manifest rows
    rows = []
    with open(args.input) as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    if args.max_rows > 0 and len(rows) > args.max_rows:
        # Stratified sampling: proportional by language, not first-N
        import random

        rng = random.Random(42)
        by_lang: dict[str, list] = {}
        for r in rows:
            by_lang.setdefault(r.get("lang", "xx"), []).append(r)

        sampled = []
        total = len(rows)
        for lang in sorted(by_lang.keys()):
            lang_rows = by_lang[lang]
            n_lang = max(1, round(len(lang_rows) / total * args.max_rows))
            rng.shuffle(lang_rows)
            sampled.extend(lang_rows[:n_lang])

        # Trim to exact count if rounding produced extra
        rng.shuffle(sampled)
        rows = sampled[: args.max_rows]

    # Group by tar file to extract efficiently
    tar_groups: dict[str, list[tuple[int, dict]]] = {}
    for i, row in enumerate(rows):
        tar_path = row["audio_filepath"]
        tar_groups.setdefault(tar_path, []).append((i, row))

    # Extract
    extracted = 0
    updated_rows = [None] * len(rows)
    for tar_path, items in tar_groups.items():
        members_needed = {row["tar_member"] for _, row in items}
        idx_map = {row["tar_member"]: (i, row) for i, row in items}

        with tarfile.open(tar_path, "r") as tf:
            for member in tqdm(
                tf,
                desc=f"  {Path(tar_path).parent.name}",
                unit="file",
                total=len(members_needed),
                leave=False,
            ):
                if member.name in members_needed:
                    # Extract to audio_dir with unique name
                    lang = idx_map[member.name][1].get("lang", "xx")
                    out_name = f"{lang}_{member.name}"
                    out_path = args.audio_dir / out_name
                    if not out_path.exists():
                        f_in = tf.extractfile(member)
                        if f_in:
                            out_path.write_bytes(f_in.read())

                    i, row = idx_map[member.name]
                    new_row = {**row, "audio_filepath": str(out_path)}
                    del new_row["tar_member"]
                    updated_rows[i] = new_row
                    extracted += 1
                    members_needed.discard(member.name)
                    if not members_needed:
                        break

    # Write updated manifest
    with open(args.output, "w") as f:
        for row in updated_rows:
            if row is not None:
                f.write(json.dumps(row, ensure_ascii=False) + "\n")

    # Write per-output metadata (sidecar next to output manifest)
    metadata = {
        "source_manifest": str(args.input.resolve()),
        "source_manifest_sha256": file_sha256(args.input),
        "max_rows": args.max_rows,
        "extracted_count": extracted,
        "audio_dir": str(args.audio_dir.resolve()),
        "created_at": datetime.now(timezone.utc).isoformat(),
    }
    # Name metadata after output: smoke_train_extracted.jsonl -> smoke_train_extracted.meta.json
    metadata_file = args.output.with_suffix(".meta.json")
    with open(metadata_file, "w") as f:
        json.dump(metadata, f, indent=2)

    print(f"Extracted {extracted} audio files to {args.audio_dir}")
    print(f"Updated manifest: {args.output}")
    print(f"Metadata: {metadata_file}")


if __name__ == "__main__":
    main()
