vikhyatk commited on
Commit
8ef2cad
·
verified ·
1 Parent(s): fb2293a

Upload HfMoondream

Browse files
Files changed (10) hide show
  1. config.json +1 -1
  2. config.py +3 -0
  3. generation_config.json +1 -1
  4. hf_moondream.py +26 -7
  5. layers.py +2 -2
  6. model.safetensors +1 -1
  7. moondream.py +204 -136
  8. region.py +19 -12
  9. text.py +114 -76
  10. vision.py +21 -7
config.json CHANGED
@@ -9,5 +9,5 @@
9
  "config": {},
10
  "model_type": "moondream1",
11
  "torch_dtype": "float16",
12
- "transformers_version": "4.44.0"
13
  }
 
9
  "config": {},
10
  "model_type": "moondream1",
11
  "torch_dtype": "float16",
12
+ "transformers_version": "4.48.0"
13
  }
config.py CHANGED
@@ -5,10 +5,12 @@ from typing import Dict, List, Optional
5
  @dataclass(frozen=True)
6
  class TextConfig:
7
  dim: int = 2048
 
8
  n_layers: int = 24
9
  vocab_size: int = 51200
10
  max_context: int = 2048
11
  n_heads: int = 32
 
12
  prefix_attn: int = 730
13
 
14
 
@@ -46,6 +48,7 @@ class TokenizerConfig:
46
  "caption": {
47
  "short": [198, 198, 16438, 8305, 25],
48
  "normal": [198, 198, 24334, 1159, 25],
 
49
  },
50
  "query": {"prefix": [198, 198, 24361, 25], "suffix": [198, 198, 33706, 25]},
51
  "detect": {"prefix": [198, 198, 47504, 25], "suffix": [628]},
 
5
  @dataclass(frozen=True)
6
  class TextConfig:
7
  dim: int = 2048
8
+ ff_dim: int = 8192
9
  n_layers: int = 24
10
  vocab_size: int = 51200
11
  max_context: int = 2048
12
  n_heads: int = 32
13
+ n_kv_heads: int = 32
14
  prefix_attn: int = 730
15
 
16
 
 
48
  "caption": {
49
  "short": [198, 198, 16438, 8305, 25],
50
  "normal": [198, 198, 24334, 1159, 25],
51
+ "long": [198, 198, 14617, 8305, 25],
52
  },
53
  "query": {"prefix": [198, 198, 24361, 25], "suffix": [198, 198, 33706, 25]},
54
  "detect": {"prefix": [198, 198, 47504, 25], "suffix": [628]},
generation_config.json CHANGED
@@ -1,4 +1,4 @@
1
  {
2
  "_from_model_config": true,
3
- "transformers_version": "4.44.0"
4
  }
 
1
  {
2
  "_from_model_config": true,
3
+ "transformers_version": "4.48.0"
4
  }
hf_moondream.py CHANGED
@@ -14,7 +14,7 @@ from .utils import *
14
  def extract_question(text):
15
  prefix = "<image>\n\nQuestion: "
16
  suffix = "\n\nAnswer:"
17
-
18
  if text.startswith(prefix) and text.endswith(suffix):
19
  return text[len(prefix) : -len(suffix)]
20
  else:
@@ -36,30 +36,44 @@ class HfMoondream(PreTrainedModel):
36
 
37
  def __init__(self, config):
38
  super().__init__(config)
39
- self.model = MoondreamModel(MoondreamConfig.from_dict(config.config))
 
 
 
 
 
 
 
 
40
 
41
  @property
42
  def encode_image(self):
 
43
  return self.model.encode_image
44
 
45
  @property
46
  def query(self):
 
47
  return self.model.query
48
 
49
  @property
50
  def caption(self):
 
51
  return self.model.caption
52
 
53
  @property
54
  def detect(self):
 
55
  return self.model.detect
56
 
57
  @property
58
  def point(self):
 
59
  return self.model.point
60
 
61
  @property
62
  def detect_gaze(self):
 
63
  return self.model.detect_gaze
64
 
