Upload ASTRA trained model
Browse files- astra_model.py +15 -15
- pytorch_model.bin +1 -1
astra_model.py
CHANGED
@@ -34,26 +34,26 @@ class ASTRAViSoBERTForMaskedLM(XLMRobertaPreTrainedModel):
|
|
34 |
remaps = []
|
35 |
for key in list(state_dict.keys()):
|
36 |
new_key = key
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
new_key = key.replace(prefix + "lm_head.", prefix + "cls_decoder.")
|
51 |
|
52 |
if new_key != key:
|
53 |
remaps.append((key, new_key))
|
54 |
|
55 |
for old_key, new_key in remaps:
|
56 |
-
|
|
|
57 |
|
58 |
self.register_load_state_dict_pre_hook(_remap_pre_hook)
|
59 |
|
|
|
34 |
remaps = []
|
35 |
for key in list(state_dict.keys()):
|
36 |
new_key = key
|
37 |
+
# Replace occurrences anywhere in the key path after this module's prefix
|
38 |
+
if key.startswith(prefix):
|
39 |
+
suffix = key[len(prefix):]
|
40 |
+
suffix = suffix.replace("cls.dense.", "cls_dense.")
|
41 |
+
suffix = suffix.replace("cls.layer_norm.", "cls_layer_norm.")
|
42 |
+
suffix = suffix.replace("cls.decoder.", "cls_decoder.")
|
43 |
+
suffix = suffix.replace("lm_head.dense.", "cls_dense.")
|
44 |
+
suffix = suffix.replace("lm_head.layer_norm.", "cls_layer_norm.")
|
45 |
+
suffix = suffix.replace("lm_head.decoder.", "cls_decoder.")
|
46 |
+
# bare lm_head.* -> cls_decoder.*
|
47 |
+
if suffix.startswith("lm_head."):
|
48 |
+
suffix = suffix.replace("lm_head.", "cls_decoder.", 1)
|
49 |
+
new_key = prefix + suffix
|
|
|
50 |
|
51 |
if new_key != key:
|
52 |
remaps.append((key, new_key))
|
53 |
|
54 |
for old_key, new_key in remaps:
|
55 |
+
if new_key not in state_dict:
|
56 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
57 |
|
58 |
self.register_load_state_dict_pre_hook(_remap_pre_hook)
|
59 |
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 393222899
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec0edd3975d6ba755d4070a9b887ecdb234ffea69406d2be147ea7dd16a764af
|
3 |
size 393222899
|