#!/usr/bin/env python3
"""Download a few samples from each language in espnet/yodas2."""

import os, re, io, json, tarfile, requests
import soundfile as sf
import numpy as np

SAMPLES_PER_LANG = 3
OUTPUT_DIR = "/home/ubuntu/yodas2_samples"
HF_BASE = "https://huggingface.co/datasets/espnet/yodas2/resolve/main/data"

lang2shard_cnt = {
    'aa000': 2, 'ab000': 2, 'af000': 2, 'ak000': 2, 'am000': 9,
    'ar000': 154, 'as000': 2, 'ay000': 2, 'az000': 4, 'ba000': 2,
    'be000': 7, 'bg000': 12, 'bh000': 2, 'bi000': 2, 'bm000': 2,
    'bn000': 92, 'bo000': 2, 'br000': 2, 'bs000': 2, 'ca000': 10,
    'co000': 2, 'cr000': 2, 'cs000': 24, 'cy000': 2, 'da000': 6,
    'de000': 369, 'dz000': 2, 'ee000': 2, 'el000': 18, 'en000': 500,
    'eo000': 4, 'es000': 483, 'et000': 2, 'eu000': 4, 'fa000': 12,
    'ff000': 2, 'fi000': 28, 'fj000': 2, 'fo000': 2, 'fr000': 315,
    'fy000': 1, 'ga000': 2, 'gd000': 2, 'gl000': 3, 'gn000': 2,
    'gu000': 8, 'ha000': 4, 'hi000': 182, 'ho000': 2, 'hr000': 5,
    'ht000': 3, 'hu000': 32, 'hy000': 3, 'ia000': 2, 'id000': 493,
    'ie000': 2, 'ig000': 2, 'ik000': 2, 'is000': 2, 'it000': 185,
    'iu000': 2, 'iw000': 21, 'ja000': 211, 'jv000': 2, 'ka000': 4,
    'ki000': 1, 'kk000': 6, 'kl000': 2, 'km000': 10, 'kn000': 7,
    'ko000': 391, 'ks000': 2, 'ku000': 2, 'ky000': 4, 'la000': 2,
    'lb000': 2, 'lg000': 2, 'ln000': 2, 'lo000': 2, 'lt000': 4,
    'lv000': 2, 'mg000': 2, 'mi000': 2, 'mk000': 2, 'ml000': 12,
    'mn000': 2, 'mr000': 18, 'ms000': 8, 'my000': 2, 'na000': 2,
    'nd000': 1, 'ne000': 6, 'nl000': 52, 'no000': 17, 'nv000': 2,
    'oc000': 2, 'om000': 2, 'or000': 3, 'pa000': 5, 'pl000': 140,
    'ps000': 2, 'pt000': 202, 'qu000': 2, 'rm000': 2, 'rn000': 2,
    'ro000': 18, 'ru000': 500, 'rw000': 2, 'sa000': 2, 'sc000': 2,
    'sd000': 2, 'sg000': 1, 'sh000': 1, 'si000': 8, 'sk000': 6,
    'sl000': 4, 'sm000': 2, 'sn000': 2, 'so000': 4, 'sq000': 2,
    'sr000': 4, 'st000': 2, 'su000': 2, 'sv000': 17, 'sw000': 4,
    'ta000': 40, 'te000': 14, 'tg000': 2, 'th000': 113, 'ti000': 2,
    'tk000': 2, 'tn000': 2, 'to000': 2, 'tr000': 155, 'ts000': 1,
    'tt000': 2, 'ug000': 2, 'uk000': 63, 'ur000': 35, 'uz000': 8,
    've000': 2, 'vi000': 465, 'vo000': 2, 'wo000': 2, 'xh000': 2,
    'yi000': 2, 'yo000': 2, 'zh000': 42, 'zu000': 2,
}

seen_langs = set()
subsets = []
for subset in sorted(lang2shard_cnt.keys()):
    lang = re.match(r'^([a-z]+)', subset).group(1)
    shard_cnt = lang2shard_cnt[subset]
    if lang not in seen_langs and shard_cnt > 0:
        seen_langs.add(lang)
        subsets.append((lang, subset))

print(f"Will download {SAMPLES_PER_LANG} samples from {len(subsets)} languages")
os.makedirs(OUTPUT_DIR, exist_ok=True)

session = requests.Session()


