import torch def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16): _step = W_q.shape[0] W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device) W_r[:_step] = (W_q & 0b11110000) >> 4 W_r[_step:] = W_q & 0b00001111 W_r.sub_(zero).mul_(scale) return W_r.reshape(orig_shape)