shrey14 commited on
Commit
0facd8f
·
verified ·
1 Parent(s): 29d6038

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inference_sdk import InferenceHTTPClient
2
+ from PIL import Image, ImageDraw, ImageFont, ImageEnhance
3
+ import matplotlib.pyplot as plt
4
+ import os
5
+ import gradio as gr
6
+ from collections import defaultdict
7
+ API_KEY = os.getenv("ROBOFLOW_API_KEY")
8
+
9
+
10
+ # Initialize the Roboflow client
11
+ CLIENT = InferenceHTTPClient(
12
+ api_url="https://detect.roboflow.com",
13
+ api_key=API_KEY
14
+ )
15
+
16
+ # Set model details
17
+ MODEL_ID = "hvacsym/5"
18
+ IMAGE_PATH = "../image1.png"
19
+ CONFIDENCE_THRESHOLD = 0.3 # Confidence threshold for filtering predictions
20
+ GRID_SIZE = (3, 3) # 4x4 segmentation
21
+
22
+ def enhance_image(image):
23
+ """Enhance image by adjusting brightness and contrast"""
24
+ if image.mode != 'L':
25
+ image = image.convert('L')
26
+ brightness = ImageEnhance.Brightness(image)
27
+ image = brightness.enhance(1.3)
28
+ contrast = ImageEnhance.Contrast(image)
29
+ image = contrast.enhance(1.2)
30
+ # Convert back to RGB for colored boxes
31
+ return image.convert('RGB')
32
+
33
+ # Ensure image exists before proceeding
34
+ if not os.path.exists(IMAGE_PATH):
35
+ raise FileNotFoundError(f"Error: The image file '{IMAGE_PATH}' was not found.")
36
+
37
+ # Load and enhance the original image
38
+ original_image = Image.open(IMAGE_PATH)
39
+ original_image = enhance_image(original_image)
40
+ width, height = original_image.size
41
+ seg_w, seg_h = width // GRID_SIZE[1], height // GRID_SIZE[0]
42
+
43
+ # Create a copy of the full image to draw bounding boxes
44
+ final_image = original_image.copy()
45
+ draw_final = ImageDraw.Draw(final_image)
46
+
47
+ # Load font for labeling
48
+ try:
49
+ font = ImageFont.truetype("arial.ttf", 10)
50
+ except:
51
+ font = ImageFont.load_default()
52
+
53
+ # Dictionary to store total counts
54
+ total_counts = defaultdict(int)
55
+
56
+ # Colors for boxes
57
+ RED = (255, 0, 0)
58
+ GREEN = (0, 255, 0)
59
+ WHITE = (255, 255, 255)
60
+ BLACK = (0, 0, 0)
61
+
62
+ # Process each segment
63
+ for row in range(GRID_SIZE[0]):
64
+ for col in range(GRID_SIZE[1]):
65
+ # Define segment coordinates
66
+ x1, y1 = col * seg_w, row * seg_h
67
+ x2, y2 = (col + 1) * seg_w, (row + 1) * seg_h
68
+
69
+ # Crop the segment
70
+ segment = original_image.crop((x1, y1, x2, y2))
71
+ draw_segment = ImageDraw.Draw(segment)
72
+ segment_path = f"/content/segment_{row}_{col}.png"
73
+ segment.save(segment_path)
74
+
75
+ # Run inference on the segment
76
+ result = CLIENT.infer(segment_path, model_id=MODEL_ID)
77
+
78
+ # Filter predictions based on confidence
79
+ filtered_predictions = [
80
+ pred for pred in result["predictions"] if pred["confidence"] * 100 >= CONFIDENCE_THRESHOLD
81
+ ]
82
+
83
+ # Dictionary to count labels in this segment
84
+ segment_counts = defaultdict(int)
85
+
86
+ # Draw bounding boxes on both segment and final image
87
+ for obj in filtered_predictions:
88
+ sx, sy, sw, sh = obj["x"], obj["y"], obj["width"], obj["height"]
89
+ class_name = obj["class"]
90
+ confidence = obj["confidence"]
91
+
92
+ # Update counts
93
+ segment_counts[class_name] += 1
94
+ total_counts[class_name] += 1
95
+
96
+ # Bounding box coordinates relative to the segment
97
+ x_min_seg, y_min_seg = sx - sw // 2, sy - sh // 2
98
+ x_max_seg, y_max_seg = sx + sw // 2, sy + sh // 2
99
+
100
+ # Draw on segment with RED
101
+ draw_segment.rectangle([x_min_seg, y_min_seg, x_max_seg, y_max_seg], outline=RED, width=2)
102
+
103
+ # Draw label on segment
104
+ text = f"{class_name} {confidence:.2f}"
105
+ text_w, text_h = draw_segment.textbbox((0, 0), text, font=font)[2:]
106
+ draw_segment.rectangle([x_min_seg, y_min_seg - text_h, x_min_seg + text_w + 4, y_min_seg], fill=BLACK)
107
+ draw_segment.text((x_min_seg + 2, y_min_seg - text_h), text, fill=WHITE, font=font)
108
+
109
+ # Adjust coordinates for the final image
110
+ x_min_full, y_min_full = x1 + x_min_seg, y1 + y_min_seg
111
+ x_max_full, y_max_full = x1 + x_max_seg, y1 + y_max_seg
112
+
113
+ # Draw on final image with GREEN
114
+ draw_final.rectangle([x_min_full, y_min_full, x_max_full, y_max_full], outline=GREEN, width=2)
115
+
116
+ # Draw label on final image
117
+ draw_final.rectangle([x_min_full, y_min_full - text_h, x_min_full + text_w + 4, y_min_full], fill=BLACK)
118
+ draw_final.text((x_min_full + 2, y_min_full - text_h), text, fill=WHITE, font=font)
119
+
120
+ # Display the segment with bounding boxes
121
+ plt.figure(figsize=(5, 5))
122
+ plt.imshow(segment) # No need for cmap='gray' as image is now RGB
123
+ plt.axis("off")
124
+ plt.title(f"Segment ({row}, {col}) with Detected Symbols")
125
+ plt.show()
126
+
127
+ # Print counts for this segment
128
+ print(f"Counts in Segment ({row}, {col}):")
129
+ for label, count in segment_counts.items():
130
+ print(f" {label}: {count}")
131
+ print("-" * 30)
132
+
133
+ # Display the final image with bounding boxes
134
+ plt.figure(figsize=(10, 10))
135
+ plt.imshow(final_image) # No need for cmap='gray' as image is now RGB
136
+ plt.axis("off")
137
+ plt.title("Final Image with Detected Symbols")
138
+ plt.show()
139
+
140
+ # Print total counts for all segments
141
+ print("\nTotal Counts Across All Segments:")
142
+ for label, count in total_counts.items():
143
+ print(f"{label}: {count}")
144
+
145
+ def process_uploaded_image(image_path):
146
+ final_image_path, total_counts = process_image(image_path) # Calls your existing function
147
+
148
+ # Convert count dictionary to readable text
149
+ count_text = "\n".join([f"{label}: {count}" for label, count in total_counts.items()])
150
+
151
+ return final_image_path, count_text
152
+
153
+ # Deploy with Gradio
154
+ iface = gr.Interface(
155
+ fn=process_uploaded_image,
156
+ inputs=gr.Image(type="filepath"), # Corrected from "file" to "filepath"
157
+ outputs=[gr.Image(type="filepath"), gr.Text()],
158
+ title="HVAC Symbol Detector",
159
+ description="Upload an HVAC blueprint image. The model will segment it, detect symbols, and return the final image with bounding boxes along with symbol counts."
160
+ )
161
+
162
+ # Launch the Gradio app
163
+ iface.launch(debug=True)