File size: 5,507 Bytes
a1c3375 bbb864a 087b944 bbb864a 087b944 fb92444 764612d d5b7158 6ed0886 764612d 6ed0886 764612d 6ed0886 764612d d5b7158 764612d 0f0951f 764612d 6ed0886 381026e 764612d 381026e c95dafb 6ed0886 764612d 3270264 eaf4f7d 764612d 0632907 764612d 381026e 6ed0886 c95dafb 0632907 c95dafb 0632907 3270264 087b944 bbb864a 087b944 bbb864a 087b944 bbb864a 087b944 bbb864a 087b944 bbb864a 087b944 bbb864a 087b944 bbb864a 087b944 bbb864a 087b944 bbb864a 087b944 bbb864a 087b944 bbb864a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import gradio as gr
import torch
import re
from transformers import AutoConfig, LlamaConfig, AutoModelForSequenceClassification, AutoTokenizer
import torch
# ==== 自定义配置类 ====
class CustomLlamaConfig(LlamaConfig):
model_type = "custom_llama" # 新名称
def _rope_scaling_validation(self):
pass # 禁用RoPE验证
# ==== 注册配置 ====
AutoConfig.register("custom_llama", CustomLlamaConfig) # 使用新名称
# ==== 加载模型 ====
# 1. 加载配置
config = CustomLlamaConfig.from_pretrained("unsloth/DeepSeek-R1-Distill-Llama-8B")
# 2. 加载模型(关键修改)
model = AutoModelForSequenceClassification.from_pretrained(
"unsloth/DeepSeek-R1-Distill-Llama-8B",
config=config,
trust_remote_code=True,
_config_class=CustomLlamaConfig # 明确指定配置类
)
# 3. 加载适配器
model.load_adapter("yxccai/ds-ai-app")
# 4. 加载分词器
tokenizer = AutoTokenizer.from_pretrained("yxccai/ds-ai-app")
# 2. 加载你的适配器
# model.load_adapter("yxccai/ds-ai-app") # 替换为你的仓库名
# model = LlamaForSequenceClassification.from_pretrained(
# "yxccai/ds-ai-model",
# trust_remote_code=True # 添加这行
# )
# tokenizer = LlamaTokenizer.from_pretrained("yxccai/ds-ai-model")
# tokenizer = AutoTokenizer.from_pretrained("yxccai/ds-ai-app")
# 疾病标签映射(必须与训练时完全一致!)
disease_labels = [
"脑梗死",
"动脉狭窄",
"动脉闭塞",
"脑缺血",
"其他脑血管病变",
"脑出血",
"动脉瘤",
"动脉壶腹",
# 根据实际标签补充完整列表...
]
# 标准化输入模板(与训练时完全一致)
MEDICAL_PROMPT = """以下是描述任务的指令,请写出一个适当完成请求的回答。
### 指令:
你是一位专业医生,需要根据患者的主诉和检查结果给出诊断结论。回答必须严格按照以下格式:
诊断结论:[具体疾病名称]
### 问题:
{}
### 回答:
{}""" # 第二个占位符保留用于兼容性
def medical_diagnosis(symptoms):
try:
# 输入预处理
symptoms = symptoms.strip()
if not symptoms:
return "⚠️ 请输入有效的症状描述"
# 检测危险关键词
emergency_keywords = ["昏迷", "胸痛", "呼吸困难", "意识丧失"]
if any(kw in symptoms for kw in emergency_keywords):
return "🚨 检测到危急症状!请立即前往急诊科就诊!"
# 构建标准化输入
formatted_input = MEDICAL_PROMPT.format(symptoms, "")
# 模型推理
inputs = tokenizer(
formatted_input,
max_length=1024,
truncation=True,
padding=True,
return_tensors="pt"
).to("cuda")
with torch.no_grad():
outputs = model(**inputs)
# 后处理
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(probabilities).item()
confidence = probabilities[0][predicted_class].item()
# 生成结构化报告
diagnosis_report = f"""
🩺 诊断报告:
-----------------------------
▪ 主要症状:{extract_key_symptoms(symptoms)}
▪ 最可能诊断:{disease_labels[predicted_class]}
▪ 置信度:{confidence*100:.1f}%
▪ 鉴别诊断:{get_differential_diagnosis(predicted_class)}
-----------------------------
⚠️ 注意:本结果仅供参考,请以临床检查为准
"""
return diagnosis_report
except Exception as e:
return f"❌ 诊断过程中出现错误:{str(e)}"
def extract_key_symptoms(text):
"""提取关键症状"""
keywords = ["头晕", "肢体无力", "言语不利", "麻木", "呕吐"]
found = [kw for kw in keywords if kw in text]
return "、".join(found[:3]) + "等" if len(found) > 3 else "、".join(found)
def get_differential_diagnosis(disease_id):
"""获取鉴别诊断"""
differential_map = {
0: ["脑出血", "短暂性脑缺血发作", "颅内肿瘤"],
1: ["动脉粥样硬化", "血管炎", "纤维肌性发育不良"],
2: ["动脉栓塞", "大动脉炎", "血栓形成"],
3: ["梅尼埃病", "前庭神经炎", "低血糖反应"],
}
return " | ".join(differential_map.get(disease_id, []))
# 创建医疗专用界面
interface = gr.Interface(
fn=medical_diagnosis,
inputs=gr.Textbox(
label="患者症状描述",
placeholder="请输入详细症状(示例:持续头痛三天,伴随恶心呕吐)",
lines=5
),
outputs=gr.Markdown(
label="AI辅助诊断报告",
show_copy_button=True
),
title="神经内科疾病辅助诊断系统",
description="**专业提示**:请输入完整的症状描述,包括:\n- 主要症状及持续时间\n- 伴随症状\n- 既往病史\n- 检查结果",
examples=[
["主诉:左侧肢体无力3天,伴言语不清。既往脑梗死病史5年..."],
["头晕伴行走不稳2天,MRI显示小脑梗死灶..."],
["突发右侧肢体麻木,CTA显示颈动脉狭窄..."],
],
allow_flagging="never",
theme="soft"
)
# 安全设置
interface.launch(
server_name="0.0.0.0",
server_port=7860,
auth=("doctor", "dsaimodel") # 建议修改为自定义账号密码
) |