#!/usr/bin/env python3
"""
Download delivery2/ from GCS → stream upload to R2.
Reads: /root/joshtalks_delivery2_filelist/{lang}.jsonl
Target: r2://indiccharan/joshtalks/delivery2/{lang}/{id}/{files}

Usage: python3 delivery2_fetch.py              # all languages
       python3 delivery2_fetch.py hindi english  # specific languages
"""

import json, os, sys, time, random, threading, queue
from pathlib import Path
from datetime import datetime
import requests, urllib3
import boto3
from botocore.config import Config as BotoConfig

urllib3.disable_warnings()

CONCURRENCY = 200
PART_SIZE = 5 * 1024 * 1024

R2_ENDPOINT = "https://cb908ed13329eb7b186e06ab51bda190.r2.cloudflarestorage.com"
R2_ACCESS_KEY = "5ee7d72f38521105f780e3d45935b7dc"
R2_SECRET_KEY = "d69432b9958cc77f87d7f0d986636e83de690178fd351ce5b192bf53e518199a"
R2_BUCKET = "indiccharan"
R2_PREFIX = "joshtalks/delivery2"

GCS_BUCKET = "joshtalks-data-collection"
FILELIST_DIR = Path("/root/joshtalks_delivery2_filelist")
ALL_LANGUAGES = ["hindi", "english", "bengali", "gujarati", "telugu"]

OXY_USER = "user-humming_Ows6w-country-US"
OXY_PASS = "mOIb_8PL7ieGppJW"
OXY_PORTS = list(range(8001, 8022))

CHECKPOINT_DIR = Path("/root/joshtalks_checkpoint")
LOG_FILE = Path("/root/joshtalks_r2_delivery2.log")

lock = threading.Lock()


def get_proxy(port=None):
    p = port or random.choice(OXY_PORTS)
    return f"http://{OXY_USER}:{OXY_PASS}@dc.oxylabs.io:{p}"


def log(msg):
    line = f"[{datetime.now().strftime('%H:%M:%S')}] {msg}"
    print(line, flush=True)
    with lock:
        try:
            with open(LOG_FILE, "a") as f:
                f.write(line + "\n")
        except:
            pass


def make_s3():
    return boto3.client("s3",
        endpoint_url=R2_ENDPOINT, aws_access_key_id=R2_ACCESS_KEY,
        aws_secret_access_key=R2_SECRET_KEY,
        config=BotoConfig(retries={"max_attempts": 3, "mode": "adaptive"}),
        region_name="auto")


def load_checkpoint(name):
    path = CHECKPOINT_DIR / f"{name}.json"
    if path.exists():
        try:
            return set(json.loads(path.read_text()))
        except:
            pass
    return set()


def save_checkpoint(name, done_set):
    path = CHECKPOINT_DIR / f"{name}.json"
    tmp = path.with_suffix(".tmp")
    with lock:
        tmp.write_text(json.dumps(list(done_set)))
        tmp.rename(path)


def content_type_for(key):
    if key.endswith(".wav"):
        return "audio/wav"
    if key.endswith(".json"):
        return "application/json"
    return "application/octet-stream"


def stream_upload(s3, r2_key, resp, ctype):
    mpu = s3.create_multipart_upload(Bucket=R2_BUCKET, Key=r2_key, ContentType=ctype)
    uid = mpu["UploadId"]
    parts = []
    pn = 1
    buf = b""
    total = 0
    try:
        for chunk in resp.iter_content(chunk_size=65536):
            buf += chunk
            total += len(chunk)
            while len(buf) >= PART_SIZE:
                part_data = buf[:PART_SIZE]
                buf = buf[PART_SIZE:]
                r = s3.upload_part(Bucket=R2_BUCKET, Key=r2_key, UploadId=uid, PartNumber=pn, Body=part_data)
                parts.append({"ETag": r["ETag"], "PartNumber": pn})
                pn += 1
        if buf:
            r = s3.upload_part(Bucket=R2_BUCKET, Key=r2_key, UploadId=uid, PartNumber=pn, Body=buf)
            parts.append({"ETag": r["ETag"], "PartNumber": pn})
        if not parts:
            s3.abort_multipart_upload(Bucket=R2_BUCKET, Key=r2_key, UploadId=uid)
            return False, 0
        s3.complete_multipart_upload(Bucket=R2_BUCKET, Key=r2_key, UploadId=uid, MultipartUpload={"Parts": parts})
        return True, total
    except:
        try:
            s3.abort_multipart_upload(Bucket=R2_BUCKET, Key=r2_key, UploadId=uid)
        except:
            pass
        raise


