File size: 339 Bytes
9b4ed9c
 
 
80427a0
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
import torch


def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16):
    _step = W_q.shape[0]
    W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
    W_r[:_step] = (W_q & 0b11110000) >> 4
    W_r[_step:] = W_q & 0b00001111
    W_r.sub_(zero).mul_(scale)
    return W_r.reshape(orig_shape)