#!/usr/bin/env python3
"""
Fast R2 upload with producer-consumer pipeline.
- Producer pool (multiprocessing): reads files, hashes, extracts metadata
- Consumer pool (threading): uploads to R2 concurrently
- Queue bridges the two for max throughput
"""

import json
import hashlib
import wave
import os
import time
import queue
import threading
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from pathlib import Path

import boto3
from botocore.config import Config

R2_ENDPOINT = "https://a08f8e1184ff290eab9ad59b357ba0a1.r2.cloudflarestorage.com"
R2_ACCESS_KEY = "14a9241bb8634d96de0314da18f67c43"
R2_SECRET_KEY = "79e49ccbbbf0cc070b346186db3e0a82c9cc4cb6568ff77f9ef3d6cd1964d817"
R2_BUCKET = "hindi"

COMBINED_JSONL = "/home/ubuntu/modi_processed/dataset_combined.jsonl"
METADATA_OUT = "/home/ubuntu/modi_processed/dataset_hf_metadata.jsonl"

UPLOAD_WORKERS = 64
READ_WORKERS = 8
QUEUE_SIZE = 200


def get_s3_client():
    return boto3.client(
        "s3",
        endpoint_url=R2_ENDPOINT,
        aws_access_key_id=R2_ACCESS_KEY,
        aws_secret_access_key=R2_SECRET_KEY,
        region_name="auto",
        config=Config(
            retries={"max_attempts": 3, "mode": "adaptive"},
            max_pool_connections=UPLOAD_WORKERS + 4,
        ),
    )


def extract_source_and_speaker(audio_path: str):
    if "modi_processed/" in audio_path:
        return "modi", "narendra_modi"
    if "soprano_data/" not in audio_path:
        return "unknown", "unknown"

    after = audio_path.split("soprano_data/")[1]
    top_dir = after.split("/")[0]

    if top_dir.startswith("google_tts_"):
        return "google_tts", top_dir.replace("google_tts_", "")
    elif top_dir == "polly_kajal":
        return "polly", "kajal"
    elif top_dir == "sarvam_data":
        fname = Path(audio_path).stem
        speaker = fname.rsplit("_", 1)[0] if "_" in fname else fname
        return "sarvam", speaker
    elif top_dir == "rasa_hindi":
        return "rasa_hindi", "rasa"
    elif top_dir == "IISc_SYSPIN_Data":
        parts = after.split("/")
        return ("iisc_syspin", parts[1]) if len(parts) >= 2 else ("iisc_syspin", "unknown")
    return top_dir, "unknown"


def prepare_entry(entry: dict):
    """CPU-bound: read file from disk, hash, get wav info."""
    audio_path = entry["audio"]
    raw_text = entry["text"]

    if not os.path.exists(audio_path):
        return None

    try:
        audio_bytes = open(audio_path, "rb").read()
    except Exception:
        return None

    text = raw_text[len("Speaker 0: "):] if raw_text.startswith("Speaker 0: ") else raw_text
    source, speaker_id = extract_source_and_speaker(audio_path)
    r2_key = f"audio/{source}/{speaker_id}/{Path(audio_path).name}"

    h = hashlib.sha256()
    h.update(audio_bytes)
    h.update(text.encode("utf-8"))
    row_id = h.hexdigest()[:16]

    try:
        with wave.open(audio_path, "rb") as w:
            duration_s = w.getnframes() / w.getframerate()
            sample_rate = w.getframerate()
    except Exception:
        duration_s, sample_rate = 0.0, 0

    if duration_s <= 0.1:
        return None

    return {
        "id": row_id,
        "text": text,
        "speaker_id": speaker_id,
        "source": source,
        "language": "hi",
        "duration_s": round(duration_s, 3),
        "r2_url": f"r2://{R2_BUCKET}/{r2_key}",
        "r2_key": r2_key,
        "sample_rate": sample_rate,
        "local_audio_path": audio_path,
        "_audio_bytes": audio_bytes,
    }


def main():
    print("Loading dataset entries...", flush=True)
    entries = []
    with open(COMBINED_JSONL) as f:
        for line in f:
            entries.append(json.loads(line.strip()))
    print(f"Total entries: {len(entries)}", flush=True)

    already_done = set()
    if os.path.exists(METADATA_OUT):
        with open(METADATA_OUT) as f:
            for line in f:
                d = json.loads(line.strip())
                already_done.add(d["local_audio_path"])
        print(f"Resuming: {len(already_done)} already processed", flush=True)

    remaining = [e for e in entries if e["audio"] not in already_done]
    total_remaining = len(remaining)
    print(f"Remaining to process: {total_remaining}", flush=True)

    if not remaining:
        print("Nothing to do!", flush=True)
        return

    upload_queue = queue.Queue(maxsize=QUEUE_SIZE)

    uploaded = 0
    failed = 0
    read_failed = 0
    lock = threading.Lock()
    start_time = time.time()
    out_f = open(METADATA_OUT, "a")

    s3_client = get_s3_client()

    def uploader():
        nonlocal uploaded, failed
        while True:
            item = upload_queue.get()
            if item is None:
                break

            audio_bytes = item.pop("_audio_bytes")
            try:
                s3_client.put_object(
                    Bucket=R2_BUCKET,
                    Key=item["r2_key"],
                    Body=audio_bytes,
                    ContentType="audio/wav",
                )
                with lock:
                    out_f.write(json.dumps(item, ensure_ascii=False) + "\n")
                    uploaded += 1
                    done = uploaded + failed
                    if done % 1000 == 0:
                        elapsed = time.time() - start_time
                        rate = done / elapsed
                        eta = (total_remaining - done) / rate / 60 if rate > 0 else 0
                        out_f.flush()
                        print(
                            f"  Progress: {done}/{total_remaining} "
                            f"({uploaded} ok, {failed} fail) "
                            f"Rate: {rate:.1f}/s  ETA: {eta:.1f}min",
                            flush=True,
                        )
            except Exception as e:
                with lock:
                    failed += 1
                    print(f"  FAIL upload: {e}", flush=True)
            finally:
                upload_queue.task_done()

    upload_threads = []
    for _ in range(UPLOAD_WORKERS):
        t = threading.Thread(target=uploader, daemon=True)
        t.start()
        upload_threads.append(t)

    print(f"Started {UPLOAD_WORKERS} upload threads, {READ_WORKERS} reader procs", flush=True)

    with ProcessPoolExecutor(max_workers=READ_WORKERS) as reader_pool:
        batch_size = 200
        for batch_start in range(0, len(remaining), batch_size):
            batch = remaining[batch_start : batch_start + batch_size]
            results = reader_pool.map(prepare_entry, batch)
            for meta in results:
                if meta is not None:
                    upload_queue.put(meta)
                else:
                    with lock:
                        failed += 1

    upload_queue.join()

    for _ in upload_threads:
        upload_queue.put(None)
    for t in upload_threads:
        t.join()

    out_f.flush()
    out_f.close()

    elapsed = time.time() - start_time
    total_done = uploaded + failed
    rate = total_done / elapsed if elapsed > 0 else 0
    print(
        f"\nDone! Uploaded: {uploaded}, Failed: {failed}, "
        f"Time: {elapsed/60:.1f}min, Rate: {rate:.1f}/s",
        flush=True,
    )


if __name__ == "__main__":
    main()
