vikhyatk commited on
Commit
9b4ed9c
·
verified ·
1 Parent(s): 4489767

Upload HfMoondream

Browse files
Files changed (9) hide show
  1. config.json +1 -1
  2. config.py +2 -2
  3. layers.py +52 -51
  4. model.safetensors +2 -2
  5. moondream.py +29 -24
  6. packing.py +52 -0
  7. region.py +3 -3
  8. text.py +87 -36
  9. vision.py +3 -3
config.json CHANGED
@@ -8,6 +8,6 @@
8
  },
9
  "config": {},
10
  "model_type": "moondream1",
11
- "torch_dtype": "float16",
12
  "transformers_version": "4.44.0"
13
  }
 
8
  },
9
  "config": {},
10
  "model_type": "moondream1",
11
+ "torch_dtype": "bfloat16",
12
  "transformers_version": "4.44.0"
13
  }
config.py CHANGED
@@ -12,7 +12,7 @@ class TextConfig:
12
  n_heads: int = 32
13
  n_kv_heads: int = 32
14
  prefix_attn: int = 730
15
- group_size: int = 128
16
 
17
 
18
  @dataclass(frozen=True)
@@ -38,7 +38,7 @@ class RegionConfig:
38
  size_feat_dim: int = 512
39
  size_out_dim: int = 2048
40
  inner_dim: int = 8192
41
-
42
 
43
  @dataclass(frozen=True)
44
  class TokenizerConfig:
 
12
  n_heads: int = 32
13
  n_kv_heads: int = 32
14
  prefix_attn: int = 730
15
+ group_size: Optional[int] = 128
16
 
17
 
18
  @dataclass(frozen=True)
 
38
  size_feat_dim: int = 512
39
  size_out_dim: int = 2048
40
  inner_dim: int = 8192
41
+ group_size: Optional[int] = 128
42
 
43
  @dataclass(frozen=True)
44
  class TokenizerConfig:
layers.py CHANGED
@@ -1,11 +1,13 @@
1
- import bitblas
2
  import torch
3
  import torch.nn as nn
 
4
 
5
  from dataclasses import dataclass
6
  from typing import Literal
7
- from bitblas.cache import OperatorCache
8
- from torch.nn import functional as F
 
 
9
 
10
 
11
  def gelu_approx(x):
@@ -18,65 +20,65 @@ class LinearWeights:
18
  bias: torch.Tensor
19
 
20
 
21
- class Linear(nn.Module):
22
- """
23
- Linear layer with support for bitblas quantization.
24
- If dtype is torch.int8, it uses bitblas for quantization.
25
- Otherwise, it uses a standard nn.Linear layer.
26
- """
27
 
 
28
  def __init__(
29
  self,
30
  in_features: int,
31
  out_features: int,
32
- bias: bool = True,
33
- dtype: torch.dtype = None,
34
- group_size: int = 128,
35
  ):
 
36
  super().__init__()
37
-
38
- if dtype == torch.int8:
39
- self.linear = bitblas.Linear(
40
- in_features=in_features,
41
- out_features=out_features,
42
- bias=bias,
43
- with_zeros=True,
44
- zeros_mode="original",
45
- with_scaling=True,
46
- A_dtype="float16",
47
- W_dtype="uint4",
48
- accum_dtype="float16",
49
- out_dtype="float16",
50
- fast_decoding=True,
51
- enable_tuning=True,
52
- group_size=group_size,
 
 
 
 
 
 
 
 
 
 
53
  )
54
- else:
 
55
  self.linear = nn.Linear(
56
- in_features=in_features,
57
- out_features=out_features,
58
- bias=bias,
59
- dtype=torch.float16,
60
  )
61
-
62
- def forward(self, x):
 
 
 
 
 
 
 
 
 
 
63
  return self.linear(x)
64
 
65
- @property
66
- def weight(self) -> torch.Tensor:
67
- try:
68
- return self.linear.weight
69
- except AttributeError:
70
- return self.linear.qweight
71
-
72
- @property
73
- def bias(self) -> torch.Tensor:
74
- return self.linear.bias
75
-
76
-
77
- def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
78
- return F.linear(x, w.weight, w.bias)
79
-
80
 
81
  @dataclass
82
  class LayerNormWeights:
@@ -96,7 +98,6 @@ class MLPWeights:
96
 
97
 
98
  def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
99
-
100
  x = w.fc1(x)
101
  x = gelu_approx(x)
102
  x = w.fc2(x)
 
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
 
5
  from dataclasses import dataclass
6
  from typing import Literal
7
+ from torchao import quantize_
8
+ from torchao.quantization import int4_weight_only
9
+
10
+ from .packing import dequantize_tensor
11
 
12
 
13
  def gelu_approx(x):
 
20
  bias: torch.Tensor
21
 
22
 
23
+ def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
24
+ return F.linear(x, w.weight, w.bias)
25
+
 
 
 
26
 
27
+ class QuantizedLinear(nn.Module):
28
  def __init__(
29
  self,
30
  in_features: int,
31
  out_features: int,
32
+ dtype: torch.dtype,
 
 
33
  ):
34
+ # TODO: Take group_size as an input instead of hardcoding it here.
35
  super().__init__()
36
+ self.in_features = in_features
37
+ self.out_features = out_features
38
+ self.weight = nn.ParameterDict(
39
+ {
40
+ "packed": nn.Parameter(
41
+ torch.empty(
42
+ out_features, in_features // 128, 64, dtype=torch.uint8
43
+ ),
44
+ requires_grad=False,
45
+ ),
46
+ "scales": nn.Parameter(
47
+ torch.empty(out_features, in_features // 128), requires_grad=False
48
+ ),
49
+ }
50
+ )
51
+ self.bias = nn.Parameter(torch.empty(out_features), requires_grad=False)
52
+ self.unpacked = False
53
+
54
+ def unpack(self):
55
+ self.weight = nn.Parameter(
56
+ dequantize_tensor(
57
+ self.weight["packed"],
58
+ self.weight["scales"],
59
+ (self.weight["packed"].shape[0], self.weight["packed"].shape[1] * 128),
60
+ 128,
61
+ torch.bfloat16,
62
  )
63
+ )
64
+ with torch.device("meta"):
65
  self.linear = nn.Linear(
66
+ self.in_features, self.out_features, dtype=torch.bfloat16
 
 
 
67
  )
68
+ self.linear.weight = self.weight
69
+ self.linear.bias = nn.Parameter(
70
+ self.bias.to(torch.bfloat16), requires_grad=False
71
+ )
72
+ del self.weight, self.bias
73
+ quantize_(self, int4_weight_only(group_size=128))
74
+ torch.cuda.empty_cache()
75
+ self.unpacked = True
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ if not self.unpacked:
79
+ self.unpack()
80
  return self.linear(x)
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  @dataclass
84
  class LayerNormWeights:
 
98
 
99
 
100
  def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
 
101
  x = w.fc1(x)
102
  x = gelu_approx(x)
103
  x = w.fc2(x)
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:73e9da0d1091d61630477994669a22011c830c7539e27e659fb63a4d6818f8a2
3
- size 2080370912
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97076df1a9a09ff4108a69ea59b4c9abf522b248e8425c9334bab98ddbaf4b33
3
+ size 1838828672
moondream.py CHANGED
@@ -12,6 +12,7 @@ from .image_crops import reconstruct_from_crops
12
  from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
13
  from .text import build_text_model, text_encoder, lm_head, text_decoder
14
  from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
 
15
  from .utils import remove_outlier_points
16
 
17
 
@@ -63,47 +64,49 @@ class KVCache(nn.Module):
63
 
64
 
65
  class MoondreamModel(nn.Module):
66
- def __init__(self, config: MoondreamConfig, dtype=torch.float16, setup_caches=True):
 
 
 
67
  super().__init__()
68
  self.config = config
69
- self.dtype = dtype
70
- self.setup_caches_flag = setup_caches
71
 
72
  self.tokenizer = Tokenizer.from_pretrained(
73
  "vikhyatk/moondream2", revision="2025-01-09"
74
  )
75
-
76
  self.vision = build_vision_model(config.vision, dtype)
77
-
78
- self.text = build_text_model(config.text, torch.int8)
79
 
80
  # Region Model
 
 
 
81
  self.region = nn.ModuleDict(
82
  {
83
- "coord_encoder": nn.Linear(
84
  config.region.coord_feat_dim, config.region.dim, dtype=dtype
85
  ),
86
  "coord_decoder": nn.ModuleDict(
87
  {
88
- "fc1": nn.Linear(
89
  config.region.dim, config.region.inner_dim, dtype=dtype
90
  ),
91
- "fc2": nn.Linear(
92
  config.region.inner_dim,
93
  config.region.coord_out_dim,
94
  dtype=dtype,
95
  ),
96
  }
97
  ),
98
- "size_encoder": nn.Linear(
99
  config.region.size_feat_dim, config.region.dim, dtype=dtype
100
  ),
101
  "size_decoder": nn.ModuleDict(
102
  {
103
- "fc1": nn.Linear(
104
  config.region.dim, config.region.inner_dim, dtype=dtype
105
  ),
106
- "fc2": nn.Linear(
107
  config.region.inner_dim,
108
  config.region.size_out_dim,
109
  dtype=dtype,
@@ -129,11 +132,11 @@ class MoondreamModel(nn.Module):
129
  attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
130
  self.register_buffer("attn_mask", attn_mask, persistent=False)
131
 
132
- def _setup_caches(self):
133
- """Setup KV caches for the text model"""
134
- if self.text is None:
135
- return # Can't set up caches without text model
136
 
 
137
  c = self.config.text
138
  for b in self.text.blocks:
139
  b.kv_cache = KVCache(
@@ -166,12 +169,16 @@ class MoondreamModel(nn.Module):
166
  return logits, hidden
167
 
168
  def compile(self):
 
 
 
 
169
  # TODO: vision_projection is not being compiled
170
- self._vis_enc = torch.compile(
171
- self._vis_enc, fullgraph=False, mode="reduce-overhead"
 
 
172
  )
173
- self._prefill = torch.compile(self._prefill)
174
- self._decode_one_tok = torch.compile(self._decode_one_tok)
175
 
176
  def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
177
  all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
@@ -204,7 +211,6 @@ class MoondreamModel(nn.Module):
204
 
205
  # Run through text model in addition to the vision encoder, to minimize
206
  # re-computation if multiple queries are performed on this image.
207
-
208
  with torch.inference_mode():
209
  img_emb = self._run_vision_encoder(image)
210
  bos_emb = text_encoder(
@@ -240,7 +246,6 @@ class MoondreamModel(nn.Module):
240
  def _prefill_prompt(
241
  self, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float
242
  ):
243
-
244
  with torch.inference_mode():
245
  prompt_emb = text_encoder(prompt_tokens, self.text)
246
  torch._dynamo.mark_dynamic(prompt_emb, 1)
@@ -585,11 +590,11 @@ class MoondreamModel(nn.Module):
585
  self.text,
586
  )
587
  x_emb = encode_coordinate(
588
- torch.tensor([[[source[0]]]], device=self.device, dtype=torch.float16),
589
  self.region,
590
  )
591
  y_emb = encode_coordinate(
592
- torch.tensor([[[source[1]]]], device=self.device, dtype=torch.float16),
593
  self.region,
594
  )
595
 
 
12
  from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
13
  from .text import build_text_model, text_encoder, lm_head, text_decoder
14
  from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
15
+ from .layers import QuantizedLinear
16
  from .utils import remove_outlier_points
17
 
18
 
 
64
 
65
 
66
  class MoondreamModel(nn.Module):
67
+
68
+ def __init__(
69
+ self, config: MoondreamConfig, dtype=torch.bfloat16, setup_caches=True
70
+ ):
71
  super().__init__()
72
  self.config = config
 
 
73
 
74
  self.tokenizer = Tokenizer.from_pretrained(
75
  "vikhyatk/moondream2", revision="2025-01-09"
76
  )
 
77
  self.vision = build_vision_model(config.vision, dtype)
78
+ self.text = build_text_model(config.text, dtype)
 
79
 
80
  # Region Model
81
+ linear_cls = (
82
+ QuantizedLinear if config.region.group_size is not None else nn.Linear
83
+ )
84
  self.region = nn.ModuleDict(
85
  {
86
+ "coord_encoder": linear_cls(
87
  config.region.coord_feat_dim, config.region.dim, dtype=dtype
88
  ),
89
  "coord_decoder": nn.ModuleDict(
90
  {
91
+ "fc1": linear_cls(
92
  config.region.dim, config.region.inner_dim, dtype=dtype
93
  ),
94
+ "fc2": linear_cls(
95
  config.region.inner_dim,
96
  config.region.coord_out_dim,
97
  dtype=dtype,
98
  ),
99
  }
100
  ),
101
+ "size_encoder": linear_cls(
102
  config.region.size_feat_dim, config.region.dim, dtype=dtype
103
  ),
104
  "size_decoder": nn.ModuleDict(
105
  {
106
+ "fc1": linear_cls(
107
  config.region.dim, config.region.inner_dim, dtype=dtype
108
  ),
109
+ "fc2": linear_cls(
110
  config.region.inner_dim,
111
  config.region.size_out_dim,
112
  dtype=dtype,
 
132
  attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
133
  self.register_buffer("attn_mask", attn_mask, persistent=False)
134
 
135
+ # Initialize KV caches.
136
+ if setup_caches:
137
+ self._setup_caches()
 
138
 
139
+ def _setup_caches(self):
140
  c = self.config.text
141
  for b in self.text.blocks:
142
  b.kv_cache = KVCache(
 
169
  return logits, hidden
170
 
171
  def compile(self):
172
+ for module in self.modules():
173
+ if isinstance(module, QuantizedLinear):
174
+ module.unpack()
175
+
176
  # TODO: vision_projection is not being compiled
177
+ self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
178
+ self._prefill = torch.compile(self._prefill, fullgraph=True)
179
+ self._decode_one_tok = torch.compile(
180
+ self._decode_one_tok, fullgraph=True, mode="reduce-overhead"
181
  )
 
 
182
 
183
  def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
184
  all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
 
211
 
212
  # Run through text model in addition to the vision encoder, to minimize
213
  # re-computation if multiple queries are performed on this image.
 
214
  with torch.inference_mode():
215
  img_emb = self._run_vision_encoder(image)
216
  bos_emb = text_encoder(
 
246
  def _prefill_prompt(
247
  self, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float
248
  ):
 
249
  with torch.inference_mode():
250
  prompt_emb = text_encoder(prompt_tokens, self.text)
251
  torch._dynamo.mark_dynamic(prompt_emb, 1)
 
590
  self.text,
591
  )
592
  x_emb = encode_coordinate(
593
+ torch.tensor([[[source[0]]]], device=self.device, dtype=torch.bfloat16),
594
  self.region,
595
  )
596
  y_emb = encode_coordinate(
597
+ torch.tensor([[[source[1]]]], device=self.device, dtype=torch.bfloat16),
598
  self.region,
599
  )
600
 
packing.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def unpack_int4(packed: torch.Tensor, original_length: int) -> torch.Tensor:
5
+ """
6
+ Unpack a tensor of uint8 packed bytes (two 4-bit values per byte) into a 1D tensor of int8 values,
7
+ vectorized over the entire input.
8
+ """
9
+ lower = packed & 0xF
10
+ upper = (packed >> 4) & 0xF
11
+ # Interleave lower and upper nibbles
12
+ nibbles = torch.stack([lower, upper], dim=-1).view(-1)[:original_length]
13
+ nibbles = nibbles.to(torch.int8)
14
+ nibbles[nibbles >= 8] -= 16
15
+ return nibbles
16
+
17
+
18
+ def dequantize_tensor(
19
+ packed: torch.Tensor,
20
+ scales: torch.Tensor,
21
+ orig_shape: torch.Size,
22
+ block_size: int,
23
+ dtype: torch.dtype,
24
+ ):
25
+ """
26
+ Dequantizes a packed int4 tensor (with given per-block scales) back to bfloat16,
27
+ using vectorized operations to avoid Python loops.
28
+ """
29
+ num_bytes_per_block = (block_size + 1) // 2 # number of packed bytes per block
30
+ num_blocks_total = packed.numel() // num_bytes_per_block
31
+ # Reshape to (num_blocks_total, num_bytes_per_block)
32
+ packed_rows = packed.view(num_blocks_total, num_bytes_per_block)
33
+
34
+ # Vectorized unpacking: compute lower and upper nibbles for all rows at once.
35
+ lower = packed_rows & 0xF
36
+ upper = (packed_rows >> 4) & 0xF
37
+ # Create a new dimension for the two nibbles and then flatten.
38
+ nibbles = torch.stack([lower, upper], dim=2).view(num_blocks_total, -1)
39
+ # Slice to get exactly block_size values per block.
40
+ quantized_flat = nibbles[:, :block_size].to(torch.int8)
41
+ quantized_flat[quantized_flat >= 8] -= 16
42
+
43
+ # Reshape to original block structure.
44
+ last_dim = orig_shape[-1]
45
+ num_blocks = last_dim // block_size
46
+ new_shape = orig_shape[:-1] + (num_blocks, block_size)
47
+ quantized = quantized_flat.view(new_shape)
48
+
49
+ # Dequantize using scales.
50
+ dequantized = quantized.to(torch.float32) * scales.unsqueeze(-1)
51
+ dequantized = dequantized.view(orig_shape)
52
+ return dequantized.to(dtype)
region.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import torch.nn as nn
3
  import math
4
 
5
- from .layers import linear, mlp
6
 
7
 
8
  def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
@@ -36,7 +36,7 @@ def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor:
36
  Returns:
37
  Encoded hidden states tensor for input to text model
38
  """
39
- return linear(fourier_features(coord, w.coord_features), w.coord_encoder)
40
 
41
 
42
  def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
@@ -64,7 +64,7 @@ def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
64
  Returns:
65
  Encoded hidden states tensor for input to text model
66
  """
67
- return linear(fourier_features(size, w.size_features), w.size_encoder)
68
 
69
 
70
  def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
 
2
  import torch.nn as nn
3
  import math
4
 
5
+ from .layers import mlp
6
 
7
 
8
  def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
 
36
  Returns:
37
  Encoded hidden states tensor for input to text model
38
  """
39
+ return w.coord_encoder(fourier_features(coord, w.coord_features))
40
 
41
 
42
  def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
 
64
  Returns:
65
  Encoded hidden states tensor for input to text model
66
  """
67
+ return w.size_encoder(fourier_features(size, w.size_features))
68
 
69
 
70
  def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
text.py CHANGED
@@ -2,9 +2,8 @@ import torch
2
  import torch.nn as nn
3
 
4
  from torch.nn import functional as F
5
- from bitblas.cache import OperatorCache
6
 
7
- from .layers import layer_norm, mlp, Linear
8
  from .rope import apply_rotary_emb, precompute_freqs_cis
9
  from .config import TextConfig
10
 
@@ -27,7 +26,6 @@ def attn(
27
  head_dim = d_model // n_heads
28
 
29
  qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
30
-
31
  q_dim = n_heads * head_dim
32
  kv_dim = n_kv_heads * head_dim
33
 
@@ -57,6 +55,71 @@ def attn(
57
  return out
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def text_decoder(
61
  x: torch.Tensor,
62
  w: nn.Module,
@@ -76,7 +139,6 @@ def text_decoder(
76
  n_kv_heads=config.n_kv_heads,
77
  position_ids=position_ids,
78
  )
79
-
80
  l_mlp = mlp(l_in, block.mlp)
81
  x = x + l_attn + l_mlp
82
 
@@ -90,30 +152,15 @@ def lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
90
  return logits
91
 
92
 
93
- def build_text_model(
94
- config: TextConfig,
95
- linear_dtype: torch.dtype = torch.float16,
96
- layernorm_dtype: torch.dtype = torch.float16,
97
- ) -> nn.Module:
98
- # note : layernorm dtype is used for layernorm, lm_head and wte not just layernorm
99
- print(
100
- "Initializing quantized backend. This only has to run once, but may take a few minutes."
101
- )
102
- qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
103
-
104
- group_size = None
105
- if linear_dtype == torch.int8:
106
 
107
- group_size = config.group_size
108
 
109
- def create_linear(in_features, out_features, dtype=linear_dtype):
110
- # factory function for creating Linear layers so we dont have to pass everything again and again
111
- return Linear(
112
- in_features=in_features,
113
- out_features=out_features,
114
- dtype=dtype,
115
- group_size=group_size,
116
- )
117
 
118
  text = nn.ModuleDict(
119
  {
@@ -121,17 +168,23 @@ def build_text_model(
121
  [
122
  nn.ModuleDict(
123
  {
124
- "ln": nn.LayerNorm(config.dim, dtype=layernorm_dtype),
125
  "attn": nn.ModuleDict(
126
  {
127
- "qkv": create_linear(config.dim, qkv_dim),
128
- "proj": create_linear(config.dim, config.dim),
 
 
129
  }
130
  ),
131
  "mlp": nn.ModuleDict(
132
  {
133
- "fc1": create_linear(config.dim, config.ff_dim),
134
- "fc2": create_linear(config.ff_dim, config.dim),
 
 
 
 
135
  }
136
  ),
137
  }
@@ -139,13 +192,11 @@ def build_text_model(
139
  for _ in range(config.n_layers)
140
  ]
141
  ),
142
- "post_ln": nn.LayerNorm(config.dim, dtype=layernorm_dtype),
143
- "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=layernorm_dtype),
144
  }
145
  )
146
- text.wte = nn.Parameter(
147
- torch.empty(config.vocab_size, config.dim, dtype=layernorm_dtype)
148
- )
149
  text.register_buffer(
150
  "freqs_cis",
151
  precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
 
2
  import torch.nn as nn
3
 
4
  from torch.nn import functional as F
 
5
 
6
+ from .layers import layer_norm, mlp, QuantizedLinear
7
  from .rope import apply_rotary_emb, precompute_freqs_cis
8
  from .config import TextConfig
9
 
 
26
  head_dim = d_model // n_heads
27
 
28
  qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
 
29
  q_dim = n_heads * head_dim
30
  kv_dim = n_kv_heads * head_dim
31
 
 
55
  return out
56
 
57
 
58
+ def _attn(
59
+ x: torch.Tensor,
60
+ w: torch.Tensor,
61
+ freqs_cis: torch.Tensor,
62
+ attn_mask: torch.Tensor,
63
+ n_heads: int,
64
+ n_kv_heads: int,
65
+ ):
66
+ bsz, q_len, d_model = x.shape
67
+ head_dim = d_model // n_heads
68
+ pos = 0
69
+
70
+ qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
71
+ q_dim = n_heads * head_dim
72
+ kv_dim = n_kv_heads * head_dim
73
+
74
+ q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
75
+ k = (
76
+ qkv_out[..., q_dim : q_dim + kv_dim]
77
+ .view(bsz, q_len, n_kv_heads, head_dim)
78
+ .transpose(1, 2)
79
+ )
80
+ v = (
81
+ qkv_out[..., q_dim + kv_dim :]
82
+ .view(bsz, q_len, n_kv_heads, head_dim)
83
+ .transpose(1, 2)
84
+ )
85
+
86
+ position_ids = torch.arange(pos, pos + q_len, dtype=torch.long)
87
+ q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
88
+ k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
89
+ out = F.scaled_dot_product_attention(
90
+ q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
91
+ )
92
+ out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
93
+ out = w.proj(out)
94
+ return out
95
+
96
+
97
+ def _produce_hidden(inputs_embeds: torch.Tensor, w: nn.Module, config: TextConfig):
98
+ hidden_BTC = inputs_embeds
99
+
100
+ bsz, q_len, d_model = inputs_embeds.shape
101
+ attn_mask = torch.zeros(q_len, q_len)
102
+ attn_mask[:730, :730] = 1
103
+ for i in range(730, q_len):
104
+ attn_mask[i, : i + 1] = 1
105
+ attn_mask = attn_mask.to(dtype=torch.bool)
106
+
107
+ for i, block in enumerate(w.blocks):
108
+ l_in = layer_norm(hidden_BTC, block.ln)
109
+ l_attn = _attn(
110
+ x=l_in,
111
+ w=block.attn,
112
+ freqs_cis=w.freqs_cis,
113
+ attn_mask=attn_mask,
114
+ n_heads=config.n_heads,
115
+ n_kv_heads=config.n_kv_heads,
116
+ )
117
+ l_mlp = mlp(l_in, block.mlp)
118
+ hidden_BTC = hidden_BTC + l_attn + l_mlp
119
+
120
+ return hidden_BTC
121
+
122
+
123
  def text_decoder(
124
  x: torch.Tensor,
125
  w: nn.Module,
 
139
  n_kv_heads=config.n_kv_heads,
140
  position_ids=position_ids,
141
  )
 
142
  l_mlp = mlp(l_in, block.mlp)
143
  x = x + l_attn + l_mlp
144
 
 
152
  return logits
153
 
154
 
155
+ def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
156
+ hidden_BTC = layer_norm(hidden_BTC, w.post_ln)
157
+ logits = w.lm_head(hidden_BTC)
158
+ return logits
 
 
 
 
 
 
 
 
 
159
 
 
160
 
161
+ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
162
+ qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
163
+ linear_cls = QuantizedLinear if config.group_size is not None else nn.Linear
 
 
 
 
 
164
 
165
  text = nn.ModuleDict(
166
  {
 
168
  [
169
  nn.ModuleDict(
170
  {
171
+ "ln": nn.LayerNorm(config.dim, dtype=dtype),
172
  "attn": nn.ModuleDict(
173
  {
174
+ "qkv": linear_cls(config.dim, qkv_dim, dtype=dtype),
175
+ "proj": linear_cls(
176
+ config.dim, config.dim, dtype=dtype
177
+ ),
178
  }
179
  ),
180
  "mlp": nn.ModuleDict(
181
  {
182
+ "fc1": linear_cls(
183
+ config.dim, config.ff_dim, dtype=dtype
184
+ ),
185
+ "fc2": linear_cls(
186
+ config.ff_dim, config.dim, dtype=dtype
187
+ ),
188
  }
189
  ),
190
  }
 
192
  for _ in range(config.n_layers)
193
  ]
194
  ),
195
+ "post_ln": nn.LayerNorm(config.dim, dtype=dtype),
196
+ "lm_head": linear_cls(config.dim, config.vocab_size, dtype=dtype),
197
  }
198
  )
199
+ text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
 
 
200
  text.register_buffer(
201
  "freqs_cis",
202
  precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
vision.py CHANGED
@@ -6,7 +6,7 @@ import numpy as np
6
  from typing import Union, Tuple
7
  from PIL import Image
8
 
9
- from .layers import attn, layer_norm, linear, mlp
10
  from .image_crops import overlap_crop_image
11
  from .config import VisionConfig
12
 
@@ -33,7 +33,7 @@ def prepare_crops(
33
  all_crops = np.transpose(all_crops, (0, 3, 1, 2))
34
  all_crops = (
35
  torch.from_numpy(all_crops)
36
- .to(device=device, dtype=torch.float16)
37
  .div_(255.0)
38
  .sub_(0.5)
39
  .div_(0.5)
@@ -64,7 +64,7 @@ def create_patches(x, patch_size):
64
  def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
65
  x = create_patches(input_BCHW, config.enc_patch_size)
66
 
67
- x = linear(x, w.patch_emb)
68
  x = x + w.pos_emb
69
  for block in w.blocks:
70
  x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)
 
6
  from typing import Union, Tuple
7
  from PIL import Image
8
 
9
+ from .layers import attn, layer_norm, mlp
10
  from .image_crops import overlap_crop_image
11
  from .config import VisionConfig
12
 
 
33
  all_crops = np.transpose(all_crops, (0, 3, 1, 2))
34
  all_crops = (
35
  torch.from_numpy(all_crops)
36
+ .to(device=device, dtype=torch.bfloat16)
37
  .div_(255.0)
38
  .sub_(0.5)
39
  .div_(0.5)
 
64
  def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
65
  x = create_patches(input_BCHW, config.enc_patch_size)
66
 
67
+ x = w.patch_emb(x)
68
  x = x + w.pos_emb
69
  for block in w.blocks:
70
  x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)