"""
Finetune Sooktam-2 on Modi data using Sooktam-2's own model code + pip Trainer.
Key: import model from sooktam2/src, import Trainer from pip.
"""
import sys, os, time, shutil

# Use Sooktam-2's model code
sys.path.insert(0, '/home/ubuntu/sooktam2/src')
from f5_tts.model import CFM, DiT
from f5_tts.model.utils import get_tokenizer

# Use pip Trainer and dataset (these are generic, not model-specific)
sys.path.insert(0, '/home/ubuntu/.local/lib/python3.10/site-packages')
from f5_tts.model.trainer import Trainer
from f5_tts.model.dataset import load_dataset

import torch

model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
mel_spec_kwargs = dict(n_fft=1024, hop_length=256, win_length=1024, n_mel_channels=100, target_sample_rate=24000, mel_spec_type='vocos')

vocab_char_map, vocab_size = get_tokenizer('/home/ubuntu/sooktam2/vocab.txt', 'custom')
print(f'Vocab: {vocab_size}')

model = CFM(
    transformer=DiT(**model_cfg, text_num_embeds=vocab_size, mel_dim=100),
    mel_spec_kwargs=mel_spec_kwargs,
    vocab_char_map=vocab_char_map,
)

# Load pretrained weights
ckpt = torch.load('/home/ubuntu/sooktam2_finetune_ready.pt', map_location='cpu', weights_only=False)
ema = ckpt['ema_model_state_dict']
clean = {k.replace('ema_model.', ''): v for k, v in ema.items() if k not in ['initted', 'update', 'step']}
missing, unexpected = model.load_state_dict(clean, strict=False)
print(f'Weights loaded: missing={len(missing)}, unexpected={len(unexpected)}')
del ckpt, ema, clean

ckpt_dir = '/home/ubuntu/ckpts_sooktam_modi'
os.makedirs(ckpt_dir, exist_ok=True)

trainer = Trainer(
    model,
    epochs=10,
    learning_rate=1e-5,
    num_warmup_updates=200,
    save_per_updates=500,
    keep_last_n_checkpoints=3,
    checkpoint_path=ckpt_dir,
    batch_size_per_gpu=9600,
    batch_size_type='frame',
    max_samples=4,
    grad_accumulation_steps=8,
    max_grad_norm=1.0,
    logger=None,
    last_per_updates=200,
)

# Manually load weights into trainer's model and EMA
# (since we didn't use pretrained_ path, trainer didn't load anything)
print('Loading dataset...')
train_dataset = load_dataset('modi_hindi', 'custom', mel_spec_kwargs=mel_spec_kwargs)
print(f'Dataset: {len(train_dataset)} samples')

print('Training...')
t0 = time.perf_counter()
trainer.train(train_dataset, resumable_with_seed=42)
elapsed = time.perf_counter() - t0
print(f'Done in {elapsed/60:.1f} min')
