File size: 4,546 Bytes
840328b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from typing import List, Optional, Union
from transformers import PreTrainedTokenizerFast
from tokenizers.processors import TemplateProcessing
from tokenizers import Tokenizer
from transformers.tokenization_utils_base import BatchEncoding, EncodedInput, PreTokenizedInput, TextInput, TruncationStrategy
from transformers.utils import PaddingStrategy, TensorType
import torch
import numpy as np
def create_tokenizer_custom(file):
with open(file, 'r') as f:
return Tokenizer.from_str(f.read())
class iPLMTokenizer(PreTrainedTokenizerFast):
def __init__(self, parallel=False, **kwargs):
super().__init__(tokenizer_object=create_tokenizer_custom(kwargs.get('tokenizer_file')), **kwargs)
self.add_special_tokens({'pad_token': '<|pad|>'})
self.parallel = parallel
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
n_queries = -1, # -1 for vary-length prompt, int with larger than 0 for fix-length, 0 for no prompt
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
text_pair_target: Optional[
Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_token_type_ids: Optional[bool] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs,
) -> BatchEncoding:
if not isinstance(text, list):
text = [text]
batching = False
else:
batching = True
# add prompt
text_with_prompt = []
for t in text:
prompt_length = 0
assert '|' in t, 'prompt not found'
raw_text = t.split('|')[-1]
if n_queries > 0: # use fix length prompt
prompt_length = n_queries
elif n_queries < 0:
prompt_length = len(raw_text.replace('1', '').replace('2', ''))
text_with_prompt.append('<|bos|>' * prompt_length + raw_text)
batch = super().__call__(
text=text_with_prompt,
text_pair=text_pair,
text_target=text_target,
text_pair_target=text_pair_target,
add_special_tokens=add_special_tokens,
padding=padding,
truncation= truncation,
max_length=max_length,
stride=stride,
is_split_into_words=is_split_into_words,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=None,
return_tensors=return_tensors,
return_token_type_ids=return_token_type_ids,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_length=return_length,
verbose=verbose,
**kwargs
)
# add structure ids
for i in range(len(text)):
if '|' not in text[i]:
continue
structure_ids = text[i].split('|')[0]
if return_tensors is None:
for j in range(len(structure_ids)):
batch['input_ids'][i][j] = ord(structure_ids[j])
else:
batch['input_ids'][i, :len(structure_ids)] = torch.tensor([ord(c) for c in structure_ids])
if "token_type_ids" in batch:
del batch["token_type_ids"]
if batching:
return batch
else:
return {k:v[0] for k, v in batch.items()}
|