File size: 20,253 Bytes
5096607 4e45d68 5096607 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 |
###########################################################################
# Computer vision - Binary neural networks demo software by HyperbeeAI. #
# Copyrights © 2023 Hyperbee.AI Inc. All rights reserved. [email protected] #
###########################################################################
import torch, sys
import torch.nn as nn
import numpy as np
from torch.autograd import Function
from functions import quantization, clamping_qa, clamping_hw, calc_out_shift
###################################################
### Base layer for conv/linear,
### enabling quantization-related mechanisms
class shallow_base_layer(nn.Module):
def __init__(
self,
quantization_mode = 'fpt', # 'fpt', 'qat', 'qat_ap' and 'eval'
pooling_flag = None, # boolean flag for now, only maxpooling of 2-pools with stride 2
operation_module = None, # torch nn module for keeping and updating conv/linear parameters
operation_fcnl = None, # torch nn.functional for actually doing the operation
activation_module = None, # torch nn module for relu/abs
batchnorm_module = None, # torch nn module for batchnorm, see super
output_width_30b = False # boolean flag that chooses between "bigdata" (32b) and normal (8b) activation modes for MAX78000
):
super().__init__()
###############################################################################
# Initialize stuff that won't change throughout the model's lifetime here
# since this place will only be run once (first time the model is declared)
if(pooling_flag==True):
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
else:
self.pool = None
### Burak: we have to access and change (forward pass) and also train (backward pass) parameters .weight and .bias for the operations
### therefore we keep both a functional and a module for Conv2d/Linear. The name "op" is mandatory for keeping params in Maxim
### checkpoint format.
self.op = operation_module
self.op_fcn = operation_fcnl
self.act = activation_module
self.bn = batchnorm_module
self.wide = output_width_30b
###############################################################################
# Initialize stuff that will change during mode progression (FPT->QAT->Eval/HW).
self.mode = quantization_mode;
self.quantize_Q_ud_8b = None
self.quantize_Q_ud_wb = None
self.quantize_Q_ud_bb = None
self.quantize_Q_ud_ap = None
self.quantize_Q_d_8b = None
self.quantize_Q_u_wb = None
self.quantize_Q_ud_wide = None
self.quantize_Q_d_wide = None
self.clamp_C_qa_8b = None
self.clamp_C_qa_bb = None
self.clamp_C_qa_wb = None
self.clamp_C_hw_8b = None
self.clamp_C_qa_wide = None
self.clamp_C_hw_wide = None
### Burak: these aren't really trainable parameters, but they're logged in the Maxim checkpoint format. It seems they marked
### them as "non-trainable parameters" to get them automatically saved in the state_dict
self.output_shift = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False) ### Burak: we called this los, this varies, default:0
self.weight_bits = nn.Parameter(torch.Tensor([ 8 ]), requires_grad=False) ### Burak: we called this wb, this varies, default:8
self.bias_bits = nn.Parameter(torch.Tensor([ 8 ]), requires_grad=False) ### Burak: this is always 8
self.quantize_activation = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False) ### Burak: this is 0 in FPT, 1 in QAT & eval/hardware, default: fpt
self.adjust_output_shift = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False) ### Burak: this is 1 in FPT & QAT, 0 in eval/hardware, default: fpt
self.shift_quantile = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False) ### Burak: this varies, default:1 (naive)
###############################################################################
# Do first mode progression (to the default)
### Burak: this recognizes that layer configuration is done via a function,
### thus, can be done again in training time for mode progression
weight_bits = self.weight_bits
bias_bits = self.bias_bits
shift_quantile = self.shift_quantile
self.configure_layer_base( weight_bits, bias_bits, shift_quantile )
# This will be called during mode progression to set fields,
# check workflow-training-modes.png in doc for further info.
# sets functions for all modes though, not just the selected mode
def configure_layer_base(self, weight_bits, bias_bits, shift_quantile):
# quantization operators
self.quantize_Q_ud_8b = quantization(xb = 8, mode ='updown' , wide=False) # 8 here is activation bits
self.quantize_Q_ud_wb = quantization(xb = weight_bits, mode ='updown' , wide=False)
self.quantize_Q_ud_bb = quantization(xb = bias_bits, mode ='updown' , wide=False)
self.quantize_Q_ud_ap = quantization(xb = 2, mode ='updown_ap' , wide=False) # 2 here is dummy, mode antipodal overrides xb
self.quantize_Q_d_8b = quantization(xb = 8, mode ='down' , wide=False) # 8 here is activation bits
self.quantize_Q_u_wb = quantization(xb = weight_bits, mode ='up' , wide=False)
self.quantize_Q_ud_wide = quantization(xb = 8, mode ='updown' , wide=True) # 8 here is activation bits, but its wide, so check inside
self.quantize_Q_d_wide = quantization(xb = 8, mode ='down' , wide=True) # 8 here is activation bits, but its wide, so check inside
# clamping operators
self.clamp_C_qa_8b = clamping_qa(xb = 8, wide=False) # 8 here is activation bits
self.clamp_C_qa_bb = clamping_qa(xb = bias_bits, wide=False)
self.clamp_C_qa_wb = clamping_qa(xb = weight_bits, wide=False)
self.clamp_C_hw_8b = clamping_hw(xb = 8, wide=False) # 8 here is activation bits
self.clamp_C_qa_wide = clamping_qa(xb = None, wide=True) # None to avoid misleading info on the # of bits, check inside
self.clamp_C_hw_wide = clamping_hw(xb = None, wide=True) # None to avoid misleading info on the # of bits, check inside
# state variables
self.weight_bits = nn.Parameter(torch.Tensor([ weight_bits ]), requires_grad=False)
self.bias_bits = nn.Parameter(torch.Tensor([ bias_bits ]), requires_grad=False)
self.shift_quantile = nn.Parameter(torch.Tensor([ shift_quantile ]), requires_grad=False)
# This will be called during mode progression, during training
def mode_fpt2qat(self, quantization_mode):
# just fold batchnorms
if(self.bn is not None):
w_fp = self.op.weight.data
b_fp = self.op.bias.data
running_mean_mu = self.bn.running_mean
running_var = self.bn.running_var
running_stdev_sigma = torch.sqrt(running_var + 1e-20)
w_hat = w_fp * (1.0 / (running_stdev_sigma*4.0)).reshape((w_fp.shape[0],) + (1,) * (len(w_fp.shape) - 1))
b_hat = (b_fp - running_mean_mu)/(running_stdev_sigma*4.0)
self.op.weight.data = w_hat
self.op.bias.data = b_hat
self.bn = None
else:
pass
#print('This layer does not have batchnorm')
self.mode = quantization_mode;
self.quantize_activation = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False) ### Burak: this is 0 in FPT, 1 in QAT & eval/hardware
self.adjust_output_shift = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False) ### Burak: this is 1 in FPT & QAT, 0 in eval/hardware
# This will be called during mode progression after training, for eval
def mode_qat2hw(self, quantization_mode):
w_hat = self.op.weight.data
b_hat = self.op.bias.data
shift = -self.output_shift.data;
s_o = 2**(shift)
wb = self.weight_bits.data.cpu().numpy()[0]
w_clamp = [-2**(wb-1) , 2**(wb-1)-1 ]
b_clamp = [-2**(wb+8-2), 2**(wb+8-2)-1] # 8 here is activation bits
w = w_hat.mul(2**(wb -1)).mul(s_o).add(0.5).floor()
w = w.clamp(min=w_clamp[0],max=w_clamp[1])
b = b_hat.mul(2**(wb -1 + 7)).mul(s_o).add(0.5).floor()
b = b.clamp(min=b_clamp[0],max=b_clamp[1])
self.op.weight.data = w
self.op.bias.data = b
self.mode = quantization_mode;
self.quantize_activation = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False) ### Burak: this is 0 in FPT, 1 in QAT & eval/hardware
self.adjust_output_shift = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False) ### Burak: this is 1 in FPT & QAT, 0 in eval/hardware
def mode_qat_ap2hw(self, quantization_mode):
w_hat = self.op.weight.data
b_hat = self.op.bias.data
shift = -self.output_shift.data;
s_o = 2**(shift)
wb = self.weight_bits.data.cpu().numpy()[0]
if(wb==2):
w = self.quantize_Q_ud_ap(w_hat).mul(2.0)
else:
w_clamp = [-2**(wb-1) , 2**(wb-1)-1 ]
w = w_hat.mul(2**(wb -1)).mul(s_o).add(0.5).floor()
w = w.clamp(min=w_clamp[0],max=w_clamp[1])
b_clamp = [-2**(wb+8-2), 2**(wb+8-2)-1] # 8 here is activation bits
b = b_hat.mul(2**(wb -1 + 7)).mul(s_o).add(0.5).floor()
b = b.clamp(min=b_clamp[0],max=b_clamp[1])
self.op.weight.data = w
self.op.bias.data = b
self.mode = quantization_mode;
self.quantize_activation = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False) ### Burak: this is 0 in FPT, 1 in QAT & eval/hardware
self.adjust_output_shift = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False) ### Burak: this is 1 in FPT & QAT, 0 in eval/hardware
def forward(self, x):
if(self.pool is not None):
x = self.pool(x)
if(self.mode == 'fpt'):
# pre-compute stuff
w_fp = self.op.weight
b_fp = self.op.bias
# actual forward pass
x = self.op_fcn(x, w_fp, b_fp, self.op.stride, self.op.padding)
if(self.bn is not None):
x = self.bn(x) # make sure var=1 and mean=0
x = x / 4.0 # since BN is only making sure var=1 and mean=0, 1/4 is to keep everything within [-1,1] w/ hi prob.
if(self.act is not None):
x = self.act(x)
if((self.wide) and (self.act is None)):
x = self.clamp_C_qa_wide(x)
else:
x = self.clamp_C_qa_8b(x)
# save stuff (los is deactivated in fpt)
self.output_shift = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False) # functional, used in Maxim-friendly checkpoints
self.quantize_activation = nn.Parameter(torch.Tensor([ 0 ]), requires_grad=False) # ceremonial, for Maxim-friendly checkpoints
self.adjust_output_shift = nn.Parameter(torch.Tensor([ 1 ]), requires_grad=False) # ceremonial, for Maxim-friendly checkpoints
elif(self.mode == 'qat'):
###############################################################################
## ASSUMPTION: batchnorms are already folded before coming here. Check doc, ##
## the parameters with _fp and with _hat are of different magnitude ##
###############################################################################
# pre-compute stuff
w_hat = self.op.weight
b_hat = self.op.bias
los = calc_out_shift(w_hat.detach(), b_hat.detach(), self.shift_quantile.detach())
s_w = 2**(-los)
s_o = 2**(los)
w_hat_q = self.clamp_C_qa_wb(self.quantize_Q_ud_wb(w_hat*s_w));
b_hat_q = self.clamp_C_qa_bb(self.quantize_Q_ud_bb(b_hat*s_w));
# actual forward pass
x = self.op_fcn(x, w_hat_q, b_hat_q, self.op.stride, self.op.padding)
x = x*s_o
if(self.act is not None):
x = self.act(x)
if((self.wide) and (self.act is None)):
x = self.quantize_Q_ud_wide(x)
x = self.clamp_C_qa_wide(x)
else:
x = self.quantize_Q_ud_8b(x)
x = self.clamp_C_qa_8b(x)
# save stuff
self.output_shift = nn.Parameter(torch.Tensor([ los ]), requires_grad=False) # functional, used in Maxim-friendly checkpoints
elif(self.mode == 'qat_ap'):
###############################################################################
## ASSUMPTION: batchnorms are already folded before coming here. Check doc, ##
## the parameters with _fp and with _hat are of different magnitude ##
###############################################################################
# pre-compute stuff
w_hat = self.op.weight
b_hat = self.op.bias
los = calc_out_shift(w_hat.detach(), b_hat.detach(), self.shift_quantile.detach())
s_w = 2**(-los)
s_o = 2**(los)
##############################################
# This is the only difference from qat
if(self.weight_bits.data==2):
w_hat_q = self.quantize_Q_ud_ap(w_hat*s_w);
else:
w_hat_q = self.clamp_C_qa_wb(self.quantize_Q_ud_wb(w_hat*s_w));
##############################################
b_hat_q = self.clamp_C_qa_bb(self.quantize_Q_ud_bb(b_hat*s_w));
# actual forward pass
x = self.op_fcn(x, w_hat_q, b_hat_q, self.op.stride, self.op.padding)
x = x*s_o
if(self.act is not None):
x = self.act(x)
if((self.wide) and (self.act is None)):
x = self.quantize_Q_ud_wide(x)
x = self.clamp_C_qa_wide(x)
else:
x = self.quantize_Q_ud_8b(x)
x = self.clamp_C_qa_8b(x)
# save stuff
self.output_shift = nn.Parameter(torch.Tensor([ los ]), requires_grad=False) # functional, used in Maxim-friendly checkpoints
elif(self.mode == 'eval'):
#####################################################################################
## ASSUMPTION: parameters are already converted to HW before coming here.Check doc ##
#####################################################################################
# pre-compute stuff
w = self.op.weight
b = self.op.bias
los = self.output_shift
s_o = 2**(los)
w_q = self.quantize_Q_u_wb(w);
b_q = self.quantize_Q_u_wb(b); # yes, wb, not a typo, they need to be on the same scale
# actual forward pass
x = self.op_fcn(x, w_q, b_q, self.op.stride, self.op.padding) # convolution / linear
x = x*s_o
if(self.act is not None):
x = self.act(x)
if((self.wide) and (self.act is None)):
x = self.quantize_Q_d_wide(x)
x = self.clamp_C_hw_wide(x)
else:
x = self.quantize_Q_d_8b(x)
x = self.clamp_C_hw_8b(x)
# nothing to save, this was a hardware-emulated evaluation pass
else:
print('wrong quantization mode. should have been one of {fpt, qat, eval}. exiting')
sys.exit()
return x
class conv(shallow_base_layer):
def __init__(
self,
C_in_channels = None, # number of input channels
D_out_channels = None, # number of output channels
K_kernel_dimension = None, # square kernel dimension
padding = None, # amount of pixels to pad on one side (other side is symmetrically padded too)
pooling = False, # boolean flag for now, only maxpooling of 2-pools with stride 2
batchnorm = False, # boolean flag for now, no trainable affine parameters
activation = None, # 'relu' is the only choice for now
output_width_30b = False # boolean flag that chooses between "bigdata" (32b) and normal (8b) activation modes for MAX78000
):
pooling_flag = pooling
if(activation is None):
activation_fcn = None;
elif(activation == 'relu'):
activation_fcn = nn.ReLU(inplace=True);
else:
print('wrong activation type in model. only {relu} is acceptable. exiting')
sys.exit()
### Burak: only a module is enough for BN since we neither need to access internals in forward pass, nor train anything (affine=False)
if(batchnorm):
batchnorm_mdl = nn.BatchNorm2d(D_out_channels, eps=1e-05, momentum=0.05, affine=False)
else:
batchnorm_mdl = None;
operation_mdl = nn.Conv2d(C_in_channels, D_out_channels, kernel_size=K_kernel_dimension, stride=1, padding=padding, bias=True);
operation_fcn = nn.functional.conv2d
super().__init__(
pooling_flag = pooling_flag,
activation_module = activation_fcn,
operation_module = operation_mdl,
operation_fcnl = operation_fcn,
batchnorm_module = batchnorm_mdl,
output_width_30b = output_width_30b
)
def linear_functional(x, weight, bias, _stride, _padding):
# dummy linear function that has same arguments as conv
return nn.functional.linear(x, weight, bias)
class fullyconnected(shallow_base_layer):
def __init__(
self,
in_features = None, # number of output features
out_features = None, # number of output features
pooling = False, # boolean flag for now, only maxpooling of 2-pools with stride 2
batchnorm = False, # boolean flag for now, no trainable affine parameters
activation = None, # 'relu' is the only choice for now
output_width_30b = False # boolean flag that chooses between "bigdata" (32b) and normal (8b) activation modes for MAX78000
):
pooling_flag = pooling
if(activation is None):
activation_fcn = None;
elif(activation == 'relu'):
activation_fcn = nn.ReLU(inplace=True);
else:
print('wrong activation type in model. only {relu} is acceptable. exiting')
sys.exit()
### Burak: only a module is enough for BN since we neither need to access internals in forward pass, nor train anything (affine=False)
if(batchnorm):
batchnorm_mdl = nn.BatchNorm2d(out_features, eps=1e-05, momentum=0.05, affine=False)
else:
batchnorm_mdl = None;
operation_mdl = nn.Linear(in_features, out_features, bias=True);
operation_fcn = linear_functional
super().__init__(
pooling_flag = pooling_flag,
activation_module = activation_fcn,
operation_module = operation_mdl,
operation_fcnl = operation_fcn,
batchnorm_module = batchnorm_mdl,
output_width_30b = output_width_30b
)
# Define dummy arguments to make Linear and conv compatible in shallow_base_layer.
# the name "op" here refers to op in super, i.e., in base_layer
self.op.stride = None
self.op.padding = None
|