import os, time, shutil
from f5_tts.model import CFM, DiT, Trainer
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer

model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
mel_spec_kwargs = dict(n_fft=1024, hop_length=256, win_length=1024, n_mel_channels=100, target_sample_rate=24000, mel_spec_type="vocos")

vocab_char_map, vocab_size = get_tokenizer("/home/ubuntu/sooktam2/vocab.txt", "custom")
print(f"Vocab size: {vocab_size}")

model = CFM(
    transformer=DiT(**model_cfg, text_num_embeds=vocab_size, mel_dim=100),
    mel_spec_kwargs=mel_spec_kwargs,
    vocab_char_map=vocab_char_map,
)

ckpt_dir = "/home/ubuntu/ckpts/modi_hindi"
os.makedirs(ckpt_dir, exist_ok=True)
pretrain_dst = os.path.join(ckpt_dir, "pretrained_model_1250000.pt")
if not os.path.exists(pretrain_dst):
    shutil.copy2("/home/ubuntu/sooktam2/model_1250000.pt", pretrain_dst)
    print("Copied pretrained checkpoint")

trainer = Trainer(
    model,
    epochs=50,
    learning_rate=1e-5,
    num_warmup_updates=200,
    save_per_updates=500,
    keep_last_n_checkpoints=3,
    checkpoint_path=ckpt_dir,
    batch_size_per_gpu=38400,
    batch_size_type="frame",
    max_samples=16,
    grad_accumulation_steps=4,
    max_grad_norm=1.0,
    logger=None,
    last_per_updates=200,
)

print("Loading dataset...")
train_dataset = load_dataset("modi_hindi", "custom", mel_spec_kwargs=mel_spec_kwargs)
print(f"Dataset loaded: {len(train_dataset)} samples")

print(f"Starting training ({trainer.epochs} epochs)...")
t0 = time.perf_counter()
trainer.train(train_dataset, resumable_with_seed=42)
elapsed = time.perf_counter() - t0
print(f"Training done in {elapsed:.1f}s ({elapsed/60:.1f}min)")
