import json from typing import List, Optional, Dict from transformers import PreTrainedTokenizer import os import json import re import torch default_config = { "custom_digits": "0123456789ABCDEF", "variable_atoms": { "left_operand": "a", # 左操作数变量名 "right_operand": "b" # 右操作数变量名 }, "other_symbols_atoms": { "left_parenthesis": "(", # 左括号 "right_parenthesis": ")", # 右括号 "equals_sign": "=", # 等号,常用于赋值或比较 "nan_symbol": "NaN", # 非数(Not a Number) "inf_symbol": "Inf" # 无穷大(Infinity) }, "operator_symbol_min_len": 1, "operator_symbol_max_len": 3, "basic_operator_symbols": ["+", "-", "*", "/", "%"], "base_symbols": [ "≮⫘↔", "⫏≰", "⪩⨒∯", "⇑⪆", "↹⩛", "≴∭⊉", "⪪⊹⋣", "⋋%⋟", "⊺⇮", "⋰*⋻", "⫖↰⪸", "⪎⋱⫍", "⨗⨭⨅", "⫶⩼⫲", "∃⊬" ], "comparison_ops": ["==", ">", "<", ">=", "<=", "!="], "logical_connectors": ["and", "or"], "definition_symbols": [ ",", ";", "if", "else", "{", "}", "abs" ] } class OpTokenizer(PreTrainedTokenizer): def __init__(self, vocab_file, **kwargs): self.param_config= default_config self.vocab = self.load_vocab(vocab_file) self.ids_to_tokens = {v: k for k, v in self.vocab.items()} super().__init__(**kwargs) # 定义基础符号 self.basic_symbols = list("0123456789()=ABCDEFab") self.special_results = ['NaN', 'Inf'] self.comparison_ops = ["==", ">", "<", ">=", "<=", "!="] self.logical_connectors = ["and", "or"] self.definition_symbols = [",", ";", "if", "else", "{", "}", "abs"] self.token_regex = self.build_token_regex() # 初始化特殊标记 ID self.pad_id = self.vocab['[PAD]'] self.unk_id = self.vocab['[UNK]'] self.sep_id = self.vocab['[SEP]'] self.mask_id = self.vocab['[MASK]'] self.bos_id = self.vocab['[BOS]'] self.eos_id = self.vocab['[EOS]'] self.eod_id = self.vocab['[EOD]'] def load_vocab(self, vocab_file): # 实现你的词表加载逻辑 with open(vocab_file, encoding="utf-8") as f: vocab = json.load(f) return vocab def save_vocabulary(self, save_directory, filename_prefix=""): if filename_prefix is None: filename_prefix = "" if not os.path.exists(save_directory): os.makedirs(save_directory) vocab_file_path = os.path.join(save_directory, filename_prefix + "vocab.json") with open(vocab_file_path, "w", encoding="utf-8") as f: json.dump(self.vocab, f, ensure_ascii=False, indent=4) print(f"Vocabulary saved to {vocab_file_path}") return (vocab_file_path,) # 返回元组而不是列表 def build_token_regex(self): """构建分词正则表达式,逐字符、符号进行匹配""" # 特殊结果的正则表达式(比如 NaN, Inf) special_results = [re.escape(result) for result in self.special_results] # 比较操作符的正则表达式 comparison_ops = [re.escape(op) for op in self.comparison_ops] # 逻辑连接符的正则表达式 logical_connectors = [re.escape(connector) for connector in self.logical_connectors] operator_pattern = r"(?P([+\-*/%]|[\u2200-\u22FF\u2A00-\u2BFF\u2190-\u21FF])+)" variable_pattern = r"(?P[a-b])" digit_pattern = r"(?P[0-9A-F])" special_result_pattern = r"(?P" + "|".join(special_results) + ")" comparison_ops_pattern = r"(?P" + "|".join(comparison_ops) + ")" logical_connectors_pattern = r"(?P" + "|".join(logical_connectors) + ")" if_else_pattern = r"(?Pif|else)" whitespace_pattern = r"(?P\s+)" abs_pattern = r"(?Pabs)" punctuation_patterns = [ r"(?P\()", r"(?P\))", r"(?P{)", r"(?P})", r"(?P;)", r"(?P,)", r"(?P=)" ] # 所有模式结合在一起,注意先后顺序,应该先匹配长的 token_patterns = [ operator_pattern, special_result_pattern, # 特殊符号(如 NaN, Inf) comparison_ops_pattern, # 比较操作符 logical_connectors_pattern, # 逻辑连接符 if_else_pattern, # if 和 else abs_pattern, digit_pattern, variable_pattern, # 小写字母(变量名) whitespace_pattern, # 空格和换行符 ] + punctuation_patterns # 将标点符号的正则表达式添加到列表中 # 使用 | 连接所有模式 combined_pattern = "|".join(token_patterns) # 返回编译后的正则表达式对象 return re.compile(combined_pattern) def tokenize(self, text: str, mode: str = 'text', add_special_tokens: bool = True): if mode == 'definition': return self._tokenize_definition(text, add_special_tokens) elif mode == 'text': return self._tokenize_equation(text, add_special_tokens) elif mode == 'withdef_text': return self._tokenize_withdef_text(text, add_special_tokens) else: raise ValueError(f"Unsupported mode: {self.mode}") def _tokenize_definition(self, text, add_special_tokens): tokens = [] if add_special_tokens: tokens.append('[DEF_START]') for match in self.token_regex.finditer(text): token_type = match.lastgroup token_value = match.group(token_type) if token_type != "WHITESPACE": tokens.append(token_value) if add_special_tokens: tokens.append('[DEF_END]') return tokens def _tokenize_equation(self, text, add_special_tokens): tokens = [] if add_special_tokens: tokens.append('[EQ_START]') self.digit_pattern = f"[{re.escape(self.param_config['custom_digits'])}]" self.number_pattern = f"[-]?{self.digit_pattern}+" self.base_symbols_pattern = f"(?:{'|'.join(map(re.escape, self.param_config['base_symbols']))})" self.base_symbols_number_pattern = f"({self.base_symbols_pattern}{self.number_pattern})" parts = re.split(self.base_symbols_number_pattern, text) final_parts = [] for part in parts: if re.search(self.number_pattern, part): sub_parts = re.split(f"({self.number_pattern})", part) final_parts.extend(sub_parts) else: final_parts.append(part) for part in final_parts: for match in self.token_regex.finditer(part): token_type = match.lastgroup token_value = match.group(token_type) if token_type != "WHITESPACE": tokens.append(token_value) if add_special_tokens: tokens.append('[EQ_END]') return tokens def _tokenize_withdef_text(self, text, add_special_tokens): tokens = [] segments = re.split(r'(\[DEF_START\]|\[DEF_JOIN\]|\[DEF_END\]|\[EQ_START\]|\[EQ_END\])', text) current_mode = None for seg in segments: seg = seg.strip() if not seg: continue if seg in ['[DEF_START]', '[DEF_JOIN]']: if add_special_tokens: tokens.append(seg) current_mode = 'definition' elif seg == '[DEF_END]': if add_special_tokens: tokens.append(seg) current_mode = None elif seg == '[EQ_START]': if add_special_tokens: tokens.append(seg) current_mode = 'text' elif seg == '[EQ_END]': if add_special_tokens: tokens.append(seg) current_mode = None else: if current_mode == 'definition': inner_tokens = self._tokenize_definition(seg, add_special_tokens=False) tokens.extend(inner_tokens) elif current_mode == 'text': inner_tokens = self._tokenize_equation(seg, add_special_tokens=False) tokens.extend(inner_tokens) else: tokens.extend(seg.split()) return tokens def convert_tokens_to_ids(self, tokens): if isinstance(tokens[0], str): return [self.vocab.get(token, self.vocab['[UNK]']) for token in tokens] return tokens def convert_ids_to_tokens(self, ids): reverse_vocab = {v: k for k, v in self.vocab.items()} return [reverse_vocab.get(i, '[UNK]') for i in ids] # def encode(self, text, mode=None, add_special_tokens=None): # tokens = self.tokenize(text, mode=mode, add_special_tokens=add_special_tokens) # return self.convert_tokens_to_ids(tokens) def get_vocab(self): return self.vocab def encode(self, texts, mode=None, add_special_tokens=True, padding=True, truncation=True, max_length=None): all_tokens = self.tokenize(texts, mode=mode, add_special_tokens=add_special_tokens) all_ids = [self.convert_tokens_to_ids(tokens) for tokens in all_tokens] # Padding and truncation logic as before if padding: max_len = max(len(ids) for ids in all_ids) padded_ids = [ids + [self.pad_id] * (max_len - len(ids)) for ids in all_ids] else: padded_ids = all_ids if truncation and max_length: padded_ids = [ids[:max_length] for ids in padded_ids] input_ids_tensor = torch.tensor(padded_ids) return input_ids_tensor def decode(self, ids, skip_special_tokens=False): tokens = self.convert_ids_to_tokens(ids) if skip_special_tokens: tokens = [t for t in tokens if not (t.startswith('[') and t.endswith(']'))] return " ".join(tokens).replace(" ##", "") def __call__(self, texts, return_tensors=None, **kwargs): if isinstance(texts, str): texts = [texts] input_ids = self.encode(texts, **kwargs) if return_tensors == "pt": return {"input_ids": input_ids} return {"input_ids": input_ids.tolist()}