Shreyansh Khaitan commited on
Commit
e24bf63
·
unverified ·
1 Parent(s): dcdde74
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -20,13 +20,14 @@ CONFIDENCE_THRESHOLD = 0.3 # Confidence threshold for filtering predictions
20
  GRID_SIZE = (3, 3) # 3x3 segmentation
21
 
22
  def detect_components(image):
23
- """ Detect components in an uploaded image with three passes. """
24
 
25
  original_image = image.convert("RGB")
26
  width, height = original_image.size
27
  seg_w, seg_h = width // GRID_SIZE[1], height // GRID_SIZE[0]
28
 
29
  def process_detection(image, pass_num):
 
30
  final_image = image.copy()
31
  draw_final = ImageDraw.Draw(final_image)
32
  total_counts = defaultdict(int)
@@ -41,6 +42,7 @@ def detect_components(image):
41
  segment_path = f"segment_{row}_{col}_pass{pass_num}.png"
42
  segment.save(segment_path)
43
 
 
44
  result = CLIENT.infer(segment_path, model_id=MODEL_ID)
45
  filtered_predictions = [pred for pred in result["predictions"] if pred["confidence"] >= CONFIDENCE_THRESHOLD]
46
 
@@ -48,10 +50,13 @@ def detect_components(image):
48
  sx, sy, sw, sh = obj["x"], obj["y"], obj["width"], obj["height"]
49
  class_name = obj["class"]
50
  total_counts[class_name] += 1
51
-
 
52
  x_min_full, y_min_full = x1 + sx - sw // 2, y1 + sy - sh // 2
53
  x_max_full, y_max_full = x1 + sx + sw // 2, y1 + sy + sh // 2
54
  detected_boxes.append((x_min_full, y_min_full, x_max_full, y_max_full))
 
 
55
  draw_final.rectangle([x_min_full, y_min_full, x_max_full, y_max_full], outline="green", width=2)
56
 
57
  return final_image, total_counts, detected_boxes
@@ -59,7 +64,7 @@ def detect_components(image):
59
  # First pass detection
60
  image_after_pass1, counts_pass1, detected_boxes = process_detection(original_image, pass_num=1)
61
 
62
- # Remove detected areas for second pass
63
  image_after_removal1 = original_image.copy()
64
  draw_removal1 = ImageDraw.Draw(image_after_removal1)
65
  for box in detected_boxes:
@@ -68,7 +73,7 @@ def detect_components(image):
68
  # Second pass detection
69
  image_after_pass2, counts_pass2, detected_boxes = process_detection(image_after_removal1, pass_num=2)
70
 
71
- # Remove detected areas for third pass
72
  image_after_removal2 = image_after_removal1.copy()
73
  draw_removal2 = ImageDraw.Draw(image_after_removal2)
74
  for box in detected_boxes:
@@ -81,11 +86,8 @@ def detect_components(image):
81
  final_counts = defaultdict(int)
82
  for label in set(counts_pass1) | set(counts_pass2) | set(counts_pass3):
83
  final_counts[label] = counts_pass1.get(label, 0) + counts_pass2.get(label, 0) + counts_pass3.get(label, 0)
84
-
85
- # Ensure actual counts are displayed as numbers (no text formatting)
86
- formatted_counts = {k: v for k, v in final_counts.items()}
87
 
88
- return image_after_pass1, image_after_pass2, image_after_pass3, formatted_counts
89
 
90
  # Gradio Interface
91
  interface = gr.Interface(
 
20
  GRID_SIZE = (3, 3) # 3x3 segmentation
21
 
22
  def detect_components(image):
23
+ """ Detect components in an uploaded image with three passes, removing detected areas after each pass. """
24
 
25
  original_image = image.convert("RGB")
26
  width, height = original_image.size
27
  seg_w, seg_h = width // GRID_SIZE[1], height // GRID_SIZE[0]
28
 
29
  def process_detection(image, pass_num):
30
+ """ Detect objects in an image segment and remove them if found. """
31
  final_image = image.copy()
32
  draw_final = ImageDraw.Draw(final_image)
33
  total_counts = defaultdict(int)
 
42
  segment_path = f"segment_{row}_{col}_pass{pass_num}.png"
43
  segment.save(segment_path)
44
 
45
+ # Run inference
46
  result = CLIENT.infer(segment_path, model_id=MODEL_ID)
47
  filtered_predictions = [pred for pred in result["predictions"] if pred["confidence"] >= CONFIDENCE_THRESHOLD]
48
 
 
50
  sx, sy, sw, sh = obj["x"], obj["y"], obj["width"], obj["height"]
51
  class_name = obj["class"]
52
  total_counts[class_name] += 1
53
+
54
+ # Convert segment coordinates to full image coordinates
55
  x_min_full, y_min_full = x1 + sx - sw // 2, y1 + sy - sh // 2
56
  x_max_full, y_max_full = x1 + sx + sw // 2, y1 + sy + sh // 2
57
  detected_boxes.append((x_min_full, y_min_full, x_max_full, y_max_full))
58
+
59
+ # Draw bounding box
60
  draw_final.rectangle([x_min_full, y_min_full, x_max_full, y_max_full], outline="green", width=2)
61
 
62
  return final_image, total_counts, detected_boxes
 
64
  # First pass detection
65
  image_after_pass1, counts_pass1, detected_boxes = process_detection(original_image, pass_num=1)
66
 
67
+ # Mask detected areas for the second pass
68
  image_after_removal1 = original_image.copy()
69
  draw_removal1 = ImageDraw.Draw(image_after_removal1)
70
  for box in detected_boxes:
 
73
  # Second pass detection
74
  image_after_pass2, counts_pass2, detected_boxes = process_detection(image_after_removal1, pass_num=2)
75
 
76
+ # Mask detected areas for the third pass
77
  image_after_removal2 = image_after_removal1.copy()
78
  draw_removal2 = ImageDraw.Draw(image_after_removal2)
79
  for box in detected_boxes:
 
86
  final_counts = defaultdict(int)
87
  for label in set(counts_pass1) | set(counts_pass2) | set(counts_pass3):
88
  final_counts[label] = counts_pass1.get(label, 0) + counts_pass2.get(label, 0) + counts_pass3.get(label, 0)
 
 
 
89
 
90
+ return image_after_pass1, image_after_pass2, image_after_pass3, dict(final_counts)
91
 
92
  # Gradio Interface
93
  interface = gr.Interface(