#!/usr/bin/env python3
"""Download remaining shards from R2 finalsftdata bucket.

Compares R2 inventory with local shards, downloads missing ones.
Excludes v1 and nemotron-ckpts prefixes.

Usage:
  python3 tools/phase2_download_remaining.py --dry-run
  python3 tools/phase2_download_remaining.py --workers 16
  python3 tools/phase2_download_remaining.py --workers 16 --prefix final-export
"""

import argparse
import json
import os
import subprocess
import sys
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path

import boto3
from dotenv import load_dotenv

load_dotenv(Path("/workspace/maya-asr/.env"))

DATA_ROOT = Path("/root/sft_data")
ARTIFACTS_DIR = Path("/workspace/maya-asr/artifacts/phase2")
BLUEPRINT = Path("/workspace/maya-asr/finalsftdata.json")

COMPONENTS = ["audio.tar", "audio_index.parquet", "manifest.json", "metadata.parquet"]


def get_s3():
    return boto3.client(
        "s3",
        endpoint_url=os.environ["R2_ENDPOINT_URL"],
        aws_access_key_id=os.environ["R2_ACCESS_KEY_ID"],
        aws_secret_access_key=os.environ["R2_SECRET_ACCESS_KEY"],
        region_name="auto",
    )


def find_missing_shards(blueprint: dict, prefix_filter: str = None) -> list[dict]:
    """Compare R2 blueprint with local files to find missing shards."""
    missing = []
    for prefix, info in blueprint["prefixes"].items():
        if prefix_filter and prefix != prefix_filter:
            continue

        path_pattern = info["path_pattern"]

        for lang_code, lang_info in info["languages"].items():
            # Determine local base path
            # Pattern like: final-export/production/shards/lang=<lang>/<shard_id>/<component>
            # or: hifitts2/lang=<lang>/<shard_id>/<component>
            sample = lang_info["sample_shard"]
            # sample: "final-export/production/shards/lang=en/en_shard_..."
            # Local path: /root/sft_data/final-export/production/shards/lang=en/en_shard_.../

            # List all shards for this language from R2 by listing the blueprint
            # We need to enumerate from R2 since blueprint doesn't list individual shard IDs
            # Use s5cmd or paginated list
            pass

    return missing


def list_r2_shards(s3, bucket: str, prefix: str) -> list[str]:
    """List all shard directories under a prefix in R2."""
    # List with delimiter to get "directories"
    shards = set()
    paginator = s3.get_paginator("list_objects_v2")

    for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
        for obj in page.get("Contents", []):
            key = obj["Key"]
            # key like: final-export/production/shards/lang=hi/hi_shard_xxx/audio.tar
            parts = key.split("/")
            # Find the shard dir (parent of component file)
            for comp in COMPONENTS:
                if key.endswith(comp):
                    shard_prefix = key[: -len(comp) - 1]  # remove /component
                    shards.add(shard_prefix)
                    break

    return sorted(shards)


def download_shard(args_tuple) -> dict:
    """Download a single shard from R2."""
    r2_prefix, local_dir_str, bucket = args_tuple
    local_dir = Path(local_dir_str)
    local_dir.mkdir(parents=True, exist_ok=True)

    s3 = get_s3()
    downloaded = 0
    errors = []
    total_bytes = 0

    for comp in COMPONENTS:
        key = f"{r2_prefix}/{comp}"
        local_path = local_dir / comp
        if local_path.exists():
            continue
        try:
            s3.download_file(bucket, key, str(local_path))
            total_bytes += local_path.stat().st_size
            downloaded += 1
        except Exception as e:
            if "xcodec2" not in comp:  # xcodec2 is optional
                errors.append(f"{comp}: {e}")

    return {
        "shard": r2_prefix,
        "local_dir": local_dir_str,
        "downloaded": downloaded,
        "total_bytes": total_bytes,
        "errors": errors,
    }


