import torch from PIL import Image from transformers import DetrImageProcessor, DetrForObjectDetection import gradio as gr import matplotlib.pyplot as plt import matplotlib.patches as patches import io # Load model and processor once processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") def detect_objects(image): # Run DETR model inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) target_sizes = torch.tensor([image.size[::-1]]) results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] # Draw boxes on image fig, ax = plt.subplots(1) ax.imshow(image) for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): xmin, ymin, xmax, ymax = box.tolist() ax.add_patch(patches.Rectangle( (xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='red', facecolor='none')) ax.text(xmin, ymin, f"{model.config.id2label[label.item()]}: {round(score.item(), 2)}", bbox=dict(facecolor='yellow', alpha=0.5), fontsize=8) # Save output to bytes buffer buf = io.BytesIO() plt.axis("off") plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) plt.close(fig) buf.seek(0) return Image.open(buf) # Create Gradio interface interface = gr.Interface(fn=detect_objects, inputs=gr.Image(type="pil"), outputs="image", title="DETR Object Detection", description="Upload an image to detect objects using Facebook's DETR model.") # Launch the app locally if __name__ == "__main__": interface.launch(share=True)