File size: 339 Bytes
9b4ed9c 80427a0 |
1 2 3 4 5 6 7 8 9 10 11 |
import torch
def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16):
_step = W_q.shape[0]
W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
W_r[:_step] = (W_q & 0b11110000) >> 4
W_r[_step:] = W_q & 0b00001111
W_r.sub_(zero).mul_(scale)
return W_r.reshape(orig_shape)
|