"""
Download parquet shards from an R2 bucket to local disk.

By default the script drops `vox_speaker_embedding` (binary blobs) to save
30-40% space. This keeps the historical validation archive and the recover
validation archive compact enough for local DuckDB analytics.

Usage:
  python scripts/pull_shards.py
  python scripts/pull_shards.py --bucket validationsrecoverfinal \
      --output data/recover_validation_shards
  python scripts/pull_shards.py --keep-all-columns
"""
from __future__ import annotations

import argparse
import logging
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

import boto3
import pyarrow.parquet as pq

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s", datefmt="%H:%M:%S")
logger = logging.getLogger(__name__)

DEFAULT_DROP_COLUMNS = ("vox_speaker_embedding",)


def get_s3():
    from dotenv import load_dotenv
    load_dotenv(Path(__file__).resolve().parent.parent / ".env")
    return boto3.client(
        "s3",
        endpoint_url=os.getenv("R2_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("R2_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("R2_SECRET_ACCESS_KEY"),
        region_name="auto",
    )


def list_shards(s3, bucket: str, prefix: str) -> list[dict]:
    """List all .parquet objects under shards/."""
    shards = []
    paginator = s3.get_paginator("list_objects_v2")
    for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
        for obj in page.get("Contents", []):
            if obj["Key"].endswith(".parquet"):
                shards.append({"key": obj["Key"], "size": obj["Size"]})
    return shards


def download_and_prune(
    s3,
    *,
    bucket: str,
    prefix: str,
    key: str,
    output_dir: Path,
    drop_columns: set[str],
) -> tuple[str, int, int]:
    """Download a shard, drop heavy columns, write pruned parquet. Returns (key, orig_size, pruned_size)."""
    rel_path = key.replace(prefix, "", 1)
    local_path = output_dir / rel_path
    local_path.parent.mkdir(parents=True, exist_ok=True)

    if local_path.exists():
        return key, 0, local_path.stat().st_size

    import tempfile
    tmp = tempfile.mktemp(suffix=".parquet", dir=str(output_dir))
    try:
        s3.download_file(bucket, key, tmp)
        orig_size = Path(tmp).stat().st_size

        if drop_columns:
            table = pq.read_table(tmp, columns=[
                c for c in pq.read_schema(tmp).names if c not in drop_columns
            ])
            pq.write_table(table, str(local_path), compression="zstd")
        else:
            Path(tmp).replace(local_path)
        pruned_size = local_path.stat().st_size

        return key, orig_size, pruned_size
    except Exception as e:
        logger.warning(f"Failed {key}: {e}")
        return key, 0, 0
    finally:
        try:
            os.unlink(tmp)
        except OSError:
            pass


def main():
    parser = argparse.ArgumentParser(description="Download parquet shards from R2")
    parser.add_argument("--bucket", type=str, default="validation-results")
    parser.add_argument("--prefix", type=str, default="shards/")
    parser.add_argument("--workers", type=int, default=32)
    parser.add_argument("--output", type=str, default="data/validation_shards")
    parser.add_argument(
        "--drop-column",
        action="append",
        default=list(DEFAULT_DROP_COLUMNS),
        help="Column to drop from downloaded parquet (repeatable)",
    )
    parser.add_argument(
        "--keep-all-columns",
        action="store_true",
        help="Do not prune any columns from the downloaded parquet",
    )
    args = parser.parse_args()

    output_dir = Path(args.output)
    output_dir.mkdir(parents=True, exist_ok=True)
    prefix = args.prefix if args.prefix.endswith("/") else f"{args.prefix}/"
    drop_columns = set() if args.keep_all_columns else {c for c in args.drop_column if c}

    s3 = get_s3()
    logger.info(f"Listing shards in s3://{args.bucket}/{prefix} ...")
    shards = list_shards(s3, args.bucket, prefix)
    total_r2_bytes = sum(s["size"] for s in shards)
    logger.info(f"Found {len(shards)} shards ({total_r2_bytes / 1e9:.2f} GB in R2)")
    if drop_columns:
        logger.info(f"Pruning columns: {sorted(drop_columns)}")
    else:
        logger.info("Keeping all parquet columns")

    # Each thread gets its own S3 client (boto3 clients aren't thread-safe)
    def _worker(shard):
        thread_s3 = get_s3()
        return download_and_prune(
            thread_s3,
            bucket=args.bucket,
            prefix=prefix,
            key=shard["key"],
            output_dir=output_dir,
            drop_columns=drop_columns,
        )

    t0 = time.time()
    done = 0
    orig_total = 0
    pruned_total = 0
    failed = 0

    with ThreadPoolExecutor(max_workers=args.workers) as pool:
        futures = {pool.submit(_worker, s): s for s in shards}
        for fut in as_completed(futures):
            key, orig, pruned = fut.result()
            done += 1
            if pruned > 0:
                orig_total += orig
                pruned_total += pruned
            else:
                if orig == 0 and pruned == 0:
                    shard = futures[fut]
                    if not (output_dir / shard["key"].replace(prefix, "", 1)).exists():
                        failed += 1

            if done % 200 == 0 or done == len(shards):
                elapsed = time.time() - t0
                rate = done / elapsed if elapsed > 0 else 0
                logger.info(
                    f"[{done}/{len(shards)}] {rate:.1f} shards/s | "
                    f"pruned {pruned_total / 1e9:.2f} GB | "
                    f"failed {failed} | {elapsed:.0f}s"
                )

    elapsed = time.time() - t0
    logger.info(
        f"\nDone: {done} shards in {elapsed:.0f}s\n"
        f"  R2 size:    {total_r2_bytes / 1e9:.2f} GB\n"
        f"  Downloaded: {orig_total / 1e9:.2f} GB (before pruning)\n"
        f"  On disk:    {pruned_total / 1e9:.2f} GB\n"
        f"  Savings:    {(1 - pruned_total / max(orig_total, 1)) * 100:.0f}%\n"
        f"  Failed:     {failed}"
    )


if __name__ == "__main__":
    main()
