import os import torch import torch.nn as nn import torch.nn.functional as F from transformers import BertModel, PreTrainedModel, BertConfig, PretrainedConfig, XLMRobertaTokenizerFast, \ AutoModel, PreTrainedTokenizerFast, BertTokenizer, PreTrainedTokenizer, PreTrainedTokenizerBase, AutoTokenizer, XLMRobertaTokenizer from typing import * class ConcatModelConfig(PretrainedConfig): model_type = "jina-v3-arctic-s" def __init__(self, **kwargs): super().__init__(**kwargs) class ConcatModel(PreTrainedModel): config_class = ConcatModelConfig def __init__(self, models): super().__init__(ConcatModelConfig()) self.models = models def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor = None, **kwargs ) -> torch.Tensor: embeddings = [] for i, model in enumerate(self.models): if i == 0: model_output = model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) else: model_output = model( input_ids=kwargs["input_ids_" + str(i)], attention_mask=kwargs["attention_mask_" + str(i)], token_type_ids=kwargs.get("token_type_ids_" + str(i)), ) pooled_output = model_output[0][:, 0] pooled_output = F.normalize(pooled_output, p=2, dim=-1) embeddings.append(pooled_output) return torch.cat(embeddings, dim=-1) def save_pretrained(self, save_directory): for i, model in enumerate(self.models): path = os.path.join(save_directory, f"model_{i}") model.save_pretrained(path) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): models = [ AutoModel.from_pretrained(f"{pretrained_model_name_or_path}/model_0", trust_remote_code=True), AutoModel.from_pretrained(f"{pretrained_model_name_or_path}/model_1", trust_remote_code=True) ] return cls(models) def __repr__(self): s = "ConcatModel with models:" for i, model in enumerate(self.models): s += f"\nModel {i}: {model}" return s def eval(self): for model in self.models: model.eval() return self def cuda(self): for i, model in enumerate(self.models): self.models[i] = model.cuda() return self class ConcatTokenizer(PreTrainedTokenizer): """ A custom tokenizer to handle multiple tokenizers for concatenated models. This tokenizer will delegate tokenization to the underlying individual tokenizers. """ def __init__(self, tokenizers, **kwargs): self.tokenizers = tokenizers def tokenize(self, text: str, **kwargs): """ Tokenizes text using all tokenizers. """ return [tokenizer.tokenize(text, **kwargs) for tokenizer in self.tokenizers] def __call__(self, text, **kwargs): """ Tokenize and encode input text using all tokenizers. Returns combined inputs. """ combined_inputs = {} for i, tokenizer in enumerate(self.tokenizers): encoded = tokenizer(text, **kwargs) # Prefix the keys to distinguish between tokenizers for key, value in encoded.items(): _key = key if i > 0: _key = f"{key}_{i}" combined_inputs[_key] = value return combined_inputs def batch_encode_plus(self, batch_text_or_text_pairs, **kwargs): """ Handles batch tokenization for all tokenizers. """ combined_inputs = {} for i, tokenizer in enumerate(self.tokenizers): encoded_batch = tokenizer.batch_encode_plus(batch_text_or_text_pairs, **kwargs) for key, value in encoded_batch.items(): _key = key if i > 0: _key = f"{key}_{i}" combined_inputs[_key] = value return combined_inputs def decode(self, token_ids, **kwargs): """ Decode tokens using the first tokenizer (or specific one, if required). """ # Choose the primary tokenizer for decoding (default: model_0) return self.tokenizers[0].decode(token_ids, **kwargs) def save_pretrained(self, save_directory): """ Save the tokenizers to the specified directory. """ for i, tokenizer in enumerate(self.tokenizers): path = os.path.join(save_directory, f"model_{i}") tokenizer.save_pretrained(path) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): """ Load the tokenizers from the specified directory. """ tokenizers = [ XLMRobertaTokenizerFast.from_pretrained(f"{pretrained_model_name_or_path}/model_0"), BertTokenizer.from_pretrained(f"{pretrained_model_name_or_path}/model_1") ] return cls(tokenizers) def __repr__(self): s = "ConcatTokenizer with tokenizers:" for i, tokenizer in enumerate(self.tokenizers): s += f"\nTokenizer {i}: {tokenizer}" return s