File size: 5,507 Bytes
a1c3375
bbb864a
087b944
bbb864a
 
087b944
fb92444
764612d
d5b7158
6ed0886
764612d
6ed0886
764612d
 
6ed0886
764612d
d5b7158
764612d
 
0f0951f
764612d
 
6ed0886
381026e
764612d
381026e
c95dafb
6ed0886
764612d
 
3270264
eaf4f7d
764612d
0632907
 
764612d
381026e
 
6ed0886
 
c95dafb
0632907
c95dafb
 
 
 
 
 
 
0632907
3270264
087b944
bbb864a
 
 
 
 
 
 
 
 
 
 
 
087b944
bbb864a
 
087b944
bbb864a
 
 
087b944
bbb864a
 
087b944
bbb864a
 
087b944
bbb864a
 
 
 
 
 
 
 
 
 
 
087b944
bbb864a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
087b944
bbb864a
 
 
 
 
087b944
bbb864a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
087b944
bbb864a
 
087b944
 
bbb864a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import gradio as gr
import torch
import re


from transformers import AutoConfig, LlamaConfig, AutoModelForSequenceClassification, AutoTokenizer
import torch

# ==== 自定义配置类 ====
class CustomLlamaConfig(LlamaConfig):
    model_type = "custom_llama"  # 新名称
    
    def _rope_scaling_validation(self):
        pass  # 禁用RoPE验证

# ==== 注册配置 ====
AutoConfig.register("custom_llama", CustomLlamaConfig)  # 使用新名称

# ==== 加载模型 ====
# 1. 加载配置
config = CustomLlamaConfig.from_pretrained("unsloth/DeepSeek-R1-Distill-Llama-8B")

# 2. 加载模型(关键修改)
model = AutoModelForSequenceClassification.from_pretrained(
    "unsloth/DeepSeek-R1-Distill-Llama-8B",
    config=config,
    trust_remote_code=True,
    _config_class=CustomLlamaConfig  # 明确指定配置类
)

# 3. 加载适配器
model.load_adapter("yxccai/ds-ai-app")

# 4. 加载分词器
tokenizer = AutoTokenizer.from_pretrained("yxccai/ds-ai-app")



# 2. 加载你的适配器
# model.load_adapter("yxccai/ds-ai-app")  # 替换为你的仓库名

# model = LlamaForSequenceClassification.from_pretrained(
#     "yxccai/ds-ai-model",
#     trust_remote_code=True  # 添加这行
# )
# tokenizer = LlamaTokenizer.from_pretrained("yxccai/ds-ai-model")

# tokenizer = AutoTokenizer.from_pretrained("yxccai/ds-ai-app")


# 疾病标签映射(必须与训练时完全一致!)
disease_labels = [
    "脑梗死", 
    "动脉狭窄",
    "动脉闭塞",
    "脑缺血",
    "其他脑血管病变",
    "脑出血",
    "动脉瘤",
    "动脉壶腹",
    # 根据实际标签补充完整列表...
]

# 标准化输入模板(与训练时完全一致)
MEDICAL_PROMPT = """以下是描述任务的指令,请写出一个适当完成请求的回答。

### 指令:
你是一位专业医生,需要根据患者的主诉和检查结果给出诊断结论。回答必须严格按照以下格式:
诊断结论:[具体疾病名称]

### 问题:
{}

### 回答:
{}"""  # 第二个占位符保留用于兼容性

def medical_diagnosis(symptoms):
    try:
        # 输入预处理
        symptoms = symptoms.strip()
        if not symptoms:
            return "⚠️ 请输入有效的症状描述"
            
        # 检测危险关键词
        emergency_keywords = ["昏迷", "胸痛", "呼吸困难", "意识丧失"]
        if any(kw in symptoms for kw in emergency_keywords):
            return "🚨 检测到危急症状!请立即前往急诊科就诊!"

        # 构建标准化输入
        formatted_input = MEDICAL_PROMPT.format(symptoms, "")
        
        # 模型推理
        inputs = tokenizer(
            formatted_input,
            max_length=1024,
            truncation=True,
            padding=True,
            return_tensors="pt"
        ).to("cuda")
        
        with torch.no_grad():
            outputs = model(**inputs)
        
        # 后处理
        probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
        predicted_class = torch.argmax(probabilities).item()
        confidence = probabilities[0][predicted_class].item()
        
        # 生成结构化报告
        diagnosis_report = f"""
🩺 诊断报告:
-----------------------------
▪ 主要症状:{extract_key_symptoms(symptoms)}
▪ 最可能诊断:{disease_labels[predicted_class]} 
▪ 置信度:{confidence*100:.1f}%
▪ 鉴别诊断:{get_differential_diagnosis(predicted_class)}
-----------------------------
⚠️ 注意:本结果仅供参考,请以临床检查为准
"""
        return diagnosis_report
    
    except Exception as e:
        return f"❌ 诊断过程中出现错误:{str(e)}"

def extract_key_symptoms(text):
    """提取关键症状"""
    keywords = ["头晕", "肢体无力", "言语不利", "麻木", "呕吐"]
    found = [kw for kw in keywords if kw in text]
    return "、".join(found[:3]) + "等" if len(found) > 3 else "、".join(found)

def get_differential_diagnosis(disease_id):
    """获取鉴别诊断"""
    differential_map = {
        0: ["脑出血", "短暂性脑缺血发作", "颅内肿瘤"],
        1: ["动脉粥样硬化", "血管炎", "纤维肌性发育不良"],
        2: ["动脉栓塞", "大动脉炎", "血栓形成"],
        3: ["梅尼埃病", "前庭神经炎", "低血糖反应"],
    }
    return " | ".join(differential_map.get(disease_id, []))

# 创建医疗专用界面
interface = gr.Interface(
    fn=medical_diagnosis,
    inputs=gr.Textbox(
        label="患者症状描述",
        placeholder="请输入详细症状(示例:持续头痛三天,伴随恶心呕吐)",
        lines=5
    ),
    outputs=gr.Markdown(
        label="AI辅助诊断报告",
        show_copy_button=True
    ),
    title="神经内科疾病辅助诊断系统",
    description="**专业提示**:请输入完整的症状描述,包括:\n- 主要症状及持续时间\n- 伴随症状\n- 既往病史\n- 检查结果",
    examples=[
        ["主诉:左侧肢体无力3天,伴言语不清。既往脑梗死病史5年..."],
        ["头晕伴行走不稳2天,MRI显示小脑梗死灶..."],
        ["突发右侧肢体麻木,CTA显示颈动脉狭窄..."],
    ],
    allow_flagging="never",
    theme="soft"
)

# 安全设置
interface.launch(
    server_name="0.0.0.0",
    server_port=7860,
    auth=("doctor", "dsaimodel")  # 建议修改为自定义账号密码
)