File size: 1,836 Bytes
9483059
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from PIL import Image
from transformers import DetrImageProcessor, DetrForObjectDetection
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import io

# Load model and processor once
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

def detect_objects(image):
    # Run DETR model
    inputs = processor(images=image, return_tensors="pt")
    outputs = model(**inputs)
    target_sizes = torch.tensor([image.size[::-1]])
    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]

    # Draw boxes on image
    fig, ax = plt.subplots(1)
    ax.imshow(image)

    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        xmin, ymin, xmax, ymax = box.tolist()
        ax.add_patch(patches.Rectangle(
            (xmin, ymin), xmax - xmin, ymax - ymin,
            linewidth=2, edgecolor='red', facecolor='none'))
        ax.text(xmin, ymin, f"{model.config.id2label[label.item()]}: {round(score.item(), 2)}",
                bbox=dict(facecolor='yellow', alpha=0.5), fontsize=8)

    # Save output to bytes buffer
    buf = io.BytesIO()
    plt.axis("off")
    plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
    plt.close(fig)
    buf.seek(0)
    return Image.open(buf)

# Create Gradio interface
interface = gr.Interface(fn=detect_objects,
                         inputs=gr.Image(type="pil"),
                         outputs="image",
                         title="DETR Object Detection",
                         description="Upload an image to detect objects using Facebook's DETR model.")

# Launch the app locally
if __name__ == "__main__":
    interface.launch(share=True)