65
  def answer_question(
@@ -98,22 +112,27 @@ class HfMoondream(PreTrainedModel):
98
  """
99
  prompt_extracted = extract_question(prompt)
100
  if prompt_extracted is not None:
101
- answer = self.model.query(image=image_embeds, question=prompt_extracted, stream=False)[
102
- "answer"
103
- ]
104
  else:
105
  image_embeds = self.encode_image(image_embeds)
106
  prompt_tokens = torch.tensor(
107
  [self.model.tokenizer.encode(prompt).ids],
108
  device=self.device,
109
  )
 
110
  def generator():
111
  for token in self.model._generate_text(
112
- prompt_tokens, image_embeds.kv_cache, image_embeds.pos, max_new_tokens
 
 
 
113
  ):
114
  yield token
 
115
  answer = "".join(list(generator()))
116
-
117
  return [answer]
118
 
119
  def get_input_embeddings(self):
 
14
  def extract_question(text):
15
  prefix = "<image>\n\nQuestion: "
16
  suffix = "\n\nAnswer:"
17
+
18
  if text.startswith(prefix) and text.endswith(suffix):
19
  return text[len(prefix) : -len(suffix)]
20
  else:
 
36
 
37
  def __init__(self, config):
38
  super().__init__(config)
39
+ self.model = MoondreamModel(
40
+ MoondreamConfig.from_dict(config.config), setup_caches=False
41
+ )
42
+ self._is_kv_cache_setup = False
43
+
44
+ def _setup_caches(self):
45
+ if not self._is_kv_cache_setup:
46
+ self.model._setup_caches()
47
+ self._is_kv_cache_setup = True
48
 
49
  @property
50
  def encode_image(self):
51
+ self._setup_caches()
52
  return self.model.encode_image
53
 
54
  @property
55
  def query(self):
56
+ self._setup_caches()
57
  return self.model.query
58
 
59
  @property
60
  def caption(self):
61
+ self._setup_caches()
62
  return self.model.caption
63
 
64
  @property
65
  def detect(self):
66
+ self._setup_caches()
67
  return self.model.detect
68
 
69
  @property
70
  def point(self):
71
+ self._setup_caches()
72
  return self.model.point
73
 
74
  @property
75
  def detect_gaze(self):
76
+ self._setup_caches()
77
  return self.model.detect_gaze
78
 
79
  def answer_question(
 
112
  """
113
  prompt_extracted = extract_question(prompt)
114
  if prompt_extracted is not None:
115
+ answer = self.model.query(
116
+ image=image_embeds, question=prompt_extracted, stream=False
117
+ )["answer"]
118
  else:
119
  image_embeds = self.encode_image(image_embeds)
120
  prompt_tokens = torch.tensor(
121
  [self.model.tokenizer.encode(prompt).ids],
122
  device=self.device,
123
  )
124
+
125
  def generator():
126
  for token in self.model._generate_text(
127
+ prompt_tokens,
128
+ image_embeds.kv_cache,
129
+ image_embeds.pos,
130
+ max_new_tokens,
131
  ):
132
  yield token
133
+
134
  answer = "".join(list(generator()))
135
+
136
  return [answer]
137
 
138
  def get_input_embeddings(self):
layers.py CHANGED
@@ -37,9 +37,9 @@ class MLPWeights:
37
 
38
 
39
  def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
40
- x = linear(x, w.fc1)
41
  x = gelu_approx(x)
42
- x = linear(x, w.fc2)
43
  return x
44
 
45
 
 
37
 
38
 
39
  def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
40
+ x = w.fc1(x)
41
  x = gelu_approx(x)
42
+ x = w.fc2(x)
43
  return x
44
 
45
 
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:23e2e6498a058d12832e119dc97a1d2f14936b4ccf77b8492bc0fefba49ea8bb
3
  size 3854538376
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fadcffea8c17fe8a20ea68af3a013cf3184a63787ee4453cc9eb75206c7c1f9b
3
  size 3854538376
moondream.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import torch.nn as nn
3
  import random
4
 
5
- from typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional
6
  from PIL import Image
7
  from dataclasses import dataclass
8
  from tokenizers import Tokenizer
@@ -10,7 +10,7 @@ from tokenizers import Tokenizer
10
  from .config import MoondreamConfig
11
  from .image_crops import reconstruct_from_crops
12
  from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
13
- from .text import build_text_model, prefill, text_encoder, lm_head, decode_one_token
14
  from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
15
  from .utils import remove_outlier_points
16
 
@@ -21,53 +21,41 @@ SamplingSettings = TypedDict(
21
  total=False,
22
  )
23
 
24
- DEFAULT_MAX_TOKENS = 512
25
 
26
 
27
  @dataclass(frozen=True)
28
  class EncodedImage:
29
  pos: int
30
- kv_cache: torch.Tensor
31
-
32
-
33
- def _min_p_sampler(
34
- logits: torch.Tensor,
35
- min_p: float = 0.1,
36
- filter_value: float = 0,
37
- min_tokens_to_keep: int = 1,
38
- temp=0.5,
39
- ) -> torch.Tensor:
40
- """
41
- Min-p sampler adapted from https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
42
- https://arxiv.org/pdf/2407.01082
43
- """
44
- logits = logits / temp
45
- probs = torch.softmax(logits, dim=-1)
46
- top_probs, _ = probs.max(dim=-1, keepdim=True)
47
- scaled_min_p = min_p * top_probs
48
- tokens_to_remove = probs < scaled_min_p
49
- sorted_indices = torch.argsort(logits, descending=True, dim=-1)
50
- sorted_indices_to_remove = torch.gather(
51
- tokens_to_remove, dim=-1, index=sorted_indices
52
- )
53
- if min_tokens_to_keep > 1:
54
- sorted_indices_to_remove[..., :min_tokens_to_keep] = False
55
-
56
- indices_to_remove = sorted_indices_to_remove.scatter(
57
- 1, sorted_indices, sorted_indices_to_remove
58
- )
59
- logits = logits.masked_fill(indices_to_remove, filter_value)
60
- token = torch.multinomial(logits, num_samples=1)
61
- return token.squeeze(0)
62
 
63
 
64
  class MoondreamModel(nn.Module):
65
- def __init__(self, config: MoondreamConfig, dtype=torch.float16):
66
  super().__init__()
67
  self.config = config
68
 
69
  self.tokenizer = Tokenizer.from_pretrained(
70
- "vikhyatk/moondream2", revision="2024-08-26"
71
  )
72
  self.vision = build_vision_model(config.vision, dtype)
73
  self.text = build_text_model(config.text, dtype)
@@ -114,35 +102,65 @@ class MoondreamModel(nn.Module):
114
  torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
115
  )
116
 
