#!/usr/bin/env python3
"""
Upload audio files to Cloudflare R2 and build HuggingFace dataset metadata.

Two phases:
  Phase 1: Upload all audio to R2 with parallel workers, collect metadata.
  Phase 2: Build and push HF dataset with Audio feature.
"""

import json
import hashlib
import wave
import os
import sys
import time
import boto3
from botocore.config import Config
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

R2_ENDPOINT = "https://a08f8e1184ff290eab9ad59b357ba0a1.r2.cloudflarestorage.com"
R2_ACCESS_KEY = "14a9241bb8634d96de0314da18f67c43"
R2_SECRET_KEY = "79e49ccbbbf0cc070b346186db3e0a82c9cc4cb6568ff77f9ef3d6cd1964d817"
R2_BUCKET = "hindi"

COMBINED_JSONL = "/home/ubuntu/modi_processed/dataset_combined.jsonl"
METADATA_OUT = "/home/ubuntu/modi_processed/dataset_hf_metadata.jsonl"
NUM_WORKERS = 16


def get_s3_client():
    return boto3.client(
        "s3",
        endpoint_url=R2_ENDPOINT,
        aws_access_key_id=R2_ACCESS_KEY,
        aws_secret_access_key=R2_SECRET_KEY,
        region_name="auto",
        config=Config(
            retries={"max_attempts": 3, "mode": "adaptive"},
            max_pool_connections=NUM_WORKERS + 4,
        ),
    )


def extract_source_and_speaker(audio_path: str) -> tuple[str, str]:
    """Extract source dataset name and speaker_id from audio path."""
    if "modi_processed/" in audio_path:
        return "modi", "narendra_modi"

    if "soprano_data/" not in audio_path:
        return "unknown", "unknown"

    after = audio_path.split("soprano_data/")[1]
    top_dir = after.split("/")[0]

    if top_dir.startswith("google_tts_"):
        voice = top_dir.replace("google_tts_", "")
        return "google_tts", voice
    elif top_dir == "polly_kajal":
        return "polly", "kajal"
    elif top_dir == "sarvam_data":
        fname = Path(audio_path).stem
        speaker = fname.rsplit("_", 1)[0] if "_" in fname else fname
        return "sarvam", speaker
    elif top_dir == "rasa_hindi":
        return "rasa_hindi", "rasa"
    elif top_dir == "IISc_SYSPIN_Data":
        parts = after.split("/")
        if len(parts) >= 2:
            spk_folder = parts[1]
            return "iisc_syspin", spk_folder
        return "iisc_syspin", "unknown"
    else:
        return top_dir, "unknown"


def build_r2_key(source: str, speaker_id: str, audio_path: str) -> str:
    filename = Path(audio_path).name
    return f"audio/{source}/{speaker_id}/{filename}"


def get_wav_info(path: str) -> tuple[float, int]:
    """Return (duration_seconds, sample_rate) for a wav file."""
    try:
        with wave.open(path, "rb") as w:
            frames = w.getnframes()
            rate = w.getframerate()
            return frames / rate, rate
    except Exception:
        return 0.0, 0


def compute_id(audio_bytes: bytes, text: str) -> str:
    h = hashlib.sha256()
    h.update(audio_bytes)
    h.update(text.encode("utf-8"))
    return h.hexdigest()[:16]


def clean_text(text: str) -> str:
    if text.startswith("Speaker 0: "):
        return text[len("Speaker 0: "):]
    return text


def process_and_upload(entry: dict, s3_client) -> dict | None:
    """Process one entry: read audio, compute metadata, upload to R2."""
    audio_path = entry["audio"]
    raw_text = entry["text"]

    if not os.path.exists(audio_path):
        return None

    source, speaker_id = extract_source_and_speaker(audio_path)
    r2_key = build_r2_key(source, speaker_id, audio_path)
    text = clean_text(raw_text)

    audio_bytes = open(audio_path, "rb").read()
    row_id = compute_id(audio_bytes, text)
    duration_s, sample_rate = get_wav_info(audio_path)

    if duration_s <= 0.1:
        return None

    try:
        s3_client.put_object(
            Bucket=R2_BUCKET,
            Key=r2_key,
            Body=audio_bytes,
            ContentType="audio/wav",
        )
    except Exception as e:
        print(f"  FAILED upload {r2_key}: {e}", flush=True)
        return None

    r2_url = f"r2://{R2_BUCKET}/{r2_key}"

    return {
        "id": row_id,
        "text": text,
        "speaker_id": speaker_id,
        "source": source,
        "language": "hi",
        "duration_s": round(duration_s, 3),
        "r2_url": r2_url,
        "r2_key": r2_key,
        "sample_rate": sample_rate,
        "local_audio_path": audio_path,
    }


def main():
    print("Loading dataset entries...", flush=True)
    entries = []
    with open(COMBINED_JSONL) as f:
        for line in f:
            entries.append(json.loads(line.strip()))
    total = len(entries)
    print(f"Total entries: {total}", flush=True)

    already_done = set()
    if os.path.exists(METADATA_OUT):
        with open(METADATA_OUT) as f:
            for line in f:
                d = json.loads(line.strip())
                already_done.add(d["local_audio_path"])
        print(f"Resuming: {len(already_done)} already processed", flush=True)

    remaining = [e for e in entries if e["audio"] not in already_done]
    print(f"Remaining to process: {len(remaining)}", flush=True)

    if not remaining:
        print("Nothing to do!", flush=True)
        return

    s3_client = get_s3_client()

    uploaded = 0
    failed = 0
    start_time = time.time()

    out_f = open(METADATA_OUT, "a")

    def _worker(entry):
        client = get_s3_client()
        return process_and_upload(entry, client)

    with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
        futures = {executor.submit(_worker, e): e for e in remaining}

        for future in as_completed(futures):
            result = future.result()
            if result:
                out_f.write(json.dumps(result, ensure_ascii=False) + "\n")
                uploaded += 1
            else:
                failed += 1

            done = uploaded + failed
            if done % 500 == 0:
                elapsed = time.time() - start_time
                rate = done / elapsed if elapsed > 0 else 0
                eta = (len(remaining) - done) / rate / 60 if rate > 0 else 0
                print(
                    f"  Progress: {done}/{len(remaining)} "
                    f"({uploaded} ok, {failed} fail) "
                    f"Rate: {rate:.1f}/s  ETA: {eta:.1f}min",
                    flush=True,
                )

    out_f.close()
    elapsed = time.time() - start_time
    print(
        f"\nDone! Uploaded: {uploaded}, Failed: {failed}, "
        f"Time: {elapsed/60:.1f}min",
        flush=True,
    )


if __name__ == "__main__":
    main()
