model.safetensors

#8
by duchuy1 - opened

BẠN ƠI CHO MÌNH XIN FILE model.safetensors ĐƯỢC KHÔNG

from safetensors.torch import load_file, save_file

if ckpt_path.endswith(".safetensors"):
    ckpt = load_file(ckpt_path, device="cpu")
    ckpt = {"ema_model_state_dict": ckpt}
elif ckpt_path.endswith(".pt"):
    ckpt = torch.load(ckpt_path, map_location="cpu")

ema_sd = ckpt.get("ema_model_state_dict", {})
embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight"
old_embed_ema = ema_sd[embed_key_ema]

vocab_old = old_embed_ema.size(0)
embed_dim = old_embed_ema.size(1)
vocab_new = vocab_old + num_new_tokens

def expand_embeddings(old_embeddings):
    new_embeddings = torch.zeros((vocab_new, embed_dim))
    new_embeddings[:vocab_old] = old_embeddings
    new_embeddings[vocab_old:] = torch.randn((num_new_tokens, embed_dim))
    return new_embeddings

ema_sd[embed_key_ema] = expand_embeddings(ema_sd[embed_key_ema])

if new_ckpt_path.endswith(".safetensors"):
    save_file(ema_sd, new_ckpt_path)
elif new_ckpt_path.endswith(".pt"):
    torch.save(ckpt, new_ckpt_path)
Your need to confirm your account before you can post a new comment.

Sign up or log in to comment