def main():
    parser = argparse.ArgumentParser(description="Download remaining R2 shards")
    parser.add_argument("--workers", type=int, default=16)
    parser.add_argument("--dry-run", action="store_true")
    parser.add_argument("--prefix", type=str, default=None, help="Filter to one prefix")
    args = parser.parse_args()

    blueprint = json.load(open(BLUEPRINT))
    bucket = blueprint["bucket"]
    s3 = get_s3()

    print("Scanning R2 for missing shards...")
    all_missing = []

    for prefix_name, prefix_info in blueprint["prefixes"].items():
        if args.prefix and prefix_name != args.prefix:
            continue

        # Determine the R2 key prefix for this source
        sample = prefix_info.get("sample_shard", "")
        if not sample:
            continue

        # Get the "directory prefix" from sample path
        # e.g., "final-export/production/shards/lang=en/en_shard_xxx"
        # We need the parent: "final-export/production/shards/"
        # But languages have different dirs, so list per language
        for lang_code, lang_info in prefix_info["languages"].items():
            lang_sample = lang_info["sample_shard"]
            # e.g., "final-export/production/shards/lang=hi/hi_shard_xxx"
            # Parent prefix: everything up to and including lang=xx/
            parts = lang_sample.split("/")
            # Find the lang= part
            lang_prefix = None
            for i, p in enumerate(parts):
                if p.startswith("lang=") or (
                    len(p) == 2 and p.isalpha()
                ):
                    lang_prefix = "/".join(parts[: i + 1]) + "/"
                    break

            if not lang_prefix:
                # Fallback: use everything except last component
                lang_prefix = "/".join(parts[:-1]) + "/"

            # List R2 shards under this lang prefix
            r2_shards = list_r2_shards(s3, bucket, lang_prefix)

            for r2_shard in r2_shards:
                # Determine local path
                local_dir = DATA_ROOT / r2_shard
                audio_tar = local_dir / "audio.tar"
                if audio_tar.exists():
                    continue
                all_missing.append(
                    {
                        "prefix": prefix_name,
                        "language": lang_code,
                        "r2_prefix": r2_shard,
                        "local_dir": str(local_dir),
                    }
                )

        print(f"  {prefix_name}: found {sum(1 for m in all_missing if m['prefix'] == prefix_name)} missing")

    print(f"\nTotal missing: {len(all_missing)} shards")

    if args.dry_run:
        print("\n[DRY RUN] Would download:")
        for m in all_missing[:10]:
            print(f"  {m['r2_prefix']} -> {m['local_dir']}")
        if len(all_missing) > 10:
            print(f"  ... and {len(all_missing) - 10} more")
        return

    if not all_missing:
        print("Nothing to download!")
        return

    # Download
    print(f"\nDownloading {len(all_missing)} shards with {args.workers} workers...")
    t0 = time.time()
    done = 0
    errors = 0
    total_gb = 0.0

    download_args = [
        (m["r2_prefix"], m["local_dir"], bucket) for m in all_missing
    ]

    with ProcessPoolExecutor(max_workers=args.workers) as executor:
        futures = {
            executor.submit(download_shard, a): a for a in download_args
        }
        last_report = time.time()
        for future in as_completed(futures):
            result = future.result()
            done += 1
            total_gb += result["total_bytes"] / (1024**3)
            if result["errors"]:
                errors += 1

            now = time.time()
            if now - last_report >= 60 or done == len(all_missing):
                elapsed = now - t0
                rate_gb = total_gb / max(elapsed, 1) * 1024  # MB/s
                print(
                    f"  [{elapsed/60:.0f}m] {done}/{len(all_missing)} "
                    f"({total_gb:.1f} GB, {rate_gb:.0f} MB/s, errors={errors})"
                )
                last_report = now

    elapsed = time.time() - t0
    print(f"\nDownload complete: {done} shards, {total_gb:.1f} GB in {elapsed/60:.0f}m")
    if errors:
        print(f"  Errors: {errors}")

    # Save download state
    state = {
        "downloaded": done,
        "errors": errors,
        "total_gb": round(total_gb, 1),
        "elapsed_min": round(elapsed / 60, 1),
    }
    with open(ARTIFACTS_DIR / "download_remaining_state.json", "w") as f:
        json.dump(state, f, indent=2)


if __name__ == "__main__":
    main()
