Spaces:
Sleeping
Sleeping
Commit
·
699e2ed
1
Parent(s):
e8983fc
Enhance predict function to return targets not found and update main.py to display results in the UI
Browse files- inference_utils/init_predict_mock.py +42 -2
- main.py +4 -1
inference_utils/init_predict_mock.py
CHANGED
@@ -11,9 +11,49 @@
|
|
11 |
# limitations under the License.
|
12 |
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def init_model():
|
15 |
return None
|
16 |
|
17 |
|
18 |
-
def predict(
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
# limitations under the License.
|
12 |
|
13 |
|
14 |
+
from typing import Tuple
|
15 |
+
from PIL import Image, ImageDraw, ImageFont
|
16 |
+
import gradio as gr
|
17 |
+
import random
|
18 |
+
|
19 |
+
|
20 |
def init_model():
|
21 |
return None
|
22 |
|
23 |
|
24 |
+
def predict(
|
25 |
+
image: Image, modality_type: str, targets: list[str]
|
26 |
+
) -> Tuple[gr.Image, str]:
|
27 |
+
# Randomly split targets into found and not found
|
28 |
+
targets_found = random.sample(targets, k=len(targets) // 2)
|
29 |
+
targets_not_found = [t for t in targets if t not in targets_found]
|
30 |
+
|
31 |
+
# Create a copy of the image to draw on
|
32 |
+
image_with_text = image.copy()
|
33 |
+
draw = ImageDraw.Draw(image_with_text)
|
34 |
+
|
35 |
+
# Draw found targets on the image with larger font
|
36 |
+
font_size = 36
|
37 |
+
try:
|
38 |
+
font = ImageFont.truetype("DejaVuSans.ttf", font_size)
|
39 |
+
except OSError:
|
40 |
+
# Fallback to default font if DejaVuSans is not available
|
41 |
+
font = ImageFont.load_default()
|
42 |
+
|
43 |
+
# Calculate starting position from bottom
|
44 |
+
line_height = 50
|
45 |
+
total_height = len(targets_found) * line_height
|
46 |
+
padding = 20
|
47 |
+
|
48 |
+
# Start from bottom and work upwards
|
49 |
+
y_position = image_with_text.height - total_height - padding
|
50 |
+
for target in targets_found:
|
51 |
+
draw.text((20, y_position), target, fill="red", font=font)
|
52 |
+
y_position += line_height
|
53 |
+
|
54 |
+
# Format targets_not_found as a string with one target per line
|
55 |
+
targets_not_found_str = (
|
56 |
+
"\n".join(targets_not_found) if targets_not_found else "All targets were found!"
|
57 |
+
)
|
58 |
+
|
59 |
+
return image_with_text, targets_not_found_str
|
main.py
CHANGED
@@ -114,6 +114,9 @@ def run():
|
|
114 |
)
|
115 |
with gr.Column():
|
116 |
output_image = gr.Image(type="pil", label="Prediction")
|
|
|
|
|
|
|
117 |
|
118 |
input_modality_type.change(
|
119 |
fn=update_input_targets,
|
@@ -125,7 +128,7 @@ def run():
|
|
125 |
submit_btn.click(
|
126 |
fn=predict,
|
127 |
inputs=[input_image, input_modality_type, input_targets],
|
128 |
-
outputs=output_image,
|
129 |
)
|
130 |
|
131 |
gr.Examples(
|
|
|
114 |
)
|
115 |
with gr.Column():
|
116 |
output_image = gr.Image(type="pil", label="Prediction")
|
117 |
+
output_targets_not_found = gr.Textbox(
|
118 |
+
label="Targets Not Found", lines=4, max_lines=10
|
119 |
+
)
|
120 |
|
121 |
input_modality_type.change(
|
122 |
fn=update_input_targets,
|
|
|
128 |
submit_btn.click(
|
129 |
fn=predict,
|
130 |
inputs=[input_image, input_modality_type, input_targets],
|
131 |
+
outputs=[output_image, output_targets_not_found],
|
132 |
)
|
133 |
|
134 |
gr.Examples(
|