qminh369 commited on
Commit
10f85ab
·
verified ·
1 Parent(s): 5573dde

Upload 4 files

Browse files
app.py CHANGED
@@ -1,12 +1,13 @@
1
  import gradio as gr
2
  import json
3
- from llmlingua import PromptCompressor
 
4
  import tiktoken
5
 
6
  compressors = {
7
  "xlm-roberta": PromptCompressor(
8
- model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
9
- #model_name='qminh369/token-classification-llmlingua2-xlm-roberta-42k_merge_1_epoch',
10
  use_llmlingua2=True,
11
  device_map="cpu"
12
  )
@@ -26,7 +27,8 @@ def compress(original_prompt, compression_rate, base_model="xlm-roberta", force_
26
  force_tokens=force_tokens,
27
  chunk_end_tokens=chunk_end_tokens,
28
  return_word_label=True,
29
- drop_consecutive=True
 
30
  )
31
 
32
  compressed_prompt = results["compressed_prompt"]
 
1
  import gradio as gr
2
  import json
3
+ #from llmlingua import PromptCompressor
4
+ from utils_llmlingua2_test import PromptCompressor
5
  import tiktoken
6
 
7
  compressors = {
8
  "xlm-roberta": PromptCompressor(
9
+ #model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
10
+ model_name='qminh369/token-classification-llmlingua2-xlm-roberta-42k_merge_1_epoch',
11
  use_llmlingua2=True,
12
  device_map="cpu"
13
  )
 
27
  force_tokens=force_tokens,
28
  chunk_end_tokens=chunk_end_tokens,
29
  return_word_label=True,
30
+ drop_consecutive=True,
31
+ force_reserve_digit=True,
32
  )
33
 
34
  compressed_prompt = results["compressed_prompt"]
core_utils_llmlingua2.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import string
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+
9
+ class TokenClfDataset(Dataset): # Hàm tạo custom dataset
10
+ def __init__(
11
+ self,
12
+ texts,
13
+ max_len=512, # 256 (phobert) 512 (xlm-roberta)
14
+ tokenizer=None,
15
+ model_name="m_bert",
16
+ ):
17
+ self.len = len(texts)
18
+ self.texts = texts
19
+ self.tokenizer = tokenizer
20
+ self.max_len = max_len
21
+ self.model_name = model_name
22
+ if "m_bert" in model_name:
23
+ self.cls_token = "[CLS]"
24
+ self.sep_token = "[SEP]"
25
+ self.unk_token = "[UNK]"
26
+ self.pad_token = "[PAD]"
27
+ self.mask_token = "[MASK]"
28
+ elif "xlm-roberta-large" in model_name:
29
+ self.bos_token = "<s>"
30
+ self.eos_token = "</s>"
31
+ self.sep_token = "</s>"
32
+ self.cls_token = "<s>"
33
+ self.unk_token = "<unk>"
34
+ self.pad_token = "<pad>"
35
+ self.mask_token = "<mask>"
36
+ elif "xlm-roberta" in model_name:
37
+ self.bos_token = "<s>"
38
+ self.eos_token = "</s>"
39
+ self.sep_token = "</s>"
40
+ self.cls_token = "<s>"
41
+ self.unk_token = "<unk>"
42
+ self.pad_token = "<pad>"
43
+ self.mask_token = "<mask>"
44
+ elif "phobert" in model_name:
45
+ self.bos_token = "<s>"
46
+ self.eos_token = "</s>"
47
+ self.sep_token = "</s>"
48
+ self.cls_token = "<s>"
49
+ self.unk_token = "<unk>"
50
+ self.pad_token = "<pad>"
51
+ self.mask_token = "<mask>"
52
+ #else: raise NotImplementedError()
53
+
54
+ def __getitem__(self, index):
55
+ text = self.texts[index]
56
+ tokenized_text = self.tokenizer.tokenize(text)
57
+
58
+ tokenized_text = (
59
+ [self.cls_token] + tokenized_text + [self.sep_token]
60
+ ) # add special tokens
61
+
62
+ if len(tokenized_text) > self.max_len:
63
+ tokenized_text = tokenized_text[: self.max_len]
64
+ else:
65
+ tokenized_text = tokenized_text + [
66
+ self.pad_token for _ in range(self.max_len - len(tokenized_text))
67
+ ]
68
+
69
+ attn_mask = [1 if tok != self.pad_token else 0 for tok in tokenized_text]
70
+
71
+ ids = self.tokenizer.convert_tokens_to_ids(tokenized_text)
72
+
73
+ return {
74
+ "ids": torch.tensor(ids, dtype=torch.long),
75
+ "mask": torch.tensor(attn_mask, dtype=torch.long),
76
+ }
77
+
78
+ def __len__(self):
79
+ return self.len
80
+
81
+
82
+ def seed_everything(seed: int):
83
+ random.seed(seed)
84
+ os.environ["PYTHONHASHSEED"] = str(seed)
85
+ np.random.seed(seed)
86
+ torch.manual_seed(seed)
87
+ torch.cuda.manual_seed(seed)
88
+ torch.backends.cudnn.deterministic = True
89
+ torch.backends.cudnn.benchmark = False
90
+
91
+
92
+ def is_begin_of_new_word(token, model_name, force_tokens, token_map): # Thêm kí tự bắt đầu vào từ mới
93
+ if "m_bert" in model_name:
94
+ if token.lstrip("##") in force_tokens or token.lstrip("##") in set(
95
+ token_map.values()
96
+ ):
97
+ return True
98
+ return not token.startswith("##")
99
+ elif "xlm-roberta-large" in model_name:
100
+ #print("xlm-roberta-large")
101
+ if (
102
+ token in string.punctuation
103
+ or token in force_tokens
104
+ or token in set(token_map.values())
105
+ ):
106
+ return True
107
+ return token.startswith("▁") # check xem token có bắt đầu bằng kí tự "_" hay ko -> Trả về False
108
+ elif "xlm-roberta" in model_name:
109
+ #print("xlm-roberta-large")
110
+ if (
111
+ token in string.punctuation
112
+ or token in force_tokens
113
+ or token in set(token_map.values())
114
+ ):
115
+ return True
116
+ return token.startswith("▁")
117
+ elif "phobert" in model_name:
118
+ #print("minh phobert")
119
+ #print("xlm-roberta-large")
120
+ if (
121
+ token in string.punctuation # điều kiện hoặc
122
+ or token in force_tokens
123
+ or token in set(token_map.values())
124
+ ):
125
+ return True
126
+ #return token.startswith("▁") #
127
+ #return not token.startswith("▁")
128
+ #return not token.startswith("@@")
129
+ return not token.endswith("@@")
130
+ #return token.startswith("@@")
131
+ #else: raise NotImplementedError()
132
+
133
+ def replace_added_token(token, token_map):
134
+ for ori_token, new_token in token_map.items():
135
+ token = token.replace(new_token, ori_token)
136
+ return token
137
+
138
+ def get_pure_token(token, model_name): # hàm get pure token trả về token gốc (sau khi loại bỏ kí tự đặc biệt subword)
139
+ if "m_bert" in model_name:
140
+ return token.lstrip("##")
141
+ elif "xlm-roberta-large" in model_name:
142
+ return token.lstrip("▁") # bỏ kí tự "_" ở phía bên trái của từ
143
+ elif "xlm-roberta" in model_name:
144
+ return token.lstrip("▁") # bỏ kí tự "_" ở ph��a bên trái của từ
145
+ elif "phobert" in model_name:
146
+ #return token.lstrip("▁")
147
+ #return token.lstrip("@@")
148
+ return token.rstrip("@@")
149
+ # else: raise NotImplementedError()
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
  gradio
2
  accelerate
3
- llmlingua==0.2.1
4
  tiktoken
 
1
  gradio
2
  accelerate
 
3
  tiktoken
utils_llmlingua2_test.py ADDED
The diff for this file is too large to render. See raw diff