def download_and_upload_file(s3, gcs_key, r2_key, proxy_url):
    ctype = content_type_for(gcs_key)
    gcs_url = f"https://storage.googleapis.com/{GCS_BUCKET}/{gcs_key}"
    proxies = {"https": proxy_url}
    for attempt in range(3):
        try:
            resp = requests.get(gcs_url, proxies=proxies, verify=False, timeout=300, stream=True)
            if resp.status_code != 200:
                if resp.status_code in (403, 404):
                    return False, 0
                time.sleep(1)
                continue
            ok, nbytes = stream_upload(s3, r2_key, resp, ctype)
            return ok, nbytes
        except:
            if attempt < 2:
                proxy_url = get_proxy()
                proxies = {"https": proxy_url}
                time.sleep(1 + random.random())
    return False, 0


stats = {"ok": 0, "fail": 0, "bytes": 0, "files": 0}


def worker(q, done_set, s3, worker_id):
    my_proxy = get_proxy(OXY_PORTS[worker_id % len(OXY_PORTS)])

    while True:
        try:
            sess_id, files = q.get_nowait()
        except queue.Empty:
            break

        sess_bytes = 0
        sess_files = 0
        all_ok = True

        for fobj in files:
            gcs_key = fobj["key"]
            r2_key = f"{R2_PREFIX}/{gcs_key.split('delivery2/')[-1]}"
            ok, nbytes = download_and_upload_file(s3, gcs_key, r2_key, my_proxy)
            if ok:
                sess_bytes += nbytes
                sess_files += 1
            else:
                all_ok = False

        with lock:
            if all_ok and sess_files > 0:
                done_set.add(sess_id)
                stats["ok"] += 1
            else:
                stats["fail"] += 1
            stats["bytes"] += sess_bytes
            stats["files"] += sess_files

        q.task_done()


def main():
    CHECKPOINT_DIR.mkdir(exist_ok=True)
    langs = sys.argv[1:] if len(sys.argv) > 1 else ALL_LANGUAGES

    log(f"Concurrency: {CONCURRENCY}")
    log(f"Proxies: {len(OXY_PORTS)}")
    log(f"Languages: {langs}")

    for lang in langs:
        checkpoint_name = f"delivery2_{lang}"
        done = load_checkpoint(checkpoint_name)

        log(f"\n{'='*60}")
        log(f"LANGUAGE: {lang} (already done: {len(done)})")
        log(f"{'='*60}")

        path = FILELIST_DIR / f"{lang}.jsonl"
        if not path.exists():
            log(f"{lang}: filelist not found, run delivery2_list.py first")
            continue

        all_sessions = []
        with open(path) as f:
            for line in f:
                line = line.strip()
                if line:
                    all_sessions.append(json.loads(line))

        log(f"Loaded {len(all_sessions)} conversations")

        todo = [(s["session_id"], s["files"]) for s in all_sessions if s["session_id"] not in done]
        log(f"Todo: {len(todo)} | Skipped: {len(all_sessions) - len(todo)}")

        if not todo:
            log(f"{lang}: nothing to do")
            continue

        stats["ok"] = 0
        stats["fail"] = 0
        stats["bytes"] = 0
        stats["files"] = 0
        start = time.time()

        q = queue.Queue()
        for item in todo:
            q.put(item)

        s3 = make_s3()
        threads = []
        for i in range(CONCURRENCY):
            t = threading.Thread(target=worker, args=(q, done, s3, i), daemon=True)
            t.start()
            threads.append(t)

        last_save = time.time()
        while any(t.is_alive() for t in threads):
            time.sleep(10)
            elapsed = time.time() - start
            speed = stats["bytes"] / elapsed / 1024 / 1024 if elapsed > 0 else 0
            left = q.qsize()
            eta_s = (left / (stats["ok"] / elapsed)) if stats["ok"] > 0 else 0

            log(f"  {lang}: ok={stats['ok']} fail={stats['fail']} left={left} "
                f"files={stats['files']} {stats['bytes']/1024/1024/1024:.2f}GB "
                f"{speed:.1f}MB/s ETA={eta_s/60:.0f}m")

            if time.time() - last_save >= 30:
                save_checkpoint(checkpoint_name, done)
                last_save = time.time()

        for t in threads:
            t.join()

        save_checkpoint(checkpoint_name, done)
        elapsed = time.time() - start
        log(f"\n  {lang} DONE: ok={stats['ok']} fail={stats['fail']} files={stats['files']} "
            f"{stats['bytes']/1024/1024/1024:.2f}GB {elapsed:.0f}s")

    log(f"\n=== ALL LANGUAGES COMPLETE ===")


if __name__ == "__main__":
    main()
