|
import torch |
|
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PretrainedConfig |
|
from transformers.models.auto import AutoConfig |
|
from peft import PeftModel |
|
from torch import nn |
|
from huggingface_hub import hf_hub_download |
|
|
|
class MatryoshkaWrapper(nn.Module): |
|
def __init__(self, peft_model, dims=[8, 64, 128, 256]): |
|
super().__init__() |
|
self.peft_model = peft_model |
|
self.base_dim = peft_model.config.hidden_size if hasattr(peft_model.config, 'hidden_size') else peft_model.base_model.config.hidden_size |
|
|
|
self.projections = nn.ModuleDict({ |
|
str(dim): self._create_projection(dim) for dim in dims |
|
}) |
|
|
|
try: |
|
self._device = next(self.parameters()).device |
|
except StopIteration: |
|
self._device = torch.device('cpu') |
|
|
|
@property |
|
def device(self): |
|
return self._device |
|
|
|
def _create_projection(self, dim): |
|
if dim == self.base_dim: |
|
return nn.Identity() |
|
elif dim >= 128: |
|
return nn.Linear(self.base_dim, dim) |
|
else: |
|
return nn.Sequential( |
|
nn.Linear(self.base_dim, 128), |
|
nn.ReLU(), |
|
nn.Linear(128, dim) |
|
) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
input_ids = input_ids.to(self.device) |
|
attention_mask = attention_mask.to(self.device) |
|
outputs = self.peft_model(input_ids, attention_mask) |
|
base_emb = outputs.last_hidden_state.mean(dim=1) |
|
return {str(dim): proj(base_emb) for dim, proj in self.projections.items()} |
|
|
|
def to(self, device): |
|
super().to(device) |
|
self._device = device |
|
return self |
|
|
|
def get_embedding(self, text, tokenizer, dim="256"): |
|
self.eval() |
|
inputs = tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
max_length=256, |
|
add_special_tokens=True, |
|
return_token_type_ids=False |
|
).to(self.device) |
|
|
|
with torch.no_grad(): |
|
outputs = self(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) |
|
emb = outputs[str(dim)] |
|
if emb.dim() > 2: |
|
emb = emb.squeeze(0) |
|
return emb |
|
|
|
class MatryoshkaConfig(PretrainedConfig): |
|
model_type = "matryoshka-arabert" |
|
|
|
def __init__(self, |
|
base_model_name="aubmindlab/bert-base-arabertv02", |
|
dims=[8, 64, 128, 256], |
|
**kwargs): |
|
super().__init__(**kwargs) |
|
self.base_model_name = base_model_name |
|
self.dims = dims |
|
|
|
class MatryoshkaBertForSentenceSimilarity(PreTrainedModel): |
|
config_class = MatryoshkaConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
print("Loading base model...") |
|
self.base_model = AutoModel.from_pretrained(config.base_model_name) |
|
|
|
print("Loading PEFT model...") |
|
|
|
repo_path = getattr(config, 'name_or_path', getattr(config, '_name_or_path', 'Abdalrahmankamel/matryoshka-arabert')) |
|
self.peft_model = PeftModel.from_pretrained(self.base_model, repo_path) |
|
|
|
print("Creating wrapper...") |
|
self.matryoshka_wrapper = MatryoshkaWrapper(self.peft_model, dims=config.dims) |
|
|
|
|
|
try: |
|
repo_path = getattr(config, 'name_or_path', getattr(config, '_name_or_path', 'Abdalrahmankamel/matryoshka-arabert')) |
|
wrapper_weights_path = hf_hub_download(repo_id=repo_path, filename="matryoshka_wrapper.pt") |
|
state_dict = torch.load(wrapper_weights_path, map_location='cpu') |
|
self.matryoshka_wrapper.load_state_dict(state_dict, strict=False) |
|
print(f"✓ Loaded matryoshka_wrapper.pt from {repo_path}") |
|
except Exception as e: |
|
print(f"⚠ Could not load matryoshka_wrapper.pt: {e}") |
|
|
|
self.eval() |
|
|
|
def to(self, device): |
|
self.base_model = self.base_model.to(device) |
|
self.peft_model = self.peft_model.to(device) |
|
self.matryoshka_wrapper = self.matryoshka_wrapper.to(device) |
|
return super().to(device) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
return self.matryoshka_wrapper(input_ids, attention_mask) |
|
|
|
def get_embedding(self, text, tokenizer, dim="256"): |
|
return self.matryoshka_wrapper.get_embedding(text, tokenizer, dim) |
|
|
|
def load_matryoshka_model(repo_id="Abdalrahmankamel/matryoshka-arabert"): |
|
""" |
|
Helper function to load the model easily |
|
Usage: |
|
from transformers import AutoTokenizer |
|
model, tokenizer = load_matryoshka_model("Abdalrahmankamel/matryoshka-arabert") |
|
""" |
|
tokenizer = AutoTokenizer.from_pretrained(repo_id) |
|
model = AutoModel.from_pretrained(repo_id, trust_remote_code=True) |
|
return model, tokenizer |