File size: 7,160 Bytes
775f038 686d35e 775f038 686d35e 775f038 686d35e 14b9a5d 775f038 14b9a5d 775f038 072e754 14b9a5d 775f038 14b9a5d 775f038 14b9a5d 775f038 14b9a5d 775f038 14b9a5d 775f038 686d35e 775f038 14b9a5d 775f038 14b9a5d 775f038 14b9a5d 775f038 14b9a5d 775f038 14b9a5d 775f038 14b9a5d 775f038 14b9a5d 775f038 14b9a5d 775f038 14b9a5d 775f038 14b9a5d 775f038 14b9a5d 775f038 14b9a5d 775f038 14b9a5d 775f038 686d35e 14b9a5d 686d35e 775f038 14b9a5d 686d35e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
#!/usr/bin/env python
# encoding: utf-8
import gradio as gr
from PIL import Image
import traceback
import re
import torch
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model # type: ignore
import argparse
from transformers import AutoModel, AutoTokenizer
# Configuration for image classification model
class_names = ['Calculus', 'Dental Caries', 'Gingivitis', 'Hypodontia', 'Tooth Discoloration']
cnn_model = load_model('new_model2.h5')
# Argparser
parser = argparse.ArgumentParser(description='app')
parser.add_argument('--device', type=str, default='cpu', help='cpu')
parser.add_argument('--dtype', type=str, default='fp32', help='fp32')
args = parser.parse_args()
device = args.device
assert device in ['cpu']
# Set dtype
if args.dtype == 'fp32':
dtype = torch.float32
else:
dtype = torch.float16
# Load model
model_path = 'openbmb/MiniCPM-V-2'
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=dtype)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to(device=device)
model.eval()
ERROR_MSG = "Error, please retry"
model_name = 'MiniCPM-V 2.0'
# UI Components
form_radio = {
'choices': ['Beam Search', 'Sampling'],
'value': 'Sampling',
'interactive': True,
'label': 'Decode Type'
}
# Sliders and their settings
num_beams_slider = {'minimum': 0, 'maximum': 5, 'value': 3, 'step': 1, 'interactive': True, 'label': 'Num Beams'}
repetition_penalty_slider = {'minimum': 0, 'maximum': 3, 'value': 1.2, 'step': 0.01, 'interactive': True, 'label': 'Repetition Penalty'}
repetition_penalty_slider2 = {'minimum': 0, 'maximum': 3, 'value': 1.05, 'step': 0.01, 'interactive': True, 'label': 'Repetition Penalty'}
max_new_tokens_slider = {'minimum': 1, 'maximum': 4096, 'value': 1024, 'step': 1, 'interactive': True, 'label': 'Max New Tokens'}
top_p_slider = {'minimum': 0, 'maximum': 1, 'value': 0.8, 'step': 0.05, 'interactive': True, 'label': 'Top P'}
top_k_slider = {'minimum': 0, 'maximum': 200, 'value': 100, 'step': 1, 'interactive': True, 'label': 'Top K'}
temperature_slider = {'minimum': 0, 'maximum': 2, 'value': 0.7, 'step': 0.05, 'interactive': True, 'label': 'Temperature'}
def classify_images(image):
# Check if the image is None
if image is None:
return "No image uploaded. Please upload a dental image."
# Resize and preprocess the image
try:
input_image = tf.image.resize(image, (180, 180)) # Resize to expected input size
input_image_array = tf.keras.utils.img_to_array(input_image)
input_image_exp_dim = tf.expand_dims(input_image_array, axis=0)
# Make predictions
predictions = cnn_model.predict(input_image_exp_dim)
result = tf.nn.softmax(predictions[0])
# Prepare the outcome message
outcome = f'The image belongs to {class_names[np.argmax(result)]} with a score of {np.max(result) * 100:.2f}%'
return outcome
except Exception as e:
return f"Error processing the image: {str(e)}"
def create_component(params, comp='Slider'):
if comp == 'Slider':
return gr.Slider(**params)
elif comp == 'Radio':
return gr.Radio(choices=params['choices'], value=params['value'], interactive=params['interactive'], label=params['label'])
elif comp == 'Button':
return gr.Button(value=params['value'], interactive=True)
def chat(img, msgs, ctx, params=None):
default_params = {"num_beams": 3, "repetition_penalty": 1.2, "max_new_tokens": 1024}
if params is None:
params = default_params
if img is None:
return -1, "Error, invalid image, please upload a new image", None, None
try:
image = img.convert('RGB')
answer, context, _ = model.chat(image=image, msgs=msgs, context=None, tokenizer=tokenizer, **params)
res = re.sub(r'(<box>.*</box>)', '', answer).replace('<ref>', '').replace('</ref>', '').replace('<box>', '').replace('</box>', '')
return 0, res, None, None
except Exception as err:
print(err)
traceback.print_exc()
return -1, ERROR_MSG, None, None
def upload_img(image, _chatbot, _app_session):
image = Image.fromarray(image)
_app_session['sts'] = None
_app_session['ctx'] = []
_app_session['img'] = image
_chatbot.append(('', 'Image uploaded successfully, you can talk to me now'))
return _chatbot, _app_session
def respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
if _app_cfg.get('ctx', None) is None:
_chat_bot.append((_question, 'Please upload an image to start'))
return '', _chat_bot, _app_cfg
_context = _app_cfg['ctx'].copy()
_context.append({"role": "user", "content": _question})
if params_form == 'Beam Search':
params = {'sampling': False, 'num_beams': num_beams, 'repetition_penalty': repetition_penalty, "max_new_tokens": 896}
else: # Ensure this block is executed for Sampling
params = {
'sampling': True,
'top_p': top_p,
'top_k': top_k,
'temperature': temperature,
'repetition_penalty': repetition_penalty_2,
"max_new_tokens": 896
}
code, _answer, _, sts = chat(_app_cfg['img'], _context, None, params)
_context.append({"role": "assistant", "content": _answer})
_chat_bot.append((_question, _answer))
if code == 0:
_app_cfg['ctx'] = _context
_app_cfg['sts'] = sts
return '', _chat_bot, _app_cfg
def clear(chat_bot, app_session):
app_session['img'] = None
chat_bot.clear()
return chat_bot
with gr.Blocks() as app:
gr.Markdown("<h1 style='text-align: center;'>Medical Assistant</h1>")
with gr.Tab("Image Classification"):
with gr.Row():
image_input = gr.Image(type='numpy', label="Upload Dental Image")
classification_output = gr.Label(num_top_classes=5, label="Classification Results")
image_input.change(fn=classify_images, inputs=image_input, outputs=classification_output)
with gr.Tab("Medical Chatbot"):
with gr.Row():
with gr.Column(scale=2, min_width=300):
app_session = gr.State({'sts': None, 'ctx': None, 'img': None})
bt_pic = gr.Image(label="Upload an image to start")
txt_message = gr.Textbox(label="Ask your question...")
with gr.Column(scale=2, min_width=300):
chat_bot = gr.Chatbot(label=f"Chatbot")
clear_button = gr.Button(value='Clear')
txt_message.submit(
respond,
[txt_message, chat_bot, app_session],
[txt_message, chat_bot, app_session]
)
bt_pic.upload(lambda: None, None, chat_bot, queue=False).then(upload_img, inputs=[bt_pic, chat_bot, app_session], outputs=[chat_bot, app_session])
clear_button.click(clear, [chat_bot, app_session], chat_bot)
# Launch
app.launch(share=True, debug=True, show_api=False) |