unsubscribe commited on
Commit
973cec1
·
verified ·
1 Parent(s): 80dd0d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -184
app.py CHANGED
@@ -1,76 +1,18 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import random
3
- import spaces
4
- from threading import Lock
5
- from typing import Literal, Optional, Sequence, Union
6
 
7
- import gradio as gr
8
-
9
- from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig,
10
- TurbomindEngineConfig)
11
- from lmdeploy.model import ChatTemplateConfig
12
- from lmdeploy.serve.async_engine import AsyncEngine
13
- from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
14
-
15
-
16
- class InterFace:
17
- async_engine: AsyncEngine = None
18
- global_session_id: int = 0
19
- lock = Lock()
20
-
21
- @spaces.GPU
22
- async def chat_stream_local(instruction: str, state_chatbot: Sequence,
23
- cancel_btn: gr.Button, reset_btn: gr.Button,
24
- session_id: int, top_p: float, temperature: float,
25
- request_output_len: int):
26
- """Chat with AI assistant.
27
-
28
- Args:
29
- instruction (str): user's prompt
30
- state_chatbot (Sequence): the chatting history
31
- cancel_btn (gr.Button): the cancel button
32
- reset_btn (gr.Button): the reset button
33
- session_id (int): the session id
34
- """
35
- state_chatbot = state_chatbot + [(instruction, None)]
36
-
37
- yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
38
- gen_config = GenerationConfig(max_new_tokens=request_output_len,
39
- top_p=top_p,
40
- top_k=40,
41
- temperature=temperature,
42
- random_seed=random.getrandbits(64)
43
- if len(state_chatbot) == 1 else None)
44
-
45
- async for outputs in InterFace.async_engine.generate(
46
- instruction,
47
- session_id,
48
- gen_config=gen_config,
49
- stream_response=True,
50
- sequence_start=(len(state_chatbot) == 1),
51
- sequence_end=False):
52
- response = outputs.response
53
- if outputs.finish_reason == 'length':
54
- gr.Warning('WARNING: exceed session max length.'
55
- ' Please restart the session by reset button.')
56
- if outputs.generate_token_len < 0:
57
- gr.Warning('WARNING: running on the old session.'
58
- ' Please restart the session by reset button.')
59
- if state_chatbot[-1][-1] is None:
60
- state_chatbot[-1] = (state_chatbot[-1][0], response)
61
- else:
62
- state_chatbot[-1] = (state_chatbot[-1][0],
63
- state_chatbot[-1][1] + response
64
- ) # piece by piece
65
- yield (state_chatbot, state_chatbot, enable_btn, disable_btn)
66
 
67
- yield (state_chatbot, state_chatbot, disable_btn, enable_btn)
 
 
 
 
68
 
69
- @spaces.GPU
70
  async def reset_local_func(instruction_txtbox: gr.Textbox,
71
  state_chatbot: Sequence, session_id: int):
72
  """reset the session.
73
-
74
  Args:
75
  instruction_txtbox (str): user's prompt
76
  state_chatbot (Sequence): the chatting history
@@ -83,11 +25,9 @@ async def reset_local_func(instruction_txtbox: gr.Textbox,
83
  session_id = InterFace.global_session_id
84
  return (state_chatbot, state_chatbot, gr.Textbox.update(value=''), session_id)
85
 
86
- @spaces.GPU
87
  async def cancel_local_func(state_chatbot: Sequence, cancel_btn: gr.Button,
88
  reset_btn: gr.Button, session_id: int):
89
  """stop the session.
90
-
91
  Args:
92
  instruction_txtbox (str): user's prompt
93
  state_chatbot (Sequence): the chatting history