117
- self.ops = {
118
- "vision_encoder": vision_encoder,
119
- "vision_projection": vision_projection,
120
- "prefill": prefill,
121
- "decode_one_token": decode_one_token,
122
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  @property
125
  def device(self):
126
  return self.vision.pos_emb.device
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  def compile(self):
129
- self.ops["vision_encoder"] = torch.compile(
130
- self.ops["vision_encoder"], fullgraph=True
131
- )
132
- # Need to figure out how to mark the 'reconstructed' input shape as dynamic
133
- # self.ops["vision_projection"] = torch.compile(
134
- # self.ops["vision_projection"], fullgraph=True
135
- # )
136
- self.ops["prefill"] = torch.compile(self.ops["prefill"], fullgraph=True)
137
- self.ops["decode_one_token"] = torch.compile(
138
- self.ops["decode_one_token"], fullgraph=True
139
  )
140
 
141
  def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
142
  all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
143
  torch._dynamo.mark_dynamic(all_crops, 0)
144
 
145
- outputs = self.ops["vision_encoder"](all_crops, self.vision, self.config.vision)
146
 
147
  global_features = outputs[0]
148
  local_features = outputs[1:].view(
@@ -159,9 +177,7 @@ class MoondreamModel(nn.Module):
159
  overlap_margin=self.config.vision.overlap_margin,
160
  )
161
 
162
- return self.ops["vision_projection"](
163
- global_features, reconstructed, self.vision, self.config.vision
164
- )
165
 
166
  def encode_image(self, image: Union[Image.Image, EncodedImage]) -> EncodedImage:
167
  if isinstance(image, EncodedImage):
@@ -171,34 +187,35 @@ class MoondreamModel(nn.Module):
171
 
172
  # Run through text model in addition to the vision encoder, to minimize
173
  # re-computation if multiple queries are performed on this image.
174
- kv_cache = torch.zeros(
175
- self.config.text.n_layers,
176
- 2, # k, v
177
- 1, # batch size
178
- self.config.text.n_heads,
179
- self.config.text.max_context, # static cache
180
- self.config.text.dim // self.config.text.n_heads, # head dim
181
- device=self.device,
182
- dtype=torch.float16,
183
- )
184
- with torch.no_grad():
185
  img_emb = self._run_vision_encoder(image)
186
  bos_emb = text_encoder(
187
  torch.tensor([[self.config.tokenizer.bos_id]], device=self.device),
188
  self.text,
189
  )
190
  inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
191
- self.ops["prefill"](inputs_embeds, kv_cache, 0, self.text, self.config.text)
192
- return EncodedImage(pos=inputs_embeds.size(1), kv_cache=kv_cache)
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- def _prefill_prompt(
195
- self, kv_cache: torch.Tensor, prompt_tokens: torch.Tensor, pos: int
196
- ):
197
- with torch.no_grad():
198
  prompt_emb = text_encoder(prompt_tokens, self.text)
199
- hidden = self.ops["prefill"](
200
- prompt_emb, kv_cache, pos, self.text, self.config.text
201
- )
 
202
  logits = lm_head(hidden, self.text)
203
  next_token = torch.argmax(logits, dim=-1)
204
  pos = pos + prompt_emb.size(1)
@@ -207,33 +224,67 @@ class MoondreamModel(nn.Module):
207
  def _generate_text(
208
  self,
209
  prompt_tokens: torch.Tensor,
210
- kv_cache: torch.Tensor,
211
  pos: int,
212
  max_tokens: int,
213
  ):
214
- kv_cache = kv_cache.clone()
215
- _, _, next_token, pos = self._prefill_prompt(kv_cache, prompt_tokens, pos)
216
 
217
  def generator(next_token, pos):
 
 
 
218
  generated_tokens = 0
219
 
 
 
 
 
220
  while (
221
  next_token_id := next_token.item()
222
  ) != self.config.tokenizer.eos_id and generated_tokens < max_tokens:
223
- yield self.tokenizer.decode([next_token_id])
224
-
225
- with torch.no_grad():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  next_emb = text_encoder(next_token, self.text)
227
- logits, _, kv_cache_update = self.ops["decode_one_token"](
228
- next_emb, kv_cache, pos, self.text, self.config.text
229
- )
230
- kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = (
231
- kv_cache_update
232
- )
233
  pos += 1
234
  next_token = torch.argmax(logits, dim=-1)
235
  generated_tokens += 1
236
 
 
 
 
 
 
 
 
237
  return generator(next_token, pos)
238
 
239
  def query(
@@ -247,10 +298,12 @@ class MoondreamModel(nn.Module):
247
  raise NotImplementedError("Model does not support querying.")
248
 
249
  image = self.encode_image(image)
 
 
250
  prompt_tokens = torch.tensor(
251
  [
252
  self.config.tokenizer.templates["query"]["prefix"]
253
- + self.tokenizer.encode(question).ids
254
  + self.config.tokenizer.templates["query"]["suffix"]
255
  ],
256
  device=self.device,
@@ -261,9 +314,7 @@ class MoondreamModel(nn.Module):
261
  max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS)
262
 
263
  def generator():
264
- for token in self._generate_text(
265
- prompt_tokens, image.kv_cache, image.pos, max_tokens
266
- ):
267
  yield token
268
 
269
  if stream:
@@ -271,10 +322,15 @@ class MoondreamModel(nn.Module):
271
  else:
272
  return {"answer": "".join(list(generator()))}
273
 
 
 
 
 
 
274
  def caption(
275
  self,
276
  image: Union[Image.Image, EncodedImage],
277
- length: Literal["normal", "short"] = "normal",
278
  stream: bool = False,
279
  settings: Optional[SamplingSettings] = None,
280
  ):
@@ -284,6 +340,8 @@ class MoondreamModel(nn.Module):
284
  raise ValueError(f"Model does not support caption length '{length}'.")
285
 
286
  image = self.encode_image(image)
 
 
287
  prompt_tokens = torch.tensor(
288
  [self.config.tokenizer.templates["caption"][length]], device=self.device
289
  )
@@ -293,9 +351,7 @@ class MoondreamModel(nn.Module):
293
  max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS)
294
 
295
  def generator():
296
- for token in self._generate_text(
297
- prompt_tokens, image.kv_cache, image.pos, max_tokens
298
- ):
299
  yield token
300
 
301
  if stream:
@@ -306,15 +362,17 @@ class MoondreamModel(nn.Module):
306
  def _generate_points(
307
  self,
308
  hidden: torch.Tensor,
309
- kv_cache: torch.Tensor,
310
  next_token: torch.Tensor,
311
  pos: int,
312
  include_size: bool = True,
313
  max_points: int = 50,
314
  ):
315
  out = []
 
 
 
316
 
317
- with torch.no_grad():
318
  while (
319
  next_token.item() != self.config.tokenizer.eos_id
320
  and len(out) < max_points
@@ -326,12 +384,8 @@ class MoondreamModel(nn.Module):
326
  )
327
 
328
  # Decode y-coordinate
329
- _, hidden, kv_cache_update = self.ops["decode_one_token"](
330
- next_emb, kv_cache, pos, self.text, self.config.text
331
- )
332
- kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = (
333
- kv_cache_update
334
- )
335
  pos += 1
336
  y_logits = decode_coordinate(hidden, self.region)
337
  y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
@@ -341,16 +395,20 @@ class MoondreamModel(nn.Module):
341
 
342
  # Decode size
343
  if include_size:
344
- logits, hidden, kv_cache_update = self.ops["decode_one_token"](
345
- next_emb, kv_cache, pos, self.text, self.config.text
346
- )
347
- kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = (
348
- kv_cache_update
349
- )
350
  pos += 1
351
  size_logits = decode_size(hidden, self.region)
352
- w = torch.argmax(size_logits[0], dim=-1) / size_logits.size(-1)
353
- h = torch.argmax(size_logits[1], dim=-1) / size_logits.size(-1)
 
 
 
 
 
 
 
 
354
  next_emb = encode_size(
355
  torch.tensor(
356
  [w, h], device=self.device, dtype=size_logits.dtype
@@ -371,12 +429,8 @@ class MoondreamModel(nn.Module):
371
  out.append({"x": x_center.item(), "y": y_center.item()})
372
 
373
  # Decode next token (x-coordinate, or eos)
374
- logits, hidden, kv_cache_update = self.ops["decode_one_token"](
375
- next_emb, kv_cache, pos, self.text, self.config.text
376
- )
377
- kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = (
378
- kv_cache_update
379
- )
380
  pos += 1
381
  next_token = torch.argmax(logits, dim=-1)
382
 
@@ -392,23 +446,22 @@ class MoondreamModel(nn.Module):
392
  raise NotImplementedError("Model does not support object detection.")
393
 
394
  image = self.encode_image(image)
 
 
395
  prompt_tokens = torch.tensor(
396
  [
397
  self.config.tokenizer.templates["detect"]["prefix"]
398
- + self.tokenizer.encode(object).ids
399
  + self.config.tokenizer.templates["detect"]["suffix"]
400
  ],
401
  device=self.device,
402
  )
403
 
404
- kv_cache = image.kv_cache.clone()
405
- _, hidden, next_token, pos = self._prefill_prompt(
406
- kv_cache, prompt_tokens, image.pos
407
- )
408
  hidden = hidden[:, -1:, :]
409
 
410
  objects = self._generate_points(
411
- hidden, kv_cache, next_token, pos, include_size=True, max_points=50
412
  )
413
 
414
  return {"objects": objects}
@@ -423,23 +476,22 @@ class MoondreamModel(nn.Module):
423
  raise NotImplementedError("Model does not support pointing.")
424
 
425
  image = self.encode_image(image)
 
 
426
  prompt_tokens = torch.tensor(
427
  [
428
  self.config.tokenizer.templates["point"]["prefix"]
429
- + self.tokenizer.encode(object).ids
430
  + self.config.tokenizer.templates["point"]["suffix"]
431
  ],
432
  device=self.device,
433
  )
434
 
435
- kv_cache = image.kv_cache.clone()
436
- _, hidden, next_token, pos = self._prefill_prompt(
437
- kv_cache, prompt_tokens, image.pos
438
- )
439
  hidden = hidden[:, -1:, :]
440
 
441
  objects = self._generate_points(
442
- hidden, kv_cache, next_token, pos, include_size=False, max_points=50
443
  )
444
 
445
  return {"points": objects}
@@ -450,7 +502,7 @@ class MoondreamModel(nn.Module):
450
  source: Tuple[float, float],
451
  force_detect: bool = False,
452
  ):
453
- with torch.no_grad():
454
  before_emb = text_encoder(
455
  torch.tensor(
456
  [self.tokenizer.encode("\n\nPoint:").ids], device=self.device
@@ -474,10 +526,13 @@ class MoondreamModel(nn.Module):
474
 
475
  prompt_emb = torch.cat([before_emb, x_emb, y_emb, after_emb], dim=1)
476
 
477
- kv_cache = image.kv_cache.clone()
478
- hidden = self.ops["prefill"](
479
- prompt_emb, kv_cache, image.pos, self.text, self.config.text
 
 
480
  )
 
481
  logits = lm_head(hidden, self.text)
482
  next_token = torch.argmax(logits, dim=-1)
483
  pos = image.pos + prompt_emb.size(1)
@@ -490,7 +545,7 @@ class MoondreamModel(nn.Module):
490
  return None
491
 
492
  gaze = self._generate_points(
493
- hidden, kv_cache, next_token, pos, include_size=False, max_points=1
494
  )
495
  return gaze[0]
496
 
@@ -584,3 +639,16 @@ class MoondreamModel(nn.Module):
584
  )
585
 
586
  return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch.nn as nn
3
  import random
4
 
5
+ from typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional, List
6
  from PIL import Image
7
  from dataclasses import dataclass
8
  from tokenizers import Tokenizer
 
10
  from .config import MoondreamConfig
11
  from .image_crops import reconstruct_from_crops
12
  from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
13
+ from .text import build_text_model, text_encoder, lm_head, text_decoder
14
  from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
15
  from .utils import remove_outlier_points
16
 
 
21
  total=False,
22
  )
23
 
24
+ DEFAULT_MAX_TOKENS = 768
25
 
26
 
27
  @dataclass(frozen=True)
28
  class EncodedImage:
29
  pos: int
30
+ caches: List[Tuple[torch.Tensor, torch.Tensor]]
31
+
32
+
33
+ class KVCache(nn.Module):
34
+
35
+ def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
36
+ super().__init__()
37
+ cache_shape = (1, n_kv_heads, max_context, dim // n_heads)
38
+ self.register_buffer(
39
+ "k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
40
+ )
41
+ self.register_buffer(
42
+ "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
43
+ )
44
+
45
+ def update(self, pos_ids, k, v):
46
+ kout, vout = self.k_cache, self.v_cache
47
+ kout[:, :, pos_ids, :] = k
48
+ vout[:, :, pos_ids, :] = v
49
+ return kout, vout
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  class MoondreamModel(nn.Module):
53
+ def __init__(self, config: MoondreamConfig, dtype=torch.float16, setup_caches=True):
54
  super().__init__()
55
  self.config = config
56
 
57
  self.tokenizer = Tokenizer.from_pretrained(
58
+ "vikhyatk/moondream2", revision="2025-01-09"
59
  )
60
  self.vision = build_vision_model(config.vision, dtype)
61
  self.text = build_text_model(config.text, dtype)
 
102
  torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
103
  )
104
 
105
+ attn_mask = torch.tril(
106
+ torch.ones(
107
+ 1, 1, config.text.max_context, config.text.max_context, dtype=torch.bool
108
+ )
109
+ )
110
+ patch_w = config.vision.crop_size // config.vision.enc_patch_size
111
+ prefix_attn_len = 1 + patch_w**2
112
+ attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
113
+ self.register_buffer("attn_mask", attn_mask, persistent=False)
114
+
115
+ # Initialize KV caches.
116
+ if setup_caches:
117
+ self._setup_caches()
118
+
119
+ def _setup_caches(self):
120
+ c = self.config.text
121
+ for b in self.text.blocks:
122
+ b.kv_cache = KVCache(
123
+ c.n_heads,
124
+ c.n_kv_heads,
125
+ c.max_context,
126
+ c.dim,
127
+ device=self.device,
128
+ dtype=self.vision.pos_emb.dtype,
129
+ )
130
 
131
  @property
132
  def device(self):
133
  return self.vision.pos_emb.device
134
 
135
+ def _vis_enc(self, x: torch.Tensor):
136
+ return vision_encoder(x, self.vision, self.config.vision)
137
+
138
+ def _vis_proj(self, g: torch.Tensor, r: torch.Tensor):
139
+ return vision_projection(g, r, self.vision, self.config.vision)
140
+
141
+ def _prefill(self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor):
142
+ return text_decoder(x, self.text, attn_mask, pos_ids, self.config.text)
143
+
144
+ def _decode_one_tok(
145
+ self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor
146
+ ):
147
+ hidden = text_decoder(x[None], self.text, attn_mask, pos_ids, self.config.text)
148
+ logits = lm_head(hidden, self.text)
149
+ return logits, hidden
150
+
151
  def compile(self):
152
+ # TODO: vision_projection is not being compiled
153
+ self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
154
+ self._prefill = torch.compile(self._prefill, fullgraph=True)
155
+ self._decode_one_tok = torch.compile(
156
+ self._decode_one_tok, fullgraph=True, mode="reduce-overhead"
 
 
 
 
 
157
  )
158
 
159
  def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
160
  all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
161
  torch._dynamo.mark_dynamic(all_crops, 0)
162
 
163
+ outputs = self._vis_enc(all_crops)
164
 
165
  global_features = outputs[0]
166
  local_features = outputs[1:].view(
 
177
  overlap_margin=self.config.vision.overlap_margin,
178
  )
179
 
180
+ return self._vis_proj(global_features, reconstructed)
 
 
181
 
182
  def encode_image(self, image: Union[Image.Image, EncodedImage]) -> EncodedImage:
183
  if isinstance(image, EncodedImage):
 
187
 
188
  # Run through text model in addition to the vision encoder, to minimize
189
  # re-computation if multiple queries are performed on this image.
190
+ with torch.inference_mode():
 
 
 
 
 
 
 
 
 
 
191
  img_emb = self._run_vision_encoder(image)
192
  bos_emb = text_encoder(
193
  torch.tensor([[self.config.tokenizer.bos_id]], device=self.device),
194
  self.text,
195
  )
196
  inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
197
+ mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
198
+ pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
199
+ self._prefill(inputs_embeds, mask, pos_ids)
200
+
201
+ return EncodedImage(
202
+ pos=inputs_embeds.size(1),
203
+ caches=[
204
+ (
205
+ b.kv_cache.k_cache[:, :, : inputs_embeds.size(1), :].clone(),
206
+ b.kv_cache.v_cache[:, :, : inputs_embeds.size(1), :].clone(),
207
+ )
208
+ for b in self.text.blocks
209
+ ],
210
+ )
211
 
212
+ def _prefill_prompt(self, prompt_tokens: torch.Tensor, pos: int):
213
+ with torch.inference_mode():
 
 
214
  prompt_emb = text_encoder(prompt_tokens, self.text)
215
+ torch._dynamo.mark_dynamic(prompt_emb, 1)
216
+ mask = self.attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
217
+ pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.long)
218
+ hidden = self._prefill(prompt_emb, mask, pos_ids)
219
  logits = lm_head(hidden, self.text)
220
  next_token = torch.argmax(logits, dim=-1)
221
  pos = pos + prompt_emb.size(1)
 
224
  def _generate_text(
225
  self,
226
  prompt_tokens: torch.Tensor,
 
227
  pos: int,
228
  max_tokens: int,
229
  ):
230
+ _, _, next_token, pos = self._prefill_prompt(prompt_tokens, pos)
 
231
 
232
  def generator(next_token, pos):
233
+ mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
234
+ mask[:, :, :pos] = 1
235
+ pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
236
  generated_tokens = 0
237
 
238
+ # For properly handling token streaming with Unicode
239
+ token_cache = []
240
+ print_len = 0
241
+
242
  while (
243
  next_token_id := next_token.item()
244
  ) != self.config.tokenizer.eos_id and generated_tokens < max_tokens:
245
+ # Add token to our cache
246
+ token_cache.append(next_token_id)
247
+
248
+ # Decode all tokens collected so far
249
+ text = self.tokenizer.decode(token_cache)
250
+
251
+ # After a newline, we flush the cache completely
252
+ if text.endswith("\n"):
253
+ printable_text = text[print_len:]
254
+ token_cache = []
255
+ print_len = 0
256
+ if printable_text:
257
+ yield printable_text
258
+ # If the last token is a CJK character, we can safely print it
259
+ elif len(text) > 0 and _is_cjk_char(ord(text[-1])):
260
+ printable_text = text[print_len:]
261
+ print_len += len(printable_text)
262
+ if printable_text:
263
+ yield printable_text
264
+ # Otherwise, only print up to the last space to avoid cutting words
265
+ else:
266
+ last_space_idx = text.rfind(" ", print_len)
267
+ if last_space_idx >= print_len:
268
+ printable_text = text[print_len : last_space_idx + 1]
269
+ print_len += len(printable_text)
270
+ if printable_text:
271
+ yield printable_text
272
+
273
+ with torch.inference_mode():
274
  next_emb = text_encoder(next_token, self.text)
275
+ mask[:, :, pos], pos_ids[0] = 1, pos
276
+ logits, _ = self._decode_one_tok(next_emb, mask, pos_ids)
 
 
 
 
277
  pos += 1
278
  next_token = torch.argmax(logits, dim=-1)
279
  generated_tokens += 1
280
 
281
+ # Flush any remaining text in the cache
282
+ if token_cache:
283
+ text = self.tokenizer.decode(token_cache)
284
+ printable_text = text[print_len:]
285
+ if printable_text:
286
+ yield printable_text
287
+
288
  return generator(next_token, pos)
289
 
290
  def query(
 
298
  raise NotImplementedError("Model does not support querying.")
299
 
300
  image = self.encode_image(image)
301
+ self.load_encoded_image(image)
302
+
303
  prompt_tokens = torch.tensor(
304
  [
305
  self.config.tokenizer.templates["query"]["prefix"]
306
+ + self.tokenizer.encode(" " + question).ids
307
  + self.config.tokenizer.templates["query"]["suffix"]
308
  ],
309
  device=self.device,
 
314
  max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS)
315
 
316
  def generator():
317
+ for token in self._generate_text(prompt_tokens, image.pos, max_tokens):
 
 
318
  yield token
319
 
320
  if stream:
 
322
  else:
323
  return {"answer": "".join(list(generator()))}
324
 
325
+ def load_encoded_image(self, encoded_image: EncodedImage):
326
+ for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
327
+ b.kv_cache.k_cache[:, :, : k.size(2), :] = k
328
+ b.kv_cache.v_cache[:, :, : v.size(2), :] = v
329
+
330
  def caption(
331
  self,
332
  image: Union[Image.Image, EncodedImage],
333
+ length: Literal["normal", "short", "long"] = "normal",
334
  stream: bool = False,
335
  settings: Optional[SamplingSettings] = None,
336
  ):
 
340
  raise ValueError(f"Model does not support caption length '{length}'.")
341
 
342
  image = self.encode_image(image)
343
+ self.load_encoded_image(image)
344
+
345
  prompt_tokens = torch.tensor(
346
  [self.config.tokenizer.templates["caption"][length]], device=self.device
347
  )
 
351
  max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS)
352
 
353
  def generator():
354
+ for token in self._generate_text(prompt_tokens, image.pos, max_tokens):
 
 
355
  yield token
356
 
357
  if stream:
 
362
  def _generate_points(
363
  self,
364
  hidden: torch.Tensor,
 
365
  next_token: torch.Tensor,
366
  pos: int,
367
  include_size: bool = True,
368
  max_points: int = 50,
369
  ):
370
  out = []
371
+ mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
372
+ mask[:, :, :pos] = 1
373
+ pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
374
 
375
+ with torch.inference_mode():
376
  while (
377
  next_token.item() != self.config.tokenizer.eos_id
378
  and len(out) < max_points
 
384
  )
385
 
386
  # Decode y-coordinate
387
+ mask[:, :, pos], pos_ids[0] = 1, pos
388
+ _, hidden = self._decode_one_tok(next_emb, mask, pos_ids)
 
 
 
 
389
  pos += 1
390
  y_logits = decode_coordinate(hidden, self.region)
391
  y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
 
395
 
396
  # Decode size
397
  if include_size:
398
+ mask[:, :, pos], pos_ids[0] = 1, pos
399
+ logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids)
 
 
 
 
400
  pos += 1
