import pandas as pd import time from PIL import Image import os import numpy as np from tensorflow.keras.models import load_model import gradio as gr from sklearn.preprocessing import LabelEncoder from sklearn.preprocessing import StandardScaler import joblib import re import json from sentence_transformers import SentenceTransformer import faiss import openai import pandas as pd from sklearn.preprocessing import LabelEncoder from sklearn.preprocessing import StandardScaler from docx import Document from openai import OpenAI from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch ### Variables definition sex_map = ['female','male'] localization_map=['back','lower extremity','trunk', 'upper extremity', 'abdomen','face', 'chest', 'foot', 'neck', 'scalp', 'hand', 'ear','genital', 'acral','unknown'] dx_type_map=['histo', 'follow_up', 'consensus', 'confocal'] labels = ['nv','mel','bkl','bcc','akiec','vasc','df'] age_scale = joblib.load("age_scaler.joblib") meta_features = joblib.load("meta_features_list.joblib") dic_lesion = {"bcc": "basal cell carcinoma", "mel": "melanoma", "nv": "melanocytic nevi", "bkl":"seborrheic keratosis", "akiec":"actinic keratosis", "vasc":"vascular lesion", "df":"dermatofibroma"} cnn_model = load_model("rgb_8_model.keras") index = faiss.read_index("kb_faiss_index.idx") embedder = SentenceTransformer('pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb') LLM_ID = "Qwen/Qwen2.5-3B-Instruct" tokenizer = AutoTokenizer.from_pretrained(LLM_ID, use_fast=True) mdl = AutoModelForCausalLM.from_pretrained(LLM_ID, torch_dtype="auto", low_cpu_mem_usage=True) generator = pipeline("text-generation", model=mdl, tokenizer=tokenizer, device=-1) SYSTEM_PROMPT = """ You are a knowledgeable, empathetic medical assistant. Given a skin-lesion classification and relevant medical text, provide a concise, lay-friendly summary of what the diagnosis means, recommended next steps (e.g., see a dermatologist, biopsy considerations), and cite your sources in square brackets (e.g., [1], [2]). Do not repeat Prediction:, Patient: in your answer. """ def load_jsonl(path): with open(path, "r", encoding="utf-8") as f: chunks = json.load(f) return chunks chunks = load_jsonl("kb_chunks_metadata.json") def preprocess_image(img): #Winning model for prediction was 8_rgb , that is the reason why we resize (8,8) img = img.resize((8,8)) #we need to transform into a numpy array, so the model is able to handle that type img_arr = np.array(img) if img_arr.shape[-1] == 4: img_arr = img_arr[...,:3] # in case that the images presents a fourth channel img_arr = img_arr/255 img_arr = img_arr.reshape((-1,8,8,3)) return img_arr def encode_meta(age,sex,localization,dx_type): user_input = pd.DataFrame([{ "age": age, "sex": sex, "localization": localization, "dx_type": dx_type }]) user_input['age_missing'] = user_input['age'].isna().astype(int) user_input["age"] = user_input["age"].fillna(user_input["age"].median()) user_input["age"] = age_scale.transform(user_input[["age"]]) user_input = pd.get_dummies(user_input, columns=["sex","dx_type","localization"]) user_input = user_input.reindex(columns=meta_features, fill_value=0) return user_input.values.astype("float32") def pred_img(img,age,sex,localization,dx_type): imagen = preprocess_image(img) data_meta = encode_meta(age,sex,localization,dx_type) preds = cnn_model.predict([imagen, data_meta]) idx=np.argmax(preds,axis=1)[0] pred_label = labels[idx] confidence = float (preds[0][idx]) return pred_label, confidence def df(age,sex,localization,dx_type): user_input = pd.DataFrame([{ "age": age, "sex": sex, "localization": localization, "dx_type": dx_type }]) user_input['age_missing'] = user_input['age'].isna().astype(int) user_input['age'] = StandardScaler().fit_transform(user_input[['age']]) # Apply same preprocessing as training user_dummies = pd.get_dummies(user_input) user_dummies = user_dummies.reindex(columns=user_input.columns, fill_value=0) meta_input = user_dummies.values return meta_input def retrieve_top_chunks(query, label, top_k=3, pool=80): q_emb = embedder.encode([query], convert_to_numpy=True) faiss.normalize_L2(q_emb) D, I = index.search(q_emb, pool) results = [] for score, idx in zip(D[0], I[0]): c = chunks[idx] tag_match = 1 if label in c.get('tags', []) else 0 adjusted = float(score) + (0.2 * tag_match) results.append((adjusted, idx)) results.sort(reverse=True) out=[] for adj, idx in results[:top_k]: item = dict(chunks[idx]) item['score']=adj out.append(item) return out def build_rag(pred_label, confidence, age,sex,location, retrieve_chunks,user_question): context_strings = [] #retrieve_chunks = retrieve_top_chunks(query,2) for i, chunk in enumerate(retrieve_chunks, start=1): context_strings.append(f'[{i}] {chunk["text"]}') #me quedo solamente con la parte del texto del diccionario . y lo guardo en contexts_block. Tambien indexo la posicion i en formato [1] para darle un orden y mejor visibilidad contexts_block = '\n\n'.join(context_strings) # 2. Build the user message user_message=( f"Prediction: {pred_label} ({confidence:.1%} confidence)\n" f'Patient: {age}-year-old {sex}, lesion on {location}.\n\n' f'Question: {user_question}\n\n' 'Using the information below, explain in plain language:\n' '1. What this diagnosis means.\n' '2. Recommended next steps.\n' '3. Cite sources by their number.\n\n' f'{contexts_block}') # Build a chat-style prompt using the model's chat template messages=[ {'role':'system', 'content':SYSTEM_PROMPT}, {'role':'user', 'content':user_message} ] prompt= tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) return prompt def generate_explanation(pred_label, confidence, age,sex,location, retrieve_chunks,user_question): prompt = build_rag(pred_label, confidence, age,sex,location, retrieve_chunks,user_question) resp = generator( prompt, max_new_tokens=300, temperature=0.2, do_sample=True, pad_token_id=tokenizer.eos_token_id, return_full_text = False )[0]['generated_text'] reply = resp.strip() return reply def predict_step(img,age,sex, localization,dx_type): if img is None: return ('Please upload an image', None,None, gr.update(interactive=False), gr.update(interactive=False)) label, confidence = pred_img(img,age,sex,localization,dx_type) disease = dic_lesion.get(label,label) pred_text = f'{disease} ({label})' x_img = preprocess_image(img) x_meta= encode_meta(age,sex,localization,dx_type) return (pred_text, label, confidence, gr.update(interactive=True), gr.update(interactive=True)) def explain_step(label,confidence,age,sex,localization,user_question,progress=gr.Progress()): if not label: return 'Run predictions first' user_question = (user_question or 'overview').strip() yield "Retrieving context…!!!" time.sleep(0.05) # progress(0.10, desc="Preparing query…") disease = dic_lesion.get(label,label) query=f'{disease} {user_question} {age}-years-old {sex} lesion {localization}' # yield "Retrieving context…", gr.update(value="Working…", interactive=False) # progress(0.35, desc="Retrieving context…") top_chunks=retrieve_top_chunks(query,label=label,top_k=2) if not top_chunks: time.sleep(0.3) yield 'Sorry, no relevant references found' return answer = generate_explanation(label,confidence or 0.0,age,sex,localization,top_chunks,user_question) # progress(1.0, desc="Done") # return answer yield answer with gr.Blocks(title='Skin Lesion Assistant - Educational only, not medical advice') as demo: gr.Markdown('**Lesion Image**') with gr.Row(): img = gr.Image(type='pil') with gr.Column(): age=gr.Number(label='Age', maximum=120, placeholder="e.g., 60") sex = gr.Dropdown(choices=sex_map, label="Sex") loc = gr.Dropdown(choices=localization_map, label="Localization") dx = gr.Dropdown(choices=dx_type_map, label="Dx Type") btn_predict = gr.Button("Predict") q = gr.Textbox(label='Question',placeholder='treatments?',interactive=False) btn_explain=gr.Button('Explain',interactive=False) pred_out = gr.Textbox(label='Prediction') expl_out = gr.Markdown(label='Explanation') # hidden state passed from Predict → Explain s_label = gr.State() s_conf = gr.State() btn_predict.click( predict_step, inputs=[img, age, sex, loc, dx], outputs=[pred_out, s_label, s_conf, q, btn_explain] ) btn_explain.click( explain_step, inputs=[s_label, s_conf, age, sex, loc, q], outputs=[expl_out] ) demo.queue().launch(ssr_mode=False)