Abdalrahmankamel commited on
Commit
2cd8049
·
verified ·
1 Parent(s): 02bdcdf

Update modeling_matryoshka.py

Browse files
Files changed (1) hide show
  1. modeling_matryoshka.py +6 -8
modeling_matryoshka.py CHANGED
@@ -90,17 +90,20 @@ class MatryoshkaBertForSentenceSimilarity(PreTrainedModel):
90
  self.base_model = AutoModel.from_pretrained(config.base_model_name)
91
 
92
  print("Loading PEFT model...")
93
- self.peft_model = PeftModel.from_pretrained(self.base_model, config.name_or_path)
 
 
94
 
95
  print("Creating wrapper...")
96
  self.matryoshka_wrapper = MatryoshkaWrapper(self.peft_model, dims=config.dims)
97
 
98
  # Load wrapper weights
99
  try:
100
- wrapper_weights_path = hf_hub_download(repo_id=config.name_or_path, filename="matryoshka_wrapper.pt")
 
101
  state_dict = torch.load(wrapper_weights_path, map_location='cpu')
102
  self.matryoshka_wrapper.load_state_dict(state_dict, strict=False)
103
- print(f"✓ Loaded matryoshka_wrapper.pt from {config.name_or_path}")
104
  except Exception as e:
105
  print(f"⚠ Could not load matryoshka_wrapper.pt: {e}")
106
 
@@ -118,11 +121,6 @@ class MatryoshkaBertForSentenceSimilarity(PreTrainedModel):
118
  def get_embedding(self, text, tokenizer, dim="256"):
119
  return self.matryoshka_wrapper.get_embedding(text, tokenizer, dim)
120
 
121
- # Register the model and config
122
- AutoConfig.register("matryoshka-arabert", MatryoshkaConfig)
123
- AutoModel.register(MatryoshkaConfig, MatryoshkaBertForSentenceSimilarity)
124
-
125
- # Add helper function for easy loading
126
  def load_matryoshka_model(repo_id="Abdalrahmankamel/matryoshka-arabert"):
127
  """
128
  Helper function to load the model easily
 
90
  self.base_model = AutoModel.from_pretrained(config.base_model_name)
91
 
92
  print("Loading PEFT model...")
93
+ # Fix the name_or_path attribute access
94
+ repo_path = getattr(config, 'name_or_path', getattr(config, '_name_or_path', 'Abdalrahmankamel/matryoshka-arabert'))
95
+ self.peft_model = PeftModel.from_pretrained(self.base_model, repo_path)
96
 
97
  print("Creating wrapper...")
98
  self.matryoshka_wrapper = MatryoshkaWrapper(self.peft_model, dims=config.dims)
99
 
100
  # Load wrapper weights
101
  try:
102
+ repo_path = getattr(config, 'name_or_path', getattr(config, '_name_or_path', 'Abdalrahmankamel/matryoshka-arabert'))
103
+ wrapper_weights_path = hf_hub_download(repo_id=repo_path, filename="matryoshka_wrapper.pt")
104
  state_dict = torch.load(wrapper_weights_path, map_location='cpu')
105
  self.matryoshka_wrapper.load_state_dict(state_dict, strict=False)
106
+ print(f"✓ Loaded matryoshka_wrapper.pt from {repo_path}")
107
  except Exception as e:
108
  print(f"⚠ Could not load matryoshka_wrapper.pt: {e}")
109
 
 
121
  def get_embedding(self, text, tokenizer, dim="256"):
122
  return self.matryoshka_wrapper.get_embedding(text, tokenizer, dim)
123
 
 
 
 
 
 
124
  def load_matryoshka_model(repo_id="Abdalrahmankamel/matryoshka-arabert"):
125
  """
126
  Helper function to load the model easily