"""Prepare Modi dataset for VibeVoice finetuning."""
import os, json, wave
from datasets import Dataset, Audio

SEG_DIR = '/home/ubuntu/modi_processed/segments'
TRANS_DIR = '/home/ubuntu/modi_processed/transcripts'

audio_paths, texts = [], []
for vid in sorted(os.listdir(SEG_DIR)):
    vid_seg = os.path.join(SEG_DIR, vid)
    vid_trans = os.path.join(TRANS_DIR, vid)
    if not os.path.isdir(vid_seg): continue
    for f in sorted(os.listdir(vid_seg)):
        if not f.endswith('.wav'): continue
        wav_path = os.path.join(vid_seg, f)
        json_path = os.path.join(vid_trans, f.replace('.wav', '.json'))
        if not os.path.exists(json_path): continue
        try:
            with wave.open(wav_path) as w:
                dur = w.getnframes() / w.getframerate()
            if dur < 1.0 or dur > 30: continue
            with open(json_path) as fp:
                raw = json.load(fp).get('transcription', '').strip()
            if not raw or len(raw) < 10: continue
            audio_paths.append(wav_path)
            texts.append(f"Speaker 1: {raw}")
        except: continue

print(f'Samples: {len(audio_paths)}')

ds = Dataset.from_dict({'audio': audio_paths, 'text': texts})
ds = ds.cast_column('audio', Audio(sampling_rate=24000))
ds.save_to_disk('/home/ubuntu/modi_vibevoice_dataset')
print(f'Saved to /home/ubuntu/modi_vibevoice_dataset')
print(f'Sample text: {ds[0]["text"][:80]}')
print(f'Sample audio SR: {ds[0]["audio"]["sampling_rate"]}')
