"""Fix checkpoint and launch training in one shot."""
import torch, os, shutil

# Fix checkpoint: remove update key entirely so Trainer takes the finetuning path
src = '/home/ubuntu/sooktam2/model_1250000.pt'
dst = '/home/ubuntu/sooktam2_finetune_ready.pt'

if not os.path.exists(dst):
    print("Loading original checkpoint...")
    ckpt = torch.load(src, map_location='cpu', weights_only=False)
    print(f"Original keys: {list(ckpt.keys())}")
    print(f"Original update: {ckpt.get('update')}")

    # Remove update, step, optimizer, scheduler -- force finetuning path
    for key in ['update', 'step', 'optimizer_state_dict', 'scheduler_state_dict', 'model_state_dict']:
        ckpt.pop(key, None)

    print(f"Cleaned keys: {list(ckpt.keys())}")
    print("Saving...")
    torch.save(ckpt, dst)
    print(f"Saved to {dst}")

    # Verify
    v = torch.load(dst, map_location='cpu', weights_only=False)
    print(f"Verify keys: {list(v.keys())}")
    print(f"Has update: {'update' in v}")
    print(f"Has ema_model_state_dict: {'ema_model_state_dict' in v}")
else:
    print(f"{dst} already exists")

print("DONE")
