Harshithtd's picture
Update app.py
409bde7 verified
raw
history blame
2.48 kB
from typing import List
import os
import numpy as np
import torch
import gradio as gr
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForObjectDetection
import supervision as sv
import spaces
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = AutoImageProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
model = AutoModelForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365").to(device)
BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
MASK_ANNOTATOR = sv.MaskAnnotator()
LABEL_ANNOTATOR = sv.LabelAnnotator()
TRACKER = sv.ByteTrack()
def annotate_image(
input_image,
detections,
labels
) -> np.ndarray:
output_image = MASK_ANNOTATOR.annotate(input_image, detections)
output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
return output_image
def process_image(
input_image,
confidence_threshold,
):
results = query(input_image, confidence_threshold)
detections = sv.Detections.from_transformers(results[0])
detections = TRACKER.update_with_detections(detections)
final_labels = [model.config.id2label[label] for label in detections.class_id.tolist()]
output_image = annotate_image(input_image, detections, final_labels)
return output_image
def query(image, confidence_threshold):
inputs = processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs=outputs, threshold=confidence_threshold, target_sizes=target_sizes)
return results
def run_demo():
input_image = gr.inputs.Image(label="Input Image")
conf = gr.inputs.Slider(label="Confidence Threshold", minimum=0.1, maximum=1.0, value=0.6, step=0.05)
output_image = gr.outputs.Image(label="Output Image")
def process_and_display(input_image, conf):
output_img = process_image(input_image, conf)
return output_img
gr.Interface(
fn=process_and_display,
inputs=[input_image, conf],
outputs=output_image,
title="Real Time Object Detection with RT-DETR",
description="This Demo uses RT-DETR for object detection in images. Adjust the confidence threshold to see different results.",
capture_session=True,
).launch()