from __future__ import annotations

import argparse
import json
import os
import shutil
import time
from pathlib import Path

import duckdb
import pyarrow.parquet as pq
from boto3.s3.transfer import TransferConfig
from dotenv import load_dotenv

ROOT = Path("/home/ubuntu/transcripts")
CANONICAL_DEFAULT = ROOT / "final_data" / "final_cleaned_segments_with_variants_rerouted_repetition_filtered.parquet"
RAW_TX_DEFAULT = ROOT / "data" / "transcription_results.parquet"
VALIDATION_DEFAULT = ROOT / "data" / "recover_v2_consolidated.parquet"
YOUTUBE_DEFAULT = ROOT / "data" / "youtube_video_metadata_all.csv"


def parse_args():
    parser = argparse.ArgumentParser(description="Upload final export reference snapshots to R2")
    parser.add_argument("--canonical", type=Path, default=CANONICAL_DEFAULT)
    parser.add_argument("--raw-transcripts", type=Path, default=RAW_TX_DEFAULT)
    parser.add_argument("--validation", type=Path, default=VALIDATION_DEFAULT)
    parser.add_argument("--youtube-meta", type=Path, default=YOUTUBE_DEFAULT)
    parser.add_argument("--variants", type=Path, default=None)
    parser.add_argument("--output-dir", type=Path, default=ROOT / "data" / "final_export_reference")
    parser.add_argument("--bucket", default="")
    parser.add_argument("--prefix", default="final-export-reference")
    parser.add_argument("--copy-local", action="store_true", help="Copy inputs into output-dir before upload")
    parser.add_argument("--no-upload", action="store_true")
    return parser.parse_args()


def parquet_rows(path: Path) -> int:
    return int(pq.ParquetFile(path).metadata.num_rows)


def csv_rows(path: Path) -> int:
    con = duckdb.connect()
    try:
        return int(con.execute(f"SELECT count(*) FROM read_csv_auto('{_sql_path(path)}', header=true)").fetchone()[0])
    finally:
        con.close()


def upload_file(bucket: str, key: str, path: Path):
    import boto3

    transfer = TransferConfig(
        multipart_threshold=64 * 1024 * 1024,
        multipart_chunksize=64 * 1024 * 1024,
        max_concurrency=16,
        use_threads=True,
    )
    s3 = boto3.client(
        "s3",
        endpoint_url=os.environ["R2_ENDPOINT_URL"],
        aws_access_key_id=os.environ["R2_ACCESS_KEY_ID"],
        aws_secret_access_key=os.environ["R2_SECRET_ACCESS_KEY"],
        region_name="auto",
    )
    print(f"Uploading {path.name} -> s3://{bucket}/{key}")
    t0 = time.time()
    s3.upload_file(str(path), bucket, key, Config=transfer)
    print(f"Uploaded {path.name} in {time.time() - t0:.1f}s")


def maybe_stage(src: Path, dest: Path, copy_local: bool) -> Path:
    dest.parent.mkdir(parents=True, exist_ok=True)
    if copy_local:
        shutil.copy2(src, dest)
        return dest
    return src


def _sql_path(path: Path) -> str:
    return path.as_posix().replace("'", "''")


def main():
    args = parse_args()
    load_dotenv(ROOT / ".env")
    bucket = args.bucket or os.getenv("FINAL_EXPORT_REFERENCE_BUCKET") or os.getenv("R2_BUCKET") or "1-cleaned-data"
    prefix = args.prefix.strip("/")
    args.output_dir.mkdir(parents=True, exist_ok=True)

    staged = {
        "canonical_segments.parquet": maybe_stage(
            args.canonical, args.output_dir / "canonical_segments.parquet", args.copy_local
        ),
        "raw_transcripts.parquet": maybe_stage(
            args.raw_transcripts, args.output_dir / "raw_transcripts.parquet", args.copy_local
        ),
        "validation.parquet": maybe_stage(
            args.validation, args.output_dir / "validation.parquet", args.copy_local
        ) if args.validation and args.validation.exists() else None,
        "youtube_meta.csv": maybe_stage(
            args.youtube_meta, args.output_dir / "youtube_meta.csv", args.copy_local
        ) if args.youtube_meta and args.youtube_meta.exists() else None,
        "variants.parquet": maybe_stage(
            args.variants, args.output_dir / "variants.parquet", args.copy_local
        ) if args.variants and args.variants.exists() else None,
    }

    manifest = {
        "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        "bucket": bucket,
        "prefix": prefix,
        "files": {},
    }

    for logical_name, path in staged.items():
        if path is None:
            continue
        entry = {
            "logical_name": logical_name,
            "path": path.as_posix(),
            "size_bytes": path.stat().st_size,
        }
        if path.suffix == ".parquet":
            entry["rows"] = parquet_rows(path)
        elif path.suffix == ".csv":
            entry["rows"] = csv_rows(path)
        manifest["files"][logical_name] = entry

    manifest_path = args.output_dir / "manifest.json"
    manifest_path.write_text(json.dumps(manifest, indent=2, sort_keys=True) + "\n")
    print(json.dumps(manifest, indent=2, sort_keys=True))

    if not args.no_upload:
        for logical_name, path in staged.items():
            if path is None:
                continue
            upload_file(bucket, f"{prefix}/{logical_name}", path)
        upload_file(bucket, f"{prefix}/manifest.json", manifest_path)


if __name__ == "__main__":
    main()