401
  size_logits = decode_size(hidden, self.region)
402
+
403
+ # Get bin indices from the logits
404
+ w_bin = torch.argmax(size_logits[0], dim=-1)
405
+ h_bin = torch.argmax(size_logits[1], dim=-1)
406
+
407
+ # Convert from bin indices to actual size values using the inverse of the log-scale mapping
408
+ # Formula: size = 2^((bin / 1023.0) * 10.0 - 10.0)
409
+ w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)
410
+ h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)
411
+
412
  next_emb = encode_size(
413
  torch.tensor(
414
  [w, h], device=self.device, dtype=size_logits.dtype
 
429
  out.append({"x": x_center.item(), "y": y_center.item()})
430
 
431
  # Decode next token (x-coordinate, or eos)
432
+ mask[:, :, pos], pos_ids[0] = 1, pos
433
+ logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids)
 
 
 
 
434
  pos += 1
435
  next_token = torch.argmax(logits, dim=-1)
436
 
 
446
  raise NotImplementedError("Model does not support object detection.")
447
 
448
  image = self.encode_image(image)
449
+ self.load_encoded_image(image)
450
+
451
  prompt_tokens = torch.tensor(
452
  [
453
  self.config.tokenizer.templates["detect"]["prefix"]
454
+ + self.tokenizer.encode(" " + object).ids
455
  + self.config.tokenizer.templates["detect"]["suffix"]
456
  ],
457
  device=self.device,
458
  )
459
 
460
+ _, hidden, next_token, pos = self._prefill_prompt(prompt_tokens, image.pos)
 
 
 
461
  hidden = hidden[:, -1:, :]
462
 
463
  objects = self._generate_points(
464
+ hidden, next_token, pos, include_size=True, max_points=50
465
  )
466
 
467
  return {"objects": objects}
 
476
  raise NotImplementedError("Model does not support pointing.")
477
 
478
  image = self.encode_image(image)
479
+ self.load_encoded_image(image)
480
+
481
  prompt_tokens = torch.tensor(
482
  [
483
  self.config.tokenizer.templates["point"]["prefix"]
484
+ + self.tokenizer.encode(" " + object).ids
485
  + self.config.tokenizer.templates["point"]["suffix"]
486
  ],
487
  device=self.device,
488
  )
489
 
490
+ _, hidden, next_token, pos = self._prefill_prompt(prompt_tokens, image.pos)
 
 
 
491
  hidden = hidden[:, -1:, :]
492
 
493
  objects = self._generate_points(
494
+ hidden, next_token, pos, include_size=False, max_points=50
495
  )
496
 
497
  return {"points": objects}
 
502
  source: Tuple[float, float],
503
  force_detect: bool = False,
504
  ):
