"""
Convert modi_processed/ to HuggingFace Dataset for F5-TTS finetuning.
Format: {audio: {array, sampling_rate}, text: str}
F5-TTS requires audio between 0.3s - 30s.
"""

import os
import json
import wave
import numpy as np
from datasets import Dataset, Audio

SEG_DIR = "/home/ubuntu/modi_processed/segments"
TRANS_DIR = "/home/ubuntu/modi_processed/transcripts"
OUTPUT_DIR = "/home/ubuntu/modi_dataset"

samples = []
skipped = 0

for vid in sorted(os.listdir(SEG_DIR)):
    vid_seg = os.path.join(SEG_DIR, vid)
    vid_trans = os.path.join(TRANS_DIR, vid)

    if not os.path.isdir(vid_seg):
        continue

    for f in sorted(os.listdir(vid_seg)):
        if not f.endswith(".wav"):
            continue

        wav_path = os.path.join(vid_seg, f)
        json_path = os.path.join(vid_trans, f.replace(".wav", ".json"))

        if not os.path.exists(json_path):
            skipped += 1
            continue

        try:
            with wave.open(wav_path) as w:
                dur = w.getnframes() / w.getframerate()

            if dur < 0.3 or dur > 30:
                skipped += 1
                continue

            with open(json_path) as fp:
                trans = json.load(fp)

            text = trans.get("transcription", "").strip()
            if not text or len(text) < 5:
                skipped += 1
                continue

            samples.append({
                "audio": wav_path,
                "text": text,
            })
        except Exception as e:
            skipped += 1
            continue

print(f"Valid samples: {len(samples)}")
print(f"Skipped: {skipped}")

ds = Dataset.from_dict({
    "audio": [s["audio"] for s in samples],
    "text": [s["text"] for s in samples],
})
ds = ds.cast_column("audio", Audio(sampling_rate=24000))

ds.save_to_disk(OUTPUT_DIR)
print(f"Dataset saved to {OUTPUT_DIR}")
print(f"Columns: {ds.column_names}")
print(f"Example: {ds[0]['text'][:100]}")
print(f"Audio SR: {ds[0]['audio']['sampling_rate']}")
print(f"Audio len: {len(ds[0]['audio']['array'])}")
