Spaces:
Running
Running
'''Some helper functions for PyTorch.''' | |
import os | |
import sys | |
import time | |
import math | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
def get_sub_image(mega_image,overlap=0.2,ratio=1,crop_size=512): | |
#mage_image: original image | |
#ratio: ratio * 512 counter the different heights of image taken | |
#return: list of sub image and list fo the upper left corner of sub image | |
coor_list = [] | |
sub_image_list = [] | |
w,h,c = mega_image.shape | |
if w < crop_size or h < crop_size: | |
mega_image = image_padding(mega_image) | |
size = int(ratio*crop_size) | |
num_rows = int(w/int(size*(1-overlap))) | |
num_cols = int(h/int(size*(1-overlap))) | |
new_size = int(size*(1-overlap)) | |
for i in range(num_rows+1): | |
if (i == num_rows): | |
for j in range(num_cols+1): | |
if (j==num_cols): | |
sub_image = mega_image[-size:,-size:,:] | |
coor_list.append([w-size,h-size]) | |
sub_image_list.append (sub_image) | |
else: | |
sub_image = mega_image[-size:,new_size*j:new_size*j+size,:] | |
coor_list.append([w-size,new_size*j]) | |
sub_image_list.append (sub_image) | |
else: | |
for j in range(num_cols+1): | |
if (j==num_cols): | |
sub_image = mega_image[new_size*i:new_size*i+size,-size:,:] | |
coor_list.append([new_size*i,h-size]) | |
sub_image_list.append (sub_image) | |
else: | |
sub_image = mega_image[new_size*i:new_size*i+size,new_size*j:new_size*j+size,:] | |
coor_list.append([new_size*i,new_size*j]) | |
sub_image_list.append (sub_image) | |
return sub_image_list,coor_list | |
def image_padding(mega_image): | |
w,h,c = mega_image.shape | |
result = np.full((max(512,h),max(512,w), 3), (0,0,0), dtype=np.uint8) | |
result[0:h,0:w] = mega_image | |
return result | |
def py_cpu_nms(dets, thresh): | |
"""Pure Python NMS baseline.""" | |
dets = np.asarray(dets) | |
x1 = dets[:, 0] | |
y1 = dets[:, 1] | |
x2 = dets[:, 2] | |
y2 = dets[:, 3] | |
scores = dets[:, 4] | |
areas = (x2 - x1 + 1) * (y2 - y1 + 1) | |
order = scores.argsort()[::-1] | |
keep = [] | |
while order.size > 0: | |
i = order[0] | |
keep.append(i) | |
xx1 = np.maximum(x1[i], x1[order[1:]]) | |
yy1 = np.maximum(y1[i], y1[order[1:]]) | |
xx2 = np.minimum(x2[i], x2[order[1:]]) | |
yy2 = np.minimum(y2[i], y2[order[1:]]) | |
w = np.maximum(0.0, xx2 - xx1 + 1) | |
h = np.maximum(0.0, yy2 - yy1 + 1) | |
inter = w * h | |
ovr = inter / (areas[i] + areas[order[1:]] - inter) | |
inds = np.where(ovr <= thresh)[0] | |
order = order[inds + 1] | |
return keep | |
def sort_key(row): | |
return row[-1] | |
def filter_small_fp(bbox_list): | |
"""Remove small predictions""" | |
bbox_area_list = [] | |
new_bbox_list = [] | |
bbox_list.sort(key = sort_key,reverse=True) | |
for bbox in bbox_list[0:max(int(0.05*len(bbox_list)),1)]: | |
bbox_area_list.append((bbox[2]-bbox[0])*(bbox[3]-bbox[1])) | |
print(len(bbox_area_list)) | |
average_area = np.mean(bbox_area_list) | |
for bbox in bbox_list: | |
bbox_area = (bbox[2]-bbox[0])*(bbox[3]-bbox[1]) | |
if abs(bbox_area-average_area)/average_area < 0.8: | |
new_bbox_list.append(bbox) | |
return new_bbox_list | |
def get_mean_and_std(dataset, max_load=10000): | |
'''Compute the mean and std value of dataset.''' | |
# dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) | |
mean = torch.zeros(3) | |
std = torch.zeros(3) | |
print('==> Computing mean and std..') | |
N = min(max_load, len(dataset)) | |
for i in range(N): | |
print(i) | |
im,_,_ = dataset.load(1) | |
for j in range(3): | |
mean[j] += im[:,j,:,:].mean() | |
std[j] += im[:,j,:,:].std() | |
mean.div_(N) | |
std.div_(N) | |
return mean, std | |
def mask_select(input, mask, dim=0): | |
'''Select tensor rows/cols using a mask tensor. | |
Args: | |
input: (tensor) input tensor, sized [N,M]. | |
mask: (tensor) mask tensor, sized [N,] or [M,]. | |
dim: (tensor) mask dim. | |
Returns: | |
(tensor) selected rows/cols. | |
Example: | |
>>> a = torch.randn(4,2) | |
>>> a | |
-0.3462 -0.6930 | |
0.4560 -0.7459 | |
-0.1289 -0.9955 | |
1.7454 1.9787 | |
[torch.FloatTensor of size 4x2] | |
>>> i = a[:,0] > 0 | |
>>> i | |
0 | |
1 | |
0 | |
1 | |
[torch.ByteTensor of size 4] | |
>>> masked_select(a, i, 0) | |
0.4560 -0.7459 | |
1.7454 1.9787 | |
[torch.FloatTensor of size 2x2] | |
''' | |
index = mask.nonzero().squeeze(1) | |
return input.index_select(dim, index) | |
def meshgrid(x, y, row_major=True): | |
'''Return meshgrid in range x & y. | |
Args: | |
x: (int) first dim range. | |
y: (int) second dim range. | |
row_major: (bool) row major or column major. | |
Returns: | |
(tensor) meshgrid, sized [x*y,2] | |
Example: | |
>> meshgrid(3,2) | |
0 0 | |
1 0 | |
2 0 | |
0 1 | |
1 1 | |
2 1 | |
[torch.FloatTensor of size 6x2] | |
>> meshgrid(3,2,row_major=False) | |
0 0 | |
0 1 | |
0 2 | |
1 0 | |
1 1 | |
1 2 | |
[torch.FloatTensor of size 6x2] | |
''' | |
a = torch.arange(0,x) | |
b = torch.arange(0,y) | |
xx = a.repeat(y).view(-1,1) | |
yy = b.view(-1,1).repeat(1,x).view(-1,1) | |
return torch.cat([xx,yy],1) if row_major else torch.cat([yy,xx],1) | |
def change_box_order(boxes, order): | |
'''Change box order between (xmin,ymin,xmax,ymax) and (xcenter,ycenter,width,height). | |
Args: | |
boxes: (tensor) bounding boxes, sized [N,4]. | |
order: (str) either 'xyxy2xywh' or 'xywh2xyxy'. | |
Returns: | |
(tensor) converted bounding boxes, sized [N,4]. | |
''' | |
assert order in ['xyxy2xywh','xywh2xyxy'] | |
a = boxes[:,:2] | |
b = boxes[:,2:] | |
if order == 'xyxy2xywh': | |
return torch.cat([(a+b)/2,b-a+1], 1) | |
return torch.cat([a-b/2,a+b/2], 1) | |
def box_iou(box1, box2, order='xyxy'): | |
'''Compute the intersection over union of two set of boxes. | |
The default box order is (xmin, ymin, xmax, ymax). | |
Args: | |
box1: (tensor) bounding boxes, sized [N,4]. | |
box2: (tensor) bounding boxes, sized [M,4]. | |
order: (str) box order, either 'xyxy' or 'xywh'. | |
Return: | |
(tensor) iou, sized [N,M]. | |
Reference: | |
https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py | |
''' | |
if order == 'xywh': | |
box1 = change_box_order(box1, 'xywh2xyxy') | |
box2 = change_box_order(box2, 'xywh2xyxy') | |
N = box1.size(0) | |
M = box2.size(0) | |
lt = torch.max(box1[:,None,:2], box2[:,:2]) # [N,M,2] | |
rb = torch.min(box1[:,None,2:], box2[:,2:]) # [N,M,2] | |
wh = (rb-lt+1).clamp(min=0) # [N,M,2] | |
inter = wh[:,:,0] * wh[:,:,1] # [N,M] | |
area1 = (box1[:,2]-box1[:,0]+1) * (box1[:,3]-box1[:,1]+1) # [N,] | |
area2 = (box2[:,2]-box2[:,0]+1) * (box2[:,3]-box2[:,1]+1) # [M,] | |
iou = inter / (area1[:,None] + area2 - inter) | |
return iou | |
def box_nms(bboxes, scores, threshold=0.5, mode='union'): | |
'''Non maximum suppression. | |
Args: | |
bboxes: (tensor) bounding boxes, sized [N,4]. | |
scores: (tensor) bbox scores, sized [N,]. | |
threshold: (float) overlap threshold. | |
mode: (str) 'union' or 'min'. | |
Returns: | |
keep: (tensor) selected indices. | |
Reference: | |
https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/nms/py_cpu_nms.py | |
''' | |
#print (bboxes.shape,scores.shape) | |
if (len(bboxes.shape)==1): | |
bboxes = bboxes.unsqueeze(0) | |
x1 = bboxes[:,0] | |
y1 = bboxes[:,1] | |
x2 = bboxes[:,2] | |
y2 = bboxes[:,3] | |
areas = (x2-x1+1) * (y2-y1+1) | |
_, order = scores.sort(0, descending=True) | |
keep = [] | |
while order.numel() > 0: | |
if order.numel() == 1: | |
i = order.item() | |
else: | |
i = order.data[0] | |
keep.append(i) | |
if order.numel() == 1: | |
break | |
xx1 = x1[order[1:]].clamp(min=x1[i]) | |
yy1 = y1[order[1:]].clamp(min=y1[i]) | |
xx2 = x2[order[1:]].clamp(max=x2[i]) | |
yy2 = y2[order[1:]].clamp(max=y2[i]) | |
w = (xx2-xx1+1).clamp(min=0) | |
h = (yy2-yy1+1).clamp(min=0) | |
inter = w*h | |
if mode == 'union': | |
ovr = inter / (areas[i] + areas[order[1:]] - inter) | |
elif mode == 'min': | |
ovr = inter / areas[order[1:]].clamp(max=areas[i]) | |
else: | |
raise TypeError('Unknown nms mode: %s.' % mode) | |
ids = (ovr<=threshold).nonzero().squeeze() | |
if ids.numel() == 0: | |
break | |
order = order[ids+1] | |
return torch.LongTensor(keep) | |
def softmax(x): | |
'''Softmax along a specific dimension. | |
Args: | |
x: (tensor) input tensor, sized [N,D]. | |
Returns: | |
(tensor) softmaxed tensor, sized [N,D]. | |
''' | |
xmax, _ = x.max(1) | |
x_shift = x - xmax.view(-1,1) | |
x_exp = x_shift.exp() | |
return x_exp / x_exp.sum(1).view(-1,1) | |
def one_hot_embedding(labels, num_classes): | |
'''Embedding labels to one-hot form. | |
Args: | |
labels: (LongTensor) class labels, sized [N,]. | |
num_classes: (int) number of classes. | |
Returns: | |
(tensor) encoded labels, sized [N,#classes]. | |
''' | |
y = torch.eye(num_classes) # [D,D] | |
return y[labels] # [N,D] | |
def msr_init(net): | |
'''Initialize layer parameters.''' | |
for layer in net: | |
if type(layer) == nn.Conv2d: | |
n = layer.kernel_size[0]*layer.kernel_size[1]*layer.out_channels | |
layer.weight.data.normal_(0, math.sqrt(2./n)) | |
layer.bias.data.zero_() | |
elif type(layer) == nn.BatchNorm2d: | |
layer.weight.data.fill_(1) | |
layer.bias.data.zero_() | |
elif type(layer) == nn.Linear: | |
layer.bias.data.zero_() | |
#_, term_width = os.popen('stty size', 'r').read().split() | |
term_width = 80 | |
TOTAL_BAR_LENGTH = 86. | |
last_time = time.time() | |
begin_time = last_time | |
def progress_bar(current, total, msg=None): | |
global last_time, begin_time | |
if current == 0: | |
begin_time = time.time() # Reset for new bar. | |
cur_len = int(TOTAL_BAR_LENGTH*current/total) | |
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 | |
sys.stdout.write(' [') | |
for i in range(cur_len): | |
sys.stdout.write('=') | |
sys.stdout.write('>') | |
for i in range(rest_len): | |
sys.stdout.write('.') | |
sys.stdout.write(']') | |
cur_time = time.time() | |
step_time = cur_time - last_time | |
last_time = cur_time | |
tot_time = cur_time - begin_time | |
L = [] | |
L.append(' Step: %s' % format_time(step_time)) | |
L.append(' | Tot: %s' % format_time(tot_time)) | |
if msg: | |
L.append(' | ' + msg) | |
msg = ''.join(L) | |
sys.stdout.write(msg) | |
for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): | |
sys.stdout.write(' ') | |
# Go back to the center of the bar. | |
for i in range(term_width-int(TOTAL_BAR_LENGTH/2)): | |
sys.stdout.write('\b') | |
sys.stdout.write(' %d/%d ' % (current+1, total)) | |
if current < total-1: | |
sys.stdout.write('\r') | |
else: | |
sys.stdout.write('\n') | |
sys.stdout.flush() | |
def format_time(seconds): | |
days = int(seconds / 3600/24) | |
seconds = seconds - days*3600*24 | |
hours = int(seconds / 3600) | |
seconds = seconds - hours*3600 | |
minutes = int(seconds / 60) | |
seconds = seconds - minutes*60 | |
secondsf = int(seconds) | |
seconds = seconds - secondsf | |
millis = int(seconds*1000) | |
f = '' | |
i = 1 | |
if days > 0: | |
f += str(days) + 'D' | |
i += 1 | |
if hours > 0 and i <= 2: | |
f += str(hours) + 'h' | |
i += 1 | |
if minutes > 0 and i <= 2: | |
f += str(minutes) + 'm' | |
i += 1 | |
if secondsf > 0 and i <= 2: | |
f += str(secondsf) + 's' | |
i += 1 | |
if millis > 0 and i <= 2: | |
f += str(millis) + 'ms' | |
i += 1 | |
if f == '': | |
f = '0ms' | |
return f | |