Update app.py
Browse files
app.py
CHANGED
@@ -107,9 +107,26 @@ def download_and_merge_model(base_model_name, lora_model_name, output_dir, devic
|
|
107 |
device_map={"": device}
|
108 |
)
|
109 |
|
|
|
|
|
110 |
# 加载tokenizer
|
111 |
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
log(f"正在加载LoRA模型: {lora_model_name}")
|
114 |
log("基础模型配置:" + str(base_model.config))
|
115 |
|
|
|
107 |
device_map={"": device}
|
108 |
)
|
109 |
|
110 |
+
old_vocab_size = base_model.get_input_embeddings().weight.shape[0]
|
111 |
+
print(f"原始词表大小: {old_vocab_size}")
|
112 |
# 加载tokenizer
|
113 |
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
114 |
|
115 |
+
new_vocab_size = tokenizer.vocab_size
|
116 |
+
print(f"调整词表大小: {old_vocab_size} -> {new_vocab_size}")
|
117 |
+
|
118 |
+
# 保存原始权重
|
119 |
+
old_embeddings = base_model.get_input_embeddings().weight.data.clone()
|
120 |
+
old_lm_head = base_model.lm_head.weight.data.clone()
|
121 |
+
|
122 |
+
# 调整词表大小
|
123 |
+
base_model.resize_token_embeddings(new_vocab_size)
|
124 |
+
|
125 |
+
# 复制原始权重到新的张量
|
126 |
+
with torch.no_grad():
|
127 |
+
base_model.get_input_embeddings().weight.data[:new_vocab_size] = old_embeddings[:new_vocab_size]
|
128 |
+
base_model.lm_head.weight.data[:new_vocab_size] = old_lm_head[:new_vocab_size]
|
129 |
+
|
130 |
log(f"正在加载LoRA模型: {lora_model_name}")
|
131 |
log("基础模型配置:" + str(base_model.config))
|
132 |
|