|
import os |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import BertModel, PreTrainedModel, BertConfig, PretrainedConfig, XLMRobertaTokenizerFast, \ |
|
AutoModel, PreTrainedTokenizerFast, BertTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, AutoTokenizer, XLMRobertaTokenizer |
|
from typing import * |
|
|
|
|
|
class ConcatModelConfig(PretrainedConfig): |
|
model_type = "jina-v3-arctic-s" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
|
|
class ConcatModel(PreTrainedModel): |
|
config_class = ConcatModelConfig |
|
|
|
def __init__(self, models): |
|
super().__init__(ConcatModelConfig()) |
|
self.models = models |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
token_type_ids: torch.Tensor = None, |
|
**kwargs |
|
) -> torch.Tensor: |
|
embeddings = [] |
|
for i, model in enumerate(self.models): |
|
if i == 0: |
|
model_output = model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
) |
|
else: |
|
model_output = model( |
|
input_ids=kwargs["input_ids_" + str(i)], |
|
attention_mask=kwargs["attention_mask_" + str(i)], |
|
token_type_ids=kwargs.get("token_type_ids_" + str(i)), |
|
) |
|
pooled_output = model_output[0][:, 0] |
|
pooled_output = F.normalize(pooled_output, p=2, dim=-1) |
|
embeddings.append(pooled_output) |
|
|
|
return torch.cat(embeddings, dim=-1) |
|
|
|
def save_pretrained(self, save_directory): |
|
for i, model in enumerate(self.models): |
|
path = os.path.join(save_directory, f"model_{i}") |
|
model.save_pretrained(path) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
|
models = [ |
|
AutoModel.from_pretrained(f"{pretrained_model_name_or_path}/model_0", trust_remote_code=True), |
|
AutoModel.from_pretrained(f"{pretrained_model_name_or_path}/model_1", trust_remote_code=True) |
|
] |
|
return cls(models) |
|
|
|
def __repr__(self): |
|
s = "ConcatModel with models:" |
|
for i, model in enumerate(self.models): |
|
s += f"\nModel {i}: {model}" |
|
return s |
|
|
|
def eval(self): |
|
for model in self.models: |
|
model.eval() |
|
return self |
|
|
|
def cuda(self): |
|
for i, model in enumerate(self.models): |
|
self.models[i] = model.cuda() |
|
return self |
|
|
|
class ConcatTokenizer(PreTrainedTokenizer): |
|
""" |
|
A custom tokenizer to handle multiple tokenizers for concatenated models. |
|
This tokenizer will delegate tokenization to the underlying individual tokenizers. |
|
""" |
|
|
|
def __init__(self, tokenizers, **kwargs): |
|
self.tokenizers = tokenizers |
|
|
|
def tokenize(self, text: str, **kwargs): |
|
""" |
|
Tokenizes text using all tokenizers. |
|
""" |
|
return [tokenizer.tokenize(text, **kwargs) for tokenizer in self.tokenizers] |
|
|
|
def __call__(self, text, **kwargs): |
|
""" |
|
Tokenize and encode input text using all tokenizers. |
|
Returns combined inputs. |
|
""" |
|
combined_inputs = {} |
|
for i, tokenizer in enumerate(self.tokenizers): |
|
encoded = tokenizer(text, **kwargs) |
|
|
|
for key, value in encoded.items(): |
|
_key = key |
|
if i > 0: |
|
_key = f"{key}_{i}" |
|
combined_inputs[_key] = value |
|
|
|
return combined_inputs |
|
|
|
def batch_encode_plus(self, batch_text_or_text_pairs, **kwargs): |
|
""" |
|
Handles batch tokenization for all tokenizers. |
|
""" |
|
combined_inputs = {} |
|
for i, tokenizer in enumerate(self.tokenizers): |
|
encoded_batch = tokenizer.batch_encode_plus(batch_text_or_text_pairs, **kwargs) |
|
for key, value in encoded_batch.items(): |
|
_key = key |
|
if i > 0: |
|
_key = f"{key}_{i}" |
|
combined_inputs[_key] = value |
|
|
|
return combined_inputs |
|
|
|
def decode(self, token_ids, **kwargs): |
|
""" |
|
Decode tokens using the first tokenizer (or specific one, if required). |
|
""" |
|
|
|
return self.tokenizers[0].decode(token_ids, **kwargs) |
|
|
|
def save_pretrained(self, save_directory): |
|
""" |
|
Save the tokenizers to the specified directory. |
|
""" |
|
for i, tokenizer in enumerate(self.tokenizers): |
|
path = os.path.join(save_directory, f"model_{i}") |
|
tokenizer.save_pretrained(path) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
|
""" |
|
Load the tokenizers from the specified directory. |
|
""" |
|
tokenizers = [ |
|
XLMRobertaTokenizerFast.from_pretrained(f"{pretrained_model_name_or_path}/model_0"), |
|
BertTokenizer.from_pretrained(f"{pretrained_model_name_or_path}/model_1") |
|
] |
|
return cls(tokenizers) |
|
|
|
def __repr__(self): |
|
s = "ConcatTokenizer with tokenizers:" |
|
for i, tokenizer in enumerate(self.tokenizers): |
|
s += f"\nTokenizer {i}: {tokenizer}" |
|
return s |
|
|