kernel-luso-comfort commited on
Commit
699e2ed
·
1 Parent(s): e8983fc

Enhance predict function to return targets not found and update main.py to display results in the UI

Browse files
Files changed (2) hide show
  1. inference_utils/init_predict_mock.py +42 -2
  2. main.py +4 -1
inference_utils/init_predict_mock.py CHANGED
@@ -11,9 +11,49 @@
11
  # limitations under the License.
12
 
13
 
 
 
 
 
 
 
14
  def init_model():
15
  return None
16
 
17
 
18
- def predict(image, modality_type: str, targets: list[str]):
19
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # limitations under the License.
12
 
13
 
14
+ from typing import Tuple
15
+ from PIL import Image, ImageDraw, ImageFont
16
+ import gradio as gr
17
+ import random
18
+
19
+
20
  def init_model():
21
  return None
22
 
23
 
24
+ def predict(
25
+ image: Image, modality_type: str, targets: list[str]
26
+ ) -> Tuple[gr.Image, str]:
27
+ # Randomly split targets into found and not found
28
+ targets_found = random.sample(targets, k=len(targets) // 2)
29
+ targets_not_found = [t for t in targets if t not in targets_found]
30
+
31
+ # Create a copy of the image to draw on
32
+ image_with_text = image.copy()
33
+ draw = ImageDraw.Draw(image_with_text)
34
+
35
+ # Draw found targets on the image with larger font
36
+ font_size = 36
37
+ try:
38
+ font = ImageFont.truetype("DejaVuSans.ttf", font_size)
39
+ except OSError:
40
+ # Fallback to default font if DejaVuSans is not available
41
+ font = ImageFont.load_default()
42
+
43
+ # Calculate starting position from bottom
44
+ line_height = 50
45
+ total_height = len(targets_found) * line_height
46
+ padding = 20
47
+
48
+ # Start from bottom and work upwards
49
+ y_position = image_with_text.height - total_height - padding
50
+ for target in targets_found:
51
+ draw.text((20, y_position), target, fill="red", font=font)
52
+ y_position += line_height
53
+
54
+ # Format targets_not_found as a string with one target per line
55
+ targets_not_found_str = (
56
+ "\n".join(targets_not_found) if targets_not_found else "All targets were found!"
57
+ )
58
+
59
+ return image_with_text, targets_not_found_str
main.py CHANGED
@@ -114,6 +114,9 @@ def run():
114
  )
115
  with gr.Column():
116
  output_image = gr.Image(type="pil", label="Prediction")
 
 
 
117
 
118
  input_modality_type.change(
119
  fn=update_input_targets,
@@ -125,7 +128,7 @@ def run():
125
  submit_btn.click(
126
  fn=predict,
127
  inputs=[input_image, input_modality_type, input_targets],
128
- outputs=output_image,
129
  )
130
 
131
  gr.Examples(
 
114
  )
115
  with gr.Column():
116
  output_image = gr.Image(type="pil", label="Prediction")
117
+ output_targets_not_found = gr.Textbox(
118
+ label="Targets Not Found", lines=4, max_lines=10
119
+ )
120
 
121
  input_modality_type.change(
122
  fn=update_input_targets,
 
128
  submit_btn.click(
129
  fn=predict,
130
  inputs=[input_image, input_modality_type, input_targets],
131
+ outputs=[output_image, output_targets_not_found],
132
  )
133
 
134
  gr.Examples(