mkj69 commited on
Commit
8e94991
·
verified ·
1 Parent(s): 522cec8

Upload op_tokenizer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. op_tokenizer.py +296 -0
op_tokenizer.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import List, Optional, Dict
3
+ from transformers import PreTrainedTokenizer
4
+ import os
5
+ import json
6
+ import re
7
+ import torch
8
+ default_config = {
9
+ "custom_digits": "0123456789ABCDEF",
10
+ "variable_atoms": {
11
+ "left_operand": "a", # 左操作数变量名
12
+ "right_operand": "b" # 右操作数变量名
13
+ },
14
+
15
+ "other_symbols_atoms": {
16
+ "left_parenthesis": "(", # 左括号
17
+ "right_parenthesis": ")", # 右括号
18
+ "equals_sign": "=", # 等号,常用于赋值或比较
19
+ "nan_symbol": "NaN", # 非数(Not a Number)
20
+ "inf_symbol": "Inf" # 无穷大(Infinity)
21
+ },
22
+
23
+ "operator_symbol_min_len": 1,
24
+ "operator_symbol_max_len": 3,
25
+
26
+ "basic_operator_symbols": ["+", "-", "*", "/", "%"],
27
+
28
+ "base_symbols": [
29
+ "≮⫘↔",
30
+ "⫏≰",
31
+ "⪩⨒∯",
32
+ "⇑⪆",
33
+ "↹⩛",
34
+ "≴∭⊉",
35
+ "⪪⊹⋣",
36
+ "⋋%⋟",
37
+ "⊺⇮",
38
+ "⋰*⋻",
39
+ "⫖↰⪸",
40
+ "⪎⋱⫍",
41
+ "⨗⨭⨅",
42
+ "⫶⩼⫲",
43
+ "∃⊬"
44
+ ],
45
+
46
+ "comparison_ops": ["==", ">", "<", ">=", "<=", "!="],
47
+
48
+ "logical_connectors": ["and", "or"],
49
+
50
+ "definition_symbols": [
51
+ ",",
52
+ ";",
53
+ "if",
54
+ "else",
55
+ "{",
56
+ "}",
57
+ "abs"
58
+ ]
59
+ }
60
+
61
+ class OpTokenizer(PreTrainedTokenizer):
62
+ def __init__(self, vocab_file, **kwargs):
63
+
64
+ self.param_config= default_config
65
+ self.vocab = self.load_vocab(vocab_file)
66
+ self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
67
+ super().__init__(**kwargs)
68
+ # 定义基础符号
69
+ self.basic_symbols = list("0123456789()=ABCDEFab")
70
+ self.special_results = ['NaN', 'Inf']
71
+ self.comparison_ops = ["==", ">", "<", ">=", "<=", "!="]
72
+ self.logical_connectors = ["and", "or"]
73
+ self.definition_symbols = [",", ";", "if", "else", "{", "}", "abs"]
74
+
75
+ self.token_regex = self.build_token_regex()
76
+
77
+ # 初始化特殊标记 ID
78
+ self.pad_id = self.vocab['[PAD]']
79
+ self.unk_id = self.vocab['[UNK]']
80
+ self.sep_id = self.vocab['[SEP]']
81
+ self.mask_id = self.vocab['[MASK]']
82
+ self.bos_id = self.vocab['[BOS]']
83
+ self.eos_id = self.vocab['[EOS]']
84
+ self.eod_id = self.vocab['[EOD]']
85
+
86
+ def load_vocab(self, vocab_file):
87
+ # 实现你的词表加载逻辑
88
+ with open(vocab_file, encoding="utf-8") as f:
89
+ vocab = json.load(f)
90
+ return vocab
91
+
92
+ def save_vocabulary(self, save_directory, filename_prefix=""):
93
+ if filename_prefix is None:
94
+ filename_prefix = ""
95
+
96
+ if not os.path.exists(save_directory):
97
+ os.makedirs(save_directory)
98
+
99
+ vocab_file_path = os.path.join(save_directory, filename_prefix + "vocab.json")
100
+
101
+ with open(vocab_file_path, "w", encoding="utf-8") as f:
102
+ json.dump(self.vocab, f, ensure_ascii=False, indent=4)
103
+
104
+ print(f"Vocabulary saved to {vocab_file_path}")
105
+
106
+ return (vocab_file_path,) # 返回元组而不是列表
107
+
108
+ def build_token_regex(self):
109
+ """构建分词正则表达式,逐字符、符号进行匹配"""
110
+ # 特殊结果的正则表达式(比如 NaN, Inf)
111
+ special_results = [re.escape(result) for result in self.special_results]
112
+ # 比较操作符的正则表达式
113
+ comparison_ops = [re.escape(op) for op in self.comparison_ops]
114
+ # 逻辑连接符的正则表达式
115
+ logical_connectors = [re.escape(connector) for connector in self.logical_connectors]
116
+
117
+ operator_pattern = r"(?P<OPERATOR>([+\-*/%]|[\u2200-\u22FF\u2A00-\u2BFF\u2190-\u21FF])+)"
118
+ variable_pattern = r"(?P<VARIABLE>[a-b])"
119
+ digit_pattern = r"(?P<DIGIT>[0-9A-F])"
120
+ special_result_pattern = r"(?P<SPECIAL_RESULT>" + "|".join(special_results) + ")"
121
+ comparison_ops_pattern = r"(?P<COMPARISON_OP>" + "|".join(comparison_ops) + ")"
122
+ logical_connectors_pattern = r"(?P<LOGICAL_CONNECTOR>" + "|".join(logical_connectors) + ")"
123
+ if_else_pattern = r"(?P<IF_ELSE>if|else)"
124
+ whitespace_pattern = r"(?P<WHITESPACE>\s+)"
125
+ abs_pattern = r"(?P<ABS>abs)"
126
+ punctuation_patterns = [
127
+ r"(?P<PARENTHESIS_LEFT>\()",
128
+ r"(?P<PARENTHESIS_RIGHT>\))",
129
+ r"(?P<CURLY_BRACE_LEFT>{)",
130
+ r"(?P<CURLY_BRACE_RIGHT>})",
131
+ r"(?P<SEMICOLON>;)",
132
+ r"(?P<COMMA>,)",
133
+ r"(?P<EQUAL>=)"
134
+ ]
135
+
136
+ # 所有模式结合在一起,注意先后顺序,应该先匹配长的
137
+ token_patterns = [
138
+ operator_pattern,
139
+ special_result_pattern, # 特殊符号(如 NaN, Inf)
140
+ comparison_ops_pattern, # 比较操作符
141
+ logical_connectors_pattern, # 逻辑连接符
142
+ if_else_pattern, # if 和 else
143
+ abs_pattern,
144
+ digit_pattern,
145
+ variable_pattern, # 小写字母(变量名)
146
+ whitespace_pattern, # 空格和换行符
147
+
148
+ ] + punctuation_patterns # 将标点符号的正则表达式添加到列表中
149
+
150
+ # 使用 | 连接所有模式
151
+ combined_pattern = "|".join(token_patterns)
152
+
153
+ # 返回编译后的正则表达式对象
154
+ return re.compile(combined_pattern)
155
+
156
+ def tokenize(self, text: str, mode: str = 'text', add_special_tokens: bool = True):
157
+ if mode == 'definition':
158
+ return self._tokenize_definition(text, add_special_tokens)
159
+ elif mode == 'text':
160
+ return self._tokenize_equation(text, add_special_tokens)
161
+ elif mode == 'withdef_text':
162
+ return self._tokenize_withdef_text(text, add_special_tokens)
163
+ else:
164
+ raise ValueError(f"Unsupported mode: {self.mode}")
165
+
166
+ def _tokenize_definition(self, text, add_special_tokens):
167
+ tokens = []
168
+ if add_special_tokens:
169
+ tokens.append('[DEF_START]')
170
+ for match in self.token_regex.finditer(text):
171
+ token_type = match.lastgroup
172
+ token_value = match.group(token_type)
173
+ if token_type != "WHITESPACE":
174
+ tokens.append(token_value)
175
+ if add_special_tokens:
176
+ tokens.append('[DEF_END]')
177
+ return tokens
178
+
179
+ def _tokenize_equation(self, text, add_special_tokens):
180
+ tokens = []
181
+ if add_special_tokens:
182
+ tokens.append('[EQ_START]')
183
+
184
+ self.digit_pattern = f"[{re.escape(self.param_config['custom_digits'])}]"
185
+ self.number_pattern = f"[-]?{self.digit_pattern}+"
186
+ self.base_symbols_pattern = f"(?:{'|'.join(map(re.escape, self.param_config['base_symbols']))})"
187
+ self.base_symbols_number_pattern = f"({self.base_symbols_pattern}{self.number_pattern})"
188
+
189
+ parts = re.split(self.base_symbols_number_pattern, text)
190
+ final_parts = []
191
+ for part in parts:
192
+ if re.search(self.number_pattern, part):
193
+ sub_parts = re.split(f"({self.number_pattern})", part)
194
+ final_parts.extend(sub_parts)
195
+ else:
196
+ final_parts.append(part)
197
+
198
+ for part in final_parts:
199
+ for match in self.token_regex.finditer(part):
200
+ token_type = match.lastgroup
201
+ token_value = match.group(token_type)
202
+ if token_type != "WHITESPACE":
203
+ tokens.append(token_value)
204
+
205
+ if add_special_tokens:
206
+ tokens.append('[EQ_END]')
207
+ return tokens
208
+
209
+ def _tokenize_withdef_text(self, text, add_special_tokens):
210
+ tokens = []
211
+ segments = re.split(r'(\[DEF_START\]|\[DEF_JOIN\]|\[DEF_END\]|\[EQ_START\]|\[EQ_END\])', text)
212
+ current_mode = None
213
+
214
+ for seg in segments:
215
+ seg = seg.strip()
216
+ if not seg:
217
+ continue
218
+
219
+ if seg in ['[DEF_START]', '[DEF_JOIN]']:
220
+ if add_special_tokens:
221
+ tokens.append(seg)
222
+ current_mode = 'definition'
223
+ elif seg == '[DEF_END]':
224
+ if add_special_tokens:
225
+ tokens.append(seg)
226
+ current_mode = None
227
+ elif seg == '[EQ_START]':
228
+ if add_special_tokens:
229
+ tokens.append(seg)
230
+ current_mode = 'text'
231
+ elif seg == '[EQ_END]':
232
+ if add_special_tokens:
233
+ tokens.append(seg)
234
+ current_mode = None
235
+ else:
236
+ if current_mode == 'definition':
237
+ inner_tokens = self._tokenize_definition(seg, add_special_tokens=False)
238
+ tokens.extend(inner_tokens)
239
+ elif current_mode == 'text':
240
+ inner_tokens = self._tokenize_equation(seg, add_special_tokens=False)
241
+ tokens.extend(inner_tokens)
242
+ else:
243
+ tokens.extend(seg.split())
244
+ return tokens
245
+
246
+
247
+ def convert_tokens_to_ids(self, tokens):
248
+ if isinstance(tokens[0], str):
249
+ return [self.vocab.get(token, self.vocab['[UNK]']) for token in tokens]
250
+ return tokens
251
+
252
+ def convert_ids_to_tokens(self, ids):
253
+ reverse_vocab = {v: k for k, v in self.vocab.items()}
254
+ return [reverse_vocab.get(i, '[UNK]') for i in ids]
255
+
256
+ # def encode(self, text, mode=None, add_special_tokens=None):
257
+ # tokens = self.tokenize(text, mode=mode, add_special_tokens=add_special_tokens)
258
+ # return self.convert_tokens_to_ids(tokens)
259
+
260
+ def get_vocab(self):
261
+ return self.vocab
262
+
263
+ def encode(self, texts, mode=None, add_special_tokens=True, padding=True, truncation=True, max_length=None):
264
+ all_tokens = self.tokenize(texts, mode=mode, add_special_tokens=add_special_tokens)
265
+ all_ids = [self.convert_tokens_to_ids(tokens) for tokens in all_tokens]
266
+
267
+ # Padding and truncation logic as before
268
+ if padding:
269
+ max_len = max(len(ids) for ids in all_ids)
270
+ padded_ids = [ids + [self.pad_id] * (max_len - len(ids)) for ids in all_ids]
271
+ else:
272
+ padded_ids = all_ids
273
+
274
+ if truncation and max_length:
275
+ padded_ids = [ids[:max_length] for ids in padded_ids]
276
+
277
+ input_ids_tensor = torch.tensor(padded_ids)
278
+ return input_ids_tensor
279
+
280
+ def decode(self, ids, skip_special_tokens=False):
281
+ tokens = self.convert_ids_to_tokens(ids)
282
+ if skip_special_tokens:
283
+ tokens = [t for t in tokens if not (t.startswith('[') and t.endswith(']'))]
284
+ return " ".join(tokens).replace(" ##", "")
285
+
286
+ def __call__(self, texts, return_tensors=None, **kwargs):
287
+ if isinstance(texts, str):
288
+ texts = [texts]
289
+
290
+ input_ids = self.encode(texts, **kwargs)
291
+
292
+ if return_tensors == "pt":
293
+ return {"input_ids": input_ids}
294
+
295
+ return {"input_ids": input_ids.tolist()}
296
+