|
|
|
|
|
|
|
|
|
import torch, sys |
|
import torch.nn as nn |
|
import numpy as np |
|
from torch.autograd import Function |
|
from functions import quantization, clamping_qa, clamping_hw, calc_out_shift |
|
|
|
|
|
|
|
|
|
class shallow_base_layer(nn.Module): |
|
def __init__( |
|
self, |
|
quantization_mode = 'fpt', |
|
pooling_flag = None, |
|
operation_module = None, |
|
operation_fcnl = None, |
|
activation_module = None, |
|
batchnorm_module = None, |
|
output_width_30b = False |
|
): |
|
super().__init__() |
|
|
|
|
|
|
|
|
|
if(pooling_flag==True): |
|
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) |
|
else: |
|
self.pool = None |
|
|
|
|
|
|
|
|
|
self.op = operation_module |
|
self.op_fcn = operation_fcnl |
|
self.act = activation_module |
|
self.bn = batchnorm_module |
|
self.wide = output_width_30b |
|
|
|
|
|
|
|
self.mode = quantization_mode; |
|
self.quantize_Q_ud_8b = None |
|
self.quantize_Q_ud_wb = None |
|
self.quantize_Q_ud_bb = None |
|
self.quantize_Q_ud_ap = None |
|
self.quantize_Q_d_8b = None |
|
self.quantize_Q_u_wb = None |
|
self.quantize_Q_ud_wide = None |
|
self.quantize_Q_d_wide = None |
|
self.clamp_C_qa_8b = None |
|
self.clamp_C_qa_bb = None |
|
self.clamp_C_qa_wb = None |
|
self.clamp_C_hw_8b = None |
|
self.clamp_C_qa_wide = None |
|
self.clamp_C_hw_wide = None |
|
|
|
|
|
|
|
self.output_shift = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False) |
|
self.weight_bits = nn.Parameter(torch.Tensor([ 8 ]), requires_grad=False) |
|
self.bias_bits = nn.Parameter(torch.Tensor([ 8 ]), requires_grad=False) |
|
self.quantize_activation = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False) |
|
self.adjust_output_shift = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False) |
|
self.shift_quantile = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False) |
|
|
|
|
|
|
|
|
|
|
|
weight_bits = self.weight_bits |
|
bias_bits = self.bias_bits |
|
shift_quantile = self.shift_quantile |
|
self.configure_layer_base( weight_bits, bias_bits, shift_quantile ) |
|
|
|
|
|
|
|
|
|
def configure_layer_base(self, weight_bits, bias_bits, shift_quantile): |
|
|
|
self.quantize_Q_ud_8b = quantization(xb = 8, mode ='updown' , wide=False) |
|
self.quantize_Q_ud_wb = quantization(xb = weight_bits, mode ='updown' , wide=False) |
|
self.quantize_Q_ud_bb = quantization(xb = bias_bits, mode ='updown' , wide=False) |
|
self.quantize_Q_ud_ap = quantization(xb = 2, mode ='updown_ap' , wide=False) |
|
self.quantize_Q_d_8b = quantization(xb = 8, mode ='down' , wide=False) |
|
self.quantize_Q_u_wb = quantization(xb = weight_bits, mode ='up' , wide=False) |
|
self.quantize_Q_ud_wide = quantization(xb = 8, mode ='updown' , wide=True) |
|
self.quantize_Q_d_wide = quantization(xb = 8, mode ='down' , wide=True) |
|
|
|
|
|
self.clamp_C_qa_8b = clamping_qa(xb = 8, wide=False) |
|
self.clamp_C_qa_bb = clamping_qa(xb = bias_bits, wide=False) |
|
self.clamp_C_qa_wb = clamping_qa(xb = weight_bits, wide=False) |
|
self.clamp_C_hw_8b = clamping_hw(xb = 8, wide=False) |
|
self.clamp_C_qa_wide = clamping_qa(xb = None, wide=True) |
|
self.clamp_C_hw_wide = clamping_hw(xb = None, wide=True) |
|
|
|
|
|
self.weight_bits = nn.Parameter(torch.Tensor([ weight_bits ]), requires_grad=False) |
|
self.bias_bits = nn.Parameter(torch.Tensor([ bias_bits ]), requires_grad=False) |
|
self.shift_quantile = nn.Parameter(torch.Tensor([ shift_quantile ]), requires_grad=False) |
|
|
|
|
|
def mode_fpt2qat(self, quantization_mode): |
|
|
|
if(self.bn is not None): |
|
w_fp = self.op.weight.data |
|
b_fp = self.op.bias.data |
|
|
|
running_mean_mu = self.bn.running_mean |
|
running_var = self.bn.running_var |
|
running_stdev_sigma = torch.sqrt(running_var + 1e-20) |
|
|
|
w_hat = w_fp * (1.0 / (running_stdev_sigma*4.0)).reshape((w_fp.shape[0],) + (1,) * (len(w_fp.shape) - 1)) |
|
b_hat = (b_fp - running_mean_mu)/(running_stdev_sigma*4.0) |
|
|
|
self.op.weight.data = w_hat |
|
self.op.bias.data = b_hat |
|
self.bn = None |
|
else: |
|
pass |
|
|
|
self.mode = quantization_mode; |
|
self.quantize_activation = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False) |
|
self.adjust_output_shift = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False) |
|
|
|
|
|
def mode_qat2hw(self, quantization_mode): |
|
w_hat = self.op.weight.data |
|
b_hat = self.op.bias.data |
|
|
|
shift = -self.output_shift.data; |
|
s_o = 2**(shift) |
|
wb = self.weight_bits.data.cpu().numpy()[0] |
|
|
|
w_clamp = [-2**(wb-1) , 2**(wb-1)-1 ] |
|
b_clamp = [-2**(wb+8-2), 2**(wb+8-2)-1] |
|
|
|
w = w_hat.mul(2**(wb -1)).mul(s_o).add(0.5).floor() |
|
w = w.clamp(min=w_clamp[0],max=w_clamp[1]) |
|
|
|
b = b_hat.mul(2**(wb -1 + 7)).mul(s_o).add(0.5).floor() |
|
b = b.clamp(min=b_clamp[0],max=b_clamp[1]) |
|
|
|
self.op.weight.data = w |
|
self.op.bias.data = b |
|
self.mode = quantization_mode; |
|
self.quantize_activation = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False) |
|
self.adjust_output_shift = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False) |
|
|
|
def mode_qat_ap2hw(self, quantization_mode): |
|
w_hat = self.op.weight.data |
|
b_hat = self.op.bias.data |
|
|
|
shift = -self.output_shift.data; |
|
s_o = 2**(shift) |
|
wb = self.weight_bits.data.cpu().numpy()[0] |
|
|
|
if(wb==2): |
|
w = self.quantize_Q_ud_ap(w_hat).mul(2.0) |
|
else: |
|
w_clamp = [-2**(wb-1) , 2**(wb-1)-1 ] |
|
w = w_hat.mul(2**(wb -1)).mul(s_o).add(0.5).floor() |
|
w = w.clamp(min=w_clamp[0],max=w_clamp[1]) |
|
|
|
b_clamp = [-2**(wb+8-2), 2**(wb+8-2)-1] |
|
b = b_hat.mul(2**(wb -1 + 7)).mul(s_o).add(0.5).floor() |
|
b = b.clamp(min=b_clamp[0],max=b_clamp[1]) |
|
|
|
self.op.weight.data = w |
|
self.op.bias.data = b |
|
self.mode = quantization_mode; |
|
self.quantize_activation = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False) |
|
self.adjust_output_shift = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False) |
|
|
|
|
|
def forward(self, x): |
|
if(self.pool is not None): |
|
x = self.pool(x) |
|
|
|
if(self.mode == 'fpt'): |
|
|
|
w_fp = self.op.weight |
|
b_fp = self.op.bias |
|
|
|
|
|
x = self.op_fcn(x, w_fp, b_fp, self.op.stride, self.op.padding) |
|
if(self.bn is not None): |
|
x = self.bn(x) |
|
x = x / 4.0 |
|
if(self.act is not None): |
|
x = self.act(x) |
|
if((self.wide) and (self.act is None)): |
|
x = self.clamp_C_qa_wide(x) |
|
else: |
|
x = self.clamp_C_qa_8b(x) |
|
|
|
|
|
self.output_shift = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False) |
|
self.quantize_activation = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False) |
|
self.adjust_output_shift = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False) |
|
|
|
elif(self.mode == 'qat'): |
|
|
|
|
|
|
|
|
|
|
|
|
|
w_hat = self.op.weight |
|
b_hat = self.op.bias |
|
los = calc_out_shift(w_hat.detach(), b_hat.detach(), self.shift_quantile.detach()) |
|
s_w = 2**(-los) |
|
s_o = 2**(los) |
|
w_hat_q = self.clamp_C_qa_wb(self.quantize_Q_ud_wb(w_hat*s_w)); |
|
b_hat_q = self.clamp_C_qa_bb(self.quantize_Q_ud_bb(b_hat*s_w)); |
|
|
|
|
|
x = self.op_fcn(x, w_hat_q, b_hat_q, self.op.stride, self.op.padding) |
|
x = x*s_o |
|
if(self.act is not None): |
|
x = self.act(x) |
|
if((self.wide) and (self.act is None)): |
|
x = self.quantize_Q_ud_wide(x) |
|
x = self.clamp_C_qa_wide(x) |
|
else: |
|
x = self.quantize_Q_ud_8b(x) |
|
x = self.clamp_C_qa_8b(x) |
|
|
|
|
|
self.output_shift = nn.Parameter(torch.Tensor([ los ]), requires_grad=False) |
|
|
|
elif(self.mode == 'qat_ap'): |
|
|
|
|
|
|
|
|
|
|
|
|
|
w_hat = self.op.weight |
|
b_hat = self.op.bias |
|
los = calc_out_shift(w_hat.detach(), b_hat.detach(), self.shift_quantile.detach()) |
|
s_w = 2**(-los) |
|
s_o = 2**(los) |
|
|
|
|
|
if(self.weight_bits.data==2): |
|
w_hat_q = self.quantize_Q_ud_ap(w_hat*s_w); |
|
else: |
|
w_hat_q = self.clamp_C_qa_wb(self.quantize_Q_ud_wb(w_hat*s_w)); |
|
|
|
b_hat_q = self.clamp_C_qa_bb(self.quantize_Q_ud_bb(b_hat*s_w)); |
|
|
|
|
|
x = self.op_fcn(x, w_hat_q, b_hat_q, self.op.stride, self.op.padding) |
|
x = x*s_o |
|
if(self.act is not None): |
|
x = self.act(x) |
|
if((self.wide) and (self.act is None)): |
|
x = self.quantize_Q_ud_wide(x) |
|
x = self.clamp_C_qa_wide(x) |
|
else: |
|
x = self.quantize_Q_ud_8b(x) |
|
x = self.clamp_C_qa_8b(x) |
|
|
|
|
|
self.output_shift = nn.Parameter(torch.Tensor([ los ]), requires_grad=False) |
|
|
|
elif(self.mode == 'eval'): |
|
|
|
|
|
|
|
|
|
|
|
w = self.op.weight |
|
b = self.op.bias |
|
los = self.output_shift |
|
s_o = 2**(los) |
|
w_q = self.quantize_Q_u_wb(w); |
|
b_q = self.quantize_Q_u_wb(b); |
|
|
|
|
|
x = self.op_fcn(x, w_q, b_q, self.op.stride, self.op.padding) |
|
x = x*s_o |
|
if(self.act is not None): |
|
x = self.act(x) |
|
if((self.wide) and (self.act is None)): |
|
x = self.quantize_Q_d_wide(x) |
|
x = self.clamp_C_hw_wide(x) |
|
else: |
|
x = self.quantize_Q_d_8b(x) |
|
x = self.clamp_C_hw_8b(x) |
|
|
|
|
|
else: |
|
print('wrong quantization mode. should have been one of {fpt, qat, eval}. exiting') |
|
sys.exit() |
|
|
|
return x |
|
|
|
|
|
class conv(shallow_base_layer): |
|
def __init__( |
|
self, |
|
C_in_channels = None, |
|
D_out_channels = None, |
|
K_kernel_dimension = None, |
|
padding = None, |
|
pooling = False, |
|
batchnorm = False, |
|
activation = None, |
|
output_width_30b = False |
|
): |
|
pooling_flag = pooling |
|
|
|
if(activation is None): |
|
activation_fcn = None; |
|
elif(activation == 'relu'): |
|
activation_fcn = nn.ReLU(inplace=True); |
|
else: |
|
print('wrong activation type in model. only {relu} is acceptable. exiting') |
|
sys.exit() |
|
|
|
|
|
if(batchnorm): |
|
batchnorm_mdl = nn.BatchNorm2d(D_out_channels, eps=1e-05, momentum=0.05, affine=False) |
|
else: |
|
batchnorm_mdl = None; |
|
|
|
operation_mdl = nn.Conv2d(C_in_channels, D_out_channels, kernel_size=K_kernel_dimension, stride=1, padding=padding, bias=True); |
|
operation_fcn = nn.functional.conv2d |
|
|
|
super().__init__( |
|
pooling_flag = pooling_flag, |
|
activation_module = activation_fcn, |
|
operation_module = operation_mdl, |
|
operation_fcnl = operation_fcn, |
|
batchnorm_module = batchnorm_mdl, |
|
output_width_30b = output_width_30b |
|
) |
|
|
|
def linear_functional(x, weight, bias, _stride, _padding): |
|
|
|
return nn.functional.linear(x, weight, bias) |
|
|
|
class fullyconnected(shallow_base_layer): |
|
def __init__( |
|
self, |
|
in_features = None, |
|
out_features = None, |
|
pooling = False, |
|
batchnorm = False, |
|
activation = None, |
|
output_width_30b = False |
|
): |
|
|
|
pooling_flag = pooling |
|
|
|
if(activation is None): |
|
activation_fcn = None; |
|
elif(activation == 'relu'): |
|
activation_fcn = nn.ReLU(inplace=True); |
|
else: |
|
print('wrong activation type in model. only {relu} is acceptable. exiting') |
|
sys.exit() |
|
|
|
|
|
if(batchnorm): |
|
batchnorm_mdl = nn.BatchNorm2d(out_features, eps=1e-05, momentum=0.05, affine=False) |
|
else: |
|
batchnorm_mdl = None; |
|
|
|
operation_mdl = nn.Linear(in_features, out_features, bias=True); |
|
operation_fcn = linear_functional |
|
|
|
super().__init__( |
|
pooling_flag = pooling_flag, |
|
activation_module = activation_fcn, |
|
operation_module = operation_mdl, |
|
operation_fcnl = operation_fcn, |
|
batchnorm_module = batchnorm_mdl, |
|
output_width_30b = output_width_30b |
|
) |
|
|
|
|
|
|
|
self.op.stride = None |
|
self.op.padding = None |
|
|