def download_text(subset):
    url = f"{HF_BASE}/{subset}/text/00000000.json"
    r = session.get(url, timeout=60)
    r.raise_for_status()
    raw = r.json()
    text_map = {}
    for entry in raw:
        aid = entry["audio_id"]
        text_map[aid] = entry.get("text", {})
    return text_map


def download_duration(subset):
    url = f"{HF_BASE}/{subset}/duration/00000000.txt"
    r = session.get(url, timeout=60)
    r.raise_for_status()
    durations = {}
    for line in r.text.strip().split("\n"):
        parts = line.strip().split()
        if len(parts) >= 2:
            durations[parts[0]] = float(parts[1])
    return durations


def download_audio_tar(subset):
    url = f"{HF_BASE}/{subset}/audio/00000000.tar.gz"
    head = session.head(url, timeout=30, allow_redirects=True)
    size_mb = int(head.headers.get("content-length", 0)) / (1024 * 1024)
    print(f"    Audio shard: {size_mb:.1f}MB", flush=True)

    r = session.get(url, timeout=600)
    r.raise_for_status()

    audio_data = {}
    buf = io.BytesIO(r.content)
    with tarfile.open(fileobj=buf, mode="r:gz") as tar:
        collected = 0
        for member in tar:
            if not member.isfile() or not member.name.endswith(".wav"):
                continue
            vid = os.path.splitext(os.path.basename(member.name))[0]
            f = tar.extractfile(member)
            if f:
                audio_data[vid] = f.read()
                collected += 1
                if collected >= SAMPLES_PER_LANG:
                    break
    return audio_data


def process_language(lang, subset):
    lang_dir = os.path.join(OUTPUT_DIR, lang)
    os.makedirs(lang_dir, exist_ok=True)

    text_map = download_text(subset)
    durations = download_duration(subset)
    audio_data = download_audio_tar(subset)

    samples_meta = []
    for j, (vid, wav_bytes) in enumerate(audio_data.items()):
        if j >= SAMPLES_PER_LANG:
            break

        wav_path = os.path.join(lang_dir, f"sample_{j}.wav")

        try:
            audio_buf = io.BytesIO(wav_bytes)
            data, sr = sf.read(audio_buf)
            max_samp = sr * 30
            if len(data) > max_samp:
                data = data[:max_samp]
            sf.write(wav_path, data, sr)
            saved_dur = round(len(data) / sr, 2)
        except Exception:
            with open(wav_path, "wb") as wf:
                wf.write(wav_bytes)
            sr = 24000
            saved_dur = 0

        utterances = []
        if vid in text_map:
            for utt_id, text in text_map[vid].items():
                parts = utt_id.split("-")
                start = int(parts[-2]) / 100.0 if len(parts) >= 3 else 0
                end = int(parts[-1]) / 100.0 if len(parts) >= 3 else 0
                utterances.append({
                    "utt_id": utt_id,
                    "text": text,
                    "start": start,
                    "end": end,
                })

        meta = {
            "video_id": vid,
            "duration": durations.get(vid, 0),
            "audio_file": f"sample_{j}.wav",
            "sampling_rate": sr,
            "saved_duration_s": saved_dur,
            "num_utterances": len(utterances),
            "utterances": utterances,
        }
        samples_meta.append(meta)

    with open(os.path.join(lang_dir, "metadata.json"), "w") as f:
        json.dump(samples_meta, f, indent=2, ensure_ascii=False)

    return len(samples_meta)


results = {}
failed = []

for i, (lang, subset) in enumerate(subsets):
    print(f"[{i+1}/{len(subsets)}] {lang} ({subset})...", flush=True)
    try:
        count = process_language(lang, subset)
        print(f"    OK: {count} samples")
        results[lang] = {"subset": subset, "samples": count}
    except Exception as e:
        err_msg = str(e)[:200]
        print(f"    FAILED: {err_msg}")
        failed.append({"lang": lang, "subset": subset, "error": err_msg})

with open(os.path.join(OUTPUT_DIR, "summary.json"), "w") as f:
    json.dump({
        "samples_per_lang": SAMPLES_PER_LANG,
        "total_languages_ok": len(results),
        "total_languages_failed": len(failed),
        "languages": results,
        "failed": failed,
    }, f, indent=2, ensure_ascii=False)

print(f"\n{'='*60}")
print(f"DONE: {len(results)} languages OK, {len(failed)} failed")
print(f"Output directory: {OUTPUT_DIR}")
if failed:
    print(f"\nFailed:")
    for f_item in failed:
        print(f"  {f_item['lang']}: {f_item['error'][:80]}")