505
+ with torch.inference_mode():
506
  before_emb = text_encoder(
507
  torch.tensor(
508
  [self.tokenizer.encode("\n\nPoint:").ids], device=self.device
 
526
 
527
  prompt_emb = torch.cat([before_emb, x_emb, y_emb, after_emb], dim=1)
528
 
529
+ self.load_encoded_image(image)
530
+
531
+ mask = self.attn_mask[:, :, image.pos : image.pos + prompt_emb.size(1), :]
532
+ pos_ids = torch.arange(
533
+ image.pos, image.pos + prompt_emb.size(1), dtype=torch.long
534
  )
535
+ hidden = self._prefill(prompt_emb, mask, pos_ids)
536
  logits = lm_head(hidden, self.text)
537
  next_token = torch.argmax(logits, dim=-1)
538
  pos = image.pos + prompt_emb.size(1)
 
545
  return None
546
 
547
  gaze = self._generate_points(
548
+ hidden, next_token, pos, include_size=False, max_points=1
549
  )
550
  return gaze[0]
551
 
 
639
  )
640
 
641
  return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}}
642
+
643
+
644
+ def _is_cjk_char(cp):
645
+ """Checks whether CP is the codepoint of a CJK character."""
646
+ # This defines a "chinese character" as anything in the CJK Unicode block:
647
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
648
+ if (
649
+ (cp >= 0x4E00 and cp <= 0x9FFF)
650
+ or (cp >= 0x3400 and cp <= 0x4DBF)
651
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)
652
+ ):
653
+ return True
654
+ return False
region.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
 
