Spaces:
Sleeping
Sleeping
Shreyansh Khaitan
commited on
changes
Browse files
app.py
CHANGED
@@ -20,13 +20,14 @@ CONFIDENCE_THRESHOLD = 0.3 # Confidence threshold for filtering predictions
|
|
20 |
GRID_SIZE = (3, 3) # 3x3 segmentation
|
21 |
|
22 |
def detect_components(image):
|
23 |
-
""" Detect components in an uploaded image with three passes. """
|
24 |
|
25 |
original_image = image.convert("RGB")
|
26 |
width, height = original_image.size
|
27 |
seg_w, seg_h = width // GRID_SIZE[1], height // GRID_SIZE[0]
|
28 |
|
29 |
def process_detection(image, pass_num):
|
|
|
30 |
final_image = image.copy()
|
31 |
draw_final = ImageDraw.Draw(final_image)
|
32 |
total_counts = defaultdict(int)
|
@@ -41,6 +42,7 @@ def detect_components(image):
|
|
41 |
segment_path = f"segment_{row}_{col}_pass{pass_num}.png"
|
42 |
segment.save(segment_path)
|
43 |
|
|
|
44 |
result = CLIENT.infer(segment_path, model_id=MODEL_ID)
|
45 |
filtered_predictions = [pred for pred in result["predictions"] if pred["confidence"] >= CONFIDENCE_THRESHOLD]
|
46 |
|
@@ -48,10 +50,13 @@ def detect_components(image):
|
|
48 |
sx, sy, sw, sh = obj["x"], obj["y"], obj["width"], obj["height"]
|
49 |
class_name = obj["class"]
|
50 |
total_counts[class_name] += 1
|
51 |
-
|
|
|
52 |
x_min_full, y_min_full = x1 + sx - sw // 2, y1 + sy - sh // 2
|
53 |
x_max_full, y_max_full = x1 + sx + sw // 2, y1 + sy + sh // 2
|
54 |
detected_boxes.append((x_min_full, y_min_full, x_max_full, y_max_full))
|
|
|
|
|
55 |
draw_final.rectangle([x_min_full, y_min_full, x_max_full, y_max_full], outline="green", width=2)
|
56 |
|
57 |
return final_image, total_counts, detected_boxes
|
@@ -59,7 +64,7 @@ def detect_components(image):
|
|
59 |
# First pass detection
|
60 |
image_after_pass1, counts_pass1, detected_boxes = process_detection(original_image, pass_num=1)
|
61 |
|
62 |
-
#
|
63 |
image_after_removal1 = original_image.copy()
|
64 |
draw_removal1 = ImageDraw.Draw(image_after_removal1)
|
65 |
for box in detected_boxes:
|
@@ -68,7 +73,7 @@ def detect_components(image):
|
|
68 |
# Second pass detection
|
69 |
image_after_pass2, counts_pass2, detected_boxes = process_detection(image_after_removal1, pass_num=2)
|
70 |
|
71 |
-
#
|
72 |
image_after_removal2 = image_after_removal1.copy()
|
73 |
draw_removal2 = ImageDraw.Draw(image_after_removal2)
|
74 |
for box in detected_boxes:
|
@@ -81,11 +86,8 @@ def detect_components(image):
|
|
81 |
final_counts = defaultdict(int)
|
82 |
for label in set(counts_pass1) | set(counts_pass2) | set(counts_pass3):
|
83 |
final_counts[label] = counts_pass1.get(label, 0) + counts_pass2.get(label, 0) + counts_pass3.get(label, 0)
|
84 |
-
|
85 |
-
# Ensure actual counts are displayed as numbers (no text formatting)
|
86 |
-
formatted_counts = {k: v for k, v in final_counts.items()}
|
87 |
|
88 |
-
return image_after_pass1, image_after_pass2, image_after_pass3,
|
89 |
|
90 |
# Gradio Interface
|
91 |
interface = gr.Interface(
|
|
|
20 |
GRID_SIZE = (3, 3) # 3x3 segmentation
|
21 |
|
22 |
def detect_components(image):
|
23 |
+
""" Detect components in an uploaded image with three passes, removing detected areas after each pass. """
|
24 |
|
25 |
original_image = image.convert("RGB")
|
26 |
width, height = original_image.size
|
27 |
seg_w, seg_h = width // GRID_SIZE[1], height // GRID_SIZE[0]
|
28 |
|
29 |
def process_detection(image, pass_num):
|
30 |
+
""" Detect objects in an image segment and remove them if found. """
|
31 |
final_image = image.copy()
|
32 |
draw_final = ImageDraw.Draw(final_image)
|
33 |
total_counts = defaultdict(int)
|
|
|
42 |
segment_path = f"segment_{row}_{col}_pass{pass_num}.png"
|
43 |
segment.save(segment_path)
|
44 |
|
45 |
+
# Run inference
|
46 |
result = CLIENT.infer(segment_path, model_id=MODEL_ID)
|
47 |
filtered_predictions = [pred for pred in result["predictions"] if pred["confidence"] >= CONFIDENCE_THRESHOLD]
|
48 |
|
|
|
50 |
sx, sy, sw, sh = obj["x"], obj["y"], obj["width"], obj["height"]
|
51 |
class_name = obj["class"]
|
52 |
total_counts[class_name] += 1
|
53 |
+
|
54 |
+
# Convert segment coordinates to full image coordinates
|
55 |
x_min_full, y_min_full = x1 + sx - sw // 2, y1 + sy - sh // 2
|
56 |
x_max_full, y_max_full = x1 + sx + sw // 2, y1 + sy + sh // 2
|
57 |
detected_boxes.append((x_min_full, y_min_full, x_max_full, y_max_full))
|
58 |
+
|
59 |
+
# Draw bounding box
|
60 |
draw_final.rectangle([x_min_full, y_min_full, x_max_full, y_max_full], outline="green", width=2)
|
61 |
|
62 |
return final_image, total_counts, detected_boxes
|
|
|
64 |
# First pass detection
|
65 |
image_after_pass1, counts_pass1, detected_boxes = process_detection(original_image, pass_num=1)
|
66 |
|
67 |
+
# Mask detected areas for the second pass
|
68 |
image_after_removal1 = original_image.copy()
|
69 |
draw_removal1 = ImageDraw.Draw(image_after_removal1)
|
70 |
for box in detected_boxes:
|
|
|
73 |
# Second pass detection
|
74 |
image_after_pass2, counts_pass2, detected_boxes = process_detection(image_after_removal1, pass_num=2)
|
75 |
|
76 |
+
# Mask detected areas for the third pass
|
77 |
image_after_removal2 = image_after_removal1.copy()
|
78 |
draw_removal2 = ImageDraw.Draw(image_after_removal2)
|
79 |
for box in detected_boxes:
|
|
|
86 |
final_counts = defaultdict(int)
|
87 |
for label in set(counts_pass1) | set(counts_pass2) | set(counts_pass3):
|
88 |
final_counts[label] = counts_pass1.get(label, 0) + counts_pass2.get(label, 0) + counts_pass3.get(label, 0)
|
|
|
|
|
|
|
89 |
|
90 |
+
return image_after_pass1, image_after_pass2, image_after_pass3, dict(final_counts)
|
91 |
|
92 |
# Gradio Interface
|
93 |
interface = gr.Interface(
|