Delete model.py
Browse files
model.py
DELETED
@@ -1,524 +0,0 @@
|
|
1 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
2 |
-
import torch
|
3 |
-
from modules.file import ExcelFileWriter
|
4 |
-
import os
|
5 |
-
|
6 |
-
from abc import ABC, abstractmethod
|
7 |
-
from typing import List
|
8 |
-
import re
|
9 |
-
|
10 |
-
class FilterPipeline():
|
11 |
-
def __init__(self, filter_list):
|
12 |
-
self._filter_list:List[Filter] = filter_list
|
13 |
-
|
14 |
-
def append(self, filter):
|
15 |
-
self._filter_list.append(filter)
|
16 |
-
|
17 |
-
def batch_encoder(self, inputs):
|
18 |
-
for filter in self._filter_list:
|
19 |
-
inputs = filter.encoder(inputs)
|
20 |
-
return inputs
|
21 |
-
|
22 |
-
def batch_decoder(self, inputs):
|
23 |
-
for filter in reversed(self._filter_list):
|
24 |
-
inputs = filter.decoder(inputs)
|
25 |
-
return inputs
|
26 |
-
|
27 |
-
class Filter(ABC):
|
28 |
-
# 抽象基类,用于定义过滤器的基本接口
|
29 |
-
def __init__(self):
|
30 |
-
self.name = 'filter' # 过滤器的名称
|
31 |
-
self.code = [] # 存储过滤或编码信息
|
32 |
-
@abstractmethod
|
33 |
-
def encoder(self, inputs):
|
34 |
-
# 抽象方法,编码或过滤输入的接口
|
35 |
-
pass
|
36 |
-
|
37 |
-
@abstractmethod
|
38 |
-
def decoder(self, inputs):
|
39 |
-
# 抽象方法,解码或还原输入的接口
|
40 |
-
pass
|
41 |
-
|
42 |
-
class SpecialTokenFilter(Filter):
|
43 |
-
# 特殊字符过滤器,用于过滤特定的特殊字符字符串
|
44 |
-
def __init__(self):
|
45 |
-
self.name = 'special token filter'
|
46 |
-
self.code = []
|
47 |
-
self.special_tokens = ['!', '!', '-'] # 定义特殊字符集
|
48 |
-
|
49 |
-
def encoder(self, inputs):
|
50 |
-
# 编码函数,过滤掉仅包含特殊字符的字符串
|
51 |
-
filtered_inputs = []
|
52 |
-
self.code = []
|
53 |
-
for i, input_str in enumerate(inputs):
|
54 |
-
if not all(char in self.special_tokens for char in input_str):
|
55 |
-
filtered_inputs.append(input_str)
|
56 |
-
else:
|
57 |
-
self.code.append([i, input_str]) # 将特殊字符字符串的位置和内容保存
|
58 |
-
return filtered_inputs
|
59 |
-
|
60 |
-
def decoder(self, inputs):
|
61 |
-
# 解码函数,将被过滤的特殊字符字符串还原
|
62 |
-
original_inputs = inputs.copy()
|
63 |
-
for removed_indice in self.code:
|
64 |
-
original_inputs.insert(removed_indice[0], removed_indice[1]) # 恢复原始位置的字符串
|
65 |
-
return original_inputs
|
66 |
-
|
67 |
-
class SperSignFilter(Filter):
|
68 |
-
# 特殊标记过滤器,用于处理包含 '%s' 的字符串
|
69 |
-
def __init__(self):
|
70 |
-
self.name = 's percentage sign filter'
|
71 |
-
self.code = []
|
72 |
-
|
73 |
-
def encoder(self, inputs):
|
74 |
-
# 编码函数,将 '%s' 替换为 '*'
|
75 |
-
encoded_inputs = []
|
76 |
-
self.code = []
|
77 |
-
for i, input_str in enumerate(inputs):
|
78 |
-
if '%s' in input_str:
|
79 |
-
encoded_str = input_str.replace('%s', '*')
|
80 |
-
self.code.append(i) # 保存包含 '%s' 的字符串位置
|
81 |
-
else:
|
82 |
-
encoded_str = input_str
|
83 |
-
encoded_inputs.append(encoded_str)
|
84 |
-
return encoded_inputs
|
85 |
-
|
86 |
-
def decoder(self, inputs):
|
87 |
-
# 解码函数,将 '*' 还原为 '%s'
|
88 |
-
decoded_inputs = inputs.copy()
|
89 |
-
for i in self.code:
|
90 |
-
decoded_inputs[i] = decoded_inputs[i].replace('*', '%s')
|
91 |
-
return decoded_inputs
|
92 |
-
|
93 |
-
class ParenSParenFilter(Filter):
|
94 |
-
# 特殊字符串过滤器,用于处理 '(s)' 的字符串
|
95 |
-
def __init__(self):
|
96 |
-
self.name = 'Paren s paren filter'
|
97 |
-
self.code = []
|
98 |
-
|
99 |
-
def encoder(self, inputs):
|
100 |
-
# 编码函数,将 '(s)' 替换为 '$'
|
101 |
-
encoded_inputs = []
|
102 |
-
self.code = []
|
103 |
-
for i, input_str in enumerate(inputs):
|
104 |
-
if '(s)' in input_str:
|
105 |
-
encoded_str = input_str.replace('(s)', '$')
|
106 |
-
self.code.append(i) # 保存包含 '(s)' 的字符串位置
|
107 |
-
else:
|
108 |
-
encoded_str = input_str
|
109 |
-
encoded_inputs.append(encoded_str)
|
110 |
-
return encoded_inputs
|
111 |
-
|
112 |
-
def decoder(self, inputs):
|
113 |
-
# 解码函数,将 '$' 还原为 '(s)'
|
114 |
-
decoded_inputs = inputs.copy()
|
115 |
-
for i in self.code:
|
116 |
-
decoded_inputs[i] = decoded_inputs[i].replace('$', '(s)')
|
117 |
-
return decoded_inputs
|
118 |
-
|
119 |
-
class ChevronsFilter(Filter):
|
120 |
-
# 尖括号过滤器,用于处理包含 '<>' 内容的字符串
|
121 |
-
def __init__(self):
|
122 |
-
self.name = 'chevrons filter'
|
123 |
-
self.code = []
|
124 |
-
|
125 |
-
def encoder(self, inputs):
|
126 |
-
# 编码函数,将尖括号内的内容替换为 '#'
|
127 |
-
encoded_inputs = []
|
128 |
-
self.code = []
|
129 |
-
pattern = re.compile(r'<.*?>')
|
130 |
-
for i, input_str in enumerate(inputs):
|
131 |
-
if pattern.search(input_str):
|
132 |
-
matches = pattern.findall(input_str)
|
133 |
-
encoded_str = pattern.sub('#', input_str)
|
134 |
-
self.code.append((i, matches)) # 保存匹配内容的位置和内容
|
135 |
-
else:
|
136 |
-
encoded_str = input_str
|
137 |
-
encoded_inputs.append(encoded_str)
|
138 |
-
return encoded_inputs
|
139 |
-
|
140 |
-
def decoder(self, inputs):
|
141 |
-
# 解码函数,将 '#' 还原为尖括号内的原内容
|
142 |
-
decoded_inputs = inputs.copy()
|
143 |
-
for i, matches in self.code:
|
144 |
-
for match in matches:
|
145 |
-
decoded_inputs[i] = decoded_inputs[i].replace('#', match, 1)
|
146 |
-
return decoded_inputs
|
147 |
-
|
148 |
-
class SimilarFilter(Filter):
|
149 |
-
# 相似字符串过滤器,用于处理只在数字上有区别的字符串
|
150 |
-
def __init__(self):
|
151 |
-
self.name = 'similar filter'
|
152 |
-
self.code = []
|
153 |
-
|
154 |
-
def is_similar(self, str1, str2):
|
155 |
-
# 判断两个字符串是否相似(忽略数字)
|
156 |
-
pattern = re.compile(r'\d+')
|
157 |
-
return pattern.sub('', str1) == pattern.sub('', str2)
|
158 |
-
|
159 |
-
def encoder(self, inputs):
|
160 |
-
# 编码函数,检测连续的相似字符串,记录索引和内容
|
161 |
-
encoded_inputs = []
|
162 |
-
self.code = []
|
163 |
-
i = 0
|
164 |
-
while i < len(inputs):
|
165 |
-
encoded_inputs.append(inputs[i])
|
166 |
-
similar_strs = [inputs[i]]
|
167 |
-
j = i + 1
|
168 |
-
while j < len(inputs) and self.is_similar(inputs[i], inputs[j]):
|
169 |
-
similar_strs.append(inputs[j])
|
170 |
-
j += 1
|
171 |
-
if len(similar_strs) > 1:
|
172 |
-
self.code.append((i, similar_strs))
|
173 |
-
i = j
|
174 |
-
return encoded_inputs
|
175 |
-
|
176 |
-
def decoder(self, inputs):
|
177 |
-
# 解码函数,将被检测的相似字符串插回原位置
|
178 |
-
decoded_inputs = inputs
|
179 |
-
for i, similar_strs in self.code:
|
180 |
-
pattern = re.compile(r'\d+')
|
181 |
-
for j in range(len(similar_strs)):
|
182 |
-
if pattern.search(similar_strs[j]):
|
183 |
-
number = re.findall(r'\d+', similar_strs[j])[0]
|
184 |
-
new_str = pattern.sub(number, inputs[i])
|
185 |
-
else:
|
186 |
-
new_str = inputs[i]
|
187 |
-
if j > 0:
|
188 |
-
decoded_inputs.insert(i + j, new_str)
|
189 |
-
return decoded_inputs
|
190 |
-
|
191 |
-
class ChineseFilter:
|
192 |
-
# 中文拼音过滤器,用于检测并过滤中文拼音单词
|
193 |
-
def __init__(self, pinyin_lib_file='pinyin.txt'):
|
194 |
-
self.name = 'chinese filter'
|
195 |
-
self.code = []
|
196 |
-
self.pinyin_lib = self.load_pinyin_lib(pinyin_lib_file) # 加载拼音库
|
197 |
-
|
198 |
-
def load_pinyin_lib(self, file_path):
|
199 |
-
# 加载拼音库文件到内存中
|
200 |
-
with open(os.path.join(script_dir, file_path), 'r', encoding='utf-8') as f:
|
201 |
-
return set(line.strip().lower() for line in f)
|
202 |
-
|
203 |
-
def is_valid_chinese(self, word):
|
204 |
-
# 判断一个单词是否符合要求: 单词仅由一个单词构成且首字母大写
|
205 |
-
if len(word.split()) == 1 and word[0].isupper():
|
206 |
-
return self.is_pinyin(word.lower())
|
207 |
-
return False
|
208 |
-
|
209 |
-
def encoder(self, inputs):
|
210 |
-
# 编码函数,检测并过滤符合拼音规则的中文单词
|
211 |
-
encoded_inputs = []
|
212 |
-
self.code = []
|
213 |
-
for i, word in enumerate(inputs):
|
214 |
-
if self.is_valid_chinese(word):
|
215 |
-
self.code.append((i, word)) # 保存符合要求的中文单词及其索引
|
216 |
-
else:
|
217 |
-
encoded_inputs.append(word)
|
218 |
-
return encoded_inputs
|
219 |
-
|
220 |
-
def decoder(self, inputs):
|
221 |
-
# 解码函数,将符合拼音规则的中文单词还原到原位置
|
222 |
-
decoded_inputs = inputs.copy()
|
223 |
-
for i, word in self.code:
|
224 |
-
decoded_inputs.insert(i, word)
|
225 |
-
return decoded_inputs
|
226 |
-
|
227 |
-
def is_pinyin(self, string):
|
228 |
-
# 判断字符串是否是拼音或英文单词
|
229 |
-
string = string.lower()
|
230 |
-
stringlen = len(string)
|
231 |
-
max_len = 6
|
232 |
-
result = []
|
233 |
-
n = 0
|
234 |
-
while n < stringlen:
|
235 |
-
matched = 0
|
236 |
-
temp_result = []
|
237 |
-
for i in range(max_len, 0, -1):
|
238 |
-
s = string[0:i]
|
239 |
-
if s in self.pinyin_lib:
|
240 |
-
temp_result.append(string[:i])
|
241 |
-
matched = i
|
242 |
-
break
|
243 |
-
if i == 1 and len(temp_result) == 0:
|
244 |
-
return False
|
245 |
-
result.extend(temp_result)
|
246 |
-
string = string[matched:]
|
247 |
-
n += matched
|
248 |
-
return True
|
249 |
-
|
250 |
-
# 定义脚本目录的路径,供拼音文件加载使用
|
251 |
-
script_dir = os.path.dirname(os.path.abspath(__file__))
|
252 |
-
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(script_dir)))
|
253 |
-
|
254 |
-
|
255 |
-
class Model():
|
256 |
-
def __init__(self, modelname, selected_lora_model, selected_gpu):
|
257 |
-
def get_gpu_index(gpu_info, target_gpu_name):
|
258 |
-
"""
|
259 |
-
从 GPU 信息中获取目标 GPU 的索引
|
260 |
-
Args:
|
261 |
-
gpu_info (list): 包含 GPU 名称的列表
|
262 |
-
target_gpu_name (str): 目标 GPU 的名称
|
263 |
-
|
264 |
-
Returns:
|
265 |
-
int: 目标 GPU 的索引,如果未找到则返回 -1
|
266 |
-
"""
|
267 |
-
for i, name in enumerate(gpu_info):
|
268 |
-
if target_gpu_name.lower() in name.lower():
|
269 |
-
return i
|
270 |
-
return -1
|
271 |
-
if selected_gpu != "cpu":
|
272 |
-
gpu_count = torch.cuda.device_count()
|
273 |
-
gpu_info = [torch.cuda.get_device_name(i) for i in range(gpu_count)]
|
274 |
-
selected_gpu_index = get_gpu_index(gpu_info, selected_gpu)
|
275 |
-
self.device_name = f"cuda:{selected_gpu_index}"
|
276 |
-
else:
|
277 |
-
self.device_name = "cpu"
|
278 |
-
print("device_name", self.device_name)
|
279 |
-
self.model = AutoModelForCausalLM.from_pretrained(modelname).to(self.device_name)
|
280 |
-
self.tokenizer = AutoTokenizer.from_pretrained(modelname)
|
281 |
-
# self.translator = pipeline('translation', model=self.original_model, tokenizer=self.tokenizer, src_lang=original_language, tgt_lang=target_language, device=device)
|
282 |
-
|
283 |
-
def generate(self, inputs, original_language, target_languages, max_batch_size):
|
284 |
-
filter_list = [SpecialTokenFilter(), ChevronsFilter(), SimilarFilter(), ChineseFilter()]
|
285 |
-
filter_pipeline = FilterPipeline(filter_list)
|
286 |
-
def language_mapping(original_language):
|
287 |
-
d = {
|
288 |
-
"Achinese (Arabic script)": "ace_Arab",
|
289 |
-
"Achinese (Latin script)": "ace_Latn",
|
290 |
-
"Mesopotamian Arabic": "acm_Arab",
|
291 |
-
"Ta'izzi-Adeni Arabic": "acq_Arab",
|
292 |
-
"Tunisian Arabic": "aeb_Arab",
|
293 |
-
"Afrikaans": "afr_Latn",
|
294 |
-
"South Levantine Arabic": "ajp_Arab",
|
295 |
-
"Akan": "aka_Latn",
|
296 |
-
"Amharic": "amh_Ethi",
|
297 |
-
"North Levantine Arabic": "apc_Arab",
|
298 |
-
"Standard Arabic": "arb_Arab",
|
299 |
-
"Najdi Arabic": "ars_Arab",
|
300 |
-
"Moroccan Arabic": "ary_Arab",
|
301 |
-
"Egyptian Arabic": "arz_Arab",
|
302 |
-
"Assamese": "asm_Beng",
|
303 |
-
"Asturian": "ast_Latn",
|
304 |
-
"Awadhi": "awa_Deva",
|
305 |
-
"Central Aymara": "ayr_Latn",
|
306 |
-
"South Azerbaijani": "azb_Arab",
|
307 |
-
"North Azerbaijani": "azj_Latn",
|
308 |
-
"Bashkir": "bak_Cyrl",
|
309 |
-
"Bambara": "bam_Latn",
|
310 |
-
"Balinese": "ban_Latn",
|
311 |
-
"Belarusian": "bel_Cyrl",
|
312 |
-
"Bemba": "bem_Latn",
|
313 |
-
"Bengali": "ben_Beng",
|
314 |
-
"Bhojpuri": "bho_Deva",
|
315 |
-
"Banjar (Arabic script)": "bjn_Arab",
|
316 |
-
"Banjar (Latin script)": "bjn_Latn",
|
317 |
-
"Tibetan": "bod_Tibt",
|
318 |
-
"Bosnian": "bos_Latn",
|
319 |
-
"Buginese": "bug_Latn",
|
320 |
-
"Bulgarian": "bul_Cyrl",
|
321 |
-
"Catalan": "cat_Latn",
|
322 |
-
"Cebuano": "ceb_Latn",
|
323 |
-
"Czech": "ces_Latn",
|
324 |
-
"Chokwe": "cjk_Latn",
|
325 |
-
"Central Kurdish": "ckb_Arab",
|
326 |
-
"Crimean Tatar": "crh_Latn",
|
327 |
-
"Welsh": "cym_Latn",
|
328 |
-
"Danish": "dan_Latn",
|
329 |
-
"German": "deu_Latn",
|
330 |
-
"Dinka": "dik_Latn",
|
331 |
-
"Jula": "dyu_Latn",
|
332 |
-
"Dzongkha": "dzo_Tibt",
|
333 |
-
"Greek": "ell_Grek",
|
334 |
-
"English": "eng_Latn",
|
335 |
-
"Esperanto": "epo_Latn",
|
336 |
-
"Estonian": "est_Latn",
|
337 |
-
"Basque": "eus_Latn",
|
338 |
-
"Ewe": "ewe_Latn",
|
339 |
-
"Faroese": "fao_Latn",
|
340 |
-
"Persian": "pes_Arab",
|
341 |
-
"Fijian": "fij_Latn",
|
342 |
-
"Finnish": "fin_Latn",
|
343 |
-
"Fon": "fon_Latn",
|
344 |
-
"French": "fra_Latn",
|
345 |
-
"Friulian": "fur_Latn",
|
346 |
-
"Nigerian Fulfulde": "fuv_Latn",
|
347 |
-
"Scottish Gaelic": "gla_Latn",
|
348 |
-
"Irish": "gle_Latn",
|
349 |
-
"Galician": "glg_Latn",
|
350 |
-
"Guarani": "grn_Latn",
|
351 |
-
"Gujarati": "guj_Gujr",
|
352 |
-
"Haitian Creole": "hat_Latn",
|
353 |
-
"Hausa": "hau_Latn",
|
354 |
-
"Hebrew": "heb_Hebr",
|
355 |
-
"Hindi": "hin_Deva",
|
356 |
-
"Chhattisgarhi": "hne_Deva",
|
357 |
-
"Croatian": "hrv_Latn",
|
358 |
-
"Hungarian": "hun_Latn",
|
359 |
-
"Armenian": "hye_Armn",
|
360 |
-
"Igbo": "ibo_Latn",
|
361 |
-
"Iloko": "ilo_Latn",
|
362 |
-
"Indonesian": "ind_Latn",
|
363 |
-
"Icelandic": "isl_Latn",
|
364 |
-
"Italian": "ita_Latn",
|
365 |
-
"Javanese": "jav_Latn",
|
366 |
-
"Japanese": "jpn_Jpan",
|
367 |
-
"Kabyle": "kab_Latn",
|
368 |
-
"Kachin": "kac_Latn",
|
369 |
-
"Arabic": "ar_AR",
|
370 |
-
"Chinese": "zho_Hans",
|
371 |
-
"Spanish": "spa_Latn",
|
372 |
-
"Dutch": "nld_Latn",
|
373 |
-
"Kazakh": "kaz_Cyrl",
|
374 |
-
"Korean": "kor_Hang",
|
375 |
-
"Lithuanian": "lit_Latn",
|
376 |
-
"Malayalam": "mal_Mlym",
|
377 |
-
"Marathi": "mar_Deva",
|
378 |
-
"Nepali": "ne_NP",
|
379 |
-
"Polish": "pol_Latn",
|
380 |
-
"Portuguese": "por_Latn",
|
381 |
-
"Russian": "rus_Cyrl",
|
382 |
-
"Sinhala": "sin_Sinh",
|
383 |
-
"Tamil": "tam_Taml",
|
384 |
-
"Turkish": "tur_Latn",
|
385 |
-
"Ukrainian": "ukr_Cyrl",
|
386 |
-
"Urdu": "urd_Arab",
|
387 |
-
"Vietnamese": "vie_Latn",
|
388 |
-
"Thai":"tha_Thai",
|
389 |
-
"Khmer":"khm_Khmr"
|
390 |
-
}
|
391 |
-
return d[original_language]
|
392 |
-
def process_gpu_translate_result(temp_outputs):
|
393 |
-
outputs = []
|
394 |
-
for temp_output in temp_outputs:
|
395 |
-
length = len(temp_output[0]["generated_translation"])
|
396 |
-
for i in range(length):
|
397 |
-
temp = []
|
398 |
-
for trans in temp_output:
|
399 |
-
temp.append({
|
400 |
-
"target_language": trans["target_language"],
|
401 |
-
"generated_translation": trans['generated_translation'][i],
|
402 |
-
})
|
403 |
-
outputs.append(temp)
|
404 |
-
excel_writer = ExcelFileWriter()
|
405 |
-
excel_writer.write_text(os.path.join(parent_dir,r"temp/empty.xlsx"), outputs, 'A', 1, len(outputs))
|
406 |
-
self.tokenizer.src_lang = language_mapping(original_language)
|
407 |
-
if self.device_name == "cpu":
|
408 |
-
# Tokenize input
|
409 |
-
input_ids = self.tokenizer(inputs, return_tensors="pt", padding=True, max_length=128).to(self.device_name)
|
410 |
-
output = []
|
411 |
-
for target_language in target_languages:
|
412 |
-
# Get language code for the target language
|
413 |
-
target_lang_code = self.tokenizer.lang_code_to_id[language_mapping(target_language)]
|
414 |
-
# Generate translation
|
415 |
-
generated_tokens = self.model.generate(
|
416 |
-
**input_ids,
|
417 |
-
forced_bos_token_id=target_lang_code,
|
418 |
-
max_length=128
|
419 |
-
)
|
420 |
-
generated_translation = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
421 |
-
# Append result to output
|
422 |
-
output.append({
|
423 |
-
"target_language": target_language,
|
424 |
-
"generated_translation": generated_translation,
|
425 |
-
})
|
426 |
-
outputs = []
|
427 |
-
length = len(output[0]["generated_translation"])
|
428 |
-
for i in range(length):
|
429 |
-
temp = []
|
430 |
-
for trans in output:
|
431 |
-
temp.append({
|
432 |
-
"target_language": trans["target_language"],
|
433 |
-
"generated_translation": trans['generated_translation'][i],
|
434 |
-
})
|
435 |
-
outputs.append(temp)
|
436 |
-
return outputs
|
437 |
-
else:
|
438 |
-
# 最大批量大小 = 可用 GPU 内存字节数 / 4 / (张量大小 + 可训练参数)
|
439 |
-
# max_batch_size = 10
|
440 |
-
# Ensure batch size is within model limits:
|
441 |
-
print("length of inputs: ",len(inputs))
|
442 |
-
batch_size = min(len(inputs), int(max_batch_size))
|
443 |
-
batches = [inputs[i:i + batch_size] for i in range(0, len(inputs), batch_size)]
|
444 |
-
print("length of batches size: ", len(batches))
|
445 |
-
temp_outputs = []
|
446 |
-
processed_num = 0
|
447 |
-
for index, batch in enumerate(batches):
|
448 |
-
# Tokenize input
|
449 |
-
print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
|
450 |
-
print(len(batch))
|
451 |
-
print(batch)
|
452 |
-
batch = filter_pipeline.batch_encoder(batch)
|
453 |
-
print(batch)
|
454 |
-
temp = []
|
455 |
-
if len(batch) > 0:
|
456 |
-
for target_language in target_languages:
|
457 |
-
batch_messages = [[
|
458 |
-
{"role": "system", "content": f"你是一个ERP系统中译英专家,你任务是把markdown格式的文本,保留其格式并从{original_language}翻译成{target_language},不要添加多余的内容。"},
|
459 |
-
{"role": "user", "content": input},
|
460 |
-
] for input in batch]
|
461 |
-
batch_texts = [self.tokenizer.apply_chat_template(
|
462 |
-
messages,
|
463 |
-
tokenize=False,
|
464 |
-
add_generation_prompt=True
|
465 |
-
) for messages in batch_messages]
|
466 |
-
self.tokenizer.padding_side = "left"
|
467 |
-
model_inputs = self.tokenizer(
|
468 |
-
batch_texts,
|
469 |
-
return_tensors="pt",
|
470 |
-
padding="longest",
|
471 |
-
truncation=True,
|
472 |
-
).to(self.device_name)
|
473 |
-
generated_ids = self.model.generate(
|
474 |
-
max_new_tokens=512,
|
475 |
-
**model_inputs
|
476 |
-
)
|
477 |
-
# Calculate the length of new tokens generated for each sequence
|
478 |
-
new_tokens = [
|
479 |
-
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
480 |
-
]
|
481 |
-
generated_translation = self.tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
|
482 |
-
# Append result to output
|
483 |
-
temp.append({
|
484 |
-
"target_language": target_language,
|
485 |
-
"generated_translation": generated_translation,
|
486 |
-
})
|
487 |
-
input_ids.to('cpu')
|
488 |
-
del input_ids
|
489 |
-
else:
|
490 |
-
for target_language in target_languages:
|
491 |
-
generated_translation = filter_pipeline.batch_decoder(batch)
|
492 |
-
print(generated_translation)
|
493 |
-
print(len(generated_translation))
|
494 |
-
# Append result to output
|
495 |
-
temp.append({
|
496 |
-
"target_language": target_language,
|
497 |
-
"generated_translation": generated_translation,
|
498 |
-
})
|
499 |
-
temp_outputs.append(temp)
|
500 |
-
processed_num += len(batch)
|
501 |
-
if (index + 1) * max_batch_size // 1000 - index * max_batch_size // 1000 == 1:
|
502 |
-
print("Already processed number: ", len(temp_outputs))
|
503 |
-
process_gpu_translate_result(temp_outputs)
|
504 |
-
outputs = []
|
505 |
-
for temp_output in temp_outputs:
|
506 |
-
length = len(temp_output[0]["generated_translation"])
|
507 |
-
for i in range(length):
|
508 |
-
temp = []
|
509 |
-
for trans in temp_output:
|
510 |
-
temp.append({
|
511 |
-
"target_language": trans["target_language"],
|
512 |
-
"generated_translation": trans['generated_translation'][i],
|
513 |
-
})
|
514 |
-
outputs.append(temp)
|
515 |
-
return outputs
|
516 |
-
for filter in self._filter_list:
|
517 |
-
inputs = filter.encoder(inputs)
|
518 |
-
return inputs
|
519 |
-
|
520 |
-
def batch_decoder(self, inputs):
|
521 |
-
for filter in reversed(self._filter_list):
|
522 |
-
inputs = filter.decoder(inputs)
|
523 |
-
return inputs
|
524 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|