Spaces:
Running
Running
import gradio as gr | |
import PIL.Image as Image | |
import spaces | |
import super_gradients | |
from tools.tools import py_cpu_nms,get_sub_image,filter_small_fp | |
import cv2 | |
import numpy as np | |
import os | |
from classifiers.MixMatch.mixmatch_classification import mixmatch_classifier_inference | |
def inference_mega_image_yolonas(img, conf_threshold, iou_threshold,height): | |
record_list = [] | |
model_dir = './checkpoint/yolonas/height_varient/ckpt_best{}.pth'.format(height) | |
model = super_gradients.training.models.get('yolo_nas_m',num_classes=1,checkpoint_path=model_dir).cuda() | |
# mega_image = np.array(img)[:, :, ::-1].copy() | |
mega_image = img | |
ratio = 1 | |
bbox_list = [] | |
sub_image_list, coor_list = get_sub_image(mega_image, overlap=0.2, ratio=ratio) | |
for index, sub_image in enumerate(sub_image_list): | |
# sub_image = cv2.cvtColor(sub_image, cv2.COLOR_BGR2RGB) | |
# sub_image = Image.fromarray(sub_image) | |
images_predictions = model.predict(sub_image) | |
image_prediction = next(iter(images_predictions)) | |
labels = image_prediction.prediction.labels | |
confidences = image_prediction.prediction.confidence | |
bboxes = image_prediction.prediction.bboxes_xyxy | |
for i in range(len(labels)): | |
label = labels[i] | |
confidence = confidences[i] | |
bbox = bboxes[i] | |
if confidence > conf_threshold: | |
bbox_list.append([int(coor_list[index][1]+bbox[0]), int(coor_list[index][0]+bbox[1]),int(coor_list[index][1]+bbox[2]), int(coor_list[index][0]+bbox[3]), confidence]) | |
if (len(bbox_list) != 0): | |
bbox_list = np.asarray([box for box in bbox_list]) | |
box_idx = py_cpu_nms(bbox_list, iou_threshold) | |
selected_bbox = bbox_list[box_idx] | |
selected_bbox = sorted(selected_bbox,key = lambda x: x[4],reverse = True) | |
mega_image = draw_image(mega_image,selected_bbox) | |
else: | |
selected_bbox = [] | |
return mega_image,selected_bbox | |
def draw_image(img,bboxes): | |
for box in bboxes: | |
cv2.rectangle(img, (int(box[0]),int(box[1])), (int(box[2]),int(box[3])), (0,255,0), 3) | |
cv2.putText(img, 'bird', (int(box[0]), int(box[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) | |
return img | |
def predict_image(img, conf_threshold, iou_threshold,height): | |
result_image,bbox_list = inference_mega_image_yolonas(img, conf_threshold, iou_threshold,height) | |
cla_dict = mixmatch_classifier_inference('./checkpoint/classifier/mixmatch/model_best.pth.tar',result_image,bbox_list) | |
return result_image,cla_dict | |
iface = gr.Interface( | |
fn=predict_image, | |
inputs=[ | |
gr.Image(type="numpy", label="Upload Image"), | |
gr.Slider(minimum=0, maximum=1, value=0.7, label="Confidence threshold"), | |
gr.Slider(minimum=0, maximum=1, value=0.3, label="IoU threshold"), | |
gr.Radio(["15m", "30m", "60m", "90m"], value="15m", label="Height", info="The image taken height"), | |
], | |
outputs=[ | |
gr.Image(type="numpy", label="Result"), | |
gr.Image(type="numpy", label="BarChart") | |
], | |
title="Waterfowl detection with YOLONAS", | |
description="Upload images for Waterfowl object detection.", | |
) | |
iface.launch(share=True) |