import torch | |
from safetensors.torch import load_file, save_file | |
file_path = "./hunyuan_dit_1.1.safetensors" | |
model_path = ("./mp_rank_00_model_states.pt") | |
loaded = load_file(file_path) | |
sd = torch.load(model_path, map_location=lambda storage, loc: storage) | |
for i in sd["module"]: | |
loaded["model."+str(i)] = sd["module"][i] | |
save_file(loaded, "Freeway_Animation_HunYuan_Demo_comfyui.safetensors") | |
# manual surgery |