Update modeling_matryoshka.py
Browse files- modeling_matryoshka.py +6 -8
modeling_matryoshka.py
CHANGED
@@ -90,17 +90,20 @@ class MatryoshkaBertForSentenceSimilarity(PreTrainedModel):
|
|
90 |
self.base_model = AutoModel.from_pretrained(config.base_model_name)
|
91 |
|
92 |
print("Loading PEFT model...")
|
93 |
-
|
|
|
|
|
94 |
|
95 |
print("Creating wrapper...")
|
96 |
self.matryoshka_wrapper = MatryoshkaWrapper(self.peft_model, dims=config.dims)
|
97 |
|
98 |
# Load wrapper weights
|
99 |
try:
|
100 |
-
|
|
|
101 |
state_dict = torch.load(wrapper_weights_path, map_location='cpu')
|
102 |
self.matryoshka_wrapper.load_state_dict(state_dict, strict=False)
|
103 |
-
print(f"✓ Loaded matryoshka_wrapper.pt from {
|
104 |
except Exception as e:
|
105 |
print(f"⚠ Could not load matryoshka_wrapper.pt: {e}")
|
106 |
|
@@ -118,11 +121,6 @@ class MatryoshkaBertForSentenceSimilarity(PreTrainedModel):
|
|
118 |
def get_embedding(self, text, tokenizer, dim="256"):
|
119 |
return self.matryoshka_wrapper.get_embedding(text, tokenizer, dim)
|
120 |
|
121 |
-
# Register the model and config
|
122 |
-
AutoConfig.register("matryoshka-arabert", MatryoshkaConfig)
|
123 |
-
AutoModel.register(MatryoshkaConfig, MatryoshkaBertForSentenceSimilarity)
|
124 |
-
|
125 |
-
# Add helper function for easy loading
|
126 |
def load_matryoshka_model(repo_id="Abdalrahmankamel/matryoshka-arabert"):
|
127 |
"""
|
128 |
Helper function to load the model easily
|
|
|
90 |
self.base_model = AutoModel.from_pretrained(config.base_model_name)
|
91 |
|
92 |
print("Loading PEFT model...")
|
93 |
+
# Fix the name_or_path attribute access
|
94 |
+
repo_path = getattr(config, 'name_or_path', getattr(config, '_name_or_path', 'Abdalrahmankamel/matryoshka-arabert'))
|
95 |
+
self.peft_model = PeftModel.from_pretrained(self.base_model, repo_path)
|
96 |
|
97 |
print("Creating wrapper...")
|
98 |
self.matryoshka_wrapper = MatryoshkaWrapper(self.peft_model, dims=config.dims)
|
99 |
|
100 |
# Load wrapper weights
|
101 |
try:
|
102 |
+
repo_path = getattr(config, 'name_or_path', getattr(config, '_name_or_path', 'Abdalrahmankamel/matryoshka-arabert'))
|
103 |
+
wrapper_weights_path = hf_hub_download(repo_id=repo_path, filename="matryoshka_wrapper.pt")
|
104 |
state_dict = torch.load(wrapper_weights_path, map_location='cpu')
|
105 |
self.matryoshka_wrapper.load_state_dict(state_dict, strict=False)
|
106 |
+
print(f"✓ Loaded matryoshka_wrapper.pt from {repo_path}")
|
107 |
except Exception as e:
|
108 |
print(f"⚠ Could not load matryoshka_wrapper.pt: {e}")
|
109 |
|
|
|
121 |
def get_embedding(self, text, tokenizer, dim="256"):
|
122 |
return self.matryoshka_wrapper.get_embedding(text, tokenizer, dim)
|
123 |
|
|
|
|
|
|
|
|
|
|
|
124 |
def load_matryoshka_model(repo_id="Abdalrahmankamel/matryoshka-arabert"):
|
125 |
"""
|
126 |
Helper function to load the model easily
|