#!/usr/bin/env python3
"""
Build a HuggingFace dataset from R2 upload metadata + local audio files.
Includes Audio feature so files are playable/streamable on HF.
Pushes to saichranreddy/internal_indic_h (private).
"""

import json
import os
from datasets import Dataset, Audio, Features, Value

METADATA_JSONL = "/home/ubuntu/modi_processed/dataset_hf_metadata.jsonl"
HF_REPO = "saichranreddy/internal_indic_h"
BATCH_SIZE = 1000


def load_metadata():
    print("Loading metadata...", flush=True)
    rows = []
    seen_ids = set()
    with open(METADATA_JSONL) as f:
        for line in f:
            row = json.loads(line.strip())
            if row["id"] in seen_ids:
                continue
            seen_ids.add(row["id"])
            rows.append(row)
    print(f"Loaded {len(rows)} unique rows (deduped from file)", flush=True)
    return rows


def build_dataset(rows):
    print("Building HF dataset...", flush=True)

    dataset_dict = {
        "id": [],
        "audio": [],
        "text": [],
        "speaker_id": [],
        "source": [],
        "language": [],
        "duration_s": [],
        "r2_url": [],
        "sample_rate": [],
    }

    skipped = 0
    for row in rows:
        audio_path = row["local_audio_path"]
        if not os.path.exists(audio_path):
            skipped += 1
            continue

        dataset_dict["id"].append(row["id"])
        dataset_dict["audio"].append(audio_path)
        dataset_dict["text"].append(row["text"])
        dataset_dict["speaker_id"].append(row["speaker_id"])
        dataset_dict["source"].append(row["source"])
        dataset_dict["language"].append(row["language"])
        dataset_dict["duration_s"].append(row["duration_s"])
        dataset_dict["r2_url"].append(row["r2_url"])
        dataset_dict["sample_rate"].append(row["sample_rate"])

    if skipped:
        print(f"Skipped {skipped} rows (audio file missing)", flush=True)

    print(f"Creating Dataset object with {len(dataset_dict['id'])} rows...", flush=True)

    ds = Dataset.from_dict(dataset_dict)
    ds = ds.cast_column("audio", Audio())

    print(f"Dataset created: {ds}", flush=True)
    print(f"Features: {ds.features}", flush=True)
    return ds


def push_to_hub(ds):
    print(f"\nPushing to {HF_REPO} (private)...", flush=True)
    ds.push_to_hub(
        HF_REPO,
        private=True,
        max_shard_size="2GB",
    )
    print(f"Done! Dataset pushed to https://huggingface.co/datasets/{HF_REPO}", flush=True)


def main():
    rows = load_metadata()

    print(f"\nSample row:", flush=True)
    sample = {k: v for k, v in rows[0].items() if k != "_audio_bytes"}
    print(json.dumps(sample, indent=2, ensure_ascii=False), flush=True)

    ds = build_dataset(rows)

    print(f"\nDataset stats:", flush=True)
    print(f"  Total rows: {len(ds)}", flush=True)
    print(f"  Unique speakers: {len(set(ds['speaker_id']))}", flush=True)
    print(f"  Unique sources: {len(set(ds['source']))}", flush=True)
    total_hours = sum(ds["duration_s"]) / 3600
    print(f"  Total duration: {total_hours:.1f} hours", flush=True)

    push_to_hub(ds)


if __name__ == "__main__":
    main()
