Yuxuan Zhang
commited on
support transformers>=4.37.2 for finetuning
Browse files- modeling_chatglm.py +5 -4
modeling_chatglm.py
CHANGED
|
@@ -634,7 +634,8 @@ class GLMTransformer(torch.nn.Module):
|
|
| 634 |
attention_mask,
|
| 635 |
rotary_pos_emb,
|
| 636 |
kv_caches[index],
|
| 637 |
-
use_cache
|
|
|
|
| 638 |
)
|
| 639 |
else:
|
| 640 |
layer_ret = layer(
|
|
@@ -697,9 +698,9 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
| 697 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
| 698 |
return position_ids
|
| 699 |
|
| 700 |
-
def
|
| 701 |
-
if
|
| 702 |
-
|
| 703 |
|
| 704 |
|
| 705 |
class Embedding(torch.nn.Module):
|
|
|
|
| 634 |
attention_mask,
|
| 635 |
rotary_pos_emb,
|
| 636 |
kv_caches[index],
|
| 637 |
+
use_cache,
|
| 638 |
+
use_reentrant=False
|
| 639 |
)
|
| 640 |
else:
|
| 641 |
layer_ret = layer(
|
|
|
|
| 698 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
| 699 |
return position_ids
|
| 700 |
|
| 701 |
+
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
| 702 |
+
if not self.supports_gradient_checkpointing:
|
| 703 |
+
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
| 704 |
|
| 705 |
|
| 706 |
class Embedding(torch.nn.Module):
|