Upload HfMoondream
Browse files- config.json +1 -1
- config.py +3 -0
- generation_config.json +1 -1
- hf_moondream.py +26 -7
- layers.py +2 -2
- model.safetensors +1 -1
- moondream.py +204 -136
- region.py +19 -12
- text.py +114 -76
- vision.py +21 -7
config.json
CHANGED
@@ -9,5 +9,5 @@
|
|
9 |
"config": {},
|
10 |
"model_type": "moondream1",
|
11 |
"torch_dtype": "float16",
|
12 |
-
"transformers_version": "4.
|
13 |
}
|
|
|
9 |
"config": {},
|
10 |
"model_type": "moondream1",
|
11 |
"torch_dtype": "float16",
|
12 |
+
"transformers_version": "4.48.0"
|
13 |
}
|
config.py
CHANGED
@@ -5,10 +5,12 @@ from typing import Dict, List, Optional
|
|
5 |
@dataclass(frozen=True)
|
6 |
class TextConfig:
|
7 |
dim: int = 2048
|
|
|
8 |
n_layers: int = 24
|
9 |
vocab_size: int = 51200
|
10 |
max_context: int = 2048
|
11 |
n_heads: int = 32
|
|
|
12 |
prefix_attn: int = 730
|
13 |
|
14 |
|
@@ -46,6 +48,7 @@ class TokenizerConfig:
|
|
46 |
"caption": {
|
47 |
"short": [198, 198, 16438, 8305, 25],
|
48 |
"normal": [198, 198, 24334, 1159, 25],
|
|
|
49 |
},
|
50 |
"query": {"prefix": [198, 198, 24361, 25], "suffix": [198, 198, 33706, 25]},
|
51 |
"detect": {"prefix": [198, 198, 47504, 25], "suffix": [628]},
|
|
|
5 |
@dataclass(frozen=True)
|
6 |
class TextConfig:
|
7 |
dim: int = 2048
|
8 |
+
ff_dim: int = 8192
|
9 |
n_layers: int = 24
|
10 |
vocab_size: int = 51200
|
11 |
max_context: int = 2048
|
12 |
n_heads: int = 32
|
13 |
+
n_kv_heads: int = 32
|
14 |
prefix_attn: int = 730
|
15 |
|
16 |
|
|
|
48 |
"caption": {
|
49 |
"short": [198, 198, 16438, 8305, 25],
|
50 |
"normal": [198, 198, 24334, 1159, 25],
|
51 |
+
"long": [198, 198, 14617, 8305, 25],
|
52 |
},
|
53 |
"query": {"prefix": [198, 198, 24361, 25], "suffix": [198, 198, 33706, 25]},
|
54 |
"detect": {"prefix": [198, 198, 47504, 25], "suffix": [628]},
|
generation_config.json
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
{
|
2 |
"_from_model_config": true,
|
3 |
-
"transformers_version": "4.
|
4 |
}
|
|
|
1 |
{
|
2 |
"_from_model_config": true,
|
3 |
+
"transformers_version": "4.48.0"
|
4 |
}
|
hf_moondream.py
CHANGED
@@ -14,7 +14,7 @@ from .utils import *
|
|
14 |
def extract_question(text):
|
15 |
prefix = "<image>\n\nQuestion: "
|
16 |
suffix = "\n\nAnswer:"
|
17 |
-
|
18 |
if text.startswith(prefix) and text.endswith(suffix):
|
19 |
return text[len(prefix) : -len(suffix)]
|
20 |
else:
|
@@ -36,30 +36,44 @@ class HfMoondream(PreTrainedModel):
|
|
36 |
|
37 |
def __init__(self, config):
|
38 |
super().__init__(config)
|
39 |
-
self.model = MoondreamModel(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
@property
|
42 |
def encode_image(self):
|
|
|
43 |
return self.model.encode_image
|
44 |
|
45 |
@property
|
46 |
def query(self):
|
|
|
47 |
return self.model.query
|
48 |
|
49 |
@property
|
50 |
def caption(self):
|
|
|
51 |
return self.model.caption
|
52 |
|
53 |
@property
|
54 |
def detect(self):
|
|
|
55 |
return self.model.detect
|
56 |
|
57 |
@property
|
58 |
def point(self):
|
|
|
59 |
return self.model.point
|
60 |
|
61 |
@property
|
62 |
def detect_gaze(self):
|
|
|
63 |
return self.model.detect_gaze
|
64 |
|
65 |
def answer_question(
|
@@ -98,22 +112,27 @@ class HfMoondream(PreTrainedModel):
|
|
98 |
"""
|
99 |
prompt_extracted = extract_question(prompt)
|
100 |
if prompt_extracted is not None:
|
101 |
-
answer = self.model.query(
|
102 |
-
|
103 |
-
]
|
104 |
else:
|
105 |
image_embeds = self.encode_image(image_embeds)
|
106 |
prompt_tokens = torch.tensor(
|
107 |
[self.model.tokenizer.encode(prompt).ids],
|
108 |
device=self.device,
|
109 |
)
|
|
|
110 |
def generator():
|
111 |
for token in self.model._generate_text(
|
112 |
-
prompt_tokens,
|
|
|
|
|
|
|
113 |
):
|
114 |
yield token
|
|
|
115 |
answer = "".join(list(generator()))
|
116 |
-
|
117 |
return [answer]
|
118 |
|
119 |
def get_input_embeddings(self):
|
|
|
14 |
def extract_question(text):
|
15 |
prefix = "<image>\n\nQuestion: "
|
16 |
suffix = "\n\nAnswer:"
|
17 |
+
|
18 |
if text.startswith(prefix) and text.endswith(suffix):
|
19 |
return text[len(prefix) : -len(suffix)]
|
20 |
else:
|
|
|
36 |
|
37 |
def __init__(self, config):
|
38 |
super().__init__(config)
|
39 |
+
self.model = MoondreamModel(
|
40 |
+
MoondreamConfig.from_dict(config.config), setup_caches=False
|
41 |
+
)
|
42 |
+
self._is_kv_cache_setup = False
|
43 |
+
|
44 |
+
def _setup_caches(self):
|
45 |
+
if not self._is_kv_cache_setup:
|
46 |
+
self.model._setup_caches()
|
47 |
+
self._is_kv_cache_setup = True
|
48 |
|
49 |
@property
|
50 |
def encode_image(self):
|
51 |
+
self._setup_caches()
|
52 |
return self.model.encode_image
|
53 |
|
54 |
@property
|
55 |
def query(self):
|
56 |
+
self._setup_caches()
|
57 |
return self.model.query
|
58 |
|
59 |
@property
|
60 |
def caption(self):
|
61 |
+
self._setup_caches()
|
62 |
return self.model.caption
|
63 |
|
64 |
@property
|
65 |
def detect(self):
|
66 |
+
self._setup_caches()
|
67 |
return self.model.detect
|
68 |
|
69 |
@property
|
70 |
def point(self):
|
71 |
+
self._setup_caches()
|
72 |
return self.model.point
|
73 |
|
74 |
@property
|
75 |
def detect_gaze(self):
|
76 |
+
self._setup_caches()
|
77 |
return self.model.detect_gaze
|
78 |
|
79 |
def answer_question(
|
|
|
112 |
"""
|
113 |
prompt_extracted = extract_question(prompt)
|
114 |
if prompt_extracted is not None:
|
115 |
+
answer = self.model.query(
|
116 |
+
image=image_embeds, question=prompt_extracted, stream=False
|
117 |
+
)["answer"]
|
118 |
else:
|
119 |
image_embeds = self.encode_image(image_embeds)
|
120 |
prompt_tokens = torch.tensor(
|
121 |
[self.model.tokenizer.encode(prompt).ids],
|
122 |
device=self.device,
|
123 |
)
|
124 |
+
|
125 |
def generator():
|
126 |
for token in self.model._generate_text(
|
127 |
+
prompt_tokens,
|
128 |
+
image_embeds.kv_cache,
|
129 |
+
image_embeds.pos,
|
130 |
+
max_new_tokens,
|
131 |
):
|
132 |
yield token
|
133 |
+
|
134 |
answer = "".join(list(generator()))
|
135 |
+
|
136 |
return [answer]
|
137 |
|
138 |
def get_input_embeddings(self):
|
layers.py
CHANGED
@@ -37,9 +37,9 @@ class MLPWeights:
|
|
37 |
|
38 |
|
39 |
def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
|
40 |
-
x =
|
41 |
x = gelu_approx(x)
|
42 |
-
x =
|
43 |
return x
|
44 |
|
45 |
|
|
|
37 |
|
38 |
|
39 |
def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
|
40 |
+
x = w.fc1(x)
|
41 |
x = gelu_approx(x)
|
42 |
+
x = w.fc2(x)
|
43 |
return x
|
44 |
|
45 |
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 3854538376
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fadcffea8c17fe8a20ea68af3a013cf3184a63787ee4453cc9eb75206c7c1f9b
|
3 |
size 3854538376
|
moondream.py
CHANGED
@@ -2,7 +2,7 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
import random
|
4 |
|
5 |
-
from typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional
|
6 |
from PIL import Image
|
7 |
from dataclasses import dataclass
|
8 |
from tokenizers import Tokenizer
|
@@ -10,7 +10,7 @@ from tokenizers import Tokenizer
|
|
10 |
from .config import MoondreamConfig
|
11 |
from .image_crops import reconstruct_from_crops
|
12 |
from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
|
13 |
-
from .text import build_text_model,
|
14 |
from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
|
15 |
from .utils import remove_outlier_points
|
16 |
|
@@ -21,53 +21,41 @@ SamplingSettings = TypedDict(
|
|
21 |
total=False,
|
22 |
)
|
23 |
|
24 |
-
DEFAULT_MAX_TOKENS =
|
25 |
|
26 |
|
27 |
@dataclass(frozen=True)
|
28 |
class EncodedImage:
|
29 |
pos: int
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
sorted_indices_to_remove = torch.gather(
|
51 |
-
tokens_to_remove, dim=-1, index=sorted_indices
|
52 |
-
)
|
53 |
-
if min_tokens_to_keep > 1:
|
54 |
-
sorted_indices_to_remove[..., :min_tokens_to_keep] = False
|
55 |
-
|
56 |
-
indices_to_remove = sorted_indices_to_remove.scatter(
|
57 |
-
1, sorted_indices, sorted_indices_to_remove
|
58 |
-
)
|
59 |
-
logits = logits.masked_fill(indices_to_remove, filter_value)
|
60 |
-
token = torch.multinomial(logits, num_samples=1)
|
61 |
-
return token.squeeze(0)
|
62 |
|
63 |
|
64 |
class MoondreamModel(nn.Module):
|
65 |
-
def __init__(self, config: MoondreamConfig, dtype=torch.float16):
|
66 |
super().__init__()
|
67 |
self.config = config
|
68 |
|
69 |
self.tokenizer = Tokenizer.from_pretrained(
|
70 |
-
"vikhyatk/moondream2", revision="
|
71 |
)
|
72 |
self.vision = build_vision_model(config.vision, dtype)
|
73 |
self.text = build_text_model(config.text, dtype)
|
@@ -114,35 +102,65 @@ class MoondreamModel(nn.Module):
|
|
114 |
torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
|
115 |
)
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
@property
|
125 |
def device(self):
|
126 |
return self.vision.pos_emb.device
|
127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
def compile(self):
|
129 |
-
|
130 |
-
|
131 |
-
)
|
132 |
-
|
133 |
-
|
134 |
-
# self.ops["vision_projection"], fullgraph=True
|
135 |
-
# )
|
136 |
-
self.ops["prefill"] = torch.compile(self.ops["prefill"], fullgraph=True)
|
137 |
-
self.ops["decode_one_token"] = torch.compile(
|
138 |
-
self.ops["decode_one_token"], fullgraph=True
|
139 |
)
|
140 |
|
141 |
def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
|
142 |
all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
|
143 |
torch._dynamo.mark_dynamic(all_crops, 0)
|
144 |
|
145 |
-
outputs = self.
|
146 |
|
147 |
global_features = outputs[0]
|
148 |
local_features = outputs[1:].view(
|
@@ -159,9 +177,7 @@ class MoondreamModel(nn.Module):
|
|
159 |
overlap_margin=self.config.vision.overlap_margin,
|
160 |
)
|
161 |
|
162 |
-
return self.
|
163 |
-
global_features, reconstructed, self.vision, self.config.vision
|
164 |
-
)
|
165 |
|
166 |
def encode_image(self, image: Union[Image.Image, EncodedImage]) -> EncodedImage:
|
167 |
if isinstance(image, EncodedImage):
|
@@ -171,34 +187,35 @@ class MoondreamModel(nn.Module):
|
|
171 |
|
172 |
# Run through text model in addition to the vision encoder, to minimize
|
173 |
# re-computation if multiple queries are performed on this image.
|
174 |
-
|
175 |
-
self.config.text.n_layers,
|
176 |
-
2, # k, v
|
177 |
-
1, # batch size
|
178 |
-
self.config.text.n_heads,
|
179 |
-
self.config.text.max_context, # static cache
|
180 |
-
self.config.text.dim // self.config.text.n_heads, # head dim
|
181 |
-
device=self.device,
|
182 |
-
dtype=torch.float16,
|
183 |
-
)
|
184 |
-
with torch.no_grad():
|
185 |
img_emb = self._run_vision_encoder(image)
|
186 |
bos_emb = text_encoder(
|
187 |
torch.tensor([[self.config.tokenizer.bos_id]], device=self.device),
|
188 |
self.text,
|
189 |
)
|
190 |
inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
|
191 |
-
self.
|
192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
-
def _prefill_prompt(
|
195 |
-
|
196 |
-
):
|
197 |
-
with torch.no_grad():
|
198 |
prompt_emb = text_encoder(prompt_tokens, self.text)
|
199 |
-
|
200 |
-
|
201 |
-
)
|
|
|
202 |
logits = lm_head(hidden, self.text)
|
203 |
next_token = torch.argmax(logits, dim=-1)
|
204 |
pos = pos + prompt_emb.size(1)
|
@@ -207,33 +224,67 @@ class MoondreamModel(nn.Module):
|
|
207 |
def _generate_text(
|
208 |
self,
|
209 |
prompt_tokens: torch.Tensor,
|
210 |
-
kv_cache: torch.Tensor,
|
211 |
pos: int,
|
212 |
max_tokens: int,
|
213 |
):
|
214 |
-
|
215 |
-
_, _, next_token, pos = self._prefill_prompt(kv_cache, prompt_tokens, pos)
|
216 |
|
217 |
def generator(next_token, pos):
|
|
|
|
|
|
|
218 |
generated_tokens = 0
|
219 |
|
|
|
|
|
|
|
|
|
220 |
while (
|
221 |
next_token_id := next_token.item()
|
222 |
) != self.config.tokenizer.eos_id and generated_tokens < max_tokens:
|
223 |
-
|
224 |
-
|
225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
next_emb = text_encoder(next_token, self.text)
|
227 |
-
|
228 |
-
|
229 |
-
)
|
230 |
-
kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = (
|
231 |
-
kv_cache_update
|
232 |
-
)
|
233 |
pos += 1
|
234 |
next_token = torch.argmax(logits, dim=-1)
|
235 |
generated_tokens += 1
|
236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
return generator(next_token, pos)
|
238 |
|
239 |
def query(
|
@@ -247,10 +298,12 @@ class MoondreamModel(nn.Module):
|
|
247 |
raise NotImplementedError("Model does not support querying.")
|
248 |
|
249 |
image = self.encode_image(image)
|
|
|
|
|
250 |
prompt_tokens = torch.tensor(
|
251 |
[
|
252 |
self.config.tokenizer.templates["query"]["prefix"]
|
253 |
-
+ self.tokenizer.encode(question).ids
|
254 |
+ self.config.tokenizer.templates["query"]["suffix"]
|
255 |
],
|
256 |
device=self.device,
|
@@ -261,9 +314,7 @@ class MoondreamModel(nn.Module):
|
|
261 |
max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS)
|
262 |
|
263 |
def generator():
|
264 |
-
for token in self._generate_text(
|
265 |
-
prompt_tokens, image.kv_cache, image.pos, max_tokens
|
266 |
-
):
|
267 |
yield token
|
268 |
|
269 |
if stream:
|
@@ -271,10 +322,15 @@ class MoondreamModel(nn.Module):
|
|
271 |
else:
|
272 |
return {"answer": "".join(list(generator()))}
|
273 |
|
|
|
|
|
|
|
|
|
|
|
274 |
def caption(
|
275 |
self,
|
276 |
image: Union[Image.Image, EncodedImage],
|
277 |
-
length: Literal["normal", "short"] = "normal",
|
278 |
stream: bool = False,
|
279 |
settings: Optional[SamplingSettings] = None,
|
280 |
):
|
@@ -284,6 +340,8 @@ class MoondreamModel(nn.Module):
|
|
284 |
raise ValueError(f"Model does not support caption length '{length}'.")
|
285 |
|
286 |
image = self.encode_image(image)
|
|
|
|
|
287 |
prompt_tokens = torch.tensor(
|
288 |
[self.config.tokenizer.templates["caption"][length]], device=self.device
|
289 |
)
|
@@ -293,9 +351,7 @@ class MoondreamModel(nn.Module):
|
|
293 |
max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS)
|
294 |
|
295 |
def generator():
|
296 |
-
for token in self._generate_text(
|
297 |
-
prompt_tokens, image.kv_cache, image.pos, max_tokens
|
298 |
-
):
|
299 |
yield token
|
300 |
|
301 |
if stream:
|
@@ -306,15 +362,17 @@ class MoondreamModel(nn.Module):
|
|
306 |
def _generate_points(
|
307 |
self,
|
308 |
hidden: torch.Tensor,
|
309 |
-
kv_cache: torch.Tensor,
|
310 |
next_token: torch.Tensor,
|
311 |
pos: int,
|
312 |
include_size: bool = True,
|
313 |
max_points: int = 50,
|
314 |
):
|
315 |
out = []
|
|
|
|
|
|
|
316 |
|
317 |
-
with torch.
|
318 |
while (
|
319 |
next_token.item() != self.config.tokenizer.eos_id
|
320 |
and len(out) < max_points
|
@@ -326,12 +384,8 @@ class MoondreamModel(nn.Module):
|
|
326 |
)
|
327 |
|
328 |
# Decode y-coordinate
|
329 |
-
|
330 |
-
|
331 |
-
)
|
332 |
-
kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = (
|
333 |
-
kv_cache_update
|
334 |
-
)
|
335 |
pos += 1
|
336 |
y_logits = decode_coordinate(hidden, self.region)
|
337 |
y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
|
@@ -341,16 +395,20 @@ class MoondreamModel(nn.Module):
|
|
341 |
|
342 |
# Decode size
|
343 |
if include_size:
|
344 |
-
|
345 |
-
|
346 |
-
)
|
347 |
-
kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = (
|
348 |
-
kv_cache_update
|
349 |
-
)
|
350 |
pos += 1
|
351 |
size_logits = decode_size(hidden, self.region)
|
352 |
-
|
353 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
next_emb = encode_size(
|
355 |
torch.tensor(
|
356 |
[w, h], device=self.device, dtype=size_logits.dtype
|
@@ -371,12 +429,8 @@ class MoondreamModel(nn.Module):
|
|
371 |
out.append({"x": x_center.item(), "y": y_center.item()})
|
372 |
|
373 |
# Decode next token (x-coordinate, or eos)
|
374 |
-
|
375 |
-
|
376 |
-
)
|
377 |
-
kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = (
|
378 |
-
kv_cache_update
|
379 |
-
)
|
380 |
pos += 1
|
381 |
next_token = torch.argmax(logits, dim=-1)
|
382 |
|
@@ -392,23 +446,22 @@ class MoondreamModel(nn.Module):
|
|
392 |
raise NotImplementedError("Model does not support object detection.")
|
393 |
|
394 |
image = self.encode_image(image)
|
|
|
|
|
395 |
prompt_tokens = torch.tensor(
|
396 |
[
|
397 |
self.config.tokenizer.templates["detect"]["prefix"]
|
398 |
-
+ self.tokenizer.encode(object).ids
|
399 |
+ self.config.tokenizer.templates["detect"]["suffix"]
|
400 |
],
|
401 |
device=self.device,
|
402 |
)
|
403 |
|
404 |
-
|
405 |
-
_, hidden, next_token, pos = self._prefill_prompt(
|
406 |
-
kv_cache, prompt_tokens, image.pos
|
407 |
-
)
|
408 |
hidden = hidden[:, -1:, :]
|
409 |
|
410 |
objects = self._generate_points(
|
411 |
-
hidden,
|
412 |
)
|
413 |
|
414 |
return {"objects": objects}
|
@@ -423,23 +476,22 @@ class MoondreamModel(nn.Module):
|
|
423 |
raise NotImplementedError("Model does not support pointing.")
|
424 |
|
425 |
image = self.encode_image(image)
|
|
|
|
|
426 |
prompt_tokens = torch.tensor(
|
427 |
[
|
428 |
self.config.tokenizer.templates["point"]["prefix"]
|
429 |
-
+ self.tokenizer.encode(object).ids
|
430 |
+ self.config.tokenizer.templates["point"]["suffix"]
|
431 |
],
|
432 |
device=self.device,
|
433 |
)
|
434 |
|
435 |
-
|
436 |
-
_, hidden, next_token, pos = self._prefill_prompt(
|
437 |
-
kv_cache, prompt_tokens, image.pos
|
438 |
-
)
|
439 |
hidden = hidden[:, -1:, :]
|
440 |
|
441 |
objects = self._generate_points(
|
442 |
-
hidden,
|
443 |
)
|
444 |
|
445 |
return {"points": objects}
|
@@ -450,7 +502,7 @@ class MoondreamModel(nn.Module):
|
|
450 |
source: Tuple[float, float],
|
451 |
force_detect: bool = False,
|
452 |
):
|
453 |
-
with torch.
|
454 |
before_emb = text_encoder(
|
455 |
torch.tensor(
|
456 |
[self.tokenizer.encode("\n\nPoint:").ids], device=self.device
|
@@ -474,10 +526,13 @@ class MoondreamModel(nn.Module):
|
|
474 |
|
475 |
prompt_emb = torch.cat([before_emb, x_emb, y_emb, after_emb], dim=1)
|
476 |
|
477 |
-
|
478 |
-
|
479 |
-
|
|
|
|
|
480 |
)
|
|
|
481 |
logits = lm_head(hidden, self.text)
|
482 |
next_token = torch.argmax(logits, dim=-1)
|
483 |
pos = image.pos + prompt_emb.size(1)
|
@@ -490,7 +545,7 @@ class MoondreamModel(nn.Module):
|
|
490 |
return None
|
491 |
|
492 |
gaze = self._generate_points(
|
493 |
-
hidden,
|
494 |
)
|
495 |
return gaze[0]
|
496 |
|
@@ -584,3 +639,16 @@ class MoondreamModel(nn.Module):
|
|
584 |
)
|
585 |
|
586 |
return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch.nn as nn
|
3 |
import random
|
4 |
|
5 |
+
from typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional, List
|
6 |
from PIL import Image
|
7 |
from dataclasses import dataclass
|
8 |
from tokenizers import Tokenizer
|
|
|
10 |
from .config import MoondreamConfig
|
11 |
from .image_crops import reconstruct_from_crops
|
12 |
from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
|
13 |
+
from .text import build_text_model, text_encoder, lm_head, text_decoder
|
14 |
from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
|
15 |
from .utils import remove_outlier_points
|
16 |
|
|
|
21 |
total=False,
|
22 |
)
|
23 |
|
24 |
+
DEFAULT_MAX_TOKENS = 768
|
25 |
|
26 |
|
27 |
@dataclass(frozen=True)
|
28 |
class EncodedImage:
|
29 |
pos: int
|
30 |
+
caches: List[Tuple[torch.Tensor, torch.Tensor]]
|
31 |
+
|
32 |
+
|
33 |
+
class KVCache(nn.Module):
|
34 |
+
|
35 |
+
def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
|
36 |
+
super().__init__()
|
37 |
+
cache_shape = (1, n_kv_heads, max_context, dim // n_heads)
|
38 |
+
self.register_buffer(
|
39 |
+
"k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
|
40 |
+
)
|
41 |
+
self.register_buffer(
|
42 |
+
"v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
|
43 |
+
)
|
44 |
+
|
45 |
+
def update(self, pos_ids, k, v):
|
46 |
+
kout, vout = self.k_cache, self.v_cache
|
47 |
+
kout[:, :, pos_ids, :] = k
|
48 |
+
vout[:, :, pos_ids, :] = v
|
49 |
+
return kout, vout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
|
52 |
class MoondreamModel(nn.Module):
|
53 |
+
def __init__(self, config: MoondreamConfig, dtype=torch.float16, setup_caches=True):
|
54 |
super().__init__()
|
55 |
self.config = config
|
56 |
|
57 |
self.tokenizer = Tokenizer.from_pretrained(
|
58 |
+
"vikhyatk/moondream2", revision="2025-01-09"
|
59 |
)
|
60 |
self.vision = build_vision_model(config.vision, dtype)
|
61 |
self.text = build_text_model(config.text, dtype)
|
|
|
102 |
torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
|
103 |
)
|
104 |
|
105 |
+
attn_mask = torch.tril(
|
106 |
+
torch.ones(
|
107 |
+
1, 1, config.text.max_context, config.text.max_context, dtype=torch.bool
|
108 |
+
)
|
109 |
+
)
|
110 |
+
patch_w = config.vision.crop_size // config.vision.enc_patch_size
|
111 |
+
prefix_attn_len = 1 + patch_w**2
|
112 |
+
attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
|
113 |
+
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
114 |
+
|
115 |
+
# Initialize KV caches.
|
116 |
+
if setup_caches:
|
117 |
+
self._setup_caches()
|
118 |
+
|
119 |
+
def _setup_caches(self):
|
120 |
+
c = self.config.text
|
121 |
+
for b in self.text.blocks:
|
122 |
+
b.kv_cache = KVCache(
|
123 |
+
c.n_heads,
|
124 |
+
c.n_kv_heads,
|
125 |
+
c.max_context,
|
126 |
+
c.dim,
|
127 |
+
device=self.device,
|
128 |
+
dtype=self.vision.pos_emb.dtype,
|
129 |
+
)
|
130 |
|
131 |
@property
|
132 |
def device(self):
|
133 |
return self.vision.pos_emb.device
|
134 |
|
135 |
+
def _vis_enc(self, x: torch.Tensor):
|
136 |
+
return vision_encoder(x, self.vision, self.config.vision)
|
137 |
+
|
138 |
+
def _vis_proj(self, g: torch.Tensor, r: torch.Tensor):
|
139 |
+
return vision_projection(g, r, self.vision, self.config.vision)
|
140 |
+
|
141 |
+
def _prefill(self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor):
|
142 |
+
return text_decoder(x, self.text, attn_mask, pos_ids, self.config.text)
|
143 |
+
|
144 |
+
def _decode_one_tok(
|
145 |
+
self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor
|
146 |
+
):
|
147 |
+
hidden = text_decoder(x[None], self.text, attn_mask, pos_ids, self.config.text)
|
148 |
+
logits = lm_head(hidden, self.text)
|
149 |
+
return logits, hidden
|
150 |
+
|
151 |
def compile(self):
|
152 |
+
# TODO: vision_projection is not being compiled
|
153 |
+
self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
|
154 |
+
self._prefill = torch.compile(self._prefill, fullgraph=True)
|
155 |
+
self._decode_one_tok = torch.compile(
|
156 |
+
self._decode_one_tok, fullgraph=True, mode="reduce-overhead"
|
|
|
|
|
|
|
|
|
|
|
157 |
)
|
158 |
|
159 |
def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
|
160 |
all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
|
161 |
torch._dynamo.mark_dynamic(all_crops, 0)
|
162 |
|
163 |
+
outputs = self._vis_enc(all_crops)
|
164 |
|
165 |
global_features = outputs[0]
|
166 |
local_features = outputs[1:].view(
|
|
|
177 |
overlap_margin=self.config.vision.overlap_margin,
|
178 |
)
|
179 |
|
180 |
+
return self._vis_proj(global_features, reconstructed)
|
|
|
|
|
181 |
|
182 |
def encode_image(self, image: Union[Image.Image, EncodedImage]) -> EncodedImage:
|
183 |
if isinstance(image, EncodedImage):
|
|
|
187 |
|
188 |
# Run through text model in addition to the vision encoder, to minimize
|
189 |
# re-computation if multiple queries are performed on this image.
|
190 |
+
with torch.inference_mode():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
img_emb = self._run_vision_encoder(image)
|
192 |
bos_emb = text_encoder(
|
193 |
torch.tensor([[self.config.tokenizer.bos_id]], device=self.device),
|
194 |
self.text,
|
195 |
)
|
196 |
inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
|
197 |
+
mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
|
198 |
+
pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
|
199 |
+
self._prefill(inputs_embeds, mask, pos_ids)
|
200 |
+
|
201 |
+
return EncodedImage(
|
202 |
+
pos=inputs_embeds.size(1),
|
203 |
+
caches=[
|
204 |
+
(
|
205 |
+
b.kv_cache.k_cache[:, :, : inputs_embeds.size(1), :].clone(),
|
206 |
+
b.kv_cache.v_cache[:, :, : inputs_embeds.size(1), :].clone(),
|
207 |
+
)
|
208 |
+
for b in self.text.blocks
|
209 |
+
],
|
210 |
+
)
|
211 |
|
212 |
+
def _prefill_prompt(self, prompt_tokens: torch.Tensor, pos: int):
|
213 |
+
with torch.inference_mode():
|
|
|
|
|
214 |
prompt_emb = text_encoder(prompt_tokens, self.text)
|
215 |
+
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
216 |
+
mask = self.attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
|
217 |
+
pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.long)
|
218 |
+
hidden = self._prefill(prompt_emb, mask, pos_ids)
|
219 |
logits = lm_head(hidden, self.text)
|
220 |
next_token = torch.argmax(logits, dim=-1)
|
221 |
pos = pos + prompt_emb.size(1)
|
|
|
224 |
def _generate_text(
|
225 |
self,
|
226 |
prompt_tokens: torch.Tensor,
|
|
|
227 |
pos: int,
|
228 |
max_tokens: int,
|
229 |
):
|
230 |
+
_, _, next_token, pos = self._prefill_prompt(prompt_tokens, pos)
|
|
|
231 |
|
232 |
def generator(next_token, pos):
|
233 |
+
mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
|
234 |
+
mask[:, :, :pos] = 1
|
235 |
+
pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
|
236 |
generated_tokens = 0
|
237 |
|
238 |
+
# For properly handling token streaming with Unicode
|
239 |
+
token_cache = []
|
240 |
+
print_len = 0
|
241 |
+
|
242 |
while (
|
243 |
next_token_id := next_token.item()
|
244 |
) != self.config.tokenizer.eos_id and generated_tokens < max_tokens:
|
245 |
+
# Add token to our cache
|
246 |
+
token_cache.append(next_token_id)
|
247 |
+
|
248 |
+
# Decode all tokens collected so far
|
249 |
+
text = self.tokenizer.decode(token_cache)
|
250 |
+
|
251 |
+
# After a newline, we flush the cache completely
|
252 |
+
if text.endswith("\n"):
|
253 |
+
printable_text = text[print_len:]
|
254 |
+
token_cache = []
|
255 |
+
print_len = 0
|
256 |
+
if printable_text:
|
257 |
+
yield printable_text
|
258 |
+
# If the last token is a CJK character, we can safely print it
|
259 |
+
elif len(text) > 0 and _is_cjk_char(ord(text[-1])):
|
260 |
+
printable_text = text[print_len:]
|
261 |
+
print_len += len(printable_text)
|
262 |
+
if printable_text:
|
263 |
+
yield printable_text
|
264 |
+
# Otherwise, only print up to the last space to avoid cutting words
|
265 |
+
else:
|
266 |
+
last_space_idx = text.rfind(" ", print_len)
|
267 |
+
if last_space_idx >= print_len:
|
268 |
+
printable_text = text[print_len : last_space_idx + 1]
|
269 |
+
print_len += len(printable_text)
|
270 |
+
if printable_text:
|
271 |
+
yield printable_text
|
272 |
+
|
273 |
+
with torch.inference_mode():
|
274 |
next_emb = text_encoder(next_token, self.text)
|
275 |
+
mask[:, :, pos], pos_ids[0] = 1, pos
|
276 |
+
logits, _ = self._decode_one_tok(next_emb, mask, pos_ids)
|
|
|
|
|
|
|
|
|
277 |
pos += 1
|
278 |
next_token = torch.argmax(logits, dim=-1)
|
279 |
generated_tokens += 1
|
280 |
|
281 |
+
# Flush any remaining text in the cache
|
282 |
+
if token_cache:
|
283 |
+
text = self.tokenizer.decode(token_cache)
|
284 |
+
printable_text = text[print_len:]
|
285 |
+
if printable_text:
|
286 |
+
yield printable_text
|
287 |
+
|
288 |
return generator(next_token, pos)
|
289 |
|
290 |
def query(
|
|
|
298 |
raise NotImplementedError("Model does not support querying.")
|
299 |
|
300 |
image = self.encode_image(image)
|
301 |
+
self.load_encoded_image(image)
|
302 |
+
|
303 |
prompt_tokens = torch.tensor(
|
304 |
[
|
305 |
self.config.tokenizer.templates["query"]["prefix"]
|
306 |
+
+ self.tokenizer.encode(" " + question).ids
|
307 |
+ self.config.tokenizer.templates["query"]["suffix"]
|
308 |
],
|
309 |
device=self.device,
|
|
|
314 |
max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS)
|
315 |
|
316 |
def generator():
|
317 |
+
for token in self._generate_text(prompt_tokens, image.pos, max_tokens):
|
|
|
|
|
318 |
yield token
|
319 |
|
320 |
if stream:
|
|
|
322 |
else:
|
323 |
return {"answer": "".join(list(generator()))}
|
324 |
|
325 |
+
def load_encoded_image(self, encoded_image: EncodedImage):
|
326 |
+
for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
|
327 |
+
b.kv_cache.k_cache[:, :, : k.size(2), :] = k
|
328 |
+
b.kv_cache.v_cache[:, :, : v.size(2), :] = v
|
329 |
+
|
330 |
def caption(
|
331 |
self,
|
332 |
image: Union[Image.Image, EncodedImage],
|
333 |
+
length: Literal["normal", "short", "long"] = "normal",
|
334 |
stream: bool = False,
|
335 |
settings: Optional[SamplingSettings] = None,
|
336 |
):
|
|
|
340 |
raise ValueError(f"Model does not support caption length '{length}'.")
|
341 |
|
342 |
image = self.encode_image(image)
|
343 |
+
self.load_encoded_image(image)
|
344 |
+
|
345 |
prompt_tokens = torch.tensor(
|
346 |
[self.config.tokenizer.templates["caption"][length]], device=self.device
|
347 |
)
|
|
|
351 |
max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS)
|
352 |
|
353 |
def generator():
|
354 |
+
for token in self._generate_text(prompt_tokens, image.pos, max_tokens):
|
|
|
|
|
355 |
yield token
|
356 |
|
357 |
if stream:
|
|
|
362 |
def _generate_points(
|
363 |
self,
|
364 |
hidden: torch.Tensor,
|
|
|
365 |
next_token: torch.Tensor,
|
366 |
pos: int,
|
367 |
include_size: bool = True,
|
368 |
max_points: int = 50,
|
369 |
):
|
370 |
out = []
|
371 |
+
mask = torch.zeros(1, 1, 2048, device=self.device, dtype=torch.bool)
|
372 |
+
mask[:, :, :pos] = 1
|
373 |
+
pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
|
374 |
|
375 |
+
with torch.inference_mode():
|
376 |
while (
|
377 |
next_token.item() != self.config.tokenizer.eos_id
|
378 |
and len(out) < max_points
|
|
|
384 |
)
|
385 |
|
386 |
# Decode y-coordinate
|
387 |
+
mask[:, :, pos], pos_ids[0] = 1, pos
|
388 |
+
_, hidden = self._decode_one_tok(next_emb, mask, pos_ids)
|
|
|
|
|
|
|
|
|
389 |
pos += 1
|
390 |
y_logits = decode_coordinate(hidden, self.region)
|
391 |
y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
|
|
|
395 |
|
396 |
# Decode size
|
397 |
if include_size:
|
398 |
+
mask[:, :, pos], pos_ids[0] = 1, pos
|
399 |
+
logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids)
|
|
|
|
|
|
|
|
|
400 |
pos += 1
|
401 |
size_logits = decode_size(hidden, self.region)
|
402 |
+
|
403 |
+
# Get bin indices from the logits
|
404 |
+
w_bin = torch.argmax(size_logits[0], dim=-1)
|
405 |
+
h_bin = torch.argmax(size_logits[1], dim=-1)
|
406 |
+
|
407 |
+
# Convert from bin indices to actual size values using the inverse of the log-scale mapping
|
408 |
+
# Formula: size = 2^((bin / 1023.0) * 10.0 - 10.0)
|
409 |
+
w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)
|
410 |
+
h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)
|
411 |
+
|
412 |
next_emb = encode_size(
|
413 |
torch.tensor(
|
414 |
[w, h], device=self.device, dtype=size_logits.dtype
|
|
|
429 |
out.append({"x": x_center.item(), "y": y_center.item()})
|
430 |
|
431 |
# Decode next token (x-coordinate, or eos)
|
432 |
+
mask[:, :, pos], pos_ids[0] = 1, pos
|
433 |
+
logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids)
|
|
|
|
|
|
|
|
|
434 |
pos += 1
|
435 |
next_token = torch.argmax(logits, dim=-1)
|
436 |
|
|
|
446 |
raise NotImplementedError("Model does not support object detection.")
|
447 |
|
448 |
image = self.encode_image(image)
|
449 |
+
self.load_encoded_image(image)
|
450 |
+
|
451 |
prompt_tokens = torch.tensor(
|
452 |
[
|
453 |
self.config.tokenizer.templates["detect"]["prefix"]
|
454 |
+
+ self.tokenizer.encode(" " + object).ids
|
455 |
+ self.config.tokenizer.templates["detect"]["suffix"]
|
456 |
],
|
457 |
device=self.device,
|
458 |
)
|
459 |
|
460 |
+
_, hidden, next_token, pos = self._prefill_prompt(prompt_tokens, image.pos)
|
|
|
|
|
|
|
461 |
hidden = hidden[:, -1:, :]
|
462 |
|
463 |
objects = self._generate_points(
|
464 |
+
hidden, next_token, pos, include_size=True, max_points=50
|
465 |
)
|
466 |
|
467 |
return {"objects": objects}
|
|
|
476 |
raise NotImplementedError("Model does not support pointing.")
|
477 |
|
478 |
image = self.encode_image(image)
|
479 |
+
self.load_encoded_image(image)
|
480 |
+
|
481 |
prompt_tokens = torch.tensor(
|
482 |
[
|
483 |
self.config.tokenizer.templates["point"]["prefix"]
|
484 |
+
+ self.tokenizer.encode(" " + object).ids
|
485 |
+ self.config.tokenizer.templates["point"]["suffix"]
|
486 |
],
|
487 |
device=self.device,
|
488 |
)
|
489 |
|
490 |
+
_, hidden, next_token, pos = self._prefill_prompt(prompt_tokens, image.pos)
|
|
|
|
|
|
|
491 |
hidden = hidden[:, -1:, :]
|
492 |
|
493 |
objects = self._generate_points(
|
494 |
+
hidden, next_token, pos, include_size=False, max_points=50
|
495 |
)
|
496 |
|
497 |
return {"points": objects}
|
|
|
502 |
source: Tuple[float, float],
|
503 |
force_detect: bool = False,
|
504 |
):
|
505 |
+
with torch.inference_mode():
|
506 |
before_emb = text_encoder(
|
507 |
torch.tensor(
|
508 |
[self.tokenizer.encode("\n\nPoint:").ids], device=self.device
|
|
|
526 |
|
527 |
prompt_emb = torch.cat([before_emb, x_emb, y_emb, after_emb], dim=1)
|
528 |
|
529 |
+
self.load_encoded_image(image)
|
530 |
+
|
531 |
+
mask = self.attn_mask[:, :, image.pos : image.pos + prompt_emb.size(1), :]
|
532 |
+
pos_ids = torch.arange(
|
533 |
+
image.pos, image.pos + prompt_emb.size(1), dtype=torch.long
|
534 |
)
|
535 |
+
hidden = self._prefill(prompt_emb, mask, pos_ids)
|
536 |
logits = lm_head(hidden, self.text)
|
537 |
next_token = torch.argmax(logits, dim=-1)
|
538 |
pos = image.pos + prompt_emb.size(1)
|
|
|
545 |
return None
|
546 |
|
547 |
gaze = self._generate_points(
|
548 |
+
hidden, next_token, pos, include_size=False, max_points=1
|
549 |
)
|
550 |
return gaze[0]
|
551 |
|
|
|
639 |
)
|
640 |
|
641 |
return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}}
|
642 |
+
|
643 |
+
|
644 |
+
def _is_cjk_char(cp):
|
645 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
646 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
647 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
648 |
+
if (
|
649 |
+
(cp >= 0x4E00 and cp <= 0x9FFF)
|
650 |
+
or (cp >= 0x3400 and cp <= 0x4DBF)
|
651 |
+
or (cp >= 0x2F800 and cp <= 0x2FA1F)
|
652 |
+
):
|
653 |
+
return True
|
654 |
+
return False
|
region.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import torch
|
|
|
2 |
import math
|
3 |
|
4 |
-
from .weights import RegionModel
|
5 |
from .layers import linear, mlp
|
6 |
|
7 |
|
@@ -25,7 +25,7 @@ def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
|
25 |
return torch.cat([f.cos(), f.sin()], dim=-1)
|
26 |
|
27 |
|
28 |
-
def encode_coordinate(coord: torch.Tensor, w:
|
29 |
"""
|
30 |
Takes as input a tensor containing a single float coordinate value (x or y)
|
31 |
and encodes it into hidden states for input to the text model.
|
@@ -39,7 +39,7 @@ def encode_coordinate(coord: torch.Tensor, w: RegionModel) -> torch.Tensor:
|
|
39 |
return linear(fourier_features(coord, w.coord_features), w.coord_encoder)
|
40 |
|
41 |
|
42 |
-
def decode_coordinate(hidden_state: torch.Tensor, w:
|
43 |
"""
|
44 |
Takes as input the last hidden state from the text model and outputs a single logit
|
45 |
representing either an x or y coordinate prediction.
|
@@ -53,13 +53,13 @@ def decode_coordinate(hidden_state: torch.Tensor, w: RegionModel) -> torch.Tenso
|
|
53 |
return mlp(hidden_state, w.coord_decoder)
|
54 |
|
55 |
|
56 |
-
def encode_size(size: torch.Tensor, w:
|
57 |
"""
|
58 |
-
Takes a tensor containing
|
59 |
-
|
60 |
|
61 |
Args:
|
62 |
-
size: Tensor with two floats for width and height
|
63 |
|
64 |
Returns:
|
65 |
Encoded hidden states tensor for input to text model
|
@@ -67,16 +67,23 @@ def encode_size(size: torch.Tensor, w: RegionModel) -> torch.Tensor:
|
|
67 |
return linear(fourier_features(size, w.size_features), w.size_encoder)
|
68 |
|
69 |
|
70 |
-
def decode_size(hidden_state: torch.Tensor, w:
|
71 |
"""
|
72 |
-
Takes as input the last hidden state from the text model and outputs
|
73 |
-
for width and height
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
Args:
|
76 |
hidden_state: The final hidden state tensor from the text model.
|
77 |
|
78 |
Returns:
|
79 |
-
A tensor containing
|
80 |
-
|
81 |
"""
|
82 |
return mlp(hidden_state, w.size_decoder).view(2, -1)
|
|
|
1 |
import torch
|
2 |
+
import torch.nn as nn
|
3 |
import math
|
4 |
|
|
|
5 |
from .layers import linear, mlp
|
6 |
|
7 |
|
|
|
25 |
return torch.cat([f.cos(), f.sin()], dim=-1)
|
26 |
|
27 |
|
28 |
+
def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
29 |
"""
|
30 |
Takes as input a tensor containing a single float coordinate value (x or y)
|
31 |
and encodes it into hidden states for input to the text model.
|
|
|
39 |
return linear(fourier_features(coord, w.coord_features), w.coord_encoder)
|
40 |
|
41 |
|
42 |
+
def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
43 |
"""
|
44 |
Takes as input the last hidden state from the text model and outputs a single logit
|
45 |
representing either an x or y coordinate prediction.
|
|
|
53 |
return mlp(hidden_state, w.coord_decoder)
|
54 |
|
55 |
|
56 |
+
def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
57 |
"""
|
58 |
+
Takes a tensor containing width and height values and encodes them into
|
59 |
+
hidden states for input to the text model.
|
60 |
|
61 |
Args:
|
62 |
+
size: Tensor with two floats for width and height
|
63 |
|
64 |
Returns:
|
65 |
Encoded hidden states tensor for input to text model
|
|
|
67 |
return linear(fourier_features(size, w.size_features), w.size_encoder)
|
68 |
|
69 |
|
70 |
+
def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
71 |
"""
|
72 |
+
Takes as input the last hidden state from the text model and outputs logits
|
73 |
+
for 1024 bins representing width and height in log-scale.
|
74 |
+
|
75 |
+
The bins are distributed according to the formula:
|
76 |
+
bin = (log2(size) + 10.0) / 10.0 * 1023.0
|
77 |
+
where size values are clamped to be at least 1/1024.
|
78 |
+
|
79 |
+
To convert from bin back to size:
|
80 |
+
size = 2^((bin / 1023.0) * 10.0 - 10.0)
|
81 |
|
82 |
Args:
|
83 |
hidden_state: The final hidden state tensor from the text model.
|
84 |
|
85 |
Returns:
|
86 |
+
A tensor containing logits for 1024 bins for width and height.
|
87 |
+
Shape is (2, 1024) where the first dimension corresponds to width and height.
|
88 |
"""
|
89 |
return mlp(hidden_state, w.size_decoder).view(2, -1)
|
text.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
|
|
3 |
from torch.nn import functional as F
|
4 |
|
5 |
-
from .layers import layer_norm,
|
6 |
from .rope import apply_rotary_emb, precompute_freqs_cis
|
7 |
-
from .weights import AttentionWeights
|
8 |
from .config import TextConfig
|
9 |
|
10 |
|
@@ -14,106 +14,153 @@ def text_encoder(input_ids: torch.Tensor, w: nn.Module):
|
|
14 |
|
15 |
def attn(
|
16 |
x: torch.Tensor,
|
17 |
-
w:
|
18 |
freqs_cis: torch.Tensor,
|
19 |
-
|
20 |
attn_mask: torch.Tensor,
|
21 |
n_heads: int,
|
22 |
-
|
|
|
23 |
):
|
24 |
bsz, q_len, d_model = x.shape
|
25 |
head_dim = d_model // n_heads
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
-
position_ids = torch.arange(pos, pos + q_len, dtype=torch.long)
|
33 |
q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
|
34 |
-
k = apply_rotary_emb(k, freqs_cis, position_ids,
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask).to(
|
42 |
-
# This type conversion isn't needed when running in PyTorch directly, but the
|
43 |
-
# ONNX export runs attention in float32 because the attention mask is cast to
|
44 |
-
# float32.
|
45 |
-
x.dtype
|
46 |
)
|
47 |
out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
|
48 |
-
out =
|
49 |
-
return out
|
50 |
|
51 |
|
52 |
-
def
|
53 |
-
|
54 |
-
w:
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
58 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
hidden_BTC = inputs_embeds
|
60 |
-
new_kv_cache = [torch.empty(0)] * len(w.blocks)
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
]
|
|
|
|
|
|
|
65 |
|
66 |
for i, block in enumerate(w.blocks):
|
67 |
l_in = layer_norm(hidden_BTC, block.ln)
|
68 |
-
l_attn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
l_in,
|
70 |
block.attn,
|
71 |
freqs_cis=w.freqs_cis,
|
72 |
-
|
73 |
attn_mask=attn_mask,
|
74 |
n_heads=config.n_heads,
|
75 |
-
|
|
|
76 |
)
|
77 |
l_mlp = mlp(l_in, block.mlp)
|
78 |
-
|
79 |
|
80 |
-
return
|
81 |
|
82 |
|
83 |
def lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
|
84 |
hidden_BC = hidden_BTC[:, -1, :]
|
85 |
hidden_BC = layer_norm(hidden_BC, w.post_ln)
|
86 |
-
logits =
|
87 |
return logits
|
88 |
|
89 |
|
90 |
-
def
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
w: nn.Module,
|
95 |
-
config: TextConfig,
|
96 |
-
):
|
97 |
-
# Updates kv_cache in-place
|
98 |
-
hidden, kv_cache[:, :, :, :, pos : pos + inputs_embeds.size(1), :] = text_decoder(
|
99 |
-
inputs_embeds, w, kv_cache, pos, config
|
100 |
-
)
|
101 |
-
return hidden
|
102 |
-
|
103 |
-
|
104 |
-
def decode_one_token(
|
105 |
-
token_emb: torch.Tensor,
|
106 |
-
kv_cache: torch.Tensor,
|
107 |
-
pos: int,
|
108 |
-
w: nn.Module,
|
109 |
-
config: TextConfig,
|
110 |
-
):
|
111 |
-
hidden, kv_cache_update = text_decoder(token_emb[None], w, kv_cache, pos, config)
|
112 |
-
logits = lm_head(hidden, w)
|
113 |
-
return logits, hidden, kv_cache_update
|
114 |
|
115 |
|
116 |
def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
|
|
|
|
|
117 |
text = nn.ModuleDict(
|
118 |
{
|
119 |
"blocks": nn.ModuleList(
|
@@ -123,9 +170,7 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
|
|
123 |
"ln": nn.LayerNorm(config.dim, dtype=dtype),
|
124 |
"attn": nn.ModuleDict(
|
125 |
{
|
126 |
-
"qkv": nn.Linear(
|
127 |
-
config.dim, 3 * config.dim, dtype=dtype
|
128 |
-
),
|
129 |
"proj": nn.Linear(
|
130 |
config.dim, config.dim, dtype=dtype
|
131 |
),
|
@@ -134,10 +179,10 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
|
|
134 |
"mlp": nn.ModuleDict(
|
135 |
{
|
136 |
"fc1": nn.Linear(
|
137 |
-
config.dim,
|
138 |
),
|
139 |
"fc2": nn.Linear(
|
140 |
-
|
141 |
),
|
142 |
}
|
143 |
),
|
@@ -157,11 +202,4 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
|
|
157 |
persistent=False,
|
158 |
)
|
159 |
|
160 |
-
attn_mask = torch.tril(
|
161 |
-
torch.ones(1, 1, config.max_context, config.max_context, dtype=torch.bool)
|
162 |
-
)
|
163 |
-
if config.prefix_attn != 0:
|
164 |
-
attn_mask[..., : config.prefix_attn, : config.prefix_attn] = 1
|
165 |
-
text.register_buffer("attn_mask", attn_mask, persistent=False)
|
166 |
-
|
167 |
return text
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
+
|
4 |
from torch.nn import functional as F
|
5 |
|
6 |
+
from .layers import layer_norm, mlp
|
7 |
from .rope import apply_rotary_emb, precompute_freqs_cis
|
|
|
8 |
from .config import TextConfig
|
9 |
|
10 |
|
|
|
14 |
|
15 |
def attn(
|
16 |
x: torch.Tensor,
|
17 |
+
w: nn.Module,
|
18 |
freqs_cis: torch.Tensor,
|
19 |
+
kv_cache: nn.Module,
|
20 |
attn_mask: torch.Tensor,
|
21 |
n_heads: int,
|
22 |
+
n_kv_heads: int,
|
23 |
+
position_ids: torch.Tensor,
|
24 |
):
|
25 |
bsz, q_len, d_model = x.shape
|
26 |
head_dim = d_model // n_heads
|
27 |
|
28 |
+
qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
|
29 |
+
q_dim = n_heads * head_dim
|
30 |
+
kv_dim = n_kv_heads * head_dim
|
31 |
+
|
32 |
+
q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
|
33 |
+
k = (
|
34 |
+
qkv_out[..., q_dim : q_dim + kv_dim]
|
35 |
+
.view(bsz, q_len, n_kv_heads, head_dim)
|
36 |
+
.transpose(1, 2)
|
37 |
+
)
|
38 |
+
v = (
|
39 |
+
qkv_out[..., q_dim + kv_dim :]
|
40 |
+
.view(bsz, q_len, n_kv_heads, head_dim)
|
41 |
+
.transpose(1, 2)
|
42 |
+
)
|
43 |
|
|
|
44 |
q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
|
45 |
+
k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
|
46 |
+
|
47 |
+
if kv_cache is not None:
|
48 |
+
k, v = kv_cache.update(position_ids, k, v)
|
49 |
+
|
50 |
+
out = F.scaled_dot_product_attention(
|
51 |
+
q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
|
|
|
|
|
|
|
|
|
|
|
52 |
)
|
53 |
out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
|
54 |
+
out = w.proj(out)
|
55 |
+
return out
|
56 |
|
57 |
|
58 |
+
def _attn(
|
59 |
+
x: torch.Tensor,
|
60 |
+
w: torch.Tensor,
|
61 |
+
freqs_cis: torch.Tensor,
|
62 |
+
attn_mask: torch.Tensor,
|
63 |
+
n_heads: int,
|
64 |
+
n_kv_heads: int,
|
65 |
):
|
66 |
+
bsz, q_len, d_model = x.shape
|
67 |
+
head_dim = d_model // n_heads
|
68 |
+
pos = 0
|
69 |
+
|
70 |
+
qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
|
71 |
+
q_dim = n_heads * head_dim
|
72 |
+
kv_dim = n_kv_heads * head_dim
|
73 |
+
|
74 |
+
q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
|
75 |
+
k = (
|
76 |
+
qkv_out[..., q_dim : q_dim + kv_dim]
|
77 |
+
.view(bsz, q_len, n_kv_heads, head_dim)
|
78 |
+
.transpose(1, 2)
|
79 |
+
)
|
80 |
+
v = (
|
81 |
+
qkv_out[..., q_dim + kv_dim :]
|
82 |
+
.view(bsz, q_len, n_kv_heads, head_dim)
|
83 |
+
.transpose(1, 2)
|
84 |
+
)
|
85 |
+
|
86 |
+
position_ids = torch.arange(pos, pos + q_len, dtype=torch.long)
|
87 |
+
q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
|
88 |
+
k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
|
89 |
+
out = F.scaled_dot_product_attention(
|
90 |
+
q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
|
91 |
+
)
|
92 |
+
out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
|
93 |
+
out = w.proj(out)
|
94 |
+
return out
|
95 |
+
|
96 |
+
|
97 |
+
def _produce_hidden(inputs_embeds: torch.Tensor, w: nn.Module, config: TextConfig):
|
98 |
hidden_BTC = inputs_embeds
|
|
|
99 |
|
100 |
+
bsz, q_len, d_model = inputs_embeds.shape
|
101 |
+
attn_mask = torch.zeros(q_len, q_len)
|
102 |
+
attn_mask[:730, :730] = 1
|
103 |
+
for i in range(730, q_len):
|
104 |
+
attn_mask[i, : i + 1] = 1
|
105 |
+
attn_mask = attn_mask.to(dtype=torch.bool)
|
106 |
|
107 |
for i, block in enumerate(w.blocks):
|
108 |
l_in = layer_norm(hidden_BTC, block.ln)
|
109 |
+
l_attn = _attn(
|
110 |
+
x=l_in,
|
111 |
+
w=block.attn,
|
112 |
+
freqs_cis=w.freqs_cis,
|
113 |
+
attn_mask=attn_mask,
|
114 |
+
n_heads=config.n_heads,
|
115 |
+
n_kv_heads=config.n_kv_heads,
|
116 |
+
)
|
117 |
+
l_mlp = mlp(l_in, block.mlp)
|
118 |
+
hidden_BTC = hidden_BTC + l_attn + l_mlp
|
119 |
+
|
120 |
+
return hidden_BTC
|
121 |
+
|
122 |
+
|
123 |
+
def text_decoder(
|
124 |
+
x: torch.Tensor,
|
125 |
+
w: nn.Module,
|
126 |
+
attn_mask: torch.Tensor,
|
127 |
+
position_ids: torch.Tensor,
|
128 |
+
config: TextConfig,
|
129 |
+
):
|
130 |
+
for i, block in enumerate(w.blocks):
|
131 |
+
l_in = layer_norm(x, block.ln)
|
132 |
+
l_attn = attn(
|
133 |
l_in,
|
134 |
block.attn,
|
135 |
freqs_cis=w.freqs_cis,
|
136 |
+
kv_cache=block.kv_cache,
|
137 |
attn_mask=attn_mask,
|
138 |
n_heads=config.n_heads,
|
139 |
+
n_kv_heads=config.n_kv_heads,
|
140 |
+
position_ids=position_ids,
|
141 |
)
|
142 |
l_mlp = mlp(l_in, block.mlp)
|
143 |
+
x = x + l_attn + l_mlp
|
144 |
|
145 |
+
return x
|
146 |
|
147 |
|
148 |
def lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
|
149 |
hidden_BC = hidden_BTC[:, -1, :]
|
150 |
hidden_BC = layer_norm(hidden_BC, w.post_ln)
|
151 |
+
logits = w.lm_head(hidden_BC)
|
152 |
return logits
|
153 |
|
154 |
|
155 |
+
def _lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
|
156 |
+
hidden_BTC = layer_norm(hidden_BTC, w.post_ln)
|
157 |
+
logits = w.lm_head(hidden_BTC)
|
158 |
+
return logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
|
161 |
def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
|
162 |
+
qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
|
163 |
+
|
164 |
text = nn.ModuleDict(
|
165 |
{
|
166 |
"blocks": nn.ModuleList(
|
|
|
170 |
"ln": nn.LayerNorm(config.dim, dtype=dtype),
|
171 |
"attn": nn.ModuleDict(
|
172 |
{
|
173 |
+
"qkv": nn.Linear(config.dim, qkv_dim, dtype=dtype),
|
|
|
|
|
174 |
"proj": nn.Linear(
|
175 |
config.dim, config.dim, dtype=dtype
|
176 |
),
|
|
|
179 |
"mlp": nn.ModuleDict(
|
180 |
{
|
181 |
"fc1": nn.Linear(
|
182 |
+
config.dim, config.ff_dim, dtype=dtype
|
183 |
),
|
184 |
"fc2": nn.Linear(
|
185 |
+
config.ff_dim, config.dim, dtype=dtype
|
186 |
),
|
187 |
}
|
188 |
),
|
|
|
202 |
persistent=False,
|
203 |
)
|
204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
return text
|
vision.py
CHANGED
@@ -4,7 +4,6 @@ import torch.nn.functional as F
|
|
4 |
import numpy as np
|
5 |
|
6 |
from typing import Union, Tuple
|
7 |
-
from einops import rearrange
|
8 |
from PIL import Image
|
9 |
|
10 |
from .layers import attn, layer_norm, linear, mlp
|
@@ -42,13 +41,28 @@ def prepare_crops(
|
|
42 |
return all_crops, overlap_crops["tiling"]
|
43 |
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
|
46 |
-
x =
|
47 |
-
input_BCHW,
|
48 |
-
"b c (h p1) (w p2) -> b (h w) (c p1 p2)",
|
49 |
-
p1=config.enc_patch_size,
|
50 |
-
p2=config.enc_patch_size,
|
51 |
-
) # B3HW -> B(HxW)(3xP1xP2), aka BTC
|
52 |
|
53 |
x = linear(x, w.patch_emb)
|
54 |
x = x + w.pos_emb
|
|
|
4 |
import numpy as np
|
5 |
|
6 |
from typing import Union, Tuple
|
|
|
7 |
from PIL import Image
|
8 |
|
9 |
from .layers import attn, layer_norm, linear, mlp
|
|
|
41 |
return all_crops, overlap_crops["tiling"]
|
42 |
|
43 |
|
44 |
+
def create_patches(x, patch_size):
|
45 |
+
# Original shape: [B, C, H, W]
|
46 |
+
B, C, H, W = x.shape
|
47 |
+
P1 = P2 = patch_size
|
48 |
+
|
49 |
+
# Step 1: Split H and W dimensions into patches
|
50 |
+
# [B, C, H/P1, P1, W/P2, P2]
|
51 |
+
x = x.reshape(B, C, H // P1, P1, W // P2, P2)
|
52 |
+
|
53 |
+
# Step 2: Rearrange dimensions to match target shape
|
54 |
+
# [B, H/P1, W/P2, C, P1, P2]
|
55 |
+
x = x.permute(0, 2, 4, 1, 3, 5)
|
56 |
+
|
57 |
+
# Step 3: Combine dimensions to get final shape
|
58 |
+
# [B, (H/P1)*(W/P2), C*P1*P2]
|
59 |
+
x = x.reshape(B, (H // P1) * (W // P2), C * P1 * P2)
|
60 |
+
|
61 |
+
return x
|
62 |
+
|
63 |
+
|
64 |
def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
|
65 |
+
x = create_patches(input_BCHW, config.enc_patch_size)
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
x = linear(x, w.patch_emb)
|
68 |
x = x + w.pos_emb
|