"""
High-throughput parallel shard downloader.

Uses multiprocessing to bypass GIL + per-process thread pools for massive
concurrent S3 downloads. Skips pyarrow pruning entirely (the #1 bottleneck
in the original pull_shards.py) — raw parquet files land directly on disk.

Architecture:
  - 1 main process lists all shard keys from R2
  - splits the work across N worker processes
  - each process creates its own boto3 client + M-thread download pool
  - already-downloaded files (matching size) are skipped instantly

Usage:
  python scripts/pull_shards_fast.py
  python scripts/pull_shards_fast.py --bucket validationsrecoverfinal --procs 16 --threads-per-proc 64
"""
from __future__ import annotations

import argparse
import logging
import multiprocessing as mp
import os
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(process)d] %(message)s", datefmt="%H:%M:%S")
logger = logging.getLogger(__name__)


def get_s3():
    import boto3
    from dotenv import load_dotenv
    load_dotenv(Path(__file__).resolve().parent.parent / ".env")
    return boto3.client(
        "s3",
        endpoint_url=os.getenv("R2_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("R2_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("R2_SECRET_ACCESS_KEY"),
        region_name="auto",
    )


def list_all_shards(bucket: str, prefix: str) -> list[dict]:
    s3 = get_s3()
    shards = []
    paginator = s3.get_paginator("list_objects_v2")
    for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
        for obj in page.get("Contents", []):
            if obj["Key"].endswith(".parquet"):
                shards.append({"key": obj["Key"], "size": obj["Size"]})
    return shards


def _download_one(s3, bucket: str, prefix: str, shard: dict, output_dir: Path) -> tuple[int, int]:
    """Returns (downloaded_bytes, skipped). skipped=1 means file already existed."""
    rel_path = shard["key"].replace(prefix, "", 1)
    local_path = output_dir / rel_path
    expected_size = shard["size"]

    if local_path.exists() and local_path.stat().st_size == expected_size:
        return 0, 1

    local_path.parent.mkdir(parents=True, exist_ok=True)
    tmp_path = local_path.with_suffix(".parquet.tmp")
    try:
        s3.download_file(bucket, shard["key"], str(tmp_path))
        tmp_path.rename(local_path)
        return expected_size, 0
    except Exception as e:
        try:
            tmp_path.unlink(missing_ok=True)
        except OSError:
            pass
        raise


def worker_process(
    proc_id: int,
    shards: list[dict],
    bucket: str,
    prefix: str,
    output_dir: Path,
    threads: int,
    counter: mp.Value,
    bytes_counter: mp.Value,
    skip_counter: mp.Value,
    fail_counter: mp.Value,
):
    """One worker process: owns its own S3 client + thread pool."""
    def _do_one(shard: dict) -> tuple[int, int, int]:
        thread_s3 = get_s3()
        try:
            downloaded, skipped = _download_one(thread_s3, bucket, prefix, shard, output_dir)
            return downloaded, skipped, 0
        except Exception:
            return 0, 0, 1

    done = 0
    with ThreadPoolExecutor(max_workers=threads) as pool:
        futures = {pool.submit(_do_one, s): s for s in shards}
        for fut in as_completed(futures):
            downloaded, skipped, failed = fut.result()
            done += 1
            with counter.get_lock():
                counter.value += 1
            with bytes_counter.get_lock():
                bytes_counter.value += downloaded
            with skip_counter.get_lock():
                skip_counter.value += skipped
            with fail_counter.get_lock():
                fail_counter.value += failed


def monitor_loop(
    total: int,
    counter: mp.Value,
    bytes_counter: mp.Value,
    skip_counter: mp.Value,
    fail_counter: mp.Value,
    start_time: float,
):
    """Print progress every 2 seconds until all shards are done."""
    while True:
        time.sleep(2)
        with counter.get_lock():
            done = counter.value
        with bytes_counter.get_lock():
            dl_bytes = bytes_counter.value
        with skip_counter.get_lock():
            skipped = skip_counter.value
        with fail_counter.get_lock():
            failed = fail_counter.value

        elapsed = time.time() - start_time
        rate = done / elapsed if elapsed > 0 else 0
        remaining = total - done
        eta_s = remaining / rate if rate > 0 else 0
        eta_m = eta_s / 60

        logger.info(
            f"[{done:,}/{total:,}] {rate:.0f} shards/s | "
            f"{dl_bytes / 1e9:.2f} GB downloaded | "
            f"{skipped:,} skipped | {failed} failed | "
            f"ETA {eta_m:.1f}min"
        )

        if done >= total:
            break


def main():
    parser = argparse.ArgumentParser(description="High-throughput parallel shard downloader")
    parser.add_argument("--bucket", default="validationsrecoverfinal")
    parser.add_argument("--prefix", default="shards/")
    parser.add_argument("--output", default="data/recover_validation_shards")
    parser.add_argument("--procs", type=int, default=0, help="Worker processes (0 = auto = 2x CPU)")
    parser.add_argument("--threads-per-proc", type=int, default=48, help="Download threads per process")
    args = parser.parse_args()

    output_dir = Path(args.output)
    output_dir.mkdir(parents=True, exist_ok=True)
    prefix = args.prefix if args.prefix.endswith("/") else f"{args.prefix}/"
    num_procs = args.procs or min(os.cpu_count() * 2, 32)

    logger.info(f"Listing shards in s3://{args.bucket}/{prefix} ...")
    t_list = time.time()
    shards = list_all_shards(args.bucket, prefix)
    total = len(shards)
    total_bytes = sum(s["size"] for s in shards)
    logger.info(
        f"Found {total:,} shards ({total_bytes / 1e9:.2f} GB) in {time.time() - t_list:.0f}s"
    )

    if total == 0:
        logger.info("Nothing to download.")
        return

    total_concurrency = num_procs * args.threads_per_proc
    logger.info(
        f"Launching {num_procs} processes x {args.threads_per_proc} threads = "
        f"{total_concurrency} concurrent downloads"
    )

    chunks = [[] for _ in range(num_procs)]
    for i, shard in enumerate(shards):
        chunks[i % num_procs].append(shard)

    counter = mp.Value("i", 0)
    bytes_counter = mp.Value("L", 0)
    skip_counter = mp.Value("i", 0)
    fail_counter = mp.Value("i", 0)
    t0 = time.time()

    monitor = mp.Process(
        target=monitor_loop,
        args=(total, counter, bytes_counter, skip_counter, fail_counter, t0),
        daemon=True,
    )
    monitor.start()

    processes = []
    for proc_id, chunk in enumerate(chunks):
        if not chunk:
            continue
        p = mp.Process(
            target=worker_process,
            args=(
                proc_id, chunk, args.bucket, prefix, output_dir,
                args.threads_per_proc, counter, bytes_counter, skip_counter, fail_counter,
            ),
        )
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

    monitor.join(timeout=5)

    elapsed = time.time() - t0
    with counter.get_lock():
        final_done = counter.value
    with bytes_counter.get_lock():
        final_bytes = bytes_counter.value
    with skip_counter.get_lock():
        final_skipped = skip_counter.value
    with fail_counter.get_lock():
        final_failed = fail_counter.value

    logger.info(
        f"\nDone: {final_done:,} shards in {elapsed:.0f}s ({final_done/elapsed:.0f} shards/s)\n"
        f"  Downloaded: {final_bytes / 1e9:.2f} GB\n"
        f"  Skipped:    {final_skipped:,} (already on disk)\n"
        f"  Failed:     {final_failed}\n"
        f"  R2 total:   {total_bytes / 1e9:.2f} GB"
    )


if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)
    main()
