vikhyatk commited on
Commit
80427a0
·
verified ·
1 Parent(s): 4f19e94

Upload HfMoondream

Browse files
Files changed (3) hide show
  1. layers.py +9 -6
  2. model.safetensors +1 -1
  3. packing.py +7 -32
layers.py CHANGED
@@ -39,15 +39,17 @@ class QuantizedLinear(nn.Module):
39
  {
40
  "packed": nn.Parameter(
41
  torch.empty(
42
- out_features, in_features // 128, 64, dtype=torch.uint8
43
  ),
44
  requires_grad=False,
45
  ),
46
  "scale": nn.Parameter(
47
- torch.empty(out_features, in_features // 128), requires_grad=False
 
48
  ),
49
  "zero_point": nn.Parameter(
50
- torch.empty(out_features, in_features // 128), requires_grad=False
 
51
  ),
52
  }
53
  )
@@ -57,13 +59,13 @@ class QuantizedLinear(nn.Module):
57
  def unpack(self):
58
  if self.unpacked:
59
  return
 
60
  self.weight = nn.Parameter(
61
  dequantize_tensor(
62
  self.weight["packed"],
63
  self.weight["scale"],
64
  self.weight["zero_point"],
65
- (self.weight["packed"].shape[0], self.weight["packed"].shape[1] * 128),
66
- 128,
67
  torch.bfloat16,
68
  )
69
  )
@@ -75,10 +77,11 @@ class QuantizedLinear(nn.Module):
75
  self.linear.bias = nn.Parameter(
76
  self.bias.to(torch.bfloat16), requires_grad=False
77
  )
 
78
  del self.weight, self.bias
79
  quantize_(self, int4_weight_only(group_size=128))
80
- torch.cuda.empty_cache()
81
  self.unpacked = True
 
82
 
83
  def forward(self, x: torch.Tensor) -> torch.Tensor:
84
  if not self.unpacked:
 
39
  {
40
  "packed": nn.Parameter(
41
  torch.empty(
42
+ out_features * in_features // (128 * 2), 128, dtype=torch.uint8
43
  ),
44
  requires_grad=False,
45
  ),
46
  "scale": nn.Parameter(
47
+ torch.empty(out_features * in_features // 128, 1),
48
+ requires_grad=False,
49
  ),
50
  "zero_point": nn.Parameter(
51
+ torch.empty(out_features * in_features // 128, 1),
52
+ requires_grad=False,
53
  ),
54
  }
55
  )
 
59
  def unpack(self):
60
  if self.unpacked:
61
  return
62
+
63
  self.weight = nn.Parameter(
64
  dequantize_tensor(
65
  self.weight["packed"],
66
  self.weight["scale"],
67
  self.weight["zero_point"],
68
+ (self.out_features, self.in_features),
 
69
  torch.bfloat16,
70
  )
71
  )
 
77
  self.linear.bias = nn.Parameter(
78
  self.bias.to(torch.bfloat16), requires_grad=False
79
  )
80
+
81
  del self.weight, self.bias
82
  quantize_(self, int4_weight_only(group_size=128))
 
83
  self.unpacked = True
84
+ torch.cuda.empty_cache()
85
 
86
  def forward(self, x: torch.Tensor) -> torch.Tensor:
87
  if not self.unpacked:
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dfce186edf359fff98d0c077ae389b980b6cae99279d157fc00b2d03ca65968f
3
  size 2032380848
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b839cdbd6716eef6242536929c05243d58af49929a12c198d3913caa05c7c3ee
3
  size 2032380848
packing.py CHANGED
@@ -1,35 +1,10 @@
1
  import torch
2
 
3
 
4
- def unpack_int4(packed: torch.Tensor, original_length: int) -> torch.Tensor:
5
- orig_shape = packed.shape
6
- last_dim = orig_shape[-1]
7
- batch_shape = orig_shape[:-1]
8
- flat_packed = packed.reshape(-1, last_dim)
9
- batch_size = flat_packed.shape[0]
10
- flat_bytes = flat_packed.reshape(-1)
11
- lower = flat_bytes & 0xF
12
- upper = (flat_bytes >> 4) & 0xF
13
- unpacked = torch.stack([lower, upper], dim=1).reshape(batch_size, last_dim * 2)
14
- unpacked = unpacked[:, :original_length]
15
- unpacked = unpacked.reshape(*batch_shape, original_length)
16
- return unpacked.to(torch.int8)
17
-
18
-
19
- def dequantize_tensor(
20
- packed: torch.Tensor,
21
- scales: torch.Tensor,
22
- zero_points: torch.Tensor,
23
- orig_shape: torch.Size,
24
- block_size: int,
25
- dtype: torch.dtype = torch.bfloat16,
26
- ):
27
- out_features, num_blocks, _ = packed.shape
28
- unpacked = unpack_int4(packed, block_size)
29
- scales_view = scales.unsqueeze(2) # Shape: [out_features, num_blocks, 1]
30
- zero_points_view = zero_points.unsqueeze(2) # Shape: [out_features, num_blocks, 1]
31
- dequantized = (unpacked.float() - zero_points_view) * scales_view
32
- dequantized = dequantized.reshape(out_features, num_blocks * block_size)
33
- dequantized = dequantized[:, : orig_shape[1]]
34
- dequantized = dequantized.reshape(orig_shape)
35
- return dequantized.to(dtype)
 
1
  import torch
2
 
3
 
4
+ def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16):
5
+ _step = W_q.shape[0]
6
+ W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
7
+ W_r[:_step] = (W_q & 0b11110000) >> 4
8
+ W_r[_step:] = W_q & 0b00001111
9
+ W_r.sub_(zero).mul_(scale)
10
+ return W_r.reshape(orig_shape)