from typing import List, Optional, Union from transformers import PreTrainedTokenizerFast from tokenizers.processors import TemplateProcessing from tokenizers import Tokenizer from transformers.tokenization_utils_base import BatchEncoding, EncodedInput, PreTokenizedInput, TextInput, TruncationStrategy from transformers.utils import PaddingStrategy, TensorType import torch import numpy as np def create_tokenizer_custom(file): with open(file, 'r') as f: return Tokenizer.from_str(f.read()) class iPLMTokenizer(PreTrainedTokenizerFast): def __init__(self, parallel=False, **kwargs): super().__init__(tokenizer_object=create_tokenizer_custom(kwargs.get('tokenizer_file')), **kwargs) self.add_special_tokens({'pad_token': '<|pad|>'}) self.parallel = parallel def __call__( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, n_queries = -1, # -1 for vary-length prompt, int with larger than 0 for fix-length, 0 for no prompt text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, text_pair_target: Optional[ Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] ] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = None, max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, **kwargs, ) -> BatchEncoding: if not isinstance(text, list): text = [text] batching = False else: batching = True # add prompt text_with_prompt = [] for t in text: prompt_length = 0 assert '|' in t, 'prompt not found' raw_text = t.split('|')[-1] if n_queries > 0: # use fix length prompt prompt_length = n_queries elif n_queries < 0: prompt_length = len(raw_text.replace('1', '').replace('2', '')) text_with_prompt.append('<|bos|>' * prompt_length + raw_text) batch = super().__call__( text=text_with_prompt, text_pair=text_pair, text_target=text_target, text_pair_target=text_pair_target, add_special_tokens=add_special_tokens, padding=padding, truncation= truncation, max_length=max_length, stride=stride, is_split_into_words=is_split_into_words, pad_to_multiple_of=pad_to_multiple_of, padding_side=None, return_tensors=return_tensors, return_token_type_ids=return_token_type_ids, return_attention_mask=return_attention_mask, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_offsets_mapping=return_offsets_mapping, return_length=return_length, verbose=verbose, **kwargs ) # add structure ids for i in range(len(text)): if '|' not in text[i]: continue structure_ids = text[i].split('|')[0] if return_tensors is None: for j in range(len(structure_ids)): batch['input_ids'][i][j] = ord(structure_ids[j]) else: batch['input_ids'][i, :len(structure_ids)] = torch.tensor([ord(c) for c in structure_ids]) if "token_type_ids" in batch: del batch["token_type_ids"] if batching: return batch else: return {k:v[0] for k, v in batch.items()}