Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -82,12 +82,7 @@ def draw_bbox(image, bbox, label="Lesion", color="red", width=3):
|
|
| 82 |
return image_copy
|
| 83 |
|
| 84 |
def predict_skin_lesion(image):
|
| 85 |
-
|
| 86 |
-
prompt = """Analyze this skin lesion image carefully. Look at the color, texture, shape, and borders.
|
| 87 |
-
Provide your diagnosis and bounding box coordinates in this exact format:
|
| 88 |
-
<diagnosis>condition_name</diagnosis> <bbox>[x1, y1, x2, y2]</bbox>
|
| 89 |
-
|
| 90 |
-
Possible conditions: melanoma, melanocytic nevus, basal cell carcinoma, actinic keratosis, benign keratosis-like lesion, dermatofibroma, vascular lesion"""
|
| 91 |
|
| 92 |
conversation = [
|
| 93 |
{
|
|
@@ -105,11 +100,11 @@ def predict_skin_lesion(image):
|
|
| 105 |
with torch.no_grad():
|
| 106 |
generated_ids = model.generate(
|
| 107 |
**inputs,
|
| 108 |
-
max_new_tokens=
|
| 109 |
-
temperature=0.
|
| 110 |
-
top_p=0.
|
|
|
|
| 111 |
do_sample=True,
|
| 112 |
-
repetition_penalty=1.1, # Reduce repetition
|
| 113 |
pad_token_id=processor.tokenizer.eos_token_id
|
| 114 |
)
|
| 115 |
|
|
@@ -122,23 +117,11 @@ def predict_skin_lesion(image):
|
|
| 122 |
|
| 123 |
result_image = image.copy()
|
| 124 |
if bbox:
|
| 125 |
-
result_image = draw_bbox(result_image, bbox, "
|
| 126 |
-
|
| 127 |
-
# Better fallback handling
|
| 128 |
-
if diagnosis.lower() in dx_names:
|
| 129 |
-
full_diagnosis = dx_names[diagnosis.lower()]
|
| 130 |
-
elif any(dx in diagnosis.lower() for dx in dx_names.keys()):
|
| 131 |
-
for code, name in dx_names.items():
|
| 132 |
-
if code in diagnosis.lower():
|
| 133 |
-
full_diagnosis = name
|
| 134 |
-
break
|
| 135 |
-
else:
|
| 136 |
-
full_diagnosis = diagnosis.title()
|
| 137 |
|
| 138 |
-
|
| 139 |
-
confidence = "High" if diagnosis and bbox else "Medium" if diagnosis else "Low"
|
| 140 |
|
| 141 |
-
return result_image, f"**Diagnosis:** {full_diagnosis}\n**
|
| 142 |
|
| 143 |
def load_random_and_predict():
|
| 144 |
image, gt_diagnosis, gt_bbox = get_random_image()
|
|
@@ -159,33 +142,21 @@ def load_random_and_predict():
|
|
| 159 |
|
| 160 |
return gt_image, result_image, prediction, ground_truth
|
| 161 |
|
| 162 |
-
with gr.Blocks(title="DrDiag Skin Disease Diagnosis"
|
| 163 |
-
gr.Markdown("# π¬ DrDiag:
|
| 164 |
-
gr.Markdown("
|
| 165 |
|
| 166 |
with gr.Row():
|
| 167 |
-
with gr.Column(
|
| 168 |
-
input_image = gr.Image(type="pil", label="
|
| 169 |
-
|
| 170 |
with gr.Row():
|
| 171 |
-
upload_btn = gr.Button("
|
| 172 |
-
random_btn = gr.Button("
|
| 173 |
-
|
| 174 |
-
gr.Markdown("""
|
| 175 |
-
**π Detectable Conditions:**
|
| 176 |
-
- Melanoma
|
| 177 |
-
- Melanocytic Nevus
|
| 178 |
-
- Basal Cell Carcinoma
|
| 179 |
-
- Actinic Keratosis
|
| 180 |
-
- Benign Keratosis-like Lesion
|
| 181 |
-
- Dermatofibroma
|
| 182 |
-
- Vascular Lesion
|
| 183 |
-
""")
|
| 184 |
|
| 185 |
-
with gr.Column(
|
| 186 |
-
output_image = gr.Image(type="pil", label="
|
| 187 |
-
prediction_text = gr.Textbox(label="
|
| 188 |
-
ground_truth_text = gr.Textbox(label="
|
| 189 |
|
| 190 |
upload_btn.click(
|
| 191 |
fn=predict_skin_lesion,
|
|
|
|
| 82 |
return image_copy
|
| 83 |
|
| 84 |
def predict_skin_lesion(image):
|
| 85 |
+
prompt = "Analyze this skin lesion image. Respond in this exact format: <diagnosis>condition_name</diagnosis> <bbox>[x1, y1, x2, y2]</bbox>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
conversation = [
|
| 88 |
{
|
|
|
|
| 100 |
with torch.no_grad():
|
| 101 |
generated_ids = model.generate(
|
| 102 |
**inputs,
|
| 103 |
+
max_new_tokens=200,
|
| 104 |
+
temperature=0.3,
|
| 105 |
+
top_p=0.8,
|
| 106 |
+
# repetition_penalty=1.1,
|
| 107 |
do_sample=True,
|
|
|
|
| 108 |
pad_token_id=processor.tokenizer.eos_token_id
|
| 109 |
)
|
| 110 |
|
|
|
|
| 117 |
|
| 118 |
result_image = image.copy()
|
| 119 |
if bbox:
|
| 120 |
+
result_image = draw_bbox(result_image, bbox, "Prediction", "red")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
+
full_diagnosis = dx_names.get(diagnosis.lower(), diagnosis)
|
|
|
|
| 123 |
|
| 124 |
+
return result_image, f"**Diagnosis:** {full_diagnosis}\n**Bbox:** {bbox}\n\n**Raw Response:** {response}"
|
| 125 |
|
| 126 |
def load_random_and_predict():
|
| 127 |
image, gt_diagnosis, gt_bbox = get_random_image()
|
|
|
|
| 142 |
|
| 143 |
return gt_image, result_image, prediction, ground_truth
|
| 144 |
|
| 145 |
+
with gr.Blocks(title="DrDiag Skin Disease Diagnosis") as demo:
|
| 146 |
+
gr.Markdown("# π¬ DrDiag: Qwen2-VL Skin Disease Diagnosis with Spatial Awareness")
|
| 147 |
+
gr.Markdown("Analyze skin lesions with AI-powered diagnosis and bounding box detection")
|
| 148 |
|
| 149 |
with gr.Row():
|
| 150 |
+
with gr.Column():
|
| 151 |
+
input_image = gr.Image(type="pil", label="Input Image")
|
|
|
|
| 152 |
with gr.Row():
|
| 153 |
+
upload_btn = gr.Button("Analyze Image", variant="primary")
|
| 154 |
+
random_btn = gr.Button("Random Sample", variant="secondary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
+
with gr.Column():
|
| 157 |
+
output_image = gr.Image(type="pil", label="Result with Bounding Box")
|
| 158 |
+
prediction_text = gr.Textbox(label="AI Prediction", lines=6)
|
| 159 |
+
ground_truth_text = gr.Textbox(label="Ground Truth", lines=4, visible=False)
|
| 160 |
|
| 161 |
upload_btn.click(
|
| 162 |
fn=predict_skin_lesion,
|