compatibility with new transformers (#60)
Browse files- compatibility with new transformers (9e21dac48837929ca4df28e3dcb6ae04c184573d)
- Update modeling_chatglm.py (c59cdd3bafd43d5f9b3e82c88c088eccb3925e02)
Co-authored-by: Ekaterina Aidova <[email protected]>
- modeling_chatglm.py +17 -3
modeling_chatglm.py
CHANGED
|
@@ -14,6 +14,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
|
|
| 14 |
from torch.nn.utils import skip_init
|
| 15 |
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
| 16 |
from copy import deepcopy
|
|
|
|
| 17 |
|
| 18 |
from transformers.modeling_outputs import (
|
| 19 |
BaseModelOutputWithPast,
|
|
@@ -45,6 +46,9 @@ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
| 45 |
# See all ChatGLM models at https://huggingface.co/models?filter=chatglm
|
| 46 |
]
|
| 47 |
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
def default_init(cls, *args, **kwargs):
|
| 50 |
return cls(*args, **kwargs)
|
|
@@ -872,9 +876,19 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 872 |
standardize_cache_format: bool = False,
|
| 873 |
) -> Dict[str, Any]:
|
| 874 |
# update past_key_values
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 878 |
|
| 879 |
# update attention mask
|
| 880 |
if "attention_mask" in model_kwargs:
|
|
|
|
| 14 |
from torch.nn.utils import skip_init
|
| 15 |
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
| 16 |
from copy import deepcopy
|
| 17 |
+
import transformers
|
| 18 |
|
| 19 |
from transformers.modeling_outputs import (
|
| 20 |
BaseModelOutputWithPast,
|
|
|
|
| 46 |
# See all ChatGLM models at https://huggingface.co/models?filter=chatglm
|
| 47 |
]
|
| 48 |
|
| 49 |
+
is_transformers_4_42_or_higher = int(transformers.__version__.split(".")[1]) >= 42
|
| 50 |
+
is_transformers_4_44_or_higher = int(transformers.__version__.split(".")[1]) >= 44
|
| 51 |
+
|
| 52 |
|
| 53 |
def default_init(cls, *args, **kwargs):
|
| 54 |
return cls(*args, **kwargs)
|
|
|
|
| 876 |
standardize_cache_format: bool = False,
|
| 877 |
) -> Dict[str, Any]:
|
| 878 |
# update past_key_values
|
| 879 |
+
if is_transformers_4_44_or_higher:
|
| 880 |
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
| 881 |
+
outputs
|
| 882 |
+
)[1]
|
| 883 |
+
elif is_transformers_4_42_or_higher:
|
| 884 |
+
# update past_key_values
|
| 885 |
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
| 886 |
+
outputs, standardize_cache_format=standardize_cache_format
|
| 887 |
+
)[1]
|
| 888 |
+
else:
|
| 889 |
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
| 890 |
+
outputs, standardize_cache_format=standardize_cache_format
|
| 891 |
+
)
|
| 892 |
|
| 893 |
# update attention mask
|
| 894 |
if "attention_mask" in model_kwargs:
|