Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,63 +1,67 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
58 |
],
|
|
|
59 |
)
|
60 |
|
61 |
-
|
62 |
if __name__ == "__main__":
|
63 |
-
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
+
|
5 |
+
class MedS_Llama3:
|
6 |
+
def __init__(self, model_path: str):
|
7 |
+
# 加载模型到CPU
|
8 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
9 |
+
model_path,
|
10 |
+
device_map='cpu', # 指定加载到CPU
|
11 |
+
torch_dtype=torch.float32 # 使用标准的float32精度
|
12 |
+
)
|
13 |
+
self.model.config.pad_token_id = self.model.config.eos_token_id = 128009
|
14 |
+
|
15 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
16 |
+
model_path,
|
17 |
+
model_max_length=2048,
|
18 |
+
padding_side="right"
|
19 |
+
)
|
20 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
21 |
+
self.model.eval()
|
22 |
+
print('Model and tokenizer loaded on CPU!')
|
23 |
+
|
24 |
+
def chat(self, query: str, instruction: str = "If you are a doctor, please perform clinical consulting with the patient.") -> str:
|
25 |
+
input_sentence = f"{instruction}\n\n{query}"
|
26 |
+
input_tokens = self.tokenizer(
|
27 |
+
input_sentence,
|
28 |
+
return_tensors="pt",
|
29 |
+
padding=True,
|
30 |
+
truncation=True
|
31 |
+
)
|
32 |
+
|
33 |
+
output = self.model.generate(
|
34 |
+
**input_tokens,
|
35 |
+
max_new_tokens=512, # 降低生成的最大新tokens数目来节省内存
|
36 |
+
eos_token_id=128009
|
37 |
+
)
|
38 |
+
|
39 |
+
generated_text = self.tokenizer.decode(
|
40 |
+
output[0][input_tokens['input_ids'].shape[1]:],
|
41 |
+
skip_special_tokens=True
|
42 |
+
)
|
43 |
+
|
44 |
+
return generated_text.strip()
|
45 |
+
|
46 |
+
# 实例化模型
|
47 |
+
model_path = "Henrychur/MMedS-Llama-3-8B" # 确保这里是模型的正确路径
|
48 |
+
chat_model = MedS_Llama3(model_path)
|
49 |
+
|
50 |
+
# 定义 Gradio 接口中使用的响应函数
|
51 |
+
def respond(message, system_message):
|
52 |
+
# 每次对话结束后清空历史,只使用当前输入和系统指令
|
53 |
+
response = chat_model.chat(query=message, instruction=system_message)
|
54 |
+
yield response
|
55 |
+
|
56 |
+
# 设置 Gradio 聊天界面
|
57 |
+
demo = gr.Interface(
|
58 |
+
fn=respond,
|
59 |
+
inputs=[
|
60 |
+
gr.Textbox(label="What is the treatment for diabetes?"),
|
61 |
+
gr.Textbox(value="If you are a doctor, please perform clinical consulting with the patient.", label="System message")
|
62 |
],
|
63 |
+
outputs="text"
|
64 |
)
|
65 |
|
|
|
66 |
if __name__ == "__main__":
|
67 |
+
demo.launch()
|