|
import torch |
|
from PIL import Image |
|
from transformers import DetrImageProcessor, DetrForObjectDetection |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import matplotlib.patches as patches |
|
import io |
|
|
|
|
|
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") |
|
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") |
|
|
|
def detect_objects(image): |
|
|
|
inputs = processor(images=image, return_tensors="pt") |
|
outputs = model(**inputs) |
|
target_sizes = torch.tensor([image.size[::-1]]) |
|
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] |
|
|
|
|
|
fig, ax = plt.subplots(1) |
|
ax.imshow(image) |
|
|
|
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): |
|
xmin, ymin, xmax, ymax = box.tolist() |
|
ax.add_patch(patches.Rectangle( |
|
(xmin, ymin), xmax - xmin, ymax - ymin, |
|
linewidth=2, edgecolor='red', facecolor='none')) |
|
ax.text(xmin, ymin, f"{model.config.id2label[label.item()]}: {round(score.item(), 2)}", |
|
bbox=dict(facecolor='yellow', alpha=0.5), fontsize=8) |
|
|
|
|
|
buf = io.BytesIO() |
|
plt.axis("off") |
|
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
|
plt.close(fig) |
|
buf.seek(0) |
|
return Image.open(buf) |
|
|
|
|
|
interface = gr.Interface(fn=detect_objects, |
|
inputs=gr.Image(type="pil"), |
|
outputs="image", |
|
title="DETR Object Detection", |
|
description="Upload an image to detect objects using Facebook's DETR model.") |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch(share=True) |
|
|