venkyvicky commited on
Commit
169303f
·
verified ·
1 Parent(s): 77b922f

Delete detr_gradio_app.py

Browse files
Files changed (1) hide show
  1. detr_gradio_app.py +0 -49
detr_gradio_app.py DELETED
@@ -1,49 +0,0 @@
1
- import torch
2
- from PIL import Image
3
- from transformers import DetrImageProcessor, DetrForObjectDetection
4
- import gradio as gr
5
- import matplotlib.pyplot as plt
6
- import matplotlib.patches as patches
7
- import io
8
-
9
- # Load model and processor once
10
- processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
11
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
12
-
13
- def detect_objects(image):
14
- # Run DETR model
15
- inputs = processor(images=image, return_tensors="pt")
16
- outputs = model(**inputs)
17
- target_sizes = torch.tensor([image.size[::-1]])
18
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
19
-
20
- # Draw boxes on image
21
- fig, ax = plt.subplots(1)
22
- ax.imshow(image)
23
-
24
- for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
25
- xmin, ymin, xmax, ymax = box.tolist()
26
- ax.add_patch(patches.Rectangle(
27
- (xmin, ymin), xmax - xmin, ymax - ymin,
28
- linewidth=2, edgecolor='red', facecolor='none'))
29
- ax.text(xmin, ymin, f"{model.config.id2label[label.item()]}: {round(score.item(), 2)}",
30
- bbox=dict(facecolor='yellow', alpha=0.5), fontsize=8)
31
-
32
- # Save output to bytes buffer
33
- buf = io.BytesIO()
34
- plt.axis("off")
35
- plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
36
- plt.close(fig)
37
- buf.seek(0)
38
- return Image.open(buf)
39
-
40
- # Create Gradio interface
41
- interface = gr.Interface(fn=detect_objects,
42
- inputs=gr.Image(type="pil"),
43
- outputs="image",
44
- title="DETR Object Detection",
45
- description="Upload an image to detect objects using Facebook's DETR model.")
46
-
47
- # Launch the app locally
48
- if __name__ == "__main__":
49
- interface.launch(share=True)