#!/usr/bin/env python3
"""
Fast async upload of audio files to Cloudflare R2.
Uses asyncio + aioboto3 for 100+ concurrent uploads.
Resumes from dataset_hf_metadata.jsonl if interrupted.
"""

import json
import hashlib
import wave
import os
import sys
import time
import asyncio
import aioboto3
from pathlib import Path

R2_ENDPOINT = "https://a08f8e1184ff290eab9ad59b357ba0a1.r2.cloudflarestorage.com"
R2_ACCESS_KEY = "14a9241bb8634d96de0314da18f67c43"
R2_SECRET_KEY = "79e49ccbbbf0cc070b346186db3e0a82c9cc4cb6568ff77f9ef3d6cd1964d817"
R2_BUCKET = "hindi"

COMBINED_JSONL = "/home/ubuntu/modi_processed/dataset_combined.jsonl"
METADATA_OUT = "/home/ubuntu/modi_processed/dataset_hf_metadata.jsonl"
CONCURRENCY = 80


def extract_source_and_speaker(audio_path: str) -> tuple:
    if "modi_processed/" in audio_path:
        return "modi", "narendra_modi"
    if "soprano_data/" not in audio_path:
        return "unknown", "unknown"

    after = audio_path.split("soprano_data/")[1]
    top_dir = after.split("/")[0]

    if top_dir.startswith("google_tts_"):
        return "google_tts", top_dir.replace("google_tts_", "")
    elif top_dir == "polly_kajal":
        return "polly", "kajal"
    elif top_dir == "sarvam_data":
        fname = Path(audio_path).stem
        speaker = fname.rsplit("_", 1)[0] if "_" in fname else fname
        return "sarvam", speaker
    elif top_dir == "rasa_hindi":
        return "rasa_hindi", "rasa"
    elif top_dir == "IISc_SYSPIN_Data":
        parts = after.split("/")
        return ("iisc_syspin", parts[1]) if len(parts) >= 2 else ("iisc_syspin", "unknown")
    return top_dir, "unknown"


def get_wav_info(path: str) -> tuple:
    try:
        with wave.open(path, "rb") as w:
            return w.getnframes() / w.getframerate(), w.getframerate()
    except Exception:
        return 0.0, 0


def compute_id(audio_bytes: bytes, text: str) -> str:
    h = hashlib.sha256()
    h.update(audio_bytes)
    h.update(text.encode("utf-8"))
    return h.hexdigest()[:16]


def clean_text(text: str) -> str:
    return text[len("Speaker 0: "):] if text.startswith("Speaker 0: ") else text


def prepare_entry(entry: dict) -> dict | None:
    """CPU-bound: read file, hash, get wav info. Returns metadata dict with audio_bytes."""
    audio_path = entry["audio"]
    raw_text = entry["text"]

    if not os.path.exists(audio_path):
        return None

    audio_bytes = open(audio_path, "rb").read()
    text = clean_text(raw_text)
    source, speaker_id = extract_source_and_speaker(audio_path)
    r2_key = f"audio/{source}/{speaker_id}/{Path(audio_path).name}"
    row_id = compute_id(audio_bytes, text)
    duration_s, sample_rate = get_wav_info(audio_path)

    if duration_s <= 0.1:
        return None

    return {
        "id": row_id,
        "text": text,
        "speaker_id": speaker_id,
        "source": source,
        "language": "hi",
        "duration_s": round(duration_s, 3),
        "r2_url": f"r2://{R2_BUCKET}/{r2_key}",
        "r2_key": r2_key,
        "sample_rate": sample_rate,
        "local_audio_path": audio_path,
        "_audio_bytes": audio_bytes,
    }


async def upload_one(s3_client, meta: dict) -> dict | None:
    """Upload a single file to R2."""
    try:
        await s3_client.put_object(
            Bucket=R2_BUCKET,
            Key=meta["r2_key"],
            Body=meta["_audio_bytes"],
            ContentType="audio/wav",
        )
        del meta["_audio_bytes"]
        return meta
    except Exception as e:
        print(f"  FAIL {meta['r2_key']}: {e}", flush=True)
        return None


async def main():
    print("Loading dataset entries...", flush=True)
    entries = []
    with open(COMBINED_JSONL) as f:
        for line in f:
            entries.append(json.loads(line.strip()))
    total = len(entries)
    print(f"Total entries: {total}", flush=True)

    already_done = set()
    if os.path.exists(METADATA_OUT):
        with open(METADATA_OUT) as f:
            for line in f:
                d = json.loads(line.strip())
                already_done.add(d["local_audio_path"])
        print(f"Resuming: {len(already_done)} already processed", flush=True)

    remaining = [e for e in entries if e["audio"] not in already_done]
    print(f"Remaining to process: {len(remaining)}", flush=True)

    if not remaining:
        print("Nothing to do!", flush=True)
        return

    uploaded = 0
    failed = 0
    start_time = time.time()
    out_f = open(METADATA_OUT, "a")

    session = aioboto3.Session()
    async with session.client(
        "s3",
        endpoint_url=R2_ENDPOINT,
        aws_access_key_id=R2_ACCESS_KEY,
        aws_secret_access_key=R2_SECRET_KEY,
        region_name="auto",
    ) as s3_client:

        sem = asyncio.Semaphore(CONCURRENCY)

        async def bounded_upload(entry):
            nonlocal uploaded, failed
            meta = await asyncio.get_event_loop().run_in_executor(None, prepare_entry, entry)
            if meta is None:
                failed += 1
                return

            async with sem:
                result = await upload_one(s3_client, meta)

            if result:
                out_f.write(json.dumps(result, ensure_ascii=False) + "\n")
                uploaded += 1
            else:
                failed += 1

            done = uploaded + failed
            if done % 1000 == 0:
                elapsed = time.time() - start_time
                rate = done / elapsed if elapsed > 0 else 0
                eta = (len(remaining) - done) / rate / 60 if rate > 0 else 0
                out_f.flush()
                print(
                    f"  Progress: {done}/{len(remaining)} "
                    f"({uploaded} ok, {failed} fail) "
                    f"Rate: {rate:.1f}/s  ETA: {eta:.1f}min",
                    flush=True,
                )

        batch_size = 500
        for i in range(0, len(remaining), batch_size):
            batch = remaining[i : i + batch_size]
            tasks = [bounded_upload(e) for e in batch]
            await asyncio.gather(*tasks)

            # Periodic flush
            out_f.flush()

    out_f.close()
    elapsed = time.time() - start_time
    print(
        f"\nDone! Uploaded: {uploaded}, Failed: {failed}, "
        f"Time: {elapsed/60:.1f}min, Rate: {(uploaded+failed)/elapsed:.1f}/s",
        flush=True,
    )


if __name__ == "__main__":
    asyncio.run(main())
