hoalq commited on
Commit
350bd07
·
1 Parent(s): b942d77

Upload ASTRA trained model

Browse files
Files changed (2) hide show
  1. astra_model.py +15 -15
  2. pytorch_model.bin +1 -1
astra_model.py CHANGED
@@ -34,26 +34,26 @@ class ASTRAViSoBERTForMaskedLM(XLMRobertaPreTrainedModel):
34
  remaps = []
35
  for key in list(state_dict.keys()):
36
  new_key = key
37
- if key.startswith(prefix + "cls.dense."):
38
- new_key = key.replace(prefix + "cls.dense.", prefix + "cls_dense.")
39
- elif key.startswith(prefix + "cls.layer_norm."):
40
- new_key = key.replace(prefix + "cls.layer_norm.", prefix + "cls_layer_norm.")
41
- elif key.startswith(prefix + "cls.decoder."):
42
- new_key = key.replace(prefix + "cls.decoder.", prefix + "cls_decoder.")
43
- elif key.startswith(prefix + "lm_head.dense."):
44
- new_key = key.replace(prefix + "lm_head.dense.", prefix + "cls_dense.")
45
- elif key.startswith(prefix + "lm_head.layer_norm."):
46
- new_key = key.replace(prefix + "lm_head.layer_norm.", prefix + "cls_layer_norm.")
47
- elif key.startswith(prefix + "lm_head.decoder."):
48
- new_key = key.replace(prefix + "lm_head.decoder.", prefix + "cls_decoder.")
49
- elif key.startswith(prefix + "lm_head."):
50
- new_key = key.replace(prefix + "lm_head.", prefix + "cls_decoder.")
51
 
52
  if new_key != key:
53
  remaps.append((key, new_key))
54
 
55
  for old_key, new_key in remaps:
56
- state_dict[new_key] = state_dict.pop(old_key)
 
57
 
58
  self.register_load_state_dict_pre_hook(_remap_pre_hook)
59
 
 
34
  remaps = []
35
  for key in list(state_dict.keys()):
36
  new_key = key
37
+ # Replace occurrences anywhere in the key path after this module's prefix
38
+ if key.startswith(prefix):
39
+ suffix = key[len(prefix):]
40
+ suffix = suffix.replace("cls.dense.", "cls_dense.")
41
+ suffix = suffix.replace("cls.layer_norm.", "cls_layer_norm.")
42
+ suffix = suffix.replace("cls.decoder.", "cls_decoder.")
43
+ suffix = suffix.replace("lm_head.dense.", "cls_dense.")
44
+ suffix = suffix.replace("lm_head.layer_norm.", "cls_layer_norm.")
45
+ suffix = suffix.replace("lm_head.decoder.", "cls_decoder.")
46
+ # bare lm_head.* -> cls_decoder.*
47
+ if suffix.startswith("lm_head."):
48
+ suffix = suffix.replace("lm_head.", "cls_decoder.", 1)
49
+ new_key = prefix + suffix
 
50
 
51
  if new_key != key:
52
  remaps.append((key, new_key))
53
 
54
  for old_key, new_key in remaps:
55
+ if new_key not in state_dict:
56
+ state_dict[new_key] = state_dict.pop(old_key)
57
 
58
  self.register_load_state_dict_pre_hook(_remap_pre_hook)
59
 
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ea42663a4a066b017945bb706b7058f7c80b8097936295df4680aba04763cc3b
3
  size 393222899
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec0edd3975d6ba755d4070a9b887ecdb234ffea69406d2be147ea7dd16a764af
3
  size 393222899