"""Pre-tokenize Modi dataset with Sooktam-2's CLS tokenizer."""
import sys, os, json, wave
sys.path.insert(0, '/home/ubuntu/sooktam2/src')
from f5_tts.infer.cls_tokenizer_v2 import cls_tokenize_text
from datasets import Dataset

SEG_DIR = '/home/ubuntu/modi_processed/segments'
TRANS_DIR = '/home/ubuntu/modi_processed/transcripts'

audio_paths, texts, durations = [], [], []
skipped = 0

for vid in sorted(os.listdir(SEG_DIR)):
    vid_seg = os.path.join(SEG_DIR, vid)
    vid_trans = os.path.join(TRANS_DIR, vid)
    if not os.path.isdir(vid_seg):
        continue
    for f in sorted(os.listdir(vid_seg)):
        if not f.endswith('.wav'):
            continue
        wav_path = os.path.join(vid_seg, f)
        json_path = os.path.join(vid_trans, f.replace('.wav', '.json'))
        if not os.path.exists(json_path):
            skipped += 1
            continue
        try:
            with wave.open(wav_path) as w:
                dur = w.getnframes() / w.getframerate()
            if dur < 0.3 or dur > 30:
                skipped += 1
                continue
            with open(json_path) as fp:
                raw_text = json.load(fp).get('transcription', '').strip()
            if not raw_text or len(raw_text) < 5:
                skipped += 1
                continue

            # CLS tokenize and join with spaces (this is what get_tokenizer("custom") expects)
            cls_tokens = cls_tokenize_text(raw_text, 'hindi')
            # The text field should be the CLS token string joined, so get_tokenizer maps each token to an ID
            # F5-TTS custom tokenizer splits by character, but CLS tokens are multi-char
            # We need to store them as a string where each "character" is a CLS token
            # Actually, looking at how get_tokenizer works with custom vocab:
            # It reads vocab.txt line by line, each line = one token
            # Then in the dataset, text is split into characters and each char is looked up
            # But CLS tokens are multi-char (like "ee", "sh", "aa")
            # The solution: join tokens with a separator that the tokenizer handles
            tokenized_text = ' '.join(cls_tokens)

            audio_paths.append(wav_path)
            texts.append(tokenized_text)
            durations.append(dur)
        except Exception as e:
            skipped += 1
            continue

    if len(audio_paths) % 500 == 0 and len(audio_paths) > 0:
        print(f'Processed {len(audio_paths)} samples...')

print(f'Total: {len(audio_paths)} samples, skipped: {skipped}')
print(f'Sample text: {texts[0][:100]}')

ds = Dataset.from_dict({'audio_path': audio_paths, 'text': texts, 'duration': durations})

from importlib.resources import files
data_dir = str(files('f5_tts').joinpath('../../data/modi_hindi_cls_custom'))
os.makedirs(data_dir, exist_ok=True)
ds.save_to_disk(os.path.join(data_dir, 'raw'))
with open(os.path.join(data_dir, 'duration.json'), 'w') as f:
    json.dump({'duration': durations}, f)
print(f'Saved to {data_dir}')
