op_nodef / op_tokenizer.py
mkj69's picture
Upload op_tokenizer.py with huggingface_hub
8e94991 verified
import json
from typing import List, Optional, Dict
from transformers import PreTrainedTokenizer
import os
import json
import re
import torch
default_config = {
"custom_digits": "0123456789ABCDEF",
"variable_atoms": {
"left_operand": "a", # 左操作数变量名
"right_operand": "b" # 右操作数变量名
},
"other_symbols_atoms": {
"left_parenthesis": "(", # 左括号
"right_parenthesis": ")", # 右括号
"equals_sign": "=", # 等号,常用于赋值或比较
"nan_symbol": "NaN", # 非数(Not a Number)
"inf_symbol": "Inf" # 无穷大(Infinity)
},
"operator_symbol_min_len": 1,
"operator_symbol_max_len": 3,
"basic_operator_symbols": ["+", "-", "*", "/", "%"],
"base_symbols": [
"≮⫘↔",
"⫏≰",
"⪩⨒∯",
"⇑⪆",
"↹⩛",
"≴∭⊉",
"⪪⊹⋣",
"⋋%⋟",
"⊺⇮",
"⋰*⋻",
"⫖↰⪸",
"⪎⋱⫍",
"⨗⨭⨅",
"⫶⩼⫲",
"∃⊬"
],
"comparison_ops": ["==", ">", "<", ">=", "<=", "!="],
"logical_connectors": ["and", "or"],
"definition_symbols": [
",",
";",
"if",
"else",
"{",
"}",
"abs"
]
}
class OpTokenizer(PreTrainedTokenizer):
def __init__(self, vocab_file, **kwargs):
self.param_config= default_config
self.vocab = self.load_vocab(vocab_file)
self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
super().__init__(**kwargs)
# 定义基础符号
self.basic_symbols = list("0123456789()=ABCDEFab")
self.special_results = ['NaN', 'Inf']
self.comparison_ops = ["==", ">", "<", ">=", "<=", "!="]
self.logical_connectors = ["and", "or"]
self.definition_symbols = [",", ";", "if", "else", "{", "}", "abs"]
self.token_regex = self.build_token_regex()
# 初始化特殊标记 ID
self.pad_id = self.vocab['[PAD]']
self.unk_id = self.vocab['[UNK]']
self.sep_id = self.vocab['[SEP]']
self.mask_id = self.vocab['[MASK]']
self.bos_id = self.vocab['[BOS]']
self.eos_id = self.vocab['[EOS]']
self.eod_id = self.vocab['[EOD]']
def load_vocab(self, vocab_file):
# 实现你的词表加载逻辑
with open(vocab_file, encoding="utf-8") as f:
vocab = json.load(f)
return vocab
def save_vocabulary(self, save_directory, filename_prefix=""):
if filename_prefix is None:
filename_prefix = ""
if not os.path.exists(save_directory):
os.makedirs(save_directory)
vocab_file_path = os.path.join(save_directory, filename_prefix + "vocab.json")
with open(vocab_file_path, "w", encoding="utf-8") as f:
json.dump(self.vocab, f, ensure_ascii=False, indent=4)
print(f"Vocabulary saved to {vocab_file_path}")
return (vocab_file_path,) # 返回元组而不是列表
def build_token_regex(self):
"""构建分词正则表达式,逐字符、符号进行匹配"""
# 特殊结果的正则表达式(比如 NaN, Inf)
special_results = [re.escape(result) for result in self.special_results]
# 比较操作符的正则表达式
comparison_ops = [re.escape(op) for op in self.comparison_ops]
# 逻辑连接符的正则表达式
logical_connectors = [re.escape(connector) for connector in self.logical_connectors]
operator_pattern = r"(?P<OPERATOR>([+\-*/%]|[\u2200-\u22FF\u2A00-\u2BFF\u2190-\u21FF])+)"
variable_pattern = r"(?P<VARIABLE>[a-b])"
digit_pattern = r"(?P<DIGIT>[0-9A-F])"
special_result_pattern = r"(?P<SPECIAL_RESULT>" + "|".join(special_results) + ")"
comparison_ops_pattern = r"(?P<COMPARISON_OP>" + "|".join(comparison_ops) + ")"
logical_connectors_pattern = r"(?P<LOGICAL_CONNECTOR>" + "|".join(logical_connectors) + ")"
if_else_pattern = r"(?P<IF_ELSE>if|else)"
whitespace_pattern = r"(?P<WHITESPACE>\s+)"
abs_pattern = r"(?P<ABS>abs)"
punctuation_patterns = [
r"(?P<PARENTHESIS_LEFT>\()",
r"(?P<PARENTHESIS_RIGHT>\))",
r"(?P<CURLY_BRACE_LEFT>{)",
r"(?P<CURLY_BRACE_RIGHT>})",
r"(?P<SEMICOLON>;)",
r"(?P<COMMA>,)",
r"(?P<EQUAL>=)"
]
# 所有模式结合在一起,注意先后顺序,应该先匹配长的
token_patterns = [
operator_pattern,
special_result_pattern, # 特殊符号(如 NaN, Inf)
comparison_ops_pattern, # 比较操作符
logical_connectors_pattern, # 逻辑连接符
if_else_pattern, # if 和 else
abs_pattern,
digit_pattern,
variable_pattern, # 小写字母(变量名)
whitespace_pattern, # 空格和换行符
] + punctuation_patterns # 将标点符号的正则表达式添加到列表中
# 使用 | 连接所有模式
combined_pattern = "|".join(token_patterns)
# 返回编译后的正则表达式对象
return re.compile(combined_pattern)
def tokenize(self, text: str, mode: str = 'text', add_special_tokens: bool = True):
if mode == 'definition':
return self._tokenize_definition(text, add_special_tokens)
elif mode == 'text':
return self._tokenize_equation(text, add_special_tokens)
elif mode == 'withdef_text':
return self._tokenize_withdef_text(text, add_special_tokens)
else:
raise ValueError(f"Unsupported mode: {self.mode}")
def _tokenize_definition(self, text, add_special_tokens):
tokens = []
if add_special_tokens:
tokens.append('[DEF_START]')
for match in self.token_regex.finditer(text):
token_type = match.lastgroup
token_value = match.group(token_type)
if token_type != "WHITESPACE":
tokens.append(token_value)
if add_special_tokens:
tokens.append('[DEF_END]')
return tokens
def _tokenize_equation(self, text, add_special_tokens):
tokens = []
if add_special_tokens:
tokens.append('[EQ_START]')
self.digit_pattern = f"[{re.escape(self.param_config['custom_digits'])}]"
self.number_pattern = f"[-]?{self.digit_pattern}+"
self.base_symbols_pattern = f"(?:{'|'.join(map(re.escape, self.param_config['base_symbols']))})"
self.base_symbols_number_pattern = f"({self.base_symbols_pattern}{self.number_pattern})"
parts = re.split(self.base_symbols_number_pattern, text)
final_parts = []
for part in parts:
if re.search(self.number_pattern, part):
sub_parts = re.split(f"({self.number_pattern})", part)
final_parts.extend(sub_parts)
else:
final_parts.append(part)
for part in final_parts:
for match in self.token_regex.finditer(part):
token_type = match.lastgroup
token_value = match.group(token_type)
if token_type != "WHITESPACE":
tokens.append(token_value)
if add_special_tokens:
tokens.append('[EQ_END]')
return tokens
def _tokenize_withdef_text(self, text, add_special_tokens):
tokens = []
segments = re.split(r'(\[DEF_START\]|\[DEF_JOIN\]|\[DEF_END\]|\[EQ_START\]|\[EQ_END\])', text)
current_mode = None
for seg in segments:
seg = seg.strip()
if not seg:
continue
if seg in ['[DEF_START]', '[DEF_JOIN]']:
if add_special_tokens:
tokens.append(seg)
current_mode = 'definition'
elif seg == '[DEF_END]':
if add_special_tokens:
tokens.append(seg)
current_mode = None
elif seg == '[EQ_START]':
if add_special_tokens:
tokens.append(seg)
current_mode = 'text'
elif seg == '[EQ_END]':
if add_special_tokens:
tokens.append(seg)
current_mode = None
else:
if current_mode == 'definition':
inner_tokens = self._tokenize_definition(seg, add_special_tokens=False)
tokens.extend(inner_tokens)
elif current_mode == 'text':
inner_tokens = self._tokenize_equation(seg, add_special_tokens=False)
tokens.extend(inner_tokens)
else:
tokens.extend(seg.split())
return tokens
def convert_tokens_to_ids(self, tokens):
if isinstance(tokens[0], str):
return [self.vocab.get(token, self.vocab['[UNK]']) for token in tokens]
return tokens
def convert_ids_to_tokens(self, ids):
reverse_vocab = {v: k for k, v in self.vocab.items()}
return [reverse_vocab.get(i, '[UNK]') for i in ids]
# def encode(self, text, mode=None, add_special_tokens=None):
# tokens = self.tokenize(text, mode=mode, add_special_tokens=add_special_tokens)
# return self.convert_tokens_to_ids(tokens)
def get_vocab(self):
return self.vocab
def encode(self, texts, mode=None, add_special_tokens=True, padding=True, truncation=True, max_length=None):
all_tokens = self.tokenize(texts, mode=mode, add_special_tokens=add_special_tokens)
all_ids = [self.convert_tokens_to_ids(tokens) for tokens in all_tokens]
# Padding and truncation logic as before
if padding:
max_len = max(len(ids) for ids in all_ids)
padded_ids = [ids + [self.pad_id] * (max_len - len(ids)) for ids in all_ids]
else:
padded_ids = all_ids
if truncation and max_length:
padded_ids = [ids[:max_length] for ids in padded_ids]
input_ids_tensor = torch.tensor(padded_ids)
return input_ids_tensor
def decode(self, ids, skip_special_tokens=False):
tokens = self.convert_ids_to_tokens(ids)
if skip_special_tokens:
tokens = [t for t in tokens if not (t.startswith('[') and t.endswith(']'))]
return " ".join(tokens).replace(" ##", "")
def __call__(self, texts, return_tensors=None, **kwargs):
if isinstance(texts, str):
texts = [texts]
input_ids = self.encode(texts, **kwargs)
if return_tensors == "pt":
return {"input_ids": input_ids}
return {"input_ids": input_ids.tolist()}