jina-v3-arctic-s / modeling_jina_v3_arctic_s.py
michaeldinzinger's picture
Upload new script
6c40ccb
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, PreTrainedModel, BertConfig, PretrainedConfig, XLMRobertaTokenizerFast, \
AutoModel, PreTrainedTokenizerFast, BertTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, AutoTokenizer, XLMRobertaTokenizer
from typing import *
class ConcatModelConfig(PretrainedConfig):
model_type = "jina-v3-arctic-s"
def __init__(self, **kwargs):
super().__init__(**kwargs)
class ConcatModel(PreTrainedModel):
config_class = ConcatModelConfig
def __init__(self, models):
super().__init__(ConcatModelConfig())
self.models = models
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor = None,
**kwargs
) -> torch.Tensor:
embeddings = []
for i, model in enumerate(self.models):
if i == 0:
model_output = model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
)
else:
model_output = model(
input_ids=kwargs["input_ids_" + str(i)],
attention_mask=kwargs["attention_mask_" + str(i)],
token_type_ids=kwargs.get("token_type_ids_" + str(i)),
)
pooled_output = model_output[0][:, 0]
pooled_output = F.normalize(pooled_output, p=2, dim=-1)
embeddings.append(pooled_output)
return torch.cat(embeddings, dim=-1)
def save_pretrained(self, save_directory):
for i, model in enumerate(self.models):
path = os.path.join(save_directory, f"model_{i}")
model.save_pretrained(path)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
models = [
AutoModel.from_pretrained(f"{pretrained_model_name_or_path}/model_0", trust_remote_code=True),
AutoModel.from_pretrained(f"{pretrained_model_name_or_path}/model_1", trust_remote_code=True)
]
return cls(models)
def __repr__(self):
s = "ConcatModel with models:"
for i, model in enumerate(self.models):
s += f"\nModel {i}: {model}"
return s
def eval(self):
for model in self.models:
model.eval()
return self
def cuda(self):
for i, model in enumerate(self.models):
self.models[i] = model.cuda()
return self
class ConcatTokenizer(PreTrainedTokenizer):
"""
A custom tokenizer to handle multiple tokenizers for concatenated models.
This tokenizer will delegate tokenization to the underlying individual tokenizers.
"""
def __init__(self, tokenizers, **kwargs):
self.tokenizers = tokenizers
def tokenize(self, text: str, **kwargs):
"""
Tokenizes text using all tokenizers.
"""
return [tokenizer.tokenize(text, **kwargs) for tokenizer in self.tokenizers]
def __call__(self, text, **kwargs):
"""
Tokenize and encode input text using all tokenizers.
Returns combined inputs.
"""
combined_inputs = {}
for i, tokenizer in enumerate(self.tokenizers):
encoded = tokenizer(text, **kwargs)
# Prefix the keys to distinguish between tokenizers
for key, value in encoded.items():
_key = key
if i > 0:
_key = f"{key}_{i}"
combined_inputs[_key] = value
return combined_inputs
def batch_encode_plus(self, batch_text_or_text_pairs, **kwargs):
"""
Handles batch tokenization for all tokenizers.
"""
combined_inputs = {}
for i, tokenizer in enumerate(self.tokenizers):
encoded_batch = tokenizer.batch_encode_plus(batch_text_or_text_pairs, **kwargs)
for key, value in encoded_batch.items():
_key = key
if i > 0:
_key = f"{key}_{i}"
combined_inputs[_key] = value
return combined_inputs
def decode(self, token_ids, **kwargs):
"""
Decode tokens using the first tokenizer (or specific one, if required).
"""
# Choose the primary tokenizer for decoding (default: model_0)
return self.tokenizers[0].decode(token_ids, **kwargs)
def save_pretrained(self, save_directory):
"""
Save the tokenizers to the specified directory.
"""
for i, tokenizer in enumerate(self.tokenizers):
path = os.path.join(save_directory, f"model_{i}")
tokenizer.save_pretrained(path)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""
Load the tokenizers from the specified directory.
"""
tokenizers = [
XLMRobertaTokenizerFast.from_pretrained(f"{pretrained_model_name_or_path}/model_0"),
BertTokenizer.from_pretrained(f"{pretrained_model_name_or_path}/model_1")
]
return cls(tokenizers)
def __repr__(self):
s = "ConcatTokenizer with tokenizers:"
for i, tokenizer in enumerate(self.tokenizers):
s += f"\nTokenizer {i}: {tokenizer}"
return s