shrey14 commited on
Commit
f3e768f
·
verified ·
1 Parent(s): 23a0863

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -124
app.py CHANGED
@@ -15,34 +15,14 @@ CLIENT = InferenceHTTPClient(
15
 
16
  # Set model details
17
  MODEL_ID = "hvacsym/5"
18
- IMAGE_PATH = "image1.jpg"
19
  CONFIDENCE_THRESHOLD = 0.3 # Confidence threshold for filtering predictions
20
- GRID_SIZE = (3, 3) # 4x4 segmentation
21
 
22
- def enhance_image(image):
23
- """Enhance image by adjusting brightness and contrast"""
24
- if image.mode != 'L':
25
- image = image.convert('L')
26
- brightness = ImageEnhance.Brightness(image)
27
- image = brightness.enhance(1.3)
28
- contrast = ImageEnhance.Contrast(image)
29
- image = contrast.enhance(1.2)
30
- # Convert back to RGB for colored boxes
31
- return image.convert('RGB')
32
-
33
- # Ensure image exists before proceeding
34
- if not os.path.exists(IMAGE_PATH):
35
- raise FileNotFoundError(f"Error: The image file '{IMAGE_PATH}' was not found.")
36
-
37
- # Load and enhance the original image
38
- original_image = Image.open(IMAGE_PATH)
39
- original_image = enhance_image(original_image)
40
- width, height = original_image.size
41
- seg_w, seg_h = width // GRID_SIZE[1], height // GRID_SIZE[0]
42
-
43
- # Create a copy of the full image to draw bounding boxes
44
- final_image = original_image.copy()
45
- draw_final = ImageDraw.Draw(final_image)
46
 
47
  # Load font for labeling
48
  try:
@@ -50,114 +30,73 @@ try:
50
  except:
51
  font = ImageFont.load_default()
52
 
53
- # Dictionary to store total counts
54
- total_counts = defaultdict(int)
55
-
56
- # Colors for boxes
57
- RED = (255, 0, 0)
58
- GREEN = (0, 255, 0)
59
- WHITE = (255, 255, 255)
60
- BLACK = (0, 0, 0)
61
-
62
- # Process each segment
63
- for row in range(GRID_SIZE[0]):
64
- for col in range(GRID_SIZE[1]):
65
- # Define segment coordinates
66
- x1, y1 = col * seg_w, row * seg_h
67
- x2, y2 = (col + 1) * seg_w, (row + 1) * seg_h
68
-
69
- # Crop the segment
70
- segment = original_image.crop((x1, y1, x2, y2))
71
- draw_segment = ImageDraw.Draw(segment)
72
- segment_path = f"/content/segment_{row}_{col}.png"
73
- segment.save(segment_path)
74
-
75
- # Run inference on the segment
76
- result = CLIENT.infer(segment_path, model_id=MODEL_ID)
77
-
78
- # Filter predictions based on confidence
79
- filtered_predictions = [
80
- pred for pred in result["predictions"] if pred["confidence"] * 100 >= CONFIDENCE_THRESHOLD
81
- ]
82
-
83
- # Dictionary to count labels in this segment
84
- segment_counts = defaultdict(int)
85
-
86
- # Draw bounding boxes on both segment and final image
87
- for obj in filtered_predictions:
88
- sx, sy, sw, sh = obj["x"], obj["y"], obj["width"], obj["height"]
89
- class_name = obj["class"]
90
- confidence = obj["confidence"]
91
-
92
- # Update counts
93
- segment_counts[class_name] += 1
94
- total_counts[class_name] += 1
95
-
96
- # Bounding box coordinates relative to the segment
97
- x_min_seg, y_min_seg = sx - sw // 2, sy - sh // 2
98
- x_max_seg, y_max_seg = sx + sw // 2, sy + sh // 2
99
-
100
- # Draw on segment with RED
101
- draw_segment.rectangle([x_min_seg, y_min_seg, x_max_seg, y_max_seg], outline=RED, width=2)
102
-
103
- # Draw label on segment
104
- text = f"{class_name} {confidence:.2f}"
105
- text_w, text_h = draw_segment.textbbox((0, 0), text, font=font)[2:]
106
- draw_segment.rectangle([x_min_seg, y_min_seg - text_h, x_min_seg + text_w + 4, y_min_seg], fill=BLACK)
107
- draw_segment.text((x_min_seg + 2, y_min_seg - text_h), text, fill=WHITE, font=font)
108
-
109
- # Adjust coordinates for the final image
110
- x_min_full, y_min_full = x1 + x_min_seg, y1 + y_min_seg
111
- x_max_full, y_max_full = x1 + x_max_seg, y1 + y_max_seg
112
-
113
- # Draw on final image with GREEN
114
- draw_final.rectangle([x_min_full, y_min_full, x_max_full, y_max_full], outline=GREEN, width=2)
115
-
116
- # Draw label on final image
117
- draw_final.rectangle([x_min_full, y_min_full - text_h, x_min_full + text_w + 4, y_min_full], fill=BLACK)
118
- draw_final.text((x_min_full + 2, y_min_full - text_h), text, fill=WHITE, font=font)
119
-
120
- # Display the segment with bounding boxes
121
- plt.figure(figsize=(5, 5))
122
- plt.imshow(segment) # No need for cmap='gray' as image is now RGB
123
- plt.axis("off")
124
- plt.title(f"Segment ({row}, {col}) with Detected Symbols")
125
- plt.show()
126
-
127
- # Print counts for this segment
128
- print(f"Counts in Segment ({row}, {col}):")
129
- for label, count in segment_counts.items():
130
- print(f" {label}: {count}")
131
- print("-" * 30)
132
-
133
- # Display the final image with bounding boxes
134
- plt.figure(figsize=(10, 10))
135
- plt.imshow(final_image) # No need for cmap='gray' as image is now RGB
136
- plt.axis("off")
137
- plt.title("Final Image with Detected Symbols")
138
- plt.show()
139
-
140
- # Print total counts for all segments
141
- print("\nTotal Counts Across All Segments:")
142
- for label, count in total_counts.items():
143
- print(f"{label}: {count}")
144
 
145
  def process_uploaded_image(image_path):
146
- final_image_path, total_counts = process_image(image_path) # Calls your existing function
147
-
148
- # Convert count dictionary to readable text
149
  count_text = "\n".join([f"{label}: {count}" for label, count in total_counts.items()])
150
-
151
  return final_image_path, count_text
152
 
153
  # Deploy with Gradio
154
  iface = gr.Interface(
155
  fn=process_uploaded_image,
156
- inputs=gr.Image(type="filepath"), # Corrected from "file" to "filepath"
157
  outputs=[gr.Image(type="filepath"), gr.Text()],
158
  title="HVAC Symbol Detector",
159
  description="Upload an HVAC blueprint image. The model will segment it, detect symbols, and return the final image with bounding boxes along with symbol counts."
160
  )
161
 
162
- # Launch the Gradio app
163
- iface.launch(debug=True)
 
15
 
16
  # Set model details
17
  MODEL_ID = "hvacsym/5"
 
18
  CONFIDENCE_THRESHOLD = 0.3 # Confidence threshold for filtering predictions
19
+ GRID_SIZE = (3, 3) # 3x3 segmentation
20
 
21
+ # Colors for bounding boxes
22
+ RED = (255, 0, 0)
23
+ GREEN = (0, 255, 0)
24
+ WHITE = (255, 255, 255)
25
+ BLACK = (0, 0, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Load font for labeling
28
  try:
 
30
  except:
31
  font = ImageFont.load_default()
32
 
33
+ def enhance_image(image):
34
+ """Enhance image by adjusting brightness and contrast."""
35
+ if image.mode != 'L':
36
+ image = image.convert('L')
37
+ brightness = ImageEnhance.Brightness(image)
38
+ image = brightness.enhance(1.3)
39
+ contrast = ImageEnhance.Contrast(image)
40
+ image = contrast.enhance(1.2)
41
+ return image.convert('RGB') # Convert back to RGB for colored boxes
42
+
43
+ def process_image(image_path):
44
+ """Processes an image by running inference and drawing bounding boxes."""
45
+ # Load and enhance the original image
46
+ original_image = Image.open(image_path)
47
+ original_image = enhance_image(original_image)
48
+ width, height = original_image.size
49
+ seg_w, seg_h = width // GRID_SIZE[1], height // GRID_SIZE[0]
50
+
51
+ # Create a copy of the full image to draw bounding boxes
52
+ final_image = original_image.copy()
53
+ draw_final = ImageDraw.Draw(final_image)
54
+ total_counts = defaultdict(int)
55
+
56
+ # Process each segment
57
+ for row in range(GRID_SIZE[0]):
58
+ for col in range(GRID_SIZE[1]):
59
+ x1, y1 = col * seg_w, row * seg_h
60
+ x2, y2 = (col + 1) * seg_w, (row + 1) * seg_h
61
+ segment = original_image.crop((x1, y1, x2, y2))
62
+ segment_path = f"segment_{row}_{col}.png"
63
+ segment.save(segment_path)
64
+
65
+ # Run inference on the segment
66
+ result = CLIENT.infer(segment_path, model_id=MODEL_ID)
67
+
68
+ # Filter predictions based on confidence
69
+ filtered_predictions = [
70
+ pred for pred in result["predictions"] if pred["confidence"] * 100 >= CONFIDENCE_THRESHOLD
71
+ ]
72
+
73
+ # Draw bounding boxes and count labels
74
+ for obj in filtered_predictions:
75
+ class_name = obj["class"]
76
+ total_counts[class_name] += 1
77
+ x_min, y_min = x1 + obj["x"] - obj["width"] // 2, y1 + obj["y"] - obj["height"] // 2
78
+ x_max, y_max = x1 + obj["x"] + obj["width"] // 2, y1 + obj["y"] + obj["height"] // 2
79
+ draw_final.rectangle([x_min, y_min, x_max, y_max], outline=GREEN, width=2)
80
+ draw_final.text((x_min, y_min - 10), class_name, fill=WHITE, font=font)
81
+
82
+ # Save the final processed image
83
+ final_image_path = "processed_image.png"
84
+ final_image.save(final_image_path)
85
+ return final_image_path, total_counts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  def process_uploaded_image(image_path):
88
+ """Handles uploaded image and processes it."""
89
+ final_image_path, total_counts = process_image(image_path)
 
90
  count_text = "\n".join([f"{label}: {count}" for label, count in total_counts.items()])
 
91
  return final_image_path, count_text
92
 
93
  # Deploy with Gradio
94
  iface = gr.Interface(
95
  fn=process_uploaded_image,
96
+ inputs=gr.Image(type="filepath"),
97
  outputs=[gr.Image(type="filepath"), gr.Text()],
98
  title="HVAC Symbol Detector",
99
  description="Upload an HVAC blueprint image. The model will segment it, detect symbols, and return the final image with bounding boxes along with symbol counts."
100
  )
101
 
102
+ iface.launch(debug=True, share=True)