|
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import gradio as gr |
|
import torch |
|
import re |
|
|
|
|
|
from transformers import AutoConfig, LlamaConfig, AutoModelForSequenceClassification, AutoTokenizer |
|
import torch |
|
|
|
|
|
class CustomLlamaConfig(LlamaConfig): |
|
model_type = "custom_llama" |
|
|
|
def _rope_scaling_validation(self): |
|
pass |
|
|
|
|
|
AutoConfig.register("custom_llama", CustomLlamaConfig) |
|
|
|
|
|
|
|
config = CustomLlamaConfig.from_pretrained("unsloth/DeepSeek-R1-Distill-Llama-8B") |
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
"unsloth/DeepSeek-R1-Distill-Llama-8B", |
|
config=config, |
|
trust_remote_code=True, |
|
_config_class=CustomLlamaConfig |
|
) |
|
|
|
|
|
model.load_adapter("yxccai/ds-ai-app") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("yxccai/ds-ai-app") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
disease_labels = [ |
|
"脑梗死", |
|
"动脉狭窄", |
|
"动脉闭塞", |
|
"脑缺血", |
|
"其他脑血管病变", |
|
"脑出血", |
|
"动脉瘤", |
|
"动脉壶腹", |
|
|
|
] |
|
|
|
|
|
MEDICAL_PROMPT = """以下是描述任务的指令,请写出一个适当完成请求的回答。 |
|
|
|
### 指令: |
|
你是一位专业医生,需要根据患者的主诉和检查结果给出诊断结论。回答必须严格按照以下格式: |
|
诊断结论:[具体疾病名称] |
|
|
|
### 问题: |
|
{} |
|
|
|
### 回答: |
|
{}""" |
|
|
|
def medical_diagnosis(symptoms): |
|
try: |
|
|
|
symptoms = symptoms.strip() |
|
if not symptoms: |
|
return "⚠️ 请输入有效的症状描述" |
|
|
|
|
|
emergency_keywords = ["昏迷", "胸痛", "呼吸困难", "意识丧失"] |
|
if any(kw in symptoms for kw in emergency_keywords): |
|
return "🚨 检测到危急症状!请立即前往急诊科就诊!" |
|
|
|
|
|
formatted_input = MEDICAL_PROMPT.format(symptoms, "") |
|
|
|
|
|
inputs = tokenizer( |
|
formatted_input, |
|
max_length=1024, |
|
truncation=True, |
|
padding=True, |
|
return_tensors="pt" |
|
).to("cuda") |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
predicted_class = torch.argmax(probabilities).item() |
|
confidence = probabilities[0][predicted_class].item() |
|
|
|
|
|
diagnosis_report = f""" |
|
🩺 诊断报告: |
|
----------------------------- |
|
▪ 主要症状:{extract_key_symptoms(symptoms)} |
|
▪ 最可能诊断:{disease_labels[predicted_class]} |
|
▪ 置信度:{confidence*100:.1f}% |
|
▪ 鉴别诊断:{get_differential_diagnosis(predicted_class)} |
|
----------------------------- |
|
⚠️ 注意:本结果仅供参考,请以临床检查为准 |
|
""" |
|
return diagnosis_report |
|
|
|
except Exception as e: |
|
return f"❌ 诊断过程中出现错误:{str(e)}" |
|
|
|
def extract_key_symptoms(text): |
|
"""提取关键症状""" |
|
keywords = ["头晕", "肢体无力", "言语不利", "麻木", "呕吐"] |
|
found = [kw for kw in keywords if kw in text] |
|
return "、".join(found[:3]) + "等" if len(found) > 3 else "、".join(found) |
|
|
|
def get_differential_diagnosis(disease_id): |
|
"""获取鉴别诊断""" |
|
differential_map = { |
|
0: ["脑出血", "短暂性脑缺血发作", "颅内肿瘤"], |
|
1: ["动脉粥样硬化", "血管炎", "纤维肌性发育不良"], |
|
2: ["动脉栓塞", "大动脉炎", "血栓形成"], |
|
3: ["梅尼埃病", "前庭神经炎", "低血糖反应"], |
|
} |
|
return " | ".join(differential_map.get(disease_id, [])) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=medical_diagnosis, |
|
inputs=gr.Textbox( |
|
label="患者症状描述", |
|
placeholder="请输入详细症状(示例:持续头痛三天,伴随恶心呕吐)", |
|
lines=5 |
|
), |
|
outputs=gr.Markdown( |
|
label="AI辅助诊断报告", |
|
show_copy_button=True |
|
), |
|
title="神经内科疾病辅助诊断系统", |
|
description="**专业提示**:请输入完整的症状描述,包括:\n- 主要症状及持续时间\n- 伴随症状\n- 既往病史\n- 检查结果", |
|
examples=[ |
|
["主诉:左侧肢体无力3天,伴言语不清。既往脑梗死病史5年..."], |
|
["头晕伴行走不稳2天,MRI显示小脑梗死灶..."], |
|
["突发右侧肢体麻木,CTA显示颈动脉狭窄..."], |
|
], |
|
allow_flagging="never", |
|
theme="soft" |
|
) |
|
|
|
|
|
interface.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
auth=("doctor", "dsaimodel") |
|
) |