Spaces:
Sleeping
Sleeping
import pandas as pd | |
import time | |
from PIL import Image | |
import os | |
import numpy as np | |
from tensorflow.keras.models import load_model | |
import gradio as gr | |
from sklearn.preprocessing import LabelEncoder | |
from sklearn.preprocessing import StandardScaler | |
import joblib | |
import re | |
import json | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import openai | |
import pandas as pd | |
from sklearn.preprocessing import LabelEncoder | |
from sklearn.preprocessing import StandardScaler | |
from docx import Document | |
from openai import OpenAI | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import torch | |
### Variables definition | |
sex_map = ['female','male'] | |
localization_map=['back','lower extremity','trunk', 'upper extremity', 'abdomen','face', 'chest', 'foot', 'neck', 'scalp', 'hand', 'ear','genital', 'acral','unknown'] | |
dx_type_map=['histo', 'follow_up', 'consensus', 'confocal'] | |
labels = ['nv','mel','bkl','bcc','akiec','vasc','df'] | |
age_scale = joblib.load("age_scaler.joblib") | |
meta_features = joblib.load("meta_features_list.joblib") | |
dic_lesion = {"bcc": "basal cell carcinoma", "mel": "melanoma", | |
"nv": "melanocytic nevi", "bkl":"seborrheic keratosis", | |
"akiec":"actinic keratosis", "vasc":"vascular lesion", "df":"dermatofibroma"} | |
cnn_model = load_model("rgb_8_model.keras") | |
index = faiss.read_index("kb_faiss_index.idx") | |
embedder = SentenceTransformer('pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb') | |
LLM_ID = "Qwen/Qwen2.5-3B-Instruct" | |
tokenizer = AutoTokenizer.from_pretrained(LLM_ID, use_fast=True) | |
mdl = AutoModelForCausalLM.from_pretrained(LLM_ID, torch_dtype="auto", low_cpu_mem_usage=True) | |
generator = pipeline("text-generation", model=mdl, tokenizer=tokenizer, device=-1) | |
SYSTEM_PROMPT = """ | |
You are a knowledgeable, empathetic medical assistant. | |
Given a skin-lesion classification and relevant medical text, | |
provide a concise, lay-friendly summary of what the diagnosis means, | |
recommended next steps (e.g., see a dermatologist, biopsy considerations), | |
and cite your sources in square brackets (e.g., [1], [2]). | |
Do not repeat Prediction:, Patient: in your answer. | |
""" | |
def load_jsonl(path): | |
with open(path, "r", encoding="utf-8") as f: | |
chunks = json.load(f) | |
return chunks | |
chunks = load_jsonl("kb_chunks_metadata.json") | |
def preprocess_image(img): | |
#Winning model for prediction was 8_rgb , that is the reason why we resize (8,8) | |
img = img.resize((8,8)) | |
#we need to transform into a numpy array, so the model is able to handle that type | |
img_arr = np.array(img) | |
if img_arr.shape[-1] == 4: | |
img_arr = img_arr[...,:3] # in case that the images presents a fourth channel | |
img_arr = img_arr/255 | |
img_arr = img_arr.reshape((-1,8,8,3)) | |
return img_arr | |
def encode_meta(age,sex,localization,dx_type): | |
user_input = pd.DataFrame([{ | |
"age": age, | |
"sex": sex, | |
"localization": localization, | |
"dx_type": dx_type | |
}]) | |
user_input['age_missing'] = user_input['age'].isna().astype(int) | |
user_input["age"] = user_input["age"].fillna(user_input["age"].median()) | |
user_input["age"] = age_scale.transform(user_input[["age"]]) | |
user_input = pd.get_dummies(user_input, columns=["sex","dx_type","localization"]) | |
user_input = user_input.reindex(columns=meta_features, fill_value=0) | |
return user_input.values.astype("float32") | |
def pred_img(img,age,sex,localization,dx_type): | |
imagen = preprocess_image(img) | |
data_meta = encode_meta(age,sex,localization,dx_type) | |
preds = cnn_model.predict([imagen, data_meta]) | |
idx=np.argmax(preds,axis=1)[0] | |
pred_label = labels[idx] | |
confidence = float (preds[0][idx]) | |
return pred_label, confidence | |
def df(age,sex,localization,dx_type): | |
user_input = pd.DataFrame([{ | |
"age": age, | |
"sex": sex, | |
"localization": localization, | |
"dx_type": dx_type | |
}]) | |
user_input['age_missing'] = user_input['age'].isna().astype(int) | |
user_input['age'] = StandardScaler().fit_transform(user_input[['age']]) | |
# Apply same preprocessing as training | |
user_dummies = pd.get_dummies(user_input) | |
user_dummies = user_dummies.reindex(columns=user_input.columns, fill_value=0) | |
meta_input = user_dummies.values | |
return meta_input | |
def retrieve_top_chunks(query, label, top_k=3, pool=80): | |
q_emb = embedder.encode([query], convert_to_numpy=True) | |
faiss.normalize_L2(q_emb) | |
D, I = index.search(q_emb, pool) | |
results = [] | |
for score, idx in zip(D[0], I[0]): | |
c = chunks[idx] | |
tag_match = 1 if label in c.get('tags', []) else 0 | |
adjusted = float(score) + (0.2 * tag_match) | |
results.append((adjusted, idx)) | |
results.sort(reverse=True) | |
out=[] | |
for adj, idx in results[:top_k]: | |
item = dict(chunks[idx]) | |
item['score']=adj | |
out.append(item) | |
return out | |
def build_rag(pred_label, confidence, age,sex,location, retrieve_chunks,user_question): | |
context_strings = [] | |
#retrieve_chunks = retrieve_top_chunks(query,2) | |
for i, chunk in enumerate(retrieve_chunks, start=1): | |
context_strings.append(f'[{i}] {chunk["text"]}') #me quedo solamente con la parte del texto del diccionario . y lo guardo en contexts_block. Tambien indexo la posicion i en formato [1] para darle un orden y mejor visibilidad | |
contexts_block = '\n\n'.join(context_strings) | |
# 2. Build the user message | |
user_message=( | |
f"Prediction: {pred_label} ({confidence:.1%} confidence)\n" | |
f'Patient: {age}-year-old {sex}, lesion on {location}.\n\n' | |
f'Question: {user_question}\n\n' | |
'Using the information below, explain in plain language:\n' | |
'1. What this diagnosis means.\n' | |
'2. Recommended next steps.\n' | |
'3. Cite sources by their number.\n\n' | |
f'{contexts_block}') | |
# Build a chat-style prompt using the model's chat template | |
messages=[ | |
{'role':'system', 'content':SYSTEM_PROMPT}, | |
{'role':'user', 'content':user_message} | |
] | |
prompt= tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
return prompt | |
def generate_explanation(pred_label, confidence, age,sex,location, retrieve_chunks,user_question): | |
prompt = build_rag(pred_label, confidence, age,sex,location, retrieve_chunks,user_question) | |
resp = generator( | |
prompt, | |
max_new_tokens=300, | |
temperature=0.2, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
return_full_text = False | |
)[0]['generated_text'] | |
reply = resp.strip() | |
return reply | |
def predict_step(img,age,sex, localization,dx_type): | |
if img is None: | |
return ('Please upload an image', None,None, gr.update(interactive=False), gr.update(interactive=False)) | |
label, confidence = pred_img(img,age,sex,localization,dx_type) | |
disease = dic_lesion.get(label,label) | |
pred_text = f'{disease} ({label})' | |
x_img = preprocess_image(img) | |
x_meta= encode_meta(age,sex,localization,dx_type) | |
return (pred_text, label, confidence, gr.update(interactive=True), gr.update(interactive=True)) | |
def explain_step(label,confidence,age,sex,localization,user_question,progress=gr.Progress()): | |
if not label: | |
return 'Run predictions first' | |
user_question = (user_question or 'overview').strip() | |
yield "Retrieving context…!!!" | |
time.sleep(0.05) | |
# progress(0.10, desc="Preparing query…") | |
disease = dic_lesion.get(label,label) | |
query=f'{disease} {user_question} {age}-years-old {sex} lesion {localization}' | |
# yield "Retrieving context…", gr.update(value="Working…", interactive=False) | |
# progress(0.35, desc="Retrieving context…") | |
top_chunks=retrieve_top_chunks(query,label=label,top_k=2) | |
if not top_chunks: | |
time.sleep(0.3) | |
yield 'Sorry, no relevant references found' | |
return | |
answer = generate_explanation(label,confidence or 0.0,age,sex,localization,top_chunks,user_question) | |
# progress(1.0, desc="Done") | |
# return answer | |
yield answer | |
with gr.Blocks(title='Skin Lesion Assistant - Educational only, not medical advice') as demo: | |
gr.Markdown('**Lesion Image**') | |
with gr.Row(): | |
img = gr.Image(type='pil') | |
with gr.Column(): | |
age=gr.Number(label='Age', maximum=120, placeholder="e.g., 60") | |
sex = gr.Dropdown(choices=sex_map, label="Sex") | |
loc = gr.Dropdown(choices=localization_map, label="Localization") | |
dx = gr.Dropdown(choices=dx_type_map, label="Dx Type") | |
btn_predict = gr.Button("Predict") | |
q = gr.Textbox(label='Question',placeholder='treatments?',interactive=False) | |
btn_explain=gr.Button('Explain',interactive=False) | |
pred_out = gr.Textbox(label='Prediction') | |
expl_out = gr.Markdown(label='Explanation') | |
# hidden state passed from Predict → Explain | |
s_label = gr.State() | |
s_conf = gr.State() | |
btn_predict.click( | |
predict_step, | |
inputs=[img, age, sex, loc, dx], | |
outputs=[pred_out, s_label, s_conf, q, btn_explain] | |
) | |
btn_explain.click( | |
explain_step, | |
inputs=[s_label, s_conf, age, sex, loc, q], | |
outputs=[expl_out] | |
) | |
demo.queue().launch(ssr_mode=False) |