pipeline_server / server_start.py
zy984764389's picture
Update server_start.py
62c1b39 verified
import gradio as gr
import PIL.Image as Image
import spaces
import super_gradients
from tools.tools import py_cpu_nms,get_sub_image,filter_small_fp
import cv2
import numpy as np
import os
from classifiers.MixMatch.mixmatch_classification import mixmatch_classifier_inference
@spaces.GPU
def inference_mega_image_yolonas(img, conf_threshold, iou_threshold,height):
record_list = []
model_dir = './checkpoint/yolonas/height_varient/ckpt_best{}.pth'.format(height)
model = super_gradients.training.models.get('yolo_nas_m',num_classes=1,checkpoint_path=model_dir).cuda()
# mega_image = np.array(img)[:, :, ::-1].copy()
mega_image = img
ratio = 1
bbox_list = []
sub_image_list, coor_list = get_sub_image(mega_image, overlap=0.2, ratio=ratio)
for index, sub_image in enumerate(sub_image_list):
# sub_image = cv2.cvtColor(sub_image, cv2.COLOR_BGR2RGB)
# sub_image = Image.fromarray(sub_image)
images_predictions = model.predict(sub_image)
image_prediction = next(iter(images_predictions))
labels = image_prediction.prediction.labels
confidences = image_prediction.prediction.confidence
bboxes = image_prediction.prediction.bboxes_xyxy
for i in range(len(labels)):
label = labels[i]
confidence = confidences[i]
bbox = bboxes[i]
if confidence > conf_threshold:
bbox_list.append([int(coor_list[index][1]+bbox[0]), int(coor_list[index][0]+bbox[1]),int(coor_list[index][1]+bbox[2]), int(coor_list[index][0]+bbox[3]), confidence])
if (len(bbox_list) != 0):
bbox_list = np.asarray([box for box in bbox_list])
box_idx = py_cpu_nms(bbox_list, iou_threshold)
selected_bbox = bbox_list[box_idx]
selected_bbox = sorted(selected_bbox,key = lambda x: x[4],reverse = True)
mega_image = draw_image(mega_image,selected_bbox)
else:
selected_bbox = []
return mega_image,selected_bbox
def draw_image(img,bboxes):
for box in bboxes:
cv2.rectangle(img, (int(box[0]),int(box[1])), (int(box[2]),int(box[3])), (0,255,0), 3)
cv2.putText(img, 'bird', (int(box[0]), int(box[1])), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
return img
@spaces.GPU
def predict_image(img, conf_threshold, iou_threshold,height):
result_image,bbox_list = inference_mega_image_yolonas(img, conf_threshold, iou_threshold,height)
cla_dict = mixmatch_classifier_inference('./checkpoint/classifier/mixmatch/model_best.pth.tar',result_image,bbox_list)
return result_image,cla_dict
iface = gr.Interface(
fn=predict_image,
inputs=[
gr.Image(type="numpy", label="Upload Image"),
gr.Slider(minimum=0, maximum=1, value=0.7, label="Confidence threshold"),
gr.Slider(minimum=0, maximum=1, value=0.3, label="IoU threshold"),
gr.Radio(["15m", "30m", "60m", "90m"], value="15m", label="Height", info="The image taken height"),
],
outputs=[
gr.Image(type="numpy", label="Result"),
gr.Image(type="numpy", label="BarChart")
],
title="Waterfowl detection with YOLONAS",
description="Upload images for Waterfowl object detection.",
)
iface.launch(share=True)