princepride commited on
Commit
863598a
·
verified ·
1 Parent(s): 8eb4aa9

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -524
model.py DELETED
@@ -1,524 +0,0 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
2
- import torch
3
- from modules.file import ExcelFileWriter
4
- import os
5
-
6
- from abc import ABC, abstractmethod
7
- from typing import List
8
- import re
9
-
10
- class FilterPipeline():
11
- def __init__(self, filter_list):
12
- self._filter_list:List[Filter] = filter_list
13
-
14
- def append(self, filter):
15
- self._filter_list.append(filter)
16
-
17
- def batch_encoder(self, inputs):
18
- for filter in self._filter_list:
19
- inputs = filter.encoder(inputs)
20
- return inputs
21
-
22
- def batch_decoder(self, inputs):
23
- for filter in reversed(self._filter_list):
24
- inputs = filter.decoder(inputs)
25
- return inputs
26
-
27
- class Filter(ABC):
28
- # 抽象基类,用于定义过滤器的基本接口
29
- def __init__(self):
30
- self.name = 'filter' # 过滤器的名称
31
- self.code = [] # 存储过滤或编码信息
32
- @abstractmethod
33
- def encoder(self, inputs):
34
- # 抽象方法,编码或过滤输入的接口
35
- pass
36
-
37
- @abstractmethod
38
- def decoder(self, inputs):
39
- # 抽象方法,解码或还原输入的接口
40
- pass
41
-
42
- class SpecialTokenFilter(Filter):
43
- # 特殊字符过滤器,用于过滤特定的特殊字符字符串
44
- def __init__(self):
45
- self.name = 'special token filter'
46
- self.code = []
47
- self.special_tokens = ['!', '!', '-'] # 定义特殊字符集
48
-
49
- def encoder(self, inputs):
50
- # 编码函数,过滤掉仅包含特殊字符的字符串
51
- filtered_inputs = []
52
- self.code = []
53
- for i, input_str in enumerate(inputs):
54
- if not all(char in self.special_tokens for char in input_str):
55
- filtered_inputs.append(input_str)
56
- else:
57
- self.code.append([i, input_str]) # 将特殊字符字符串的位置和内容保存
58
- return filtered_inputs
59
-
60
- def decoder(self, inputs):
61
- # 解码函数,将被过滤的特殊字符字符串还原
62
- original_inputs = inputs.copy()
63
- for removed_indice in self.code:
64
- original_inputs.insert(removed_indice[0], removed_indice[1]) # 恢复原始位置的字符串
65
- return original_inputs
66
-
67
- class SperSignFilter(Filter):
68
- # 特殊标记过滤器,用于处理包含 '%s' 的字符串
69
- def __init__(self):
70
- self.name = 's percentage sign filter'
71
- self.code = []
72
-
73
- def encoder(self, inputs):
74
- # 编码函数,将 '%s' 替换为 '*'
75
- encoded_inputs = []
76
- self.code = []
77
- for i, input_str in enumerate(inputs):
78
- if '%s' in input_str:
79
- encoded_str = input_str.replace('%s', '*')
80
- self.code.append(i) # 保存包含 '%s' 的字符串位置
81
- else:
82
- encoded_str = input_str
83
- encoded_inputs.append(encoded_str)
84
- return encoded_inputs
85
-
86
- def decoder(self, inputs):
87
- # 解码函数,将 '*' 还原为 '%s'
88
- decoded_inputs = inputs.copy()
89
- for i in self.code:
90
- decoded_inputs[i] = decoded_inputs[i].replace('*', '%s')
91
- return decoded_inputs
92
-
93
- class ParenSParenFilter(Filter):
94
- # 特殊字符串过滤器,用于处理 '(s)' 的字符串
95
- def __init__(self):
96
- self.name = 'Paren s paren filter'
97
- self.code = []
98
-
99
- def encoder(self, inputs):
100
- # 编码函数,将 '(s)' 替换为 '$'
101
- encoded_inputs = []
102
- self.code = []
103
- for i, input_str in enumerate(inputs):
104
- if '(s)' in input_str:
105
- encoded_str = input_str.replace('(s)', '$')
106
- self.code.append(i) # 保存包含 '(s)' 的字符串位置
107
- else:
108
- encoded_str = input_str
109
- encoded_inputs.append(encoded_str)
110
- return encoded_inputs
111
-
112
- def decoder(self, inputs):
113
- # 解码函数,将 '$' 还原为 '(s)'
114
- decoded_inputs = inputs.copy()
115
- for i in self.code:
116
- decoded_inputs[i] = decoded_inputs[i].replace('$', '(s)')
117
- return decoded_inputs
118
-
119
- class ChevronsFilter(Filter):
120
- # 尖括号过滤器,用于处理包含 '<>' 内容的字符串
121
- def __init__(self):
122
- self.name = 'chevrons filter'
123
- self.code = []
124
-
125
- def encoder(self, inputs):
126
- # 编码函数,将尖括号内的内容替换为 '#'
127
- encoded_inputs = []
128
- self.code = []
129
- pattern = re.compile(r'<.*?>')
130
- for i, input_str in enumerate(inputs):
131
- if pattern.search(input_str):
132
- matches = pattern.findall(input_str)
133
- encoded_str = pattern.sub('#', input_str)
134
- self.code.append((i, matches)) # 保存匹配内容的位置和内容
135
- else:
136
- encoded_str = input_str
137
- encoded_inputs.append(encoded_str)
138
- return encoded_inputs
139
-
140
- def decoder(self, inputs):
141
- # 解码函数,将 '#' 还原为尖括号内的原内容
142
- decoded_inputs = inputs.copy()
143
- for i, matches in self.code:
144
- for match in matches:
145
- decoded_inputs[i] = decoded_inputs[i].replace('#', match, 1)
146
- return decoded_inputs
147
-
148
- class SimilarFilter(Filter):
149
- # 相似字符串过滤器,用于处理只在数字上有区别的字符串
150
- def __init__(self):
151
- self.name = 'similar filter'
152
- self.code = []
153
-
154
- def is_similar(self, str1, str2):
155
- # 判断两个字符串是否相似(忽略数字)
156
- pattern = re.compile(r'\d+')
157
- return pattern.sub('', str1) == pattern.sub('', str2)
158
-
159
- def encoder(self, inputs):
160
- # 编码函数,检测连续的相似字符串,记录索引和内容
161
- encoded_inputs = []
162
- self.code = []
163
- i = 0
164
- while i < len(inputs):
165
- encoded_inputs.append(inputs[i])
166
- similar_strs = [inputs[i]]
167
- j = i + 1
168
- while j < len(inputs) and self.is_similar(inputs[i], inputs[j]):
169
- similar_strs.append(inputs[j])
170
- j += 1
171
- if len(similar_strs) > 1:
172
- self.code.append((i, similar_strs))
173
- i = j
174
- return encoded_inputs
175
-
176
- def decoder(self, inputs):
177
- # 解码函数,将被检测的相似字符串插回原位置
178
- decoded_inputs = inputs
179
- for i, similar_strs in self.code:
180
- pattern = re.compile(r'\d+')
181
- for j in range(len(similar_strs)):
182
- if pattern.search(similar_strs[j]):
183
- number = re.findall(r'\d+', similar_strs[j])[0]
184
- new_str = pattern.sub(number, inputs[i])
185
- else:
186
- new_str = inputs[i]
187
- if j > 0:
188
- decoded_inputs.insert(i + j, new_str)
189
- return decoded_inputs
190
-
191
- class ChineseFilter:
192
- # 中文拼音过滤器,用于检测并过滤中文拼音单词
193
- def __init__(self, pinyin_lib_file='pinyin.txt'):
194
- self.name = 'chinese filter'
195
- self.code = []
196
- self.pinyin_lib = self.load_pinyin_lib(pinyin_lib_file) # 加载拼音库
197
-
198
- def load_pinyin_lib(self, file_path):
199
- # 加载拼音库文件到内存中
200
- with open(os.path.join(script_dir, file_path), 'r', encoding='utf-8') as f:
201
- return set(line.strip().lower() for line in f)
202
-
203
- def is_valid_chinese(self, word):
204
- # 判断一个单词是否符合要求: 单词仅由一个单词构成且首字母大写
205
- if len(word.split()) == 1 and word[0].isupper():
206
- return self.is_pinyin(word.lower())
207
- return False
208
-
209
- def encoder(self, inputs):
210
- # 编码函数,检测并过滤符合拼音规则的中文单词
211
- encoded_inputs = []
212
- self.code = []
213
- for i, word in enumerate(inputs):
214
- if self.is_valid_chinese(word):
215
- self.code.append((i, word)) # 保存符合要求的中文单词及其索引
216
- else:
217
- encoded_inputs.append(word)
218
- return encoded_inputs
219
-
220
- def decoder(self, inputs):
221
- # 解码函数,将符合拼音规则的中文单词还原到原位置
222
- decoded_inputs = inputs.copy()
223
- for i, word in self.code:
224
- decoded_inputs.insert(i, word)
225
- return decoded_inputs
226
-
227
- def is_pinyin(self, string):
228
- # 判断字符串是否是拼音或英文单词
229
- string = string.lower()
230
- stringlen = len(string)
231
- max_len = 6
232
- result = []
233
- n = 0
234
- while n < stringlen:
235
- matched = 0
236
- temp_result = []
237
- for i in range(max_len, 0, -1):
238
- s = string[0:i]
239
- if s in self.pinyin_lib:
240
- temp_result.append(string[:i])
241
- matched = i
242
- break
243
- if i == 1 and len(temp_result) == 0:
244
- return False
245
- result.extend(temp_result)
246
- string = string[matched:]
247
- n += matched
248
- return True
249
-
250
- # 定义脚本目录的路径,供拼音文件加载使用
251
- script_dir = os.path.dirname(os.path.abspath(__file__))
252
- parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(script_dir)))
253
-
254
-
255
- class Model():
256
- def __init__(self, modelname, selected_lora_model, selected_gpu):
257
- def get_gpu_index(gpu_info, target_gpu_name):
258
- """
259
- 从 GPU 信息中获取目标 GPU 的索引
260
- Args:
261
- gpu_info (list): 包含 GPU 名称的列表
262
- target_gpu_name (str): 目标 GPU 的名称
263
-
264
- Returns:
265
- int: 目标 GPU 的索引,如果未找到则返回 -1
266
- """
267
- for i, name in enumerate(gpu_info):
268
- if target_gpu_name.lower() in name.lower():
269
- return i
270
- return -1
271
- if selected_gpu != "cpu":
272
- gpu_count = torch.cuda.device_count()
273
- gpu_info = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
274
- selected_gpu_index = get_gpu_index(gpu_info, selected_gpu)
275
- self.device_name = f"cuda:{selected_gpu_index}"
276
- else:
277
- self.device_name = "cpu"
278
- print("device_name", self.device_name)
279
- self.model = AutoModelForCausalLM.from_pretrained(modelname).to(self.device_name)
280
- self.tokenizer = AutoTokenizer.from_pretrained(modelname)
281
- # self.translator = pipeline('translation', model=self.original_model, tokenizer=self.tokenizer, src_lang=original_language, tgt_lang=target_language, device=device)
282
-
283
- def generate(self, inputs, original_language, target_languages, max_batch_size):
284
- filter_list = [SpecialTokenFilter(), ChevronsFilter(), SimilarFilter(), ChineseFilter()]
285
- filter_pipeline = FilterPipeline(filter_list)
286
- def language_mapping(original_language):
287
- d = {
288
- "Achinese (Arabic script)": "ace_Arab",
289
- "Achinese (Latin script)": "ace_Latn",
290
- "Mesopotamian Arabic": "acm_Arab",
291
- "Ta'izzi-Adeni Arabic": "acq_Arab",
292
- "Tunisian Arabic": "aeb_Arab",
293
- "Afrikaans": "afr_Latn",
294
- "South Levantine Arabic": "ajp_Arab",
295
- "Akan": "aka_Latn",
296
- "Amharic": "amh_Ethi",
297
- "North Levantine Arabic": "apc_Arab",
298
- "Standard Arabic": "arb_Arab",
299
- "Najdi Arabic": "ars_Arab",
300
- "Moroccan Arabic": "ary_Arab",
301
- "Egyptian Arabic": "arz_Arab",
302
- "Assamese": "asm_Beng",
303
- "Asturian": "ast_Latn",
304
- "Awadhi": "awa_Deva",
305
- "Central Aymara": "ayr_Latn",
306
- "South Azerbaijani": "azb_Arab",
307
- "North Azerbaijani": "azj_Latn",
308
- "Bashkir": "bak_Cyrl",
309
- "Bambara": "bam_Latn",
310
- "Balinese": "ban_Latn",
311
- "Belarusian": "bel_Cyrl",
312
- "Bemba": "bem_Latn",
313
- "Bengali": "ben_Beng",
314
- "Bhojpuri": "bho_Deva",
315
- "Banjar (Arabic script)": "bjn_Arab",
316
- "Banjar (Latin script)": "bjn_Latn",
317
- "Tibetan": "bod_Tibt",
318
- "Bosnian": "bos_Latn",
319
- "Buginese": "bug_Latn",
320
- "Bulgarian": "bul_Cyrl",
321
- "Catalan": "cat_Latn",
322
- "Cebuano": "ceb_Latn",
323
- "Czech": "ces_Latn",
324
- "Chokwe": "cjk_Latn",
325
- "Central Kurdish": "ckb_Arab",
326
- "Crimean Tatar": "crh_Latn",
327
- "Welsh": "cym_Latn",
328
- "Danish": "dan_Latn",
329
- "German": "deu_Latn",
330
- "Dinka": "dik_Latn",
331
- "Jula": "dyu_Latn",
332
- "Dzongkha": "dzo_Tibt",
333
- "Greek": "ell_Grek",
334
- "English": "eng_Latn",
335
- "Esperanto": "epo_Latn",
336
- "Estonian": "est_Latn",
337
- "Basque": "eus_Latn",
338
- "Ewe": "ewe_Latn",
339
- "Faroese": "fao_Latn",
340
- "Persian": "pes_Arab",
341
- "Fijian": "fij_Latn",
342
- "Finnish": "fin_Latn",
343
- "Fon": "fon_Latn",
344
- "French": "fra_Latn",
345
- "Friulian": "fur_Latn",
346
- "Nigerian Fulfulde": "fuv_Latn",
347
- "Scottish Gaelic": "gla_Latn",
348
- "Irish": "gle_Latn",
349
- "Galician": "glg_Latn",
350
- "Guarani": "grn_Latn",
351
- "Gujarati": "guj_Gujr",
352
- "Haitian Creole": "hat_Latn",
353
- "Hausa": "hau_Latn",
354
- "Hebrew": "heb_Hebr",
355
- "Hindi": "hin_Deva",
356
- "Chhattisgarhi": "hne_Deva",
357
- "Croatian": "hrv_Latn",
358
- "Hungarian": "hun_Latn",
359
- "Armenian": "hye_Armn",
360
- "Igbo": "ibo_Latn",
361
- "Iloko": "ilo_Latn",
362
- "Indonesian": "ind_Latn",
363
- "Icelandic": "isl_Latn",
364
- "Italian": "ita_Latn",
365
- "Javanese": "jav_Latn",
366
- "Japanese": "jpn_Jpan",
367
- "Kabyle": "kab_Latn",
368
- "Kachin": "kac_Latn",
369
- "Arabic": "ar_AR",
370
- "Chinese": "zho_Hans",
371
- "Spanish": "spa_Latn",
372
- "Dutch": "nld_Latn",
373
- "Kazakh": "kaz_Cyrl",
374
- "Korean": "kor_Hang",
375
- "Lithuanian": "lit_Latn",
376
- "Malayalam": "mal_Mlym",
377
- "Marathi": "mar_Deva",
378
- "Nepali": "ne_NP",
379
- "Polish": "pol_Latn",
380
- "Portuguese": "por_Latn",
381
- "Russian": "rus_Cyrl",
382
- "Sinhala": "sin_Sinh",
383
- "Tamil": "tam_Taml",
384
- "Turkish": "tur_Latn",
385
- "Ukrainian": "ukr_Cyrl",
386
- "Urdu": "urd_Arab",
387
- "Vietnamese": "vie_Latn",
388
- "Thai":"tha_Thai",
389
- "Khmer":"khm_Khmr"
390
- }
391
- return d[original_language]
392
- def process_gpu_translate_result(temp_outputs):
393
- outputs = []
394
- for temp_output in temp_outputs:
395
- length = len(temp_output[0]["generated_translation"])
396
- for i in range(length):
397
- temp = []
398
- for trans in temp_output:
399
- temp.append({
400
- "target_language": trans["target_language"],
401
- "generated_translation": trans['generated_translation'][i],
402
- })
403
- outputs.append(temp)
404
- excel_writer = ExcelFileWriter()
405
- excel_writer.write_text(os.path.join(parent_dir,r"temp/empty.xlsx"), outputs, 'A', 1, len(outputs))
406
- self.tokenizer.src_lang = language_mapping(original_language)
407
- if self.device_name == "cpu":
408
- # Tokenize input
409
- input_ids = self.tokenizer(inputs, return_tensors="pt", padding=True, max_length=128).to(self.device_name)
410
- output = []
411
- for target_language in target_languages:
412
- # Get language code for the target language
413
- target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
414
- # Generate translation
415
- generated_tokens = self.model.generate(
416
- **input_ids,
417
- forced_bos_token_id=target_lang_code,
418
- max_length=128
419
- )
420
- generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
421
- # Append result to output
422
- output.append({
423
- "target_language": target_language,
424
- "generated_translation": generated_translation,
425
- })
426
- outputs = []
427
- length = len(output[0]["generated_translation"])
428
- for i in range(length):
429
- temp = []
430
- for trans in output:
431
- temp.append({
432
- "target_language": trans["target_language"],
433
- "generated_translation": trans['generated_translation'][i],
434
- })
435
- outputs.append(temp)
436
- return outputs
437
- else:
438
- # 最大批量大小 = 可用 GPU 内存字节数 / 4 / (张量大小 + 可训练参数)
439
- # max_batch_size = 10
440
- # Ensure batch size is within model limits:
441
- print("length of inputs: ",len(inputs))
442
- batch_size = min(len(inputs), int(max_batch_size))
443
- batches = [inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)]
444
- print("length of batches size: ", len(batches))
445
- temp_outputs = []
446
- processed_num = 0
447
- for index, batch in enumerate(batches):
448
- # Tokenize input
449
- print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
450
- print(len(batch))
451
- print(batch)
452
- batch = filter_pipeline.batch_encoder(batch)
453
- print(batch)
454
- temp = []
455
- if len(batch) > 0:
456
- for target_language in target_languages:
457
- batch_messages = [[
458
- {"role": "system", "content": f"你是一个ERP系统中译英专家,你任务是把markdown格式的文本,保留其格式并从{original_language}翻译成{target_language},不要添加多余的内容。"},
459
- {"role": "user", "content": input},
460
- ] for input in batch]
461
- batch_texts = [self.tokenizer.apply_chat_template(
462
- messages,
463
- tokenize=False,
464
- add_generation_prompt=True
465
- ) for messages in batch_messages]
466
- self.tokenizer.padding_side = "left"
467
- model_inputs = self.tokenizer(
468
- batch_texts,
469
- return_tensors="pt",
470
- padding="longest",
471
- truncation=True,
472
- ).to(self.device_name)
473
- generated_ids = self.model.generate(
474
- max_new_tokens=512,
475
- **model_inputs
476
- )
477
- # Calculate the length of new tokens generated for each sequence
478
- new_tokens = [
479
- output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
480
- ]
481
- generated_translation = self.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
482
- # Append result to output
483
- temp.append({
484
- "target_language": target_language,
485
- "generated_translation": generated_translation,
486
- })
487
- input_ids.to('cpu')
488
- del input_ids
489
- else:
490
- for target_language in target_languages:
491
- generated_translation = filter_pipeline.batch_decoder(batch)
492
- print(generated_translation)
493
- print(len(generated_translation))
494
- # Append result to output
495
- temp.append({
496
- "target_language": target_language,
497
- "generated_translation": generated_translation,
498
- })
499
- temp_outputs.append(temp)
500
- processed_num += len(batch)
501
- if (index + 1) * max_batch_size // 1000 - index * max_batch_size // 1000 == 1:
502
- print("Already processed number: ", len(temp_outputs))
503
- process_gpu_translate_result(temp_outputs)
504
- outputs = []
505
- for temp_output in temp_outputs:
506
- length = len(temp_output[0]["generated_translation"])
507
- for i in range(length):
508
- temp = []
509
- for trans in temp_output:
510
- temp.append({
511
- "target_language": trans["target_language"],
512
- "generated_translation": trans['generated_translation'][i],
513
- })
514
- outputs.append(temp)
515
- return outputs
516
- for filter in self._filter_list:
517
- inputs = filter.encoder(inputs)
518
- return inputs
519
-
520
- def batch_decoder(self, inputs):
521
- for filter in reversed(self._filter_list):
522
- inputs = filter.decoder(inputs)
523
- return inputs
524
-