2
  import math
3
 
4
- from .weights import RegionModel
5
  from .layers import linear, mlp
6
 
7
 
@@ -25,7 +25,7 @@ def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
25
  return torch.cat([f.cos(), f.sin()], dim=-1)
26
 
27
 
28
- def encode_coordinate(coord: torch.Tensor, w: RegionModel) -> torch.Tensor:
29
  """
30
  Takes as input a tensor containing a single float coordinate value (x or y)
31
  and encodes it into hidden states for input to the text model.
@@ -39,7 +39,7 @@ def encode_coordinate(coord: torch.Tensor, w: RegionModel) -> torch.Tensor:
39
  return linear(fourier_features(coord, w.coord_features), w.coord_encoder)
40
 
41
 
42
- def decode_coordinate(hidden_state: torch.Tensor, w: RegionModel) -> torch.Tensor:
43
  """
44
  Takes as input the last hidden state from the text model and outputs a single logit
45
  representing either an x or y coordinate prediction.
@@ -53,13 +53,13 @@ def decode_coordinate(hidden_state: torch.Tensor, w: RegionModel) -> torch.Tenso
53
  return mlp(hidden_state, w.coord_decoder)
54
 
55
 
56
- def encode_size(size: torch.Tensor, w: RegionModel) -> torch.Tensor:
57
  """
58
- Takes a tensor containing normalized width and height values in range [0,1]
59
- and encodes them into hidden states for input to the text model.
60
 
61
  Args:
62
- size: Tensor with two floats for width and height in range [0,1]
63
 
64
  Returns:
65
  Encoded hidden states tensor for input to text model
@@ -67,16 +67,23 @@ def encode_size(size: torch.Tensor, w: RegionModel) -> torch.Tensor:
67
  return linear(fourier_features(size, w.size_features), w.size_encoder)
68
 
69
 
70
- def decode_size(hidden_state: torch.Tensor, w: RegionModel) -> torch.Tensor:
71
  """
72
- Takes as input the last hidden state from the text model and outputs two logits
73
- for width and height respectively.
 
 
 
 
 
 
 
74
 
75
  Args:
76
  hidden_state: The final hidden state tensor from the text model.
77
 
78
  Returns:
79
- A tensor containing two logits - one for predicted width and one for
80
- predicted height.
81
  """
82
  return mlp(hidden_state, w.size_decoder).view(2, -1)
 
1
  import torch
2
+ import torch.nn as nn
3
  import math
4
 
 
5
  from .layers import linear, mlp
6
 
7
 
 
25
  return torch.cat([f.cos(), f.sin()], dim=-1)
26
 
27
 
28
+ def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor:
29
  """
30
  Takes as input a tensor containing a single float coordinate value (x or y)
31
  and encodes it into hidden states for input to the text model.
 
39
  return linear(fourier_features(coord, w.coord_features), w.coord_encoder)
40
 
41
 
42
+ def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
43
  """
44
  Takes as input the last hidden state from the text model and outputs a single logit
45
  representing either an x or y coordinate prediction.
 
53
  return mlp(hidden_state, w.coord_decoder)
54
 
55
 
56
+ def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
57
  """
58
+ Takes a tensor containing width and height values and encodes them into
59
+ hidden states for input to the text model.
60
 
61
  Args:
62
+ size: Tensor with two floats for width and height
63
 
64
  Returns:
65
  Encoded hidden states tensor for input to text model
 
67
  return linear(fourier_features(size, w.size_features), w.size_encoder)
68
 
69
 
70
+ def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
71
  """
72
+ Takes as input the last hidden state from the text model and outputs logits
73
+ for 1024 bins representing width and height in log-scale.
74
+
75
+ The bins are distributed according to the formula:
76
+ bin = (log2(size) + 10.0) / 10.0 * 1023.0
77
+ where size values are clamped to be at least 1/1024.
78
+
79
+ To convert from bin back to size:
80
+ size = 2^((bin / 1023.0) * 10.0 - 10.0)
81
 
82
  Args:
83
  hidden_state: The final hidden state tensor from the text model.
84
 
85
  Returns:
86
+ A tensor containing logits for 1024 bins for width and height.
87
+ Shape is (2, 1024) where the first dimension corresponds to width and height.
88
  """
89
  return mlp(hidden_state, w.size_decoder).view(2, -1)
