mkj69 commited on
Commit
522cec8
·
verified ·
1 Parent(s): 747cf01

Delete op_tokenizer.py

Browse files
Files changed (1) hide show
  1. op_tokenizer.py +0 -268
op_tokenizer.py DELETED
@@ -1,268 +0,0 @@
1
- import json
2
- from typing import List, Optional, Dict
3
- from transformers import PreTrainedTokenizer
4
- import os
5
- import json
6
- import re
7
- default_config = {
8
- "custom_digits": "0123456789ABCDEF",
9
- "variable_atoms": {
10
- "left_operand": "a", # 左操作数变量名
11
- "right_operand": "b" # 右操作数变量名
12
- },
13
-
14
- "other_symbols_atoms": {
15
- "left_parenthesis": "(", # 左括号
16
- "right_parenthesis": ")", # 右括号
17
- "equals_sign": "=", # 等号,常用于赋值或比较
18
- "nan_symbol": "NaN", # 非数(Not a Number)
19
- "inf_symbol": "Inf" # 无穷大(Infinity)
20
- },
21
-
22
- "operator_symbol_min_len": 1,
23
- "operator_symbol_max_len": 3,
24
-
25
- "basic_operator_symbols": ["+", "-", "*", "/", "%"],
26
-
27
- "base_symbols": [
28
- "≮⫘↔",
29
- "⫏≰",
30
- "⪩⨒∯",
31
- "⇑⪆",
32
- "↹⩛",
33
- "≴∭⊉",
34
- "⪪⊹⋣",
35
- "⋋%⋟",
36
- "⊺⇮",
37
- "⋰*⋻",
38
- "⫖↰⪸",
39
- "⪎⋱⫍",
40
- "⨗⨭⨅",
41
- "⫶⩼⫲",
42
- "∃⊬"
43
- ],
44
-
45
- "comparison_ops": ["==", ">", "<", ">=", "<=", "!="],
46
-
47
- "logical_connectors": ["and", "or"],
48
-
49
- "definition_symbols": [
50
- ",",
51
- ";",
52
- "if",
53
- "else",
54
- "{",
55
- "}",
56
- "abs"
57
- ]
58
- }
59
-
60
- class OpTokenizer(PreTrainedTokenizer):
61
- def __init__(self, vocab_file, **kwargs):
62
-
63
- self.param_config= default_config
64
- self.vocab = self.load_vocab(vocab_file)
65
- self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
66
- super().__init__(**kwargs)
67
- # 定义基础符号
68
- self.basic_symbols = list("0123456789()=ABCDEFab")
69
- self.special_results = ['NaN', 'Inf']
70
- self.comparison_ops = ["==", ">", "<", ">=", "<=", "!="]
71
- self.logical_connectors = ["and", "or"]
72
- self.definition_symbols = [",", ";", "if", "else", "{", "}", "abs"]
73
-
74
- self.token_regex = self.build_token_regex()
75
-
76
- # 初始化特殊标记 ID
77
- self.pad_id = self.vocab['[PAD]']
78
- self.unk_id = self.vocab['[UNK]']
79
- self.sep_id = self.vocab['[SEP]']
80
- self.mask_id = self.vocab['[MASK]']
81
- self.bos_id = self.vocab['[BOS]']
82
- self.eos_id = self.vocab['[EOS]']
83
- self.eod_id = self.vocab['[EOD]']
84
-
85
- def load_vocab(self, vocab_file):
86
- # 实现你的词表加载逻辑
87
- with open(vocab_file, encoding="utf-8") as f:
88
- vocab = json.load(f)
89
- return vocab
90
-
91
- def save_vocabulary(self, save_directory, filename_prefix=""):
92
- if filename_prefix is None:
93
- filename_prefix = ""
94
-
95
- if not os.path.exists(save_directory):
96
- os.makedirs(save_directory)
97
-
98
- vocab_file_path = os.path.join(save_directory, filename_prefix + "vocab.json")
99
-
100
- with open(vocab_file_path, "w", encoding="utf-8") as f:
101
- json.dump(self.vocab, f, ensure_ascii=False, indent=4)
102
-
103
- print(f"Vocabulary saved to {vocab_file_path}")
104
-
105
- return (vocab_file_path,) # 返回元组而不是列表
106
-
107
- def build_token_regex(self):
108
- """构建分词正则表达式,逐字符、符号进行匹配"""
109
- # 特殊结果的正则表达式(比如 NaN, Inf)
110
- special_results = [re.escape(result) for result in self.special_results]
111
- # 比较操作符的正则表达式
112
- comparison_ops = [re.escape(op) for op in self.comparison_ops]
113
- # 逻辑连接符的正则表达式
114
- logical_connectors = [re.escape(connector) for connector in self.logical_connectors]
115
-
116
- operator_pattern = r"(?P<OPERATOR>([+\-*/%]|[\u2200-\u22FF\u2A00-\u2BFF\u2190-\u21FF])+)"
117
- variable_pattern = r"(?P<VARIABLE>[a-b])"
118
- digit_pattern = r"(?P<DIGIT>[0-9A-F])"
119
- special_result_pattern = r"(?P<SPECIAL_RESULT>" + "|".join(special_results) + ")"
120
- comparison_ops_pattern = r"(?P<COMPARISON_OP>" + "|".join(comparison_ops) + ")"
121
- logical_connectors_pattern = r"(?P<LOGICAL_CONNECTOR>" + "|".join(logical_connectors) + ")"
122
- if_else_pattern = r"(?P<IF_ELSE>if|else)"
123
- whitespace_pattern = r"(?P<WHITESPACE>\s+)"
124
- abs_pattern = r"(?P<ABS>abs)"
125
- punctuation_patterns = [
126
- r"(?P<PARENTHESIS_LEFT>\()",
127
- r"(?P<PARENTHESIS_RIGHT>\))",
128
- r"(?P<CURLY_BRACE_LEFT>{)",
129
- r"(?P<CURLY_BRACE_RIGHT>})",
130
- r"(?P<SEMICOLON>;)",
131
- r"(?P<COMMA>,)",
132
- r"(?P<EQUAL>=)"
133
- ]
134
-
135
- # 所有模式结合在一起,注意先后顺序,应该先匹配长的
136
- token_patterns = [
137
- operator_pattern,
138
- special_result_pattern, # 特殊符号(如 NaN, Inf)
139
- comparison_ops_pattern, # 比较操作符
140
- logical_connectors_pattern, # 逻辑连接符
141
- if_else_pattern, # if 和 else
142
- abs_pattern,
143
- digit_pattern,
144
- variable_pattern, # 小写字母(变量名)
145
- whitespace_pattern, # 空格和换行符
146
-
147
- ] + punctuation_patterns # 将标点符号的正则表达式添加到列表中
148
-
149
- # 使用 | 连接所有模式
150
- combined_pattern = "|".join(token_patterns)
151
-
152
- # 返回编译后的正则表达式对象
153
- return re.compile(combined_pattern)
154
-
155
- def tokenize(self, text: str, mode: str = 'text', add_special_tokens: bool = True):
156
- if mode == 'definition':
157
- return self._tokenize_definition(text, add_special_tokens)
158
- elif mode == 'text':
159
- return self._tokenize_equation(text, add_special_tokens)
160
- elif mode == 'withdef_text':
161
- return self._tokenize_withdef_text(text, add_special_tokens)
162
- else:
163
- raise ValueError(f"Unsupported mode: {self.mode}")
164
-
165
- def _tokenize_definition(self, text, add_special_tokens):
166
- tokens = []
167
- if add_special_tokens:
168
- tokens.append('[DEF_START]')
169
- for match in self.token_regex.finditer(text):
170
- token_type = match.lastgroup
171
- token_value = match.group(token_type)
172
- if token_type != "WHITESPACE":
173
- tokens.append(token_value)
174
- if add_special_tokens:
175
- tokens.append('[DEF_END]')
176
- return tokens
177
-
178
- def _tokenize_equation(self, text, add_special_tokens):
179
- tokens = []
180
- if add_special_tokens:
181
- tokens.append('[EQ_START]')
182
-
183
- self.digit_pattern = f"[{re.escape(self.param_config['custom_digits'])}]"
184
- self.number_pattern = f"[-]?{self.digit_pattern}+"
185
- self.base_symbols_pattern = f"(?:{'|'.join(map(re.escape, self.param_config['base_symbols']))})"
186
- self.base_symbols_number_pattern = f"({self.base_symbols_pattern}{self.number_pattern})"
187
-
188
- parts = re.split(self.base_symbols_number_pattern, text)
189
- final_parts = []
190
- for part in parts:
191
- if re.search(self.number_pattern, part):
192
- sub_parts = re.split(f"({self.number_pattern})", part)
193
- final_parts.extend(sub_parts)
194
- else:
195
- final_parts.append(part)
196
-
197
- for part in final_parts:
198
- for match in self.token_regex.finditer(part):
199
- token_type = match.lastgroup
200
- token_value = match.group(token_type)
201
- if token_type != "WHITESPACE":
202
- tokens.append(token_value)
203
-
204
- if add_special_tokens:
205
- tokens.append('[EQ_END]')
206
- return tokens
207
-
208
- def _tokenize_withdef_text(self, text, add_special_tokens):
209
- tokens = []
210
- segments = re.split(r'(\[DEF_START\]|\[DEF_JOIN\]|\[DEF_END\]|\[EQ_START\]|\[EQ_END\])', text)
211
- current_mode = None
212
-
213
- for seg in segments:
214
- seg = seg.strip()
215
- if not seg:
216
- continue
217
-
218
- if seg in ['[DEF_START]', '[DEF_JOIN]']:
219
- if add_special_tokens:
220
- tokens.append(seg)
221
- current_mode = 'definition'
222
- elif seg == '[DEF_END]':
223
- if add_special_tokens:
224
- tokens.append(seg)
225
- current_mode = None
226
- elif seg == '[EQ_START]':
227
- if add_special_tokens:
228
- tokens.append(seg)
229
- current_mode = 'text'
230
- elif seg == '[EQ_END]':
231
- if add_special_tokens:
232
- tokens.append(seg)
233
- current_mode = None
234
- else:
235
- if current_mode == 'definition':
236
- inner_tokens = self._tokenize_definition(seg, add_special_tokens=False)
237
- tokens.extend(inner_tokens)
238
- elif current_mode == 'text':
239
- inner_tokens = self._tokenize_equation(seg, add_special_tokens=False)
240
- tokens.extend(inner_tokens)
241
- else:
242
- tokens.extend(seg.split())
243
- return tokens
244
-
245
-
246
- def convert_tokens_to_ids(self, tokens):
247
- if isinstance(tokens[0], str):
248
- return [self.vocab.get(token, self.vocab['[UNK]']) for token in tokens]
249
- return tokens
250
-
251
- def convert_ids_to_tokens(self, ids):
252
- reverse_vocab = {v: k for k, v in self.vocab.items()}
253
- return [reverse_vocab.get(i, '[UNK]') for i in ids]
254
-
255
- def encode(self, text, mode=None, add_special_tokens=None):
256
- tokens = self.tokenize(text, mode=mode, add_special_tokens=add_special_tokens)
257
- return self.convert_tokens_to_ids(tokens)
258
-
259
-
260
- def decode(self, ids, skip_special_tokens=False):
261
- tokens = self.convert_ids_to_tokens(ids)
262
- if skip_special_tokens:
263
- tokens = [t for t in tokens if not (t.startswith('[') and t.endswith(']'))]
264
- return " ".join(tokens).replace(" ##", "")
265
-
266
- def get_vocab(self):
267
- return self.vocab
268
-