Update app.py
Browse files
app.py
CHANGED
@@ -1,64 +1,128 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
-
|
|
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
""
|
7 |
-
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
history: list[tuple[str, str]],
|
13 |
-
system_message,
|
14 |
-
max_tokens,
|
15 |
-
temperature,
|
16 |
-
top_p,
|
17 |
-
):
|
18 |
-
messages = [{"role": "system", "content": system_message}]
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
if val[1]:
|
24 |
-
messages.append({"role": "assistant", "content": val[1]})
|
25 |
|
26 |
-
|
|
|
27 |
|
28 |
-
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
38 |
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
],
|
|
|
|
|
60 |
)
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
2 |
import gradio as gr
|
3 |
+
import torch
|
4 |
+
import re
|
5 |
|
6 |
+
# 加载医学诊断模型
|
7 |
+
model = AutoModelForSequenceClassification.from_pretrained("yxccai/ds-ai-model")
|
8 |
+
tokenizer = AutoTokenizer.from_pretrained("yxccai/ds-ai-model")
|
|
|
9 |
|
10 |
+
# 疾病标签映射(必须与训练时完全一致!)
|
11 |
+
disease_labels = [
|
12 |
+
"脑梗死",
|
13 |
+
"动脉狭窄",
|
14 |
+
"动脉闭塞",
|
15 |
+
"脑缺血",
|
16 |
+
"其他脑血管病变",
|
17 |
+
"脑出血",
|
18 |
+
"动脉瘤",
|
19 |
+
"动脉壶腹",
|
20 |
+
# 根据实际标签补充完整列表...
|
21 |
+
]
|
22 |
|
23 |
+
# 标准化输入模板(与训练时完全一致)
|
24 |
+
MEDICAL_PROMPT = """以下是描述任务的指令,请写出一个适当完成请求的回答。
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
+
### 指令:
|
27 |
+
你是一位专业医生,需要根据患者的主诉和检查结果给出诊断结论。回答必须严格按照以下格式:
|
28 |
+
诊断结论:[具体疾病名称]
|
|
|
|
|
29 |
|
30 |
+
### 问题:
|
31 |
+
{}
|
32 |
|
33 |
+
### 回答:
|
34 |
+
{}""" # 第二个占位符保留用于兼容性
|
35 |
|
36 |
+
def medical_diagnosis(symptoms):
|
37 |
+
try:
|
38 |
+
# 输入预处理
|
39 |
+
symptoms = symptoms.strip()
|
40 |
+
if not symptoms:
|
41 |
+
return "⚠️ 请输入有效的症状描述"
|
42 |
+
|
43 |
+
# 检测危险关键词
|
44 |
+
emergency_keywords = ["昏迷", "胸痛", "呼吸困难", "意识丧失"]
|
45 |
+
if any(kw in symptoms for kw in emergency_keywords):
|
46 |
+
return "🚨 检测到危急症状!请立即前往急诊科就诊!"
|
47 |
|
48 |
+
# 构建标准化输入
|
49 |
+
formatted_input = MEDICAL_PROMPT.format(symptoms, "")
|
50 |
+
|
51 |
+
# 模型推理
|
52 |
+
inputs = tokenizer(
|
53 |
+
formatted_input,
|
54 |
+
max_length=1024,
|
55 |
+
truncation=True,
|
56 |
+
padding=True,
|
57 |
+
return_tensors="pt"
|
58 |
+
).to("cuda")
|
59 |
+
|
60 |
+
with torch.no_grad():
|
61 |
+
outputs = model(**inputs)
|
62 |
+
|
63 |
+
# 后处理
|
64 |
+
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
65 |
+
predicted_class = torch.argmax(probabilities).item()
|
66 |
+
confidence = probabilities[0][predicted_class].item()
|
67 |
+
|
68 |
+
# 生成结构化报告
|
69 |
+
diagnosis_report = f"""
|
70 |
+
🩺 诊断报告:
|
71 |
+
-----------------------------
|
72 |
+
▪ 主要症状:{extract_key_symptoms(symptoms)}
|
73 |
+
▪ 最可能诊断:{disease_labels[predicted_class]}
|
74 |
+
▪ 置信度:{confidence*100:.1f}%
|
75 |
+
▪ 鉴别诊断:{get_differential_diagnosis(predicted_class)}
|
76 |
+
-----------------------------
|
77 |
+
⚠️ 注意:本结果仅供参考,请以临床检查为准
|
78 |
+
"""
|
79 |
+
return diagnosis_report
|
80 |
+
|
81 |
+
except Exception as e:
|
82 |
+
return f"❌ 诊断过程中出现错误:{str(e)}"
|
83 |
|
84 |
+
def extract_key_symptoms(text):
|
85 |
+
"""提取关键症状"""
|
86 |
+
keywords = ["头晕", "肢体无力", "言语不利", "麻木", "呕吐"]
|
87 |
+
found = [kw for kw in keywords if kw in text]
|
88 |
+
return "、".join(found[:3]) + "等" if len(found) > 3 else "、".join(found)
|
89 |
|
90 |
+
def get_differential_diagnosis(disease_id):
|
91 |
+
"""获取鉴别诊断"""
|
92 |
+
differential_map = {
|
93 |
+
0: ["脑出血", "短暂性脑缺血发作", "颅内肿瘤"],
|
94 |
+
1: ["动脉粥样硬化", "血管炎", "纤维肌性发育不良"],
|
95 |
+
2: ["动脉栓塞", "大动脉炎", "血栓形成"],
|
96 |
+
3: ["梅尼埃病", "前庭神经炎", "低血糖反应"],
|
97 |
+
}
|
98 |
+
return " | ".join(differential_map.get(disease_id, []))
|
99 |
+
|
100 |
+
# 创建医疗专用界面
|
101 |
+
interface = gr.Interface(
|
102 |
+
fn=medical_diagnosis,
|
103 |
+
inputs=gr.Textbox(
|
104 |
+
label="患者症状描述",
|
105 |
+
placeholder="请输入详细症状(示例:持续头痛三天,伴随恶心呕吐)",
|
106 |
+
lines=5
|
107 |
+
),
|
108 |
+
outputs=gr.Markdown(
|
109 |
+
label="AI辅助诊断报告",
|
110 |
+
show_copy_button=True
|
111 |
+
),
|
112 |
+
title="神经内科疾病辅助诊断系统",
|
113 |
+
description="**专业提示**:请输入完整的症状描述,包括:\n- 主要症状及持续时间\n- 伴随症状\n- 既往病史\n- 检查结果",
|
114 |
+
examples=[
|
115 |
+
["主诉:左侧肢体无力3天,伴言语不清。既往脑梗死病史5年..."],
|
116 |
+
["头晕伴行走不稳2天,MRI显示小脑梗死灶..."],
|
117 |
+
["突发右侧肢体麻木,CTA显示颈动脉狭窄..."],
|
118 |
],
|
119 |
+
allow_flagging="never",
|
120 |
+
theme="soft"
|
121 |
)
|
122 |
|
123 |
+
# 安全设置
|
124 |
+
interface.launch(
|
125 |
+
server_name="0.0.0.0",
|
126 |
+
server_port=7860,
|
127 |
+
auth=("doctor", "dsaimodel") # 建议修改为自定义账号密码
|
128 |
+
)
|