|
import json |
|
from typing import List, Optional, Dict |
|
from transformers import PreTrainedTokenizer |
|
import os |
|
import json |
|
import re |
|
default_config = { |
|
"custom_digits": "0123456789ABCDEF", |
|
"variable_atoms": { |
|
"left_operand": "a", |
|
"right_operand": "b" |
|
}, |
|
|
|
"other_symbols_atoms": { |
|
"left_parenthesis": "(", |
|
"right_parenthesis": ")", |
|
"equals_sign": "=", |
|
"nan_symbol": "NaN", |
|
"inf_symbol": "Inf" |
|
}, |
|
|
|
"operator_symbol_min_len": 1, |
|
"operator_symbol_max_len": 3, |
|
|
|
"basic_operator_symbols": ["+", "-", "*", "/", "%"], |
|
|
|
"base_symbols": [ |
|
"≮⫘↔", |
|
"⫏≰", |
|
"⪩⨒∯", |
|
"⇑⪆", |
|
"↹⩛", |
|
"≴∭⊉", |
|
"⪪⊹⋣", |
|
"⋋%⋟", |
|
"⊺⇮", |
|
"⋰*⋻", |
|
"⫖↰⪸", |
|
"⪎⋱⫍", |
|
"⨗⨭⨅", |
|
"⫶⩼⫲", |
|
"∃⊬" |
|
], |
|
|
|
"comparison_ops": ["==", ">", "<", ">=", "<=", "!="], |
|
|
|
"logical_connectors": ["and", "or"], |
|
|
|
"definition_symbols": [ |
|
",", |
|
";", |
|
"if", |
|
"else", |
|
"{", |
|
"}", |
|
"abs" |
|
] |
|
} |
|
|
|
class OpTokenizer(PreTrainedTokenizer): |
|
def __init__(self, vocab_file, **kwargs): |
|
|
|
self.param_config= default_config |
|
self.vocab = self.load_vocab(vocab_file) |
|
self.ids_to_tokens = {v: k for k, v in self.vocab.items()} |
|
super().__init__(**kwargs) |
|
|
|
self.basic_symbols = list("0123456789()=ABCDEFab") |
|
self.special_results = ['NaN', 'Inf'] |
|
self.comparison_ops = ["==", ">", "<", ">=", "<=", "!="] |
|
self.logical_connectors = ["and", "or"] |
|
self.definition_symbols = [",", ";", "if", "else", "{", "}", "abs"] |
|
|
|
self.token_regex = self.build_token_regex() |
|
|
|
|
|
self.pad_id = self.vocab['[PAD]'] |
|
self.unk_id = self.vocab['[UNK]'] |
|
self.sep_id = self.vocab['[SEP]'] |
|
self.mask_id = self.vocab['[MASK]'] |
|
self.bos_id = self.vocab['[BOS]'] |
|
self.eos_id = self.vocab['[EOS]'] |
|
self.eod_id = self.vocab['[EOD]'] |
|
|
|
def load_vocab(self, vocab_file): |
|
|
|
with open(vocab_file, encoding="utf-8") as f: |
|
vocab = json.load(f) |
|
return vocab |
|
|
|
def save_vocabulary(self, save_directory, filename_prefix=""): |
|
if filename_prefix is None: |
|
filename_prefix = "" |
|
|
|
if not os.path.exists(save_directory): |
|
os.makedirs(save_directory) |
|
|
|
vocab_file_path = os.path.join(save_directory, filename_prefix + "vocab.json") |
|
|
|
with open(vocab_file_path, "w", encoding="utf-8") as f: |
|
json.dump(self.vocab, f, ensure_ascii=False, indent=4) |
|
|
|
print(f"Vocabulary saved to {vocab_file_path}") |
|
|
|
return (vocab_file_path,) |
|
|
|
def build_token_regex(self): |
|
"""构建分词正则表达式,逐字符、符号进行匹配""" |
|
|
|
special_results = [re.escape(result) for result in self.special_results] |
|
|
|
comparison_ops = [re.escape(op) for op in self.comparison_ops] |
|
|
|
logical_connectors = [re.escape(connector) for connector in self.logical_connectors] |
|
|
|
operator_pattern = r"(?P<OPERATOR>([+\-*/%]|[\u2200-\u22FF\u2A00-\u2BFF\u2190-\u21FF])+)" |
|
variable_pattern = r"(?P<VARIABLE>[a-b])" |
|
digit_pattern = r"(?P<DIGIT>[0-9A-F])" |
|
special_result_pattern = r"(?P<SPECIAL_RESULT>" + "|".join(special_results) + ")" |
|
comparison_ops_pattern = r"(?P<COMPARISON_OP>" + "|".join(comparison_ops) + ")" |
|
logical_connectors_pattern = r"(?P<LOGICAL_CONNECTOR>" + "|".join(logical_connectors) + ")" |
|
if_else_pattern = r"(?P<IF_ELSE>if|else)" |
|
whitespace_pattern = r"(?P<WHITESPACE>\s+)" |
|
abs_pattern = r"(?P<ABS>abs)" |
|
punctuation_patterns = [ |
|
r"(?P<PARENTHESIS_LEFT>\()", |
|
r"(?P<PARENTHESIS_RIGHT>\))", |
|
r"(?P<CURLY_BRACE_LEFT>{)", |
|
r"(?P<CURLY_BRACE_RIGHT>})", |
|
r"(?P<SEMICOLON>;)", |
|
r"(?P<COMMA>,)", |
|
r"(?P<EQUAL>=)" |
|
] |
|
|
|
|
|
token_patterns = [ |
|
operator_pattern, |
|
special_result_pattern, |
|
comparison_ops_pattern, |
|
logical_connectors_pattern, |
|
if_else_pattern, |
|
abs_pattern, |
|
digit_pattern, |
|
variable_pattern, |
|
whitespace_pattern, |
|
|
|
] + punctuation_patterns |
|
|
|
|
|
combined_pattern = "|".join(token_patterns) |
|
|
|
|
|
return re.compile(combined_pattern) |
|
|
|
def tokenize(self, text: str, mode: str = 'text', add_special_tokens: bool = True): |
|
if mode == 'definition': |
|
return self._tokenize_definition(text, add_special_tokens) |
|
elif mode == 'text': |
|
return self._tokenize_equation(text, add_special_tokens) |
|
elif mode == 'withdef_text': |
|
return self._tokenize_withdef_text(text, add_special_tokens) |
|
else: |
|
raise ValueError(f"Unsupported mode: {self.mode}") |
|
|
|
def _tokenize_definition(self, text, add_special_tokens): |
|
tokens = [] |
|
if add_special_tokens: |
|
tokens.append('[DEF_START]') |
|
for match in self.token_regex.finditer(text): |
|
token_type = match.lastgroup |
|
token_value = match.group(token_type) |
|
if token_type != "WHITESPACE": |
|
tokens.append(token_value) |
|
if add_special_tokens: |
|
tokens.append('[DEF_END]') |
|
return tokens |
|
|
|
def _tokenize_equation(self, text, add_special_tokens): |
|
tokens = [] |
|
if add_special_tokens: |
|
tokens.append('[EQ_START]') |
|
|
|
self.digit_pattern = f"[{re.escape(self.param_config['custom_digits'])}]" |
|
self.number_pattern = f"[-]?{self.digit_pattern}+" |
|
self.base_symbols_pattern = f"(?:{'|'.join(map(re.escape, self.param_config['base_symbols']))})" |
|
self.base_symbols_number_pattern = f"({self.base_symbols_pattern}{self.number_pattern})" |
|
|
|
parts = re.split(self.base_symbols_number_pattern, text) |
|
final_parts = [] |
|
for part in parts: |
|
if re.search(self.number_pattern, part): |
|
sub_parts = re.split(f"({self.number_pattern})", part) |
|
final_parts.extend(sub_parts) |
|
else: |
|
final_parts.append(part) |
|
|
|
for part in final_parts: |
|
for match in self.token_regex.finditer(part): |
|
token_type = match.lastgroup |
|
token_value = match.group(token_type) |
|
if token_type != "WHITESPACE": |
|
tokens.append(token_value) |
|
|
|
if add_special_tokens: |
|
tokens.append('[EQ_END]') |
|
return tokens |
|
|
|
def _tokenize_withdef_text(self, text, add_special_tokens): |
|
tokens = [] |
|
segments = re.split(r'(\[DEF_START\]|\[DEF_JOIN\]|\[DEF_END\]|\[EQ_START\]|\[EQ_END\])', text) |
|
current_mode = None |
|
|
|
for seg in segments: |
|
seg = seg.strip() |
|
if not seg: |
|
continue |
|
|
|
if seg in ['[DEF_START]', '[DEF_JOIN]']: |
|
if add_special_tokens: |
|
tokens.append(seg) |
|
current_mode = 'definition' |
|
elif seg == '[DEF_END]': |
|
if add_special_tokens: |
|
tokens.append(seg) |
|
current_mode = None |
|
elif seg == '[EQ_START]': |
|
if add_special_tokens: |
|
tokens.append(seg) |
|
current_mode = 'text' |
|
elif seg == '[EQ_END]': |
|
if add_special_tokens: |
|
tokens.append(seg) |
|
current_mode = None |
|
else: |
|
if current_mode == 'definition': |
|
inner_tokens = self._tokenize_definition(seg, add_special_tokens=False) |
|
tokens.extend(inner_tokens) |
|
elif current_mode == 'text': |
|
inner_tokens = self._tokenize_equation(seg, add_special_tokens=False) |
|
tokens.extend(inner_tokens) |
|
else: |
|
tokens.extend(seg.split()) |
|
return tokens |
|
|
|
|
|
def convert_tokens_to_ids(self, tokens): |
|
if isinstance(tokens[0], str): |
|
return [self.vocab.get(token, self.vocab['[UNK]']) for token in tokens] |
|
return tokens |
|
|
|
def convert_ids_to_tokens(self, ids): |
|
reverse_vocab = {v: k for k, v in self.vocab.items()} |
|
return [reverse_vocab.get(i, '[UNK]') for i in ids] |
|
|
|
def encode(self, text, mode=None, add_special_tokens=None): |
|
tokens = self.tokenize(text, mode=mode, add_special_tokens=add_special_tokens) |
|
return self.convert_tokens_to_ids(tokens) |
|
|
|
|
|
def decode(self, ids, skip_special_tokens=False): |
|
tokens = self.convert_ids_to_tokens(ids) |
|
if skip_special_tokens: |
|
tokens = [t for t in tokens if not (t.startswith('[') and t.endswith(']'))] |
|
return " ".join(tokens).replace(" ##", "") |
|
|
|
def get_vocab(self): |
|
return self.vocab |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
tokenizer = OpTokenizer(vocab_file="/map-vepfs/kaijing/Megatron-LM-Op/vocab_100000.json") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf_tokenizer_dir = "/map-vepfs/kaijing/Megatron-LM-Op/op_plus_script/convert/op_hf_tokenizer" |
|
tokenizer.save_pretrained(save_directory = hf_tokenizer_dir) |