abaryan commited on
Commit
1a8cd16
Β·
verified Β·
1 Parent(s): bb933e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -48
app.py CHANGED
@@ -82,12 +82,7 @@ def draw_bbox(image, bbox, label="Lesion", color="red", width=3):
82
  return image_copy
83
 
84
  def predict_skin_lesion(image):
85
-
86
- prompt = """Analyze this skin lesion image carefully. Look at the color, texture, shape, and borders.
87
- Provide your diagnosis and bounding box coordinates in this exact format:
88
- <diagnosis>condition_name</diagnosis> <bbox>[x1, y1, x2, y2]</bbox>
89
-
90
- Possible conditions: melanoma, melanocytic nevus, basal cell carcinoma, actinic keratosis, benign keratosis-like lesion, dermatofibroma, vascular lesion"""
91
 
92
  conversation = [
93
  {
@@ -105,11 +100,11 @@ def predict_skin_lesion(image):
105
  with torch.no_grad():
106
  generated_ids = model.generate(
107
  **inputs,
108
- max_new_tokens=150,
109
- temperature=0.6, # Lower temperature for more focused predictions
110
- top_p=0.9,
 
111
  do_sample=True,
112
- repetition_penalty=1.1, # Reduce repetition
113
  pad_token_id=processor.tokenizer.eos_token_id
114
  )
115
 
@@ -122,23 +117,11 @@ def predict_skin_lesion(image):
122
 
123
  result_image = image.copy()
124
  if bbox:
125
- result_image = draw_bbox(result_image, bbox, "AI Prediction", "red")
126
-
127
- # Better fallback handling
128
- if diagnosis.lower() in dx_names:
129
- full_diagnosis = dx_names[diagnosis.lower()]
130
- elif any(dx in diagnosis.lower() for dx in dx_names.keys()):
131
- for code, name in dx_names.items():
132
- if code in diagnosis.lower():
133
- full_diagnosis = name
134
- break
135
- else:
136
- full_diagnosis = diagnosis.title()
137
 
138
- # Confidence indicator based on response quality
139
- confidence = "High" if diagnosis and bbox else "Medium" if diagnosis else "Low"
140
 
141
- return result_image, f"**Diagnosis:** {full_diagnosis}\n**Confidence:** {confidence}\n**Bbox:** {bbox}\n\n**Raw Response:** {response}"
142
 
143
  def load_random_and_predict():
144
  image, gt_diagnosis, gt_bbox = get_random_image()
@@ -159,33 +142,21 @@ def load_random_and_predict():
159
 
160
  return gt_image, result_image, prediction, ground_truth
161
 
162
- with gr.Blocks(title="DrDiag Skin Disease Diagnosis", theme=gr.themes.Soft()) as demo:
163
- gr.Markdown("# πŸ”¬ DrDiag: AI-Powered Skin Lesion Analysis")
164
- gr.Markdown("Advanced skin disease diagnosis with bounding box localization using Qwen2-VL")
165
 
166
  with gr.Row():
167
- with gr.Column(scale=1):
168
- input_image = gr.Image(type="pil", label="πŸ“€ Upload Skin Lesion Image", height=400)
169
-
170
  with gr.Row():
171
- upload_btn = gr.Button("πŸ” Analyze Image", variant="primary", size="lg")
172
- random_btn = gr.Button("🎲 Random Sample", variant="secondary", size="lg")
173
-
174
- gr.Markdown("""
175
- **πŸ“‹ Detectable Conditions:**
176
- - Melanoma
177
- - Melanocytic Nevus
178
- - Basal Cell Carcinoma
179
- - Actinic Keratosis
180
- - Benign Keratosis-like Lesion
181
- - Dermatofibroma
182
- - Vascular Lesion
183
- """)
184
 
185
- with gr.Column(scale=1):
186
- output_image = gr.Image(type="pil", label="🎯 Analysis Result", height=400)
187
- prediction_text = gr.Textbox(label="πŸ“Š AI Analysis", lines=8, show_copy_button=True)
188
- ground_truth_text = gr.Textbox(label="βœ… Ground Truth Comparison", lines=3, visible=False)
189
 
190
  upload_btn.click(
191
  fn=predict_skin_lesion,
 
82
  return image_copy
83
 
84
  def predict_skin_lesion(image):
85
+ prompt = "Analyze this skin lesion image. Respond in this exact format: <diagnosis>condition_name</diagnosis> <bbox>[x1, y1, x2, y2]</bbox>"
 
 
 
 
 
86
 
87
  conversation = [
88
  {
 
100
  with torch.no_grad():
101
  generated_ids = model.generate(
102
  **inputs,
103
+ max_new_tokens=200,
104
+ temperature=0.3,
105
+ top_p=0.8,
106
+ # repetition_penalty=1.1,
107
  do_sample=True,
 
108
  pad_token_id=processor.tokenizer.eos_token_id
109
  )
110
 
 
117
 
118
  result_image = image.copy()
119
  if bbox:
120
+ result_image = draw_bbox(result_image, bbox, "Prediction", "red")
 
 
 
 
 
 
 
 
 
 
 
121
 
122
+ full_diagnosis = dx_names.get(diagnosis.lower(), diagnosis)
 
123
 
124
+ return result_image, f"**Diagnosis:** {full_diagnosis}\n**Bbox:** {bbox}\n\n**Raw Response:** {response}"
125
 
126
  def load_random_and_predict():
127
  image, gt_diagnosis, gt_bbox = get_random_image()
 
142
 
143
  return gt_image, result_image, prediction, ground_truth
144
 
145
+ with gr.Blocks(title="DrDiag Skin Disease Diagnosis") as demo:
146
+ gr.Markdown("# πŸ”¬ DrDiag: Qwen2-VL Skin Disease Diagnosis with Spatial Awareness")
147
+ gr.Markdown("Analyze skin lesions with AI-powered diagnosis and bounding box detection")
148
 
149
  with gr.Row():
150
+ with gr.Column():
151
+ input_image = gr.Image(type="pil", label="Input Image")
 
152
  with gr.Row():
153
+ upload_btn = gr.Button("Analyze Image", variant="primary")
154
+ random_btn = gr.Button("Random Sample", variant="secondary")
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ with gr.Column():
157
+ output_image = gr.Image(type="pil", label="Result with Bounding Box")
158
+ prediction_text = gr.Textbox(label="AI Prediction", lines=6)
159
+ ground_truth_text = gr.Textbox(label="Ground Truth", lines=4, visible=False)
160
 
161
  upload_btn.click(
162
  fn=predict_skin_lesion,