MarkChen1214 commited on
Commit
84f66dc
·
1 Parent(s): c44ada1

feat: Add application file

Browse files
Files changed (2) hide show
  1. app.py +294 -3
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,7 +1,298 @@
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
  demo.launch()
 
1
+ import os
2
+ import requests
3
+ import json
4
+ import time
5
+
6
  import gradio as gr
7
+ from transformers import AutoTokenizer
8
+
9
+
10
+ import socket
11
+ hostname=socket.gethostname()
12
+ IPAddr=socket.gethostbyname(hostname)
13
+ print("Your Computer Name is:" + hostname)
14
+ print("Your Computer IP Address is:" + IPAddr)
15
+
16
+
17
+ DESCRIPTION = """
18
+ # Cloned from MediaTek Research Breeze-7B
19
+ MediaTek Research Breeze-7B (hereinafter referred to as Breeze-7B) is a language model family that builds on top of [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.1), specifically intended for Traditional Chinese use.
20
+ [Breeze-7B-Base](https://huggingface.co/MediaTek-Research/Breeze-7B-Base-v1_0) is the base model for the Breeze-7B series.
21
+ It is suitable for use if you have substantial fine-tuning data to tune it for your specific use case.
22
+ [Breeze-7B-Instruct](https://huggingface.co/MediaTek-Research/Breeze-7B-Instruct-v1_0) derives from the base model Breeze-7B-Base, making the resulting model amenable to be used as-is for commonly seen tasks.
23
+ The current release version of Breeze-7B is v1.0.
24
+ *A project by the members (in alphabetical order): Chan-Jan Hsu 許湛然, Chang-Le Liu 劉昶樂, Feng-Ting Liao 廖峰挺, Po-Chun Hsu 許博竣, Yi-Chang Chen 陳宜昌, and the supervisor Da-Shan Shiu 許大山.*
25
+ **免責聲明: MediaTek Research Breeze-7B 並未針對問答進行安全保護,因此語言模型的任何回應不代表 MediaTek Research 立場。**
26
+ """
27
+
28
+ LICENSE = """
29
+ """
30
+
31
+ DEFAULT_SYSTEM_PROMPT = "You are a helpful AI assistant built by MediaTek Research. The user you are helping speaks Traditional Chinese and comes from Taiwan."
32
+
33
+ API_URL = os.environ.get("API_URL")
34
+ TOKEN = os.environ.get("TOKEN")
35
+ TOKENIZER_REPO = "MediaTek-Research/Breeze-7B-Instruct-v1_0"
36
+ API_MODEL_TYPE = "breeze-7b-instruct-v10"
37
+
38
+ HEADERS = {
39
+ "Authorization": f"Bearer {TOKEN}",
40
+ "Content-Type": "application/json",
41
+ "accept": "application/json"
42
+ }
43
+
44
+ MAX_SEC = 30
45
+ MAX_INPUT_LENGTH = 5000
46
+
47
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_REPO, use_auth_token=os.environ.get("HF_TOKEN"))
48
+
49
+
50
+ def refusal_condition(query):
51
+ # 不要再問這些問題啦!
52
+
53
+ query_remove_space = query.replace(' ', '').lower()
54
+ is_including_tw = False
55
+ for x in ['台灣', '台湾', 'taiwan', 'tw', '中華民國', '中华民国']:
56
+ if x in query_remove_space:
57
+ is_including_tw = True
58
+ is_including_cn = False
59
+ for x in ['中國', '中国', 'cn', 'china', '大陸', '內地', '大陆', '内地', '中華人民共和國', '中华人民共和国']:
60
+ if x in query_remove_space:
61
+ is_including_cn = True
62
+ if is_including_tw and is_including_cn:
63
+ return True
64
+
65
+ for x in ['一個中國', '兩岸', '一中原則', '一中政策', '一个中国', '两岸', '一中原则']:
66
+ if x in query_remove_space:
67
+ return True
68
+
69
+ return False
70
+
71
+ with gr.Blocks() as demo:
72
+ gr.Markdown(DESCRIPTION)
73
+
74
+ system_prompt = gr.Textbox(label='System prompt',
75
+ value=DEFAULT_SYSTEM_PROMPT,
76
+ lines=1)
77
+
78
+ with gr.Accordion(label='Advanced options', open=False):
79
+
80
+ max_new_tokens = gr.Slider(
81
+ label='Max new tokens',
82
+ minimum=32,
83
+ maximum=2048,
84
+ step=1,
85
+ value=1024,
86
+ )
87
+ temperature = gr.Slider(
88
+ label='Temperature',
89
+ minimum=0.01,
90
+ maximum=0.5,
91
+ step=0.01,
92
+ value=0.01,
93
+ )
94
+ top_p = gr.Slider(
95
+ label='Top-p (nucleus sampling)',
96
+ minimum=0.01,
97
+ maximum=0.99,
98
+ step=0.01,
99
+ value=0.01,
100
+ )
101
+
102
+ chatbot = gr.Chatbot(show_copy_button=True, show_share_button=True, )
103
+ with gr.Row():
104
+ msg = gr.Textbox(
105
+ container=False,
106
+ show_label=False,
107
+ placeholder='Type a message...',
108
+ scale=10,
109
+ lines=6
110
+ )
111
+ submit_button = gr.Button('Submit',
112
+ variant='primary',
113
+ scale=1,
114
+ min_width=0)
115
+
116
+ with gr.Row():
117
+ retry_button = gr.Button('🔄 Retry', variant='secondary')
118
+ undo_button = gr.Button('↩️ Undo', variant='secondary')
119
+ clear = gr.Button('🗑️ Clear', variant='secondary')
120
+
121
+ saved_input = gr.State()
122
+
123
+ def user(user_message, history):
124
+ return "", history + [[user_message, None]]
125
+
126
+
127
+ def connect_server(data):
128
+ for _ in range(3):
129
+ s = requests.Session()
130
+ r = s.post(API_URL, headers=HEADERS, json=data, stream=True, timeout=30)
131
+ time.sleep(1)
132
+ if r.status_code == 200:
133
+ return r
134
+ return None
135
+
136
+
137
+ def stream_response_from_server(r):
138
+ # start_time = time.time()
139
+ keep_streaming = True
140
+ for line in r.iter_lines():
141
+ # if time.time() - start_time > MAX_SEC:
142
+ # keep_streaming = False
143
+ # break
144
+
145
+ if line and keep_streaming:
146
+ if r.status_code != 200:
147
+ continue
148
+ json_response = json.loads(line)
149
+
150
+ if "fragment" not in json_response["result"]:
151
+ keep_streaming = False
152
+ break
153
+
154
+ delta = json_response["result"]["fragment"]["data"]["text"]
155
+ yield delta
156
+
157
+ # start_time = time.time()
158
+
159
+
160
+ def bot(history, max_new_tokens, temperature, top_p, system_prompt):
161
+ chat_data = []
162
+ system_prompt = system_prompt.strip()
163
+ if system_prompt:
164
+ chat_data.append({"role": "system", "content": system_prompt})
165
+ for user_msg, assistant_msg in history:
166
+ chat_data.append({"role": "user", "content": user_msg if user_msg is not None else ''})
167
+ chat_data.append({"role": "assistant", "content": assistant_msg if assistant_msg is not None else ''})
168
+
169
+ message = tokenizer.apply_chat_template(chat_data, tokenize=False)
170
+ message = message[3:] # remove SOT token
171
+
172
+ if len(message) > MAX_INPUT_LENGTH:
173
+ raise Exception()
174
+
175
+ response = '[ERROR]'
176
+ if refusal_condition(history[-1][0]):
177
+ history = [['[安全拒答啟動]', '[安全拒答啟動] 請清除再開啟對話']]
178
+ response = '[REFUSAL]'
179
+ yield history
180
+ else:
181
+ data = {
182
+ "model_type": API_MODEL_TYPE,
183
+ "prompt": str(message),
184
+ "parameters": {
185
+ "temperature": float(temperature),
186
+ "top_p": float(top_p),
187
+ "max_new_tokens": int(max_new_tokens),
188
+ "repetition_penalty": 1.1
189
+ }
190
+ }
191
+
192
+ r = connect_server(data)
193
+ if r is not None:
194
+ for delta in stream_response_from_server(r):
195
+ if history[-1][1] is None:
196
+ history[-1][1] = ''
197
+ history[-1][1] += delta
198
+ yield history
199
+
200
+ if history[-1][1].endswith('</s>'):
201
+ history[-1][1] = history[-1][1][:-4]
202
+ yield history
203
+
204
+ response = history[-1][1]
205
+
206
+ if refusal_condition(history[-1][1]):
207
+ history[-1][1] = history[-1][1] + '\n\n**[免責聲明: 此模型並未針對問答進行安全保護,因此語言模型的任何回應不代表 MediaTek Research 立場。]**'
208
+ yield history
209
+ else:
210
+ del history[-1]
211
+ yield history
212
+
213
+ print('== Record ==\nQuery: {query}\nResponse: {response}'.format(query=repr(message), response=repr(history[-1][1])))
214
+
215
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
216
+ fn=bot,
217
+ inputs=[
218
+ chatbot,
219
+ max_new_tokens,
220
+ temperature,
221
+ top_p,
222
+ system_prompt,
223
+ ],
224
+ outputs=chatbot
225
+ )
226
+ submit_button.click(
227
+ user, [msg, chatbot], [msg, chatbot], queue=False
228
+ ).then(
229
+ fn=bot,
230
+ inputs=[
231
+ chatbot,
232
+ max_new_tokens,
233
+ temperature,
234
+ top_p,
235
+ system_prompt,
236
+ ],
237
+ outputs=chatbot
238
+ )
239
+
240
+
241
+ def delete_prev_fn(
242
+ history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
243
+ try:
244
+ message, _ = history.pop()
245
+ except IndexError:
246
+ message = ''
247
+ return history, message or ''
248
+
249
+
250
+ def display_input(message: str,
251
+ history: list[tuple[str, str]]) -> list[tuple[str, str]]:
252
+ history.append((message, ''))
253
+ return history
254
+
255
+ retry_button.click(
256
+ fn=delete_prev_fn,
257
+ inputs=chatbot,
258
+ outputs=[chatbot, saved_input],
259
+ api_name=False,
260
+ queue=False,
261
+ ).then(
262
+ fn=display_input,
263
+ inputs=[saved_input, chatbot],
264
+ outputs=chatbot,
265
+ api_name=False,
266
+ queue=False,
267
+ ).then(
268
+ fn=bot,
269
+ inputs=[
270
+ chatbot,
271
+ max_new_tokens,
272
+ temperature,
273
+ top_p,
274
+ system_prompt,
275
+ ],
276
+ outputs=chatbot,
277
+ )
278
+
279
+ undo_button.click(
280
+ fn=delete_prev_fn,
281
+ inputs=chatbot,
282
+ outputs=[chatbot, saved_input],
283
+ api_name=False,
284
+ queue=False,
285
+ ).then(
286
+ fn=lambda x: x,
287
+ inputs=[saved_input],
288
+ outputs=msg,
289
+ api_name=False,
290
+ queue=False,
291
+ )
292
+
293
+ clear.click(lambda: None, None, chatbot, queue=False)
294
 
295
+ gr.Markdown(LICENSE)
 
296
 
297
+ demo.queue(concurrency_count=4, max_size=128)
298
  demo.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers==4.38.2
2
+ sentencepiece==0.2.0