vcabreragreco's picture
Update app.py
8516513 verified
import pandas as pd
import time
from PIL import Image
import os
import numpy as np
from tensorflow.keras.models import load_model
import gradio as gr
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
import joblib
import re
import json
from sentence_transformers import SentenceTransformer
import faiss
import openai
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from docx import Document
from openai import OpenAI
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
### Variables definition
sex_map = ['female','male']
localization_map=['back','lower extremity','trunk', 'upper extremity', 'abdomen','face', 'chest', 'foot', 'neck', 'scalp', 'hand', 'ear','genital', 'acral','unknown']
dx_type_map=['histo', 'follow_up', 'consensus', 'confocal']
labels = ['nv','mel','bkl','bcc','akiec','vasc','df']
age_scale = joblib.load("age_scaler.joblib")
meta_features = joblib.load("meta_features_list.joblib")
dic_lesion = {"bcc": "basal cell carcinoma", "mel": "melanoma",
"nv": "melanocytic nevi", "bkl":"seborrheic keratosis",
"akiec":"actinic keratosis", "vasc":"vascular lesion", "df":"dermatofibroma"}
cnn_model = load_model("rgb_8_model.keras")
index = faiss.read_index("kb_faiss_index.idx")
embedder = SentenceTransformer('pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb')
LLM_ID = "Qwen/Qwen2.5-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(LLM_ID, use_fast=True)
mdl = AutoModelForCausalLM.from_pretrained(LLM_ID, torch_dtype="auto", low_cpu_mem_usage=True)
generator = pipeline("text-generation", model=mdl, tokenizer=tokenizer, device=-1)
SYSTEM_PROMPT = """
You are a knowledgeable, empathetic medical assistant.
Given a skin-lesion classification and relevant medical text,
provide a concise, lay-friendly summary of what the diagnosis means,
recommended next steps (e.g., see a dermatologist, biopsy considerations),
and cite your sources in square brackets (e.g., [1], [2]).
Do not repeat Prediction:, Patient: in your answer.
"""
def load_jsonl(path):
with open(path, "r", encoding="utf-8") as f:
chunks = json.load(f)
return chunks
chunks = load_jsonl("kb_chunks_metadata.json")
def preprocess_image(img):
#Winning model for prediction was 8_rgb , that is the reason why we resize (8,8)
img = img.resize((8,8))
#we need to transform into a numpy array, so the model is able to handle that type
img_arr = np.array(img)
if img_arr.shape[-1] == 4:
img_arr = img_arr[...,:3] # in case that the images presents a fourth channel
img_arr = img_arr/255
img_arr = img_arr.reshape((-1,8,8,3))
return img_arr
def encode_meta(age,sex,localization,dx_type):
user_input = pd.DataFrame([{
"age": age,
"sex": sex,
"localization": localization,
"dx_type": dx_type
}])
user_input['age_missing'] = user_input['age'].isna().astype(int)
user_input["age"] = user_input["age"].fillna(user_input["age"].median())
user_input["age"] = age_scale.transform(user_input[["age"]])
user_input = pd.get_dummies(user_input, columns=["sex","dx_type","localization"])
user_input = user_input.reindex(columns=meta_features, fill_value=0)
return user_input.values.astype("float32")
def pred_img(img,age,sex,localization,dx_type):
imagen = preprocess_image(img)
data_meta = encode_meta(age,sex,localization,dx_type)
preds = cnn_model.predict([imagen, data_meta])
idx=np.argmax(preds,axis=1)[0]
pred_label = labels[idx]
confidence = float (preds[0][idx])
return pred_label, confidence
def df(age,sex,localization,dx_type):
user_input = pd.DataFrame([{
"age": age,
"sex": sex,
"localization": localization,
"dx_type": dx_type
}])
user_input['age_missing'] = user_input['age'].isna().astype(int)
user_input['age'] = StandardScaler().fit_transform(user_input[['age']])
# Apply same preprocessing as training
user_dummies = pd.get_dummies(user_input)
user_dummies = user_dummies.reindex(columns=user_input.columns, fill_value=0)
meta_input = user_dummies.values
return meta_input
def retrieve_top_chunks(query, label, top_k=3, pool=80):
q_emb = embedder.encode([query], convert_to_numpy=True)
faiss.normalize_L2(q_emb)
D, I = index.search(q_emb, pool)
results = []
for score, idx in zip(D[0], I[0]):
c = chunks[idx]
tag_match = 1 if label in c.get('tags', []) else 0
adjusted = float(score) + (0.2 * tag_match)
results.append((adjusted, idx))
results.sort(reverse=True)
out=[]
for adj, idx in results[:top_k]:
item = dict(chunks[idx])
item['score']=adj
out.append(item)
return out
def build_rag(pred_label, confidence, age,sex,location, retrieve_chunks,user_question):
context_strings = []
#retrieve_chunks = retrieve_top_chunks(query,2)
for i, chunk in enumerate(retrieve_chunks, start=1):
context_strings.append(f'[{i}] {chunk["text"]}') #me quedo solamente con la parte del texto del diccionario . y lo guardo en contexts_block. Tambien indexo la posicion i en formato [1] para darle un orden y mejor visibilidad
contexts_block = '\n\n'.join(context_strings)
# 2. Build the user message
user_message=(
f"Prediction: {pred_label} ({confidence:.1%} confidence)\n"
f'Patient: {age}-year-old {sex}, lesion on {location}.\n\n'
f'Question: {user_question}\n\n'
'Using the information below, explain in plain language:\n'
'1. What this diagnosis means.\n'
'2. Recommended next steps.\n'
'3. Cite sources by their number.\n\n'
f'{contexts_block}')
# Build a chat-style prompt using the model's chat template
messages=[
{'role':'system', 'content':SYSTEM_PROMPT},
{'role':'user', 'content':user_message}
]
prompt= tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
return prompt
def generate_explanation(pred_label, confidence, age,sex,location, retrieve_chunks,user_question):
prompt = build_rag(pred_label, confidence, age,sex,location, retrieve_chunks,user_question)
resp = generator(
prompt,
max_new_tokens=300,
temperature=0.2,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
return_full_text = False
)[0]['generated_text']
reply = resp.strip()
return reply
def predict_step(img,age,sex, localization,dx_type):
if img is None:
return ('Please upload an image', None,None, gr.update(interactive=False), gr.update(interactive=False))
label, confidence = pred_img(img,age,sex,localization,dx_type)
disease = dic_lesion.get(label,label)
pred_text = f'{disease} ({label})'
x_img = preprocess_image(img)
x_meta= encode_meta(age,sex,localization,dx_type)
return (pred_text, label, confidence, gr.update(interactive=True), gr.update(interactive=True))
def explain_step(label,confidence,age,sex,localization,user_question,progress=gr.Progress()):
if not label:
return 'Run predictions first'
user_question = (user_question or 'overview').strip()
yield "Retrieving context…!!!"
time.sleep(0.05)
# progress(0.10, desc="Preparing query…")
disease = dic_lesion.get(label,label)
query=f'{disease} {user_question} {age}-years-old {sex} lesion {localization}'
# yield "Retrieving context…", gr.update(value="Working…", interactive=False)
# progress(0.35, desc="Retrieving context…")
top_chunks=retrieve_top_chunks(query,label=label,top_k=2)
if not top_chunks:
time.sleep(0.3)
yield 'Sorry, no relevant references found'
return
answer = generate_explanation(label,confidence or 0.0,age,sex,localization,top_chunks,user_question)
# progress(1.0, desc="Done")
# return answer
yield answer
with gr.Blocks(title='Skin Lesion Assistant - Educational only, not medical advice') as demo:
gr.Markdown('**Lesion Image**')
with gr.Row():
img = gr.Image(type='pil')
with gr.Column():
age=gr.Number(label='Age', maximum=120, placeholder="e.g., 60")
sex = gr.Dropdown(choices=sex_map, label="Sex")
loc = gr.Dropdown(choices=localization_map, label="Localization")
dx = gr.Dropdown(choices=dx_type_map, label="Dx Type")
btn_predict = gr.Button("Predict")
q = gr.Textbox(label='Question',placeholder='treatments?',interactive=False)
btn_explain=gr.Button('Explain',interactive=False)
pred_out = gr.Textbox(label='Prediction')
expl_out = gr.Markdown(label='Explanation')
# hidden state passed from Predict → Explain
s_label = gr.State()
s_conf = gr.State()
btn_predict.click(
predict_step,
inputs=[img, age, sex, loc, dx],
outputs=[pred_out, s_label, s_conf, q, btn_explain]
)
btn_explain.click(
explain_step,
inputs=[s_label, s_conf, age, sex, loc, q],
outputs=[expl_out]
)
demo.queue().launch(ssr_mode=False)