text.py CHANGED
@@ -1,10 +1,10 @@
1
  import torch
2
  import torch.nn as nn
 
3
  from torch.nn import functional as F
4
 
5
- from .layers import layer_norm, linear, mlp
6
  from .rope import apply_rotary_emb, precompute_freqs_cis
7
- from .weights import AttentionWeights
8
  from .config import TextConfig
9
 
10
 
@@ -14,106 +14,153 @@ def text_encoder(input_ids: torch.Tensor, w: nn.Module):
14
 
15
  def attn(
16
  x: torch.Tensor,
17
- w: AttentionWeights,
18
  freqs_cis: torch.Tensor,
19
- layer_kv_cache: torch.Tensor,
20
  attn_mask: torch.Tensor,
21
  n_heads: int,
22
- pos: int,
 
23
  ):
24
  bsz, q_len, d_model = x.shape
25
  head_dim = d_model // n_heads
26
 
27
- q, k, v = [
28
- t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
29
- for t in linear(x, w.qkv).chunk(3, dim=-1)
30
- ]
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- position_ids = torch.arange(pos, pos + q_len, dtype=torch.long)
33
  q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
34
- k = apply_rotary_emb(k, freqs_cis, position_ids, n_heads)
35
-
36
- k_, v_ = k, v
37
- if layer_kv_cache is not None:
38
- k = torch.cat([layer_kv_cache[0, :, :, :pos, :], k], dim=2)
39
- v = torch.cat([layer_kv_cache[1, :, :, :pos, :], v], dim=2)
40
-
41
- out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask).to(
42
- # This type conversion isn't needed when running in PyTorch directly, but the
43
- # ONNX export runs attention in float32 because the attention mask is cast to
44
- # float32.
45
- x.dtype
46
  )
47
  out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
48
- out = linear(out, w.proj)
49
- return out, torch.stack([k_, v_])
50
 
51
 
