"""
Bare PyTorch training loop for Sooktam-2.
Uses ONLY Sooktam-2's model code. No pip F5-TTS Trainer.
CLS-tokenized data loaded manually.
"""
import sys, os, time, json, torch, torchaudio
from torch.utils.data import Dataset, DataLoader

sys.path.insert(0, '/home/ubuntu/sooktam2/src')
from f5_tts.model import CFM, DiT
from f5_tts.model.utils import get_tokenizer
from f5_tts.model.modules import MelSpec

# Config
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
vocab_char_map, vocab_size = get_tokenizer('/home/ubuntu/sooktam2/vocab.txt', 'custom')
print(f'Vocab: {vocab_size}')

# Model
mel_spec_kwargs = dict(n_fft=1024, hop_length=256, win_length=1024, n_mel_channels=100, target_sample_rate=24000, mel_spec_type='vocos')
model = CFM(
    transformer=DiT(**model_cfg, text_num_embeds=vocab_size, mel_dim=100),
    mel_spec_kwargs=mel_spec_kwargs, vocab_char_map=vocab_char_map,
)

# Load weights
ckpt = torch.load('/home/ubuntu/sooktam2_finetune_ready.pt', map_location='cpu', weights_only=False)
ema = ckpt['ema_model_state_dict']
clean = {k.replace('ema_model.', ''): v for k, v in ema.items() if k not in ['initted', 'update', 'step']}
missing, unexpected = model.load_state_dict(clean, strict=False)
print(f'Loaded: missing={len(missing)}, unexpected={len(unexpected)}')
del ckpt, ema, clean

model = model.cuda().train()

# Simple dataset
class ModiDataset(Dataset):
    def __init__(self, data_dir):
        from datasets import load_from_disk
        self.ds = load_from_disk(os.path.join(data_dir, 'raw'))
        with open(os.path.join(data_dir, 'duration.json')) as f:
            self.durations = json.load(f)['duration']
        self.mel_spec = MelSpec(**mel_spec_kwargs)
        
    def __len__(self):
        return len(self.ds)
    
    def __getitem__(self, idx):
        row = self.ds[idx]
        audio, sr = torchaudio.load(row['audio_path'])
        if audio.shape[0] > 1:
            audio = audio.mean(0, keepdim=True)
        if sr != 24000:
            audio = torchaudio.transforms.Resample(sr, 24000)(audio)
        mel = self.mel_spec(audio).squeeze(0)  # [d, t]
        return mel, row['text'], self.durations[idx]

dataset = ModiDataset('/home/ubuntu/sooktam2/data/modi_hindi_cls_custom')
print(f'Dataset: {len(dataset)} samples')

# Training
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, total_iters=200)

print('Training...')
for epoch in range(10):
    total_loss = 0
    n = 0
    for i in range(len(dataset)):
        mel, text, dur = dataset[i]
        if dur < 0.3 or dur > 30:
            continue
        mel = mel.unsqueeze(0).cuda()  # [1, d, t] -> need [1, t, d]
        mel = mel.permute(0, 2, 1)     # [1, t, d]
        lens = torch.tensor([mel.shape[1]]).cuda()
        
        optimizer.zero_grad()
        try:
            loss, _, _ = model(mel, text=[text], lens=lens)
        except Exception as e:
            print(f'  Error at sample {i}: {e}')
            continue
        
        if torch.isnan(loss):
            print(f'  NaN at epoch {epoch+1} step {n+1} (sample {i})')
            continue
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        n += 1
        
        if n % 50 == 0:
            print(f'  Epoch {epoch+1} step {n}: loss={total_loss/n:.4f} lr={scheduler.get_last_lr()[0]:.2e}', flush=True)
    
    if n > 0:
        print(f'Epoch {epoch+1}: avg_loss={total_loss/n:.4f}, steps={n}', flush=True)
    
    # Save checkpoint
    torch.save({
        'model_state_dict': model.state_dict(),
        'epoch': epoch,
    }, f'/home/ubuntu/ckpts_bare/epoch_{epoch+1}.pt')
    os.makedirs('/home/ubuntu/ckpts_bare', exist_ok=True)

print('Done')
