Yuxuan Zhang
commited on
add set_input_embeddings(self, value):
Browse files- modeling_chatglm.py +3 -0
modeling_chatglm.py
CHANGED
|
@@ -769,6 +769,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 769 |
def get_input_embeddings(self):
|
| 770 |
return self.embedding.word_embeddings
|
| 771 |
|
|
|
|
|
|
|
|
|
|
| 772 |
def get_prompt(self, batch_size, device, dtype=torch.half):
|
| 773 |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
| 774 |
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|
|
|
|
| 769 |
def get_input_embeddings(self):
|
| 770 |
return self.embedding.word_embeddings
|
| 771 |
|
| 772 |
+
def set_input_embeddings(self, value):
|
| 773 |
+
self.embedding.word_embeddings = value
|
| 774 |
+
|
| 775 |
def get_prompt(self, batch_size, device, dtype=torch.half):
|
| 776 |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
| 777 |
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|