#!/usr/bin/env python3
import os
import sys
import re
import boto3
from boto3.s3.transfer import TransferConfig
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
import time

BUCKET = "voices"
LOCAL_DIR = "/home/ubuntu/modi"

ENDPOINT = "https://cb908ed13329eb7b186e06ab51bda190.r2.cloudflarestorage.com"
ACCESS_KEY = "bf40b8ba75e6e08bd32e12e74226c7d2"
SECRET_KEY = "766fd70a7c794e89fe4ba0fd1e185be923b5a91cf9627a354a57ed5eaad07296"

transfer_config = TransferConfig(
    multipart_threshold=100 * 1024 * 1024,
    multipart_chunksize=50 * 1024 * 1024,
    max_concurrency=4,
    use_threads=True,
)

lock = threading.Lock()
uploaded_count = 0
uploaded_bytes = 0
failed = []
start_time = time.time()

def make_client():
    return boto3.client(
        "s3",
        endpoint_url=ENDPOINT,
        aws_access_key_id=ACCESS_KEY,
        aws_secret_access_key=SECRET_KEY,
        region_name="auto",
    )

def upload_file(local_path, s3_key, file_size):
    global uploaded_count, uploaded_bytes
    client = make_client()
    try:
        content_type = "audio/wav" if local_path.endswith(".wav") else "application/json"
        client.upload_file(
            local_path, BUCKET, s3_key,
            Config=transfer_config,
            ExtraArgs={"ContentType": content_type},
        )
        with lock:
            uploaded_count += 1
            uploaded_bytes += file_size
            elapsed = time.time() - start_time
            rate = uploaded_bytes / elapsed / 1024 / 1024 if elapsed > 0 else 0
            print(f"[{uploaded_count}/{total_files}] {s3_key} ({file_size/1024/1024:.1f}MB) - "
                  f"{uploaded_bytes/1024/1024/1024:.2f}GB uploaded - {rate:.1f}MB/s",
                  flush=True)
        return True
    except Exception as e:
        with lock:
            failed.append((local_path, str(e)))
            print(f"FAILED: {s3_key} - {e}", flush=True)
        return False

files_to_upload = []
for fname in sorted(os.listdir(LOCAL_DIR)):
    fpath = os.path.join(LOCAL_DIR, fname)
    if not os.path.isfile(fpath):
        continue

    if fname.endswith("_metadata.json"):
        file_id = fname.replace("_metadata.json", "")
        s3_key = f"modi/{file_id}/{file_id}_metadata.json"
    elif fname.endswith(".wav"):
        file_id = fname.replace(".wav", "")
        s3_key = f"modi/{file_id}/{file_id}.wav"
    else:
        s3_key = f"modi/{fname}"

    fsize = os.path.getsize(fpath)
    files_to_upload.append((fpath, s3_key, fsize))

total_files = len(files_to_upload)
total_size = sum(f[2] for f in files_to_upload)
print(f"Uploading {total_files} files ({total_size/1024/1024/1024:.2f}GB) to s3://{BUCKET}/")
print(f"Structure: modi/{{id}}/{{id}}.wav + modi/{{id}}/{{id}}_metadata.json")
print(f"Using 8 parallel workers with multipart uploads")
print("=" * 80, flush=True)

with ThreadPoolExecutor(max_workers=8) as executor:
    futures = {
        executor.submit(upload_file, fpath, s3_key, fsize): s3_key
        for fpath, s3_key, fsize in files_to_upload
    }
    for future in as_completed(futures):
        future.result()

elapsed = time.time() - start_time
print("=" * 80)
print(f"Done! {uploaded_count}/{total_files} files uploaded in {elapsed:.0f}s")
print(f"Total: {uploaded_bytes/1024/1024/1024:.2f}GB at {uploaded_bytes/elapsed/1024/1024:.1f}MB/s avg")
if failed:
    print(f"\nFailed uploads ({len(failed)}):")
    for path, err in failed:
        print(f"  {path}: {err}")
    sys.exit(1)
