venkyvicky commited on
Commit
9483059
·
verified ·
1 Parent(s): 169303f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from transformers import DetrImageProcessor, DetrForObjectDetection
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ import matplotlib.patches as patches
7
+ import io
8
+
9
+ # Load model and processor once
10
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
11
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
12
+
13
+ def detect_objects(image):
14
+ # Run DETR model
15
+ inputs = processor(images=image, return_tensors="pt")
16
+ outputs = model(**inputs)
17
+ target_sizes = torch.tensor([image.size[::-1]])
18
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
19
+
20
+ # Draw boxes on image
21
+ fig, ax = plt.subplots(1)
22
+ ax.imshow(image)
23
+
24
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
25
+ xmin, ymin, xmax, ymax = box.tolist()
26
+ ax.add_patch(patches.Rectangle(
27
+ (xmin, ymin), xmax - xmin, ymax - ymin,
28
+ linewidth=2, edgecolor='red', facecolor='none'))
29
+ ax.text(xmin, ymin, f"{model.config.id2label[label.item()]}: {round(score.item(), 2)}",
30
+ bbox=dict(facecolor='yellow', alpha=0.5), fontsize=8)
31
+
32
+ # Save output to bytes buffer
33
+ buf = io.BytesIO()
34
+ plt.axis("off")
35
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
36
+ plt.close(fig)
37
+ buf.seek(0)
38
+ return Image.open(buf)
39
+
40
+ # Create Gradio interface
41
+ interface = gr.Interface(fn=detect_objects,
42
+ inputs=gr.Image(type="pil"),
43
+ outputs="image",
44
+ title="DETR Object Detection",
45
+ description="Upload an image to detect objects using Facebook's DETR model.")
46
+
47
+ # Launch the app locally
48
+ if __name__ == "__main__":
49
+ interface.launch(share=True)