duzx16
commited on
Commit
·
35ca523
1
Parent(s):
0829959
Fix input embeds
Browse files- modeling_chatglm.py +2 -3
modeling_chatglm.py
CHANGED
|
@@ -918,7 +918,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 918 |
elif input_ids is not None:
|
| 919 |
batch_size, seq_length = input_ids.shape[:2]
|
| 920 |
elif inputs_embeds is not None:
|
| 921 |
-
batch_size, seq_length
|
| 922 |
else:
|
| 923 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 924 |
|
|
@@ -972,9 +972,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 972 |
|
| 973 |
if attention_mask is None:
|
| 974 |
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
| 975 |
-
|
| 976 |
else:
|
| 977 |
-
attention_mask = attention_mask.to(
|
| 978 |
|
| 979 |
for i, layer in enumerate(self.layers):
|
| 980 |
|
|
|
|
| 918 |
elif input_ids is not None:
|
| 919 |
batch_size, seq_length = input_ids.shape[:2]
|
| 920 |
elif inputs_embeds is not None:
|
| 921 |
+
batch_size, seq_length = inputs_embeds.shape[:2]
|
| 922 |
else:
|
| 923 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 924 |
|
|
|
|
| 972 |
|
| 973 |
if attention_mask is None:
|
| 974 |
attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
|
|
|
|
| 975 |
else:
|
| 976 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
| 977 |
|
| 978 |
for i, layer in enumerate(self.layers):
|
| 979 |
|