@@ -95,11 +35,11 @@ async def cancel_local_func(state_chatbot: Sequence, cancel_btn: gr.Button,
95
  reset_btn (gr.Button): the reset button
96
  session_id (int): the session id
97
  """
98
- yield (state_chatbot, disable_btn, disable_btn)
99
  InterFace.async_engine.stop_session(session_id)
100
  # pytorch backend does not support resume chat history now
101
  if InterFace.async_engine.backend == 'pytorch':
102
- yield (state_chatbot, disable_btn, enable_btn)
103
  else:
104
  with InterFace.lock:
105
  InterFace.global_session_id += 1
@@ -119,118 +59,62 @@ async def cancel_local_func(state_chatbot: Sequence, cancel_btn: gr.Button,
119
  pass
120
  yield (state_chatbot, disable_btn, enable_btn, session_id)
121
 
122
- @spaces.GPU
123
- def run_local(model_path: str,
124
- model_name: Optional[str] = None,
125
- backend: Literal['turbomind', 'pytorch'] = 'turbomind',
126
- backend_config: Optional[Union[PytorchEngineConfig,
127
- TurbomindEngineConfig]] = None,
128
- chat_template_config: Optional[ChatTemplateConfig] = None,
129
- server_name: str = 'localhost',
130
- server_port: int = 6006,
131
- tp: int = 1,
132
- **kwargs):
133
- """chat with AI assistant through web ui.
134
-
135
- Args:
136
- model_path (str): the path of a model.
137
- It could be one of the following options:
138
- - i) A local directory path of a turbomind model which is
139
- converted by `lmdeploy convert` command or download from
140
- ii) and iii).
141
- - ii) The model_id of a lmdeploy-quantized model hosted
142
- inside a model repo on huggingface.co, such as
143
- "InternLM/internlm-chat-20b-4bit",
144
- "lmdeploy/llama2-chat-70b-4bit", etc.
145
- - iii) The model_id of a model hosted inside a model repo
146
- on huggingface.co, such as "internlm/internlm-chat-7b",
147
- "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
148
- and so on.
149
- model_name (str): needed when model_path is a pytorch model on
150
- huggingface.co, such as "internlm/internlm-chat-7b",
151
- "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat" and so on.
152
- backend (str): either `turbomind` or `pytorch` backend. Default to
153
- `turbomind` backend.
154
- backend_config (TurbomindEngineConfig | PytorchEngineConfig): beckend
155
- config instance. Default to none.
156
- chat_template_config (ChatTemplateConfig): chat template configuration.
157
- Default to None.
158
- server_name (str): the ip address of gradio server
159
- server_port (int): the port of gradio server
160
- tp (int): tensor parallel for Turbomind
161
- """
162
- InterFace.async_engine = AsyncEngine(
163
- model_path=model_path,
164
- backend=backend,
165
- backend_config=backend_config,
166
- chat_template_config=chat_template_config,
167
- model_name=model_name,
168
- tp=tp,
169
- **kwargs)
170
-
171
- with gr.Blocks(css=CSS, theme=THEME) as demo:
172
- state_chatbot = gr.State([])
173
- state_session_id = gr.State(0)
174
-
175
- with gr.Column(elem_id='container'):
176
- gr.Markdown('## LMDeploy Playground')
177
-
178
- chatbot = gr.Chatbot(
179
- elem_id='chatbot',
180
- label=InterFace.async_engine.engine.model_name)
181
- instruction_txtbox = gr.Textbox(
182
- placeholder='Please input the instruction',
183
- label='Instruction')
184
- with gr.Row():
185
- cancel_btn = gr.Button(value='Cancel', interactive=False)
186
- reset_btn = gr.Button(value='Reset')
187
- with gr.Row():
188
- request_output_len = gr.Slider(1,
189
- 2048,
190
- value=512,
191
- step=1,
192
- label='Maximum new tokens')
193
- top_p = gr.Slider(0.01, 1, value=0.8, step=0.01, label='Top_p')
194
- temperature = gr.Slider(0.01,
195
- 1.5,
196
- value=0.7,
197
- step=0.01,
198
- label='Temperature')
199
-
200
- send_event = instruction_txtbox.submit(chat_stream_local, [
201
- instruction_txtbox, state_chatbot, cancel_btn, reset_btn,
202
- state_session_id, top_p, temperature, request_output_len
203
- ], [state_chatbot, chatbot, cancel_btn, reset_btn])
204
- instruction_txtbox.submit(
205
- lambda: gr.Textbox.update(value=''),
206
- [],
207
- [instruction_txtbox],
208
- )
209
- cancel_btn.click(
210
- cancel_local_func,
211
- [state_chatbot, cancel_btn, reset_btn, state_session_id],
212
- [state_chatbot, cancel_btn, reset_btn, state_session_id],
213
- cancels=[send_event])
214
-
215
- reset_btn.click(reset_local_func,
216
- [instruction_txtbox, state_chatbot, state_session_id],
217
- [state_chatbot, chatbot, instruction_txtbox],
218
- cancels=[send_event])
219
-
220
- def init():
221
- with InterFace.lock:
222
- InterFace.global_session_id += 1
223
- new_session_id = InterFace.global_session_id
224
- return new_session_id
225
-
226
- demo.load(init, inputs=None, outputs=[state_session_id])
227
-
228
- demo.queue(concurrency_count=InterFace.async_engine.instance_num,
229
- max_size=100).launch()
230
-
231
-
232
 
233
- backend_config = TurbomindEngineConfig(max_batch_size=1, cache_max_entry_count=0.05, model_format='awq')
234
- model_path = 'internlm/internlm2-chat-20b-4bits'
235
 
236
- run_local(model_path, backend_config=backend_config)
 
 
1
+ from lmdeploy.serve.gradio.turbomind_coupled import *
2
+ from lmdeploy.messages import TurbomindEngineConfig
 
 
 
3
 
4
+ backend_config = TurbomindEngineConfig(max_batch_size=1, cache_max_entry_count=0.05, model_format='awq')
5
+ model_path = 'internlm/internlm2-chat-20b-4bits'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ InterFace.async_engine = AsyncEngine(
8
+ model_path=model_path,
9
+ backend='turbomind',
10
+ backend_config=backend_config,
11
+ tp=1)
12
 
 
13
  async def reset_local_func(instruction_txtbox: gr.Textbox,
14
  state_chatbot: Sequence, session_id: int):
15
  """reset the session.
 
16
  Args:
17
  instruction_txtbox (str): user's prompt
18
  state_chatbot (Sequence): the chatting history
 
25
  session_id = InterFace.global_session_id
26
  return (state_chatbot, state_chatbot, gr.Textbox.update(value=''), session_id)
27
 
 
28
  async def cancel_local_func(state_chatbot: Sequence, cancel_btn: gr.Button,
29
  reset_btn: gr.Button, session_id: int):
30
  """stop the session.
 
31
  Args:
32
  instruction_txtbox (str): user's prompt
33
  state_chatbot (Sequence): the chatting history
 
35
  reset_btn (gr.Button): the reset button
36
  session_id (int): the session id
37
  """
38
+ yield (state_chatbot, disable_btn, disable_btn, session_id)
39
  InterFace.async_engine.stop_session(session_id)
40
  # pytorch backend does not support resume chat history now
41
  if InterFace.async_engine.backend == 'pytorch':
42
+ yield (state_chatbot, disable_btn, enable_btn, session_id)
43
  else:
44
  with InterFace.lock:
45
  InterFace.global_session_id += 1
 
59
  pass
60
  yield (state_chatbot, disable_btn, enable_btn, session_id)
61
 
62
+ with gr.Blocks(css=CSS, theme=THEME) as demo:
63
+ state_chatbot = gr.State([])
64
+ state_session_id = gr.State(0)
65
+
66
+ with gr.Column(elem_id='container'):
67
+ gr.Markdown('## LMDeploy Playground')
68
+
69
+ chatbot = gr.Chatbot(
70
+ elem_id='chatbot',
71
+ label=InterFace.async_engine.engine.model_name)
72
+ instruction_txtbox = gr.Textbox(
73
+ placeholder='Please input the instruction',
74
+ label='Instruction')
75
+ with gr.Row():
76
+ cancel_btn = gr.Button(value='Cancel', interactive=False)
77
+ reset_btn = gr.Button(value='Reset')
78
+ with gr.Row():
79
+ request_output_len = gr.Slider(1,
80
+ 2048,
81
+ value=512,
82
+ step=1,
83
+ label='Maximum new tokens')
84
+ top_p = gr.Slider(0.01, 1, value=0.8, step=0.01, label='Top_p')
85
+ temperature = gr.Slider(0.01,
86
+ 1.5,
87
+ value=0.7,
88
+ step=0.01,
89
+ label='Temperature')
90
+
91
+ send_event = instruction_txtbox.submit(chat_stream_local, [
92
+ instruction_txtbox, state_chatbot, cancel_btn, reset_btn,
93
+ state_session_id, top_p, temperature, request_output_len
94
+ ], [state_chatbot, chatbot, cancel_btn, reset_btn])
95
+ instruction_txtbox.submit(
96
+ lambda: gr.Textbox.update(value=''),
97
+ [],
98
+ [instruction_txtbox],
99
+ )
100
+ cancel_btn.click(
101
+ cancel_local_func,
102
+ [state_chatbot, cancel_btn, reset_btn, state_session_id],
103
+ [state_chatbot, cancel_btn, reset_btn, state_session_id],
104
+ cancels=[send_event])
105
+
106
+ reset_btn.click(reset_local_func,
107
+ [instruction_txtbox, state_chatbot, state_session_id],
108
+ [state_chatbot, chatbot, instruction_txtbox, state_session_id],
109
+ cancels=[send_event])
110
+
111
+ def init():
112
+ with InterFace.lock:
113
+ InterFace.global_session_id += 1
114
+ new_session_id = InterFace.global_session_id
115
+ return new_session_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ demo.load(init, inputs=None, outputs=[state_session_id])
 
118
 
119
+ demo.queue(concurrency_count=InterFace.async_engine.instance_num,
120
+ max_size=100).launch()