52
- def text_decoder(
53
- inputs_embeds: torch.Tensor,
54
- w: nn.Module,
55
- kv_cache: torch.Tensor,
56
- pos: int,
57
- config: TextConfig,
 
58
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  hidden_BTC = inputs_embeds
60
- new_kv_cache = [torch.empty(0)] * len(w.blocks)
61
 
62
- attn_mask = w.attn_mask[
63
- :, :, pos : pos + hidden_BTC.size(1), : pos + hidden_BTC.size(1)
64
- ]
 
 
 
65
 
66
  for i, block in enumerate(w.blocks):
67
  l_in = layer_norm(hidden_BTC, block.ln)
68
- l_attn, new_kv_cache[i] = attn(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  l_in,
70
  block.attn,
71
  freqs_cis=w.freqs_cis,
72
- layer_kv_cache=kv_cache[i],
73
  attn_mask=attn_mask,
74
  n_heads=config.n_heads,
75
- pos=pos,
 
76
  )
77
  l_mlp = mlp(l_in, block.mlp)
78
- hidden_BTC = hidden_BTC + l_attn + l_mlp
79
 
80
- return hidden_BTC, torch.stack(new_kv_cache)
81
 
82
 
83
  def lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
84
  hidden_BC = hidden_BTC[:, -1, :]
85
  hidden_BC = layer_norm(hidden_BC, w.post_ln)
86
- logits = linear(hidden_BC, w.lm_head)
87
  return logits
88
 
89
 
90
- def prefill(
91
- inputs_embeds: torch.Tensor,
92
- kv_cache: torch.Tensor,
93
- pos: int,
94
- w: nn.Module,
95
- config: TextConfig,
96
- ):
97
- # Updates kv_cache in-place
98
- hidden, kv_cache[:, :, :, :, pos : pos + inputs_embeds.size(1), :] = text_decoder(
99
- inputs_embeds, w, kv_cache, pos, config
100
- )
101
- return hidden
102
-
103
-
104
- def decode_one_token(
105
- token_emb: torch.Tensor,
106
- kv_cache: torch.Tensor,
107
- pos: int,
108
- w: nn.Module,
109
- config: TextConfig,
110
- ):
111
- hidden, kv_cache_update = text_decoder(token_emb[None], w, kv_cache, pos, config)
112
- logits = lm_head(hidden, w)
113
- return logits, hidden, kv_cache_update
114
 
115
 
116
  def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
 
 
117
  text = nn.ModuleDict(
118
  {
119
  "blocks": nn.ModuleList(
@@ -123,9 +170,7 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
123
  "ln": nn.LayerNorm(config.dim, dtype=dtype),
124
  "attn": nn.ModuleDict(
125
  {
126
- "qkv": nn.Linear(
127
- config.dim, 3 * config.dim, dtype=dtype
128
- ),
129
  "proj": nn.Linear(
130
  config.dim, config.dim, dtype=dtype
131
  ),
@@ -134,10 +179,10 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
134
  "mlp": nn.ModuleDict(
135
  {
136
  "fc1": nn.Linear(
137
- config.dim, 4 * config.dim, dtype=dtype
138
  ),
139
  "fc2": nn.Linear(
140
- 4 * config.dim, config.dim, dtype=dtype
141
  ),
142
  }
143
  ),
@@ -157,11 +202,4 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
157
  persistent=False,
158
  )
159
 
160
- attn_mask = torch.tril(
161
- torch.ones(1, 1, config.max_context, config.max_context, dtype=torch.bool)
162
- )
163
- if config.prefix_attn != 0:
164
- attn_mask[..., : config.prefix_attn, : config.prefix_attn] = 1
165
- text.register_buffer("attn_mask", attn_mask, persistent=False)
166
-
167
  return text
 
1
  import torch
2
  import torch.nn as nn
3
+
4
  from torch.nn import functional as F
5
 
6
+ from .layers import layer_norm, mlp
7
  from .rope import apply_rotary_emb, precompute_freqs_cis
 
8
  from .config import TextConfig
9
 
10
 
 
14
 
15
  def attn(
16
  x: torch.Tensor,
17
+ w: nn.Module,
18
  freqs_cis: torch.Tensor,
19
+ kv_cache: nn.Module,
20
  attn_mask: torch.Tensor,
21
  n_heads: int,
22
+ n_kv_heads: int,
23
+ position_ids: torch.Tensor,
24
  ):
25
  bsz, q_len, d_model = x.shape
26
  head_dim = d_model // n_heads
27
 
28
+ qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
29
+ q_dim = n_heads * head_dim
30
+ kv_dim = n_kv_heads * head_dim
31
+
32
+ q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
33
+ k = (
34
+ qkv_out[..., q_dim : q_dim + kv_dim]
35
+ .view(bsz, q_len, n_kv_heads, head_dim)
36
+ .transpose(1, 2)
37
+ )
38
+ v = (
39
+ qkv_out[..., q_dim + kv_dim :]
40
+ .view(bsz, q_len, n_kv_heads, head_dim)
41
+ .transpose(1, 2)
42
+ )
43
 
 
44
  q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
45
+ k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
46
+
47
+ if kv_cache is not None:
48
+ k, v = kv_cache.update(position_ids, k, v)
49
+
50
+ out = F.scaled_dot_product_attention(
51
+ q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
 
 
 
 
 
52
  )
53
  out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
54
+ out = w.proj(out)
55
+ return out
56
 
57
 
58
+ def _attn(
59
+ x: torch.Tensor,
60
+ w: torch.Tensor,
61
+ freqs_cis: torch.Tensor,
62
+ attn_mask: torch.Tensor,
63
+ n_heads: int,
64
+ n_kv_heads: int,
65
  ):
66
+ bsz, q_len, d_model = x.shape
67
+ head_dim = d_model // n_heads
68
+ pos = 0
69
+
70
+ qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
71
+ q_dim = n_heads * head_dim
72
+ kv_dim = n_kv_heads * head_dim
73
+
74
+ q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
75
+ k = (
76
+ qkv_out[..., q_dim : q_dim + kv_dim]
77
+ .view(bsz, q_len, n_kv_heads, head_dim)
78
+ .transpose(1, 2)
79
+ )
80
+ v = (
81
+ qkv_out[..., q_dim + kv_dim :]
82
+ .view(bsz, q_len, n_kv_heads, head_dim)
83
+ .transpose(1, 2)
84
+ )
85
+
86
+ position_ids = torch.arange(pos, pos + q_len, dtype=torch.long)
87
+ q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
88
+ k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
89
+ out = F.scaled_dot_product_attention(
90
+ q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
91
+ )
92
+ out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
93
+ out = w.proj(out)
94
+ return out
95
+
96
+
97
+ def _produce_hidden(inputs_embeds: torch.Tensor, w: nn.Module, config: TextConfig):
98
  hidden_BTC = inputs_embeds
 
99
 
100
+ bsz, q_len, d_model = inputs_embeds.shape
101
+ attn_mask = torch.zeros(q_len, q_len)
102
+ attn_mask[:730, :730] = 1
103
+ for i in range(730, q_len):
104
+ attn_mask[i, : i + 1] = 1
105
+ attn_mask = attn_mask.to(dtype=torch.bool)
106
 
107
  for i, block in enumerate(w.blocks):
108
  l_in = layer_norm(hidden_BTC, block.ln)
109
+ l_attn = _attn(
110
+ x=l_in,
111
+ w=block.attn,
112
+ freqs_cis=w.freqs_cis,
113
+ attn_mask=attn_mask,
114
+ n_heads=config.n_heads,
115
+ n_kv_heads=config.n_kv_heads,
116
+ )
117
+ l_mlp = mlp(l_in, block.mlp)
118
+ hidden_BTC = hidden_BTC + l_attn + l_mlp
119
+
120
+ return hidden_BTC
121
+
122
+
123
+ def text_decoder(
124
+ x: torch.Tensor,
125
+ w: nn.Module,
126
+ attn_mask: torch.Tensor,
127
+ position_ids: torch.Tensor,
128
+ config: TextConfig,
129
+ ):
130
+ for i, block in enumerate(w.blocks):
131
+ l_in = layer_norm(x, block.ln)
132
+ l_attn = attn(
133
  l_in,
134
  block.attn,
135
  freqs_cis=w.freqs_cis,
136
+ kv_cache=block.kv_cache,
137
  attn_mask=attn_mask,
138
  n_heads=config.n_heads,
139
+ n_kv_heads=config.n_kv_heads,
140
+ position_ids=position_ids,
141
  )
142
  l_mlp = mlp(l_in, block.mlp)
143
+ x = x + l_attn + l_mlp
144
 
145
+ return x
146
 
147
 
148
  def lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
149
  hidden_BC = hidden_BTC[:, -1, :]
150
  hidden_BC = layer_norm(hidden_BC, w.post_ln)
151
+ logits = w.lm_head(hidden_BC)
152
  return logits
153
 
154
 
155
+ def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
156
+ hidden_BTC = layer_norm(hidden_BTC, w.post_ln)
157
+ logits = w.lm_head(hidden_BTC)
158
+ return logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
 
161
  def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
162
+ qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
163
+
164
  text = nn.ModuleDict(
165
  {
166
  "blocks": nn.ModuleList(
 
170
  "ln": nn.LayerNorm(config.dim, dtype=dtype),
171
  "attn": nn.ModuleDict(
172
  {
173
+ "qkv": nn.Linear(config.dim, qkv_dim, dtype=dtype),
 
 
174
  "proj": nn.Linear(
175
  config.dim, config.dim, dtype=dtype
176
  ),
 
179
  "mlp": nn.ModuleDict(
180
  {
181
  "fc1": nn.Linear(
182
+ config.dim, config.ff_dim, dtype=dtype
183
  ),
184
  "fc2": nn.Linear(
185
+ config.ff_dim, config.dim, dtype=dtype
186
  ),
187
  }
188
  ),
 
202
  persistent=False,
203
  )
204
 
 
 
 
 
 
 
 
205
  return text
vision.py CHANGED
@@ -4,7 +4,6 @@ import torch.nn.functional as F
4
  import numpy as np
5
 
6
  from typing import Union, Tuple
7
- from einops import rearrange
8
  from PIL import Image
9
 
10
  from .layers import attn, layer_norm, linear, mlp
@@ -42,13 +41,28 @@ def prepare_crops(
42
  return all_crops, overlap_crops["tiling"]
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
46
- x = rearrange(
47
- input_BCHW,
48
- "b c (h p1) (w p2) -> b (h w) (c p1 p2)",
49
- p1=config.enc_patch_size,
50
- p2=config.enc_patch_size,
51
- ) # B3HW -> B(HxW)(3xP1xP2), aka BTC
52
 
53
  x = linear(x, w.patch_emb)
54
  x = x + w.pos_emb
 
4
  import numpy as np
5
 
6
  from typing import Union, Tuple
 
7
  from PIL import Image
8
 
9
  from .layers import attn, layer_norm, linear, mlp
 
41
  return all_crops, overlap_crops["tiling"]
42
 
43
 
44
+ def create_patches(x, patch_size):
45
+ # Original shape: [B, C, H, W]
46
+ B, C, H, W = x.shape
47
+ P1 = P2 = patch_size
48
+
49
+ # Step 1: Split H and W dimensions into patches
50
+ # [B, C, H/P1, P1, W/P2, P2]
51
+ x = x.reshape(B, C, H // P1, P1, W // P2, P2)
52
+
53
+ # Step 2: Rearrange dimensions to match target shape
54
+ # [B, H/P1, W/P2, C, P1, P2]
55
+ x = x.permute(0, 2, 4, 1, 3, 5)
56
+
57
+ # Step 3: Combine dimensions to get final shape
58
+ # [B, (H/P1)*(W/P2), C*P1*P2]
59
+ x = x.reshape(B, (H // P1) * (W // P2), C * P1 * P2)
60
+
61
+ return x
62
+
63
+
64
  def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
65
+ x = create_patches(input_BCHW, config.enc_patch_size)
 
 
 
 
 
66
 
67
  x = linear(x, w.patch_emb)
68
  x = x + w.pos_emb