import os | |
import torch | |
from safetensors.torch import save_file | |
checkpoint_dir = '../out/cpt-core-pre-4' | |
output_dir = '../out/cpt-core-pre-4' | |
state_dict = torch.load(os.path.join(checkpoint_dir, 'model.pth')) | |
save_file(state_dict, os.path.join(output_dir, 'model.safetensors')) | |