File size: 1,836 Bytes
9483059 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import torch
from PIL import Image
from transformers import DetrImageProcessor, DetrForObjectDetection
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import io
# Load model and processor once
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
def detect_objects(image):
# Run DETR model
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
# Draw boxes on image
fig, ax = plt.subplots(1)
ax.imshow(image)
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
xmin, ymin, xmax, ymax = box.tolist()
ax.add_patch(patches.Rectangle(
(xmin, ymin), xmax - xmin, ymax - ymin,
linewidth=2, edgecolor='red', facecolor='none'))
ax.text(xmin, ymin, f"{model.config.id2label[label.item()]}: {round(score.item(), 2)}",
bbox=dict(facecolor='yellow', alpha=0.5), fontsize=8)
# Save output to bytes buffer
buf = io.BytesIO()
plt.axis("off")
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close(fig)
buf.seek(0)
return Image.open(buf)
# Create Gradio interface
interface = gr.Interface(fn=detect_objects,
inputs=gr.Image(type="pil"),
outputs="image",
title="DETR Object Detection",
description="Upload an image to detect objects using Facebook's DETR model.")
# Launch the app locally
if __name__ == "__main__":
interface.launch(share=True)
|