matryoshka-arabert / modeling_matryoshka.py
Abdalrahmankamel's picture
Update modeling_matryoshka.py
2cd8049 verified
import torch
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PretrainedConfig
from transformers.models.auto import AutoConfig
from peft import PeftModel
from torch import nn
from huggingface_hub import hf_hub_download
class MatryoshkaWrapper(nn.Module):
def __init__(self, peft_model, dims=[8, 64, 128, 256]):
super().__init__()
self.peft_model = peft_model
self.base_dim = peft_model.config.hidden_size if hasattr(peft_model.config, 'hidden_size') else peft_model.base_model.config.hidden_size
self.projections = nn.ModuleDict({
str(dim): self._create_projection(dim) for dim in dims
})
try:
self._device = next(self.parameters()).device
except StopIteration:
self._device = torch.device('cpu')
@property
def device(self):
return self._device
def _create_projection(self, dim):
if dim == self.base_dim:
return nn.Identity()
elif dim >= 128:
return nn.Linear(self.base_dim, dim)
else:
return nn.Sequential(
nn.Linear(self.base_dim, 128),
nn.ReLU(),
nn.Linear(128, dim)
)
def forward(self, input_ids, attention_mask):
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
outputs = self.peft_model(input_ids, attention_mask)
base_emb = outputs.last_hidden_state.mean(dim=1)
return {str(dim): proj(base_emb) for dim, proj in self.projections.items()}
def to(self, device):
super().to(device)
self._device = device
return self
def get_embedding(self, text, tokenizer, dim="256"):
self.eval()
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=256,
add_special_tokens=True,
return_token_type_ids=False
).to(self.device)
with torch.no_grad():
outputs = self(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
emb = outputs[str(dim)]
if emb.dim() > 2:
emb = emb.squeeze(0)
return emb
class MatryoshkaConfig(PretrainedConfig):
model_type = "matryoshka-arabert"
def __init__(self,
base_model_name="aubmindlab/bert-base-arabertv02",
dims=[8, 64, 128, 256],
**kwargs):
super().__init__(**kwargs)
self.base_model_name = base_model_name
self.dims = dims
class MatryoshkaBertForSentenceSimilarity(PreTrainedModel):
config_class = MatryoshkaConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# Load models step by step to avoid recursion
print("Loading base model...")
self.base_model = AutoModel.from_pretrained(config.base_model_name)
print("Loading PEFT model...")
# Fix the name_or_path attribute access
repo_path = getattr(config, 'name_or_path', getattr(config, '_name_or_path', 'Abdalrahmankamel/matryoshka-arabert'))
self.peft_model = PeftModel.from_pretrained(self.base_model, repo_path)
print("Creating wrapper...")
self.matryoshka_wrapper = MatryoshkaWrapper(self.peft_model, dims=config.dims)
# Load wrapper weights
try:
repo_path = getattr(config, 'name_or_path', getattr(config, '_name_or_path', 'Abdalrahmankamel/matryoshka-arabert'))
wrapper_weights_path = hf_hub_download(repo_id=repo_path, filename="matryoshka_wrapper.pt")
state_dict = torch.load(wrapper_weights_path, map_location='cpu')
self.matryoshka_wrapper.load_state_dict(state_dict, strict=False)
print(f"✓ Loaded matryoshka_wrapper.pt from {repo_path}")
except Exception as e:
print(f"⚠ Could not load matryoshka_wrapper.pt: {e}")
self.eval()
def to(self, device):
self.base_model = self.base_model.to(device)
self.peft_model = self.peft_model.to(device)
self.matryoshka_wrapper = self.matryoshka_wrapper.to(device)
return super().to(device)
def forward(self, input_ids, attention_mask):
return self.matryoshka_wrapper(input_ids, attention_mask)
def get_embedding(self, text, tokenizer, dim="256"):
return self.matryoshka_wrapper.get_embedding(text, tokenizer, dim)
def load_matryoshka_model(repo_id="Abdalrahmankamel/matryoshka-arabert"):
"""
Helper function to load the model easily
Usage:
from transformers import AutoTokenizer
model, tokenizer = load_matryoshka_model("Abdalrahmankamel/matryoshka-arabert")
"""
tokenizer = AutoTokenizer.from_pretrained(repo_id)
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True)
return model, tokenizer