"""
Download validation parquet shards from R2 to local disk.
Drops vox_speaker_embedding (binary blobs) to save ~30-40% space.
Uses ThreadPoolExecutor for parallel downloads.

Usage:
  python scripts/pull_shards.py [--workers 32] [--output data/validation_shards]
"""
from __future__ import annotations

import argparse
import logging
import os
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

import boto3
import pyarrow.parquet as pq

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s", datefmt="%H:%M:%S")
logger = logging.getLogger(__name__)

DROP_COLUMNS = {"vox_speaker_embedding"}

BUCKET = "validation-results"
PREFIX = "shards/"


def get_s3():
    from dotenv import load_dotenv
    load_dotenv(Path(__file__).resolve().parent.parent / ".env")
    return boto3.client(
        "s3",
        endpoint_url=os.getenv("R2_ENDPOINT_URL"),
        aws_access_key_id=os.getenv("R2_ACCESS_KEY_ID"),
        aws_secret_access_key=os.getenv("R2_SECRET_ACCESS_KEY"),
        region_name="auto",
    )


def list_shards(s3) -> list[dict]:
    """List all .parquet objects under shards/."""
    shards = []
    paginator = s3.get_paginator("list_objects_v2")
    for page in paginator.paginate(Bucket=BUCKET, Prefix=PREFIX):
        for obj in page.get("Contents", []):
            if obj["Key"].endswith(".parquet"):
                shards.append({"key": obj["Key"], "size": obj["Size"]})
    return shards


def download_and_prune(s3, key: str, output_dir: Path) -> tuple[str, int, int]:
    """Download a shard, drop heavy columns, write pruned parquet. Returns (key, orig_size, pruned_size)."""
    rel_path = key.replace(PREFIX, "", 1)
    local_path = output_dir / rel_path
    local_path.parent.mkdir(parents=True, exist_ok=True)

    if local_path.exists():
        return key, 0, local_path.stat().st_size

    import tempfile
    tmp = tempfile.mktemp(suffix=".parquet", dir=str(output_dir))
    try:
        s3.download_file(BUCKET, key, tmp)
        orig_size = Path(tmp).stat().st_size

        table = pq.read_table(tmp, columns=[
            c for c in pq.read_schema(tmp).names if c not in DROP_COLUMNS
        ])
        pq.write_table(table, str(local_path), compression="zstd")
        pruned_size = local_path.stat().st_size

        return key, orig_size, pruned_size
    except Exception as e:
        logger.warning(f"Failed {key}: {e}")
        return key, 0, 0
    finally:
        try:
            os.unlink(tmp)
        except OSError:
            pass


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--workers", type=int, default=32)
    parser.add_argument("--output", type=str, default="data/validation_shards")
    args = parser.parse_args()

    output_dir = Path(args.output)
    output_dir.mkdir(parents=True, exist_ok=True)

    s3 = get_s3()
    logger.info("Listing shards in R2...")
    shards = list_shards(s3)
    total_r2_bytes = sum(s["size"] for s in shards)
    logger.info(f"Found {len(shards)} shards ({total_r2_bytes / 1e9:.2f} GB in R2)")

    # Each thread gets its own S3 client (boto3 clients aren't thread-safe)
    def _worker(shard):
        thread_s3 = get_s3()
        return download_and_prune(thread_s3, shard["key"], output_dir)

    t0 = time.time()
    done = 0
    orig_total = 0
    pruned_total = 0
    failed = 0

    with ThreadPoolExecutor(max_workers=args.workers) as pool:
        futures = {pool.submit(_worker, s): s for s in shards}
        for fut in as_completed(futures):
            key, orig, pruned = fut.result()
            done += 1
            if pruned > 0:
                orig_total += orig
                pruned_total += pruned
            else:
                if orig == 0 and pruned == 0:
                    shard = futures[fut]
                    if not (output_dir / shard["key"].replace(PREFIX, "", 1)).exists():
                        failed += 1

            if done % 200 == 0 or done == len(shards):
                elapsed = time.time() - t0
                rate = done / elapsed if elapsed > 0 else 0
                logger.info(
                    f"[{done}/{len(shards)}] {rate:.1f} shards/s | "
                    f"pruned {pruned_total / 1e9:.2f} GB | "
                    f"failed {failed} | {elapsed:.0f}s"
                )

    elapsed = time.time() - t0
    logger.info(
        f"\nDone: {done} shards in {elapsed:.0f}s\n"
        f"  R2 size:    {total_r2_bytes / 1e9:.2f} GB\n"
        f"  Downloaded: {orig_total / 1e9:.2f} GB (before pruning)\n"
        f"  On disk:    {pruned_total / 1e9:.2f} GB (after dropping {DROP_COLUMNS})\n"
        f"  Savings:    {(1 - pruned_total / max(orig_total, 1)) * 100:.0f}%\n"
        f"  Failed:     {failed}"
    )


if __name__ == "__main__":
    main()
