shrey14's picture
Update app.py
45a3a9d verified
import gradio as gr
from inference_sdk import InferenceHTTPClient
from PIL import Image, ImageDraw
import os
from collections import defaultdict
# βœ… Load API key securely from Hugging Face Spaces Secrets
API_KEY = os.getenv("ROBOFLOW_API_KEY")
if not API_KEY:
raise ValueError("API Key is missing! Set it in HF Space Secrets.")
# βœ… Initialize Roboflow Client
CLIENT = InferenceHTTPClient(
api_url="https://detect.roboflow.com",
api_key=API_KEY
)
MODEL_ID = "hvacsym/5"
CONFIDENCE_THRESHOLD = 0.2 # βœ… Confidence threshold for filtering predictions
GRID_SIZE = (3, 3) # βœ… 3x3 segmentation
def format_counts_as_table(counts, pass_num):
"""Formats detection counts into a Markdown table for Gradio."""
if not counts:
return f"### Pass {pass_num}: No components detected."
table = f"### Pass {pass_num} Detection Results:\n\n"
table += "| Component | Count |\n"
table += "|-----------|-------|\n"
for component, count in counts.items():
table += f"| {component} | {count} |\n"
return table
def detect_components(image):
""" Detect components in an uploaded image with three passes. """
original_image = image.convert("RGB")
width, height = original_image.size
seg_w, seg_h = width // GRID_SIZE[1], height // GRID_SIZE[0]
def process_detection(image, pass_num):
""" Detect objects in an image segment and remove them if found. """
final_image = image.copy()
draw_final = ImageDraw.Draw(final_image)
total_counts = defaultdict(int)
detected_boxes = []
for row in range(GRID_SIZE[0]):
for col in range(GRID_SIZE[1]):
x1, y1 = col * seg_w, row * seg_h
x2, y2 = (col + 1) * seg_w, (row + 1) * seg_h
segment = image.crop((x1, y1, x2, y2))
segment_path = f"segment_{row}_{col}_pass{pass_num}.png"
segment.save(segment_path)
# βœ… Run inference
result = CLIENT.infer(segment_path, model_id=MODEL_ID)
filtered_predictions = [pred for pred in result["predictions"] if pred["confidence"] >= CONFIDENCE_THRESHOLD]
for obj in filtered_predictions:
sx, sy, sw, sh = obj["x"], obj["y"], obj["width"], obj["height"]
class_name = obj["class"]
total_counts[class_name] += 1
# βœ… Convert segment coordinates to full image coordinates
x_min_full, y_min_full = x1 + sx - sw // 2, y1 + sy - sh // 2
x_max_full, y_max_full = x1 + sx + sw // 2, y1 + sy + sh // 2
detected_boxes.append((x_min_full, y_min_full, x_max_full, y_max_full))
# βœ… Draw bounding box
draw_final.rectangle([x_min_full, y_min_full, x_max_full, y_max_full], outline="green", width=2)
return final_image, total_counts, detected_boxes
# βœ… First pass detection
image_after_pass1, counts_pass1, detected_boxes = process_detection(original_image, pass_num=1)
counts_pass1_table = format_counts_as_table(counts_pass1, 1)
# βœ… Mask detected areas for the second pass
image_after_removal1 = original_image.copy()
draw_removal1 = ImageDraw.Draw(image_after_removal1)
for box in detected_boxes:
draw_removal1.rectangle(box, fill=(255, 255, 255))
# βœ… Second pass detection
image_after_pass2, counts_pass2, detected_boxes = process_detection(image_after_removal1, pass_num=2)
counts_pass2_table = format_counts_as_table(counts_pass2, 2)
# βœ… Mask detected areas for the third pass
image_after_removal2 = image_after_removal1.copy()
draw_removal2 = ImageDraw.Draw(image_after_removal2)
for box in detected_boxes:
draw_removal2.rectangle(box, fill=(255, 255, 255))
# βœ… Third pass detection
image_after_pass3, counts_pass3, _ = process_detection(image_after_removal2, pass_num=3)
counts_pass3_table = format_counts_as_table(counts_pass3, 3)
# βœ… Sum counts from all passes
final_counts = defaultdict(int)
for label in set(counts_pass1) | set(counts_pass2) | set(counts_pass3):
final_counts[label] = counts_pass1.get(label, 0) + counts_pass2.get(label, 0) + counts_pass3.get(label, 0)
final_counts_table = format_counts_as_table(final_counts, "Final")
# βœ… Return counts in Markdown table format
return (
image_after_pass1, counts_pass1_table,
image_after_pass2, counts_pass2_table,
image_after_pass3, counts_pass3_table,
final_counts_table
)
# βœ… Gradio Interface
interface = gr.Interface(
fn=detect_components,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(type="pil", label="Detection Pass 1"),
gr.Markdown(label="Counts After Pass 1"), # βœ… Prettified Markdown Table
gr.Image(type="pil", label="Detection Pass 2"),
gr.Markdown(label="Counts After Pass 2"), # βœ… Prettified Markdown Table
gr.Image(type="pil", label="Detection Pass 3"),
gr.Markdown(label="Counts After Pass 3"), # βœ… Prettified Markdown Table
gr.Markdown(label="Final Detected Components") # βœ… Prettified Final Results
],
title="HVAC Component Detector",
description="Upload an image to detect HVAC components using Roboflow API across three passes."
)
# βœ… Launch the app
if __name__ == "__main__":
interface.launch()