#!/usr/bin/env python3
"""Phase 2: Orchestrate parallel shard conversion with resumability.

Reads conversion_queue.parquet, distributes shards across workers,
tracks progress, retries failures, saves checkpoints.

Usage:
  python3 tools/phase2_orchestrate_conversion.py --workers 32
  python3 tools/phase2_orchestrate_conversion.py --workers 32 --max-shards 10  # test
  python3 tools/phase2_orchestrate_conversion.py --workers 32 --resume  # restart
"""

import argparse
import json
import os
import sys
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path

import pyarrow.parquet as pq

# Thread safety
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["TORCH_NUM_THREADS"] = "1"

ARTIFACTS_DIR = Path("/workspace/maya-asr/artifacts/phase2")
QUEUE_PATH = ARTIFACTS_DIR / "conversion_queue.parquet"
STATE_PATH = ARTIFACTS_DIR / "conversion_state.json"


def convert_one_shard(shard_dir_str: str) -> dict:
    """Worker function: convert a single shard."""
    import importlib.util

    spec = importlib.util.spec_from_file_location(
        "phase2_convert",
        Path(__file__).parent / "phase2_convert_shard_audio16.py",
    )
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    return mod.convert_shard(Path(shard_dir_str))


def load_state() -> dict:
    """Load or initialize conversion state."""
    if STATE_PATH.exists():
        return json.loads(STATE_PATH.read_text())
    return {"completed": {}, "failed": {}, "started_at": time.time()}


def save_state(state: dict):
    """Persist conversion state."""
    state["last_updated"] = time.time()
    STATE_PATH.write_text(json.dumps(state, indent=2))


def main():
    parser = argparse.ArgumentParser(description="Orchestrate shard conversion")
    parser.add_argument("--workers", type=int, default=32)
    parser.add_argument("--max-shards", type=int, default=0, help="0=all")
    parser.add_argument("--max-retries", type=int, default=2)
    parser.add_argument("--resume", action="store_true", help="Resume from state")
    parser.add_argument("--dry-run", action="store_true")
    args = parser.parse_args()

    if not QUEUE_PATH.exists():
        print("ERROR: Run phase2_scan_inventory.py first", file=sys.stderr)
        sys.exit(1)

    # Load queue
    queue_df = pq.read_table(QUEUE_PATH).to_pandas()
    all_shards = list(queue_df["shard_dir"])
    print(f"Conversion queue: {len(all_shards)} shards")

    if args.max_shards > 0:
        all_shards = all_shards[: args.max_shards]
        print(f"  Limited to {args.max_shards} shards")

    # Load state for resume
    state = load_state() if args.resume else {"completed": {}, "failed": {}, "started_at": time.time()}

    # Filter out already completed
    pending = [s for s in all_shards if s not in state["completed"]]
    # Add failed shards back for retry (up to max_retries)
    for shard, info in list(state["failed"].items()):
        if info.get("attempts", 1) < args.max_retries and shard in all_shards:
            if shard not in pending:
                pending.append(shard)

    print(f"  Completed: {len(state['completed'])}")
    print(f"  Failed: {len(state['failed'])}")
    print(f"  Pending: {len(pending)}")
    print(f"  Workers: {args.workers}")
    print()

    if args.dry_run:
        print("[DRY RUN] Would process these shards:")
        for s in pending[:5]:
            print(f"  {s}")
        if len(pending) > 5:
            print(f"  ... and {len(pending) - 5} more")
        return

    if not pending:
        print("Nothing to do — all shards converted.")
        return

    t0 = time.time()
    done = 0
    errors = 0
    total_hours = 0.0
    last_report = time.time()

    with ProcessPoolExecutor(max_workers=args.workers) as executor:
        futures = {
            executor.submit(convert_one_shard, shard): shard for shard in pending
        }

        for future in as_completed(futures):
            shard = futures[future]
            try:
                result = future.result()
            except Exception as e:
                result = {"status": "error", "error": str(e), "shard_dir": shard}

            if result["status"] == "success":
                state["completed"][shard] = {
                    "output_count": result.get("output_count", 0),
                    "hours": result.get("output_hours", 0),
                    "elapsed_s": result.get("elapsed_s", 0),
                }
                done += 1
                total_hours += result.get("output_hours", 0)
            elif result["status"] == "skipped":
                state["completed"][shard] = {"skipped": True}
                done += 1
            else:
                attempts = state["failed"].get(shard, {}).get("attempts", 0) + 1
                state["failed"][shard] = {
                    "error": result.get("error", "unknown"),
                    "attempts": attempts,
                }
                errors += 1

            # Progress report every 60s
            now = time.time()
            if now - last_report >= 60 or done + errors == len(pending):
                elapsed = now - t0
                rate = total_hours * 3600 / max(elapsed, 1)
                remaining = len(pending) - done - errors
                eta_h = (remaining * elapsed / max(done + errors, 1)) / 3600

                print(
                    f"[{elapsed/60:.0f}m] "
                    f"done={done}/{len(pending)} "
                    f"errors={errors} "
                    f"hours={total_hours:.1f}h "
                    f"rate={rate:.0f}x "
                    f"ETA={eta_h:.1f}h"
                )
                save_state(state)
                last_report = now

    # Final save
    save_state(state)
    elapsed = time.time() - t0

    print(f"\n{'='*50}")
    print(f"Conversion complete in {elapsed/3600:.1f}h")
    print(f"  Converted: {done}")
    print(f"  Errors: {errors}")
    print(f"  Total audio hours: {total_hours:.1f}h")
    if errors > 0:
        print("\nFailed shards:")
        for shard, info in state["failed"].items():
            print(f"  {shard}: {info['error']} (attempts={info['attempts']})")


if __name__ == "__main__":
    main()
