vikhyatk commited on
Commit
05d640e
·
verified ·
1 Parent(s): 97da4c7

Upload HfMoondream

Browse files
Files changed (14) hide show
  1. config.json +4 -6
  2. config.py +83 -0
  3. generation_config.json +0 -2
  4. hf_moondream.py +123 -0
  5. image_crops.py +208 -0
  6. layers.py +63 -0
  7. model.safetensors +2 -2
  8. moondream.py +535 -179
  9. region.py +82 -0
  10. rope.py +48 -0
  11. text.py +167 -0
  12. utils.py +41 -0
  13. vision.py +133 -0
  14. weights.py +292 -0
config.json CHANGED
@@ -1,15 +1,13 @@
1
  {
2
  "architectures": [
3
- "Moondream"
4
  ],
5
  "auto_map": {
6
- "AutoConfig": "configuration_moondream.MoondreamConfig",
7
- "AutoModelForCausalLM": "moondream.Moondream"
8
  },
 
9
  "model_type": "moondream1",
10
- "text_config": {
11
- "model_type": "phi"
12
- },
13
  "torch_dtype": "float16",
14
  "transformers_version": "4.44.0"
15
  }
 
1
  {
2
  "architectures": [
3
+ "HfMoondream"
4
  ],
5
  "auto_map": {
6
+ "AutoConfig": "hf_moondream.HfConfig",
7
+ "AutoModelForCausalLM": "hf_moondream.HfMoondream"
8
  },
9
+ "config": {},
10
  "model_type": "moondream1",
 
 
 
11
  "torch_dtype": "float16",
12
  "transformers_version": "4.44.0"
13
  }
config.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, List, Optional
3
+
4
+
5
+ @dataclass(frozen=True)
6
+ class TextConfig:
7
+ dim: int = 2048
8
+ n_layers: int = 24
9
+ vocab_size: int = 51200
10
+ max_context: int = 2048
11
+ n_heads: int = 32
12
+ prefix_attn: int = 730
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class VisionConfig:
17
+ enc_dim: int = 1152
18
+ enc_patch_size: int = 14
19
+ enc_n_layers: int = 27
20
+ enc_ff_dim: int = 4304
21
+ enc_n_heads: int = 16
22
+ proj_out_dim: int = 2048
23
+ crop_size: int = 378
24
+ in_channels: int = 3
25
+ max_crops: int = 12
26
+ overlap_margin: int = 4
27
+ proj_inner_dim: int = 8192
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class RegionConfig:
32
+ dim: int = 2048
33
+ coord_feat_dim: int = 256
34
+ coord_out_dim: int = 1024
35
+ size_feat_dim: int = 512
36
+ size_out_dim: int = 2048
37
+ inner_dim: int = 8192
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class TokenizerConfig:
42
+ bos_id: int = 50256
43
+ eos_id: int = 50256
44
+ templates: Dict[str, Optional[Dict[str, List[int]]]] = field(
45
+ default_factory=lambda: {
46
+ "caption": {
47
+ "short": [198, 198, 16438, 8305, 25],
48
+ "normal": [198, 198, 24334, 1159, 25],
49
+ },
50
+ "query": {"prefix": [198, 198, 24361, 25], "suffix": [198, 198, 33706, 25]},
51
+ "detect": {"prefix": [198, 198, 47504, 25], "suffix": [628]},
52
+ "point": {"prefix": [198, 198, 12727, 25], "suffix": [628]},
53
+ }
54
+ )
55
+
56
+
57
+ @dataclass(frozen=True)
58
+ class MoondreamConfig:
59
+ text: TextConfig = TextConfig()
60
+ vision: VisionConfig = VisionConfig()
61
+ region: RegionConfig = RegionConfig()
62
+ tokenizer: TokenizerConfig = TokenizerConfig()
63
+
64
+ @classmethod
65
+ def from_dict(cls, config_dict: dict):
66
+ text_config = TextConfig(**config_dict.get("text", {}))
67
+ vision_config = VisionConfig(**config_dict.get("vision", {}))
68
+ region_config = RegionConfig(**config_dict.get("region", {}))
69
+ tokenizer_config = TokenizerConfig(**config_dict.get("tokenizer", {}))
70
+ return cls(
71
+ text=text_config,
72
+ vision=vision_config,
73
+ region=region_config,
74
+ tokenizer=tokenizer_config,
75
+ )
76
+
77
+ def to_dict(self):
78
+ return {
79
+ "text": self.text.__dict__,
80
+ "vision": self.vision.__dict__,
81
+ "region": self.region.__dict__,
82
+ "tokenizer": self.tokenizer.__dict__,
83
+ }
generation_config.json CHANGED
@@ -1,6 +1,4 @@
1
  {
2
  "_from_model_config": true,
3
- "bos_token_id": 1,
4
- "eos_token_id": 2,
5
  "transformers_version": "4.44.0"
6
  }
 
1
  {
2
  "_from_model_config": true,
 
 
3
  "transformers_version": "4.44.0"
4
  }
hf_moondream.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+
3
+ from .config import MoondreamConfig
4
+ from .moondream import MoondreamModel
5
+
6
+ # Files sometimes don't get loaded without these...
7
+ from .image_crops import *
8
+ from .vision import *
9
+ from .text import *
10
+ from .region import *
11
+ from .utils import *
12
+
13
+
14
+ def extract_question(text):
15
+ prefix = "<image>\n\nQuestion: "
16
+ suffix = "\n\nAnswer:"
17
+
18
+ if text.startswith(prefix) and text.endswith(suffix):
19
+ return text[len(prefix) : -len(suffix)]
20
+ else:
21
+ return None
22
+
23
+
24
+ class HfConfig(PretrainedConfig):
25
+ _auto_class = "AutoConfig"
26
+ model_type = "moondream1"
27
+
28
+ def __init__(self, **kwargs):
29
+ super().__init__(**kwargs)
30
+ self.config = {}
31
+
32
+
33
+ class HfMoondream(PreTrainedModel):
34
+ _auto_class = "AutoModelForCausalLM"
35
+ config_class = HfConfig
36
+
37
+ def __init__(self, config):
38
+ super().__init__(config)
39
+ self.model = MoondreamModel(MoondreamConfig.from_dict(config.config))
40
+
41
+ @property
42
+ def encode_image(self):
43
+ return self.model.encode_image
44
+
45
+ @property
46
+ def query(self):
47
+ return self.model.query
48
+
49
+ @property
50
+ def caption(self):
51
+ return self.model.caption
52
+
53
+ @property
54
+ def detect(self):
55
+ return self.model.detect
56
+
57
+ @property
58
+ def point(self):
59
+ return self.model.point
60
+
61
+ @property
62
+ def detect_gaze(self):
63
+ return self.model.detect_gaze
64
+
65
+ def answer_question(
66
+ self,
67
+ image_embeds,
68
+ question,
69
+ tokenizer=None,
70
+ chat_history="",
71
+ result_queue=None,
72
+ max_new_tokens=256,
73
+ **kwargs
74
+ ):
75
+ answer = self.query(image_embeds, question)["answer"].strip()
76
+
77
+ if result_queue is not None:
78
+ result_queue.put(answer)
79
+ return answer
80
+
81
+ def batch_answer(self, images, prompts, tokenizer=None, **kwargs):
82
+ answers = []
83
+ for image, prompt in zip(images, prompts):
84
+ answers.append(self.query(image, prompt)["answer"].strip())
85
+ return answers
86
+
87
+ def _unsupported_exception(self):
88
+ raise NotImplementedError(
89
+ "This method is not supported in the latest version of moondream. "
90
+ "Consider upgrading to the updated API spec, or alternately pin "
91
+ "to 'revision=2024-08-26'."
92
+ )
93
+
94
+ def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs):
95
+ """
96
+ Function definition remains unchanged for backwards compatibility.
97
+ Be aware that tokenizer, max_new_takens, and kwargs are ignored.
98
+ """
99
+ prompt_extracted = extract_question(prompt)
100
+ if prompt_extracted is not None:
101
+ answer = self.model.query(image=image_embeds, question=prompt_extracted, stream=False)[
102
+ "answer"
103
+ ]
104
+ else:
105
+ image_embeds = self.encode_image(image_embeds)
106
+ prompt_tokens = torch.tensor(
107
+ [self.model.tokenizer.encode(prompt).ids],
108
+ device=self.device,
109
+ )
110
+ def generator():
111
+ for token in self.model._generate_text(
112
+ prompt_tokens, image_embeds.kv_cache, image_embeds.pos, max_new_tokens
113
+ ):
114
+ yield token
115
+ answer = "".join(list(generator()))
116
+
117
+ return [answer]
118
+
119
+ def get_input_embeddings(self):
120
+ return super().get_input_embeddings()
121
+
122
+ def input_embeds(self, *args, **kwargs):
123
+ self._unsupported_exception()
image_crops.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import pyvips
5
+
6
+ from typing import TypedDict
7
+
8
+
9
+ def select_tiling(
10
+ height: int, width: int, crop_size: int, max_crops: int
11
+ ) -> tuple[int, int]:
12
+ """
13
+ Determine the optimal number of tiles to cover an image with overlapping crops.
14
+ """
15
+ if height <= crop_size or width <= crop_size:
16
+ return (1, 1)
17
+
18
+ # Minimum required tiles in each dimension
19
+ min_h = math.ceil(height / crop_size)
20
+ min_w = math.ceil(width / crop_size)
21
+
22
+ # If minimum required tiles exceed max_crops, return proportional distribution
23
+ if min_h * min_w > max_crops:
24
+ ratio = math.sqrt(max_crops / (min_h * min_w))
25
+ return (max(1, math.floor(min_h * ratio)), max(1, math.floor(min_w * ratio)))
26
+
27
+ # Perfect aspect-ratio tiles that satisfy max_crops
28
+ h_tiles = math.floor(math.sqrt(max_crops * height / width))
29
+ w_tiles = math.floor(math.sqrt(max_crops * width / height))
30
+
31
+ # Ensure we meet minimum tile requirements
32
+ h_tiles = max(h_tiles, min_h)
33
+ w_tiles = max(w_tiles, min_w)
34
+
35
+ # If we exceeded max_crops, scale down the larger dimension
36
+ if h_tiles * w_tiles > max_crops:
37
+ if w_tiles > h_tiles:
38
+ w_tiles = math.floor(max_crops / h_tiles)
39
+ else:
40
+ h_tiles = math.floor(max_crops / w_tiles)
41
+
42
+ return (max(1, h_tiles), max(1, w_tiles))
43
+
44
+
45
+ class OverlapCropOutput(TypedDict):
46
+ crops: np.ndarray
47
+ tiling: tuple[int, int]
48
+
49
+
50
+ def overlap_crop_image(
51
+ image: np.ndarray,
52
+ overlap_margin: int,
53
+ max_crops: int,
54
+ base_size: tuple[int, int] = (378, 378),
55
+ patch_size: int = 14,
56
+ ) -> OverlapCropOutput:
57
+ """
58
+ Process an image using an overlap-and-resize cropping strategy with margin handling.
59
+
60
+ This function takes an input image and creates multiple overlapping crops with
61
+ consistent margins. It produces:
62
+ 1. A single global crop resized to base_size
63
+ 2. Multiple overlapping local crops that maintain high resolution details
64
+ 3. A patch ordering matrix that tracks correspondence between crops
65
+
66
+ The overlap strategy ensures:
67
+ - Smooth transitions between adjacent crops
68
+ - No loss of information at crop boundaries
69
+ - Proper handling of features that cross crop boundaries
70
+ - Consistent patch indexing across the full image
71
+
72
+ Args:
73
+ image (np.ndarray): Input image as numpy array with shape (H,W,C)
74
+ base_size (tuple[int,int]): Target size for crops, default (378,378)
75
+ patch_size (int): Size of patches in pixels, default 14
76
+ overlap_margin (int): Margin size in patch units, default 4
77
+ max_crops (int): Maximum number of crops allowed, default 12
78
+
79
+ Returns:
80
+ OverlapCropOutput: Dictionary containing:
81
+ - crops: A numpy array containing the global crop of the full image (index 0)
82
+ followed by the overlapping cropped regions (indices 1+)
83
+ - tiling: Tuple of (height,width) tile counts
84
+ """
85
+ original_h, original_w = image.shape[:2]
86
+
87
+ # Convert margin from patch units to pixels
88
+ margin_pixels = patch_size * overlap_margin
89
+ total_margin_pixels = margin_pixels * 2 # Both sides
90
+
91
+ # Calculate crop parameters
92
+ crop_patches = base_size[0] // patch_size # patches per crop dimension
93
+ crop_window_patches = crop_patches - (2 * overlap_margin) # usable patches
94
+ crop_window_size = crop_window_patches * patch_size # usable size in pixels
95
+
96
+ # Determine tiling
97
+ tiling = select_tiling(
98
+ original_h - total_margin_pixels,
99
+ original_w - total_margin_pixels,
100
+ crop_window_size,
101
+ max_crops,
102
+ )
103
+
104
+ # Pre-allocate crops.
105
+ n_crops = tiling[0] * tiling[1] + 1 # 1 = global crop
106
+ crops = np.zeros(
107
+ (n_crops, base_size[0], base_size[1], image.shape[2]), dtype=np.uint8
108
+ )
109
+
110
+ # Resize image to fit tiling
111
+ target_size = (
112
+ tiling[0] * crop_window_size + total_margin_pixels,
113
+ tiling[1] * crop_window_size + total_margin_pixels,
114
+ )
115
+
116
+ # Convert to vips for resizing
117
+ vips_image = pyvips.Image.new_from_array(image)
118
+ scale_x = target_size[1] / image.shape[1]
119
+ scale_y = target_size[0] / image.shape[0]
120
+ resized = vips_image.resize(scale_x, vscale=scale_y)
121
+ image = resized.numpy()
122
+
123
+ # Create global crop
124
+ scale_x = base_size[1] / vips_image.width
125
+ scale_y = base_size[0] / vips_image.height
126
+ global_vips = vips_image.resize(scale_x, vscale=scale_y)
127
+ crops[0] = global_vips.numpy()
128
+
129
+ for i in range(tiling[0]):
130
+ for j in range(tiling[1]):
131
+ # Calculate crop coordinates
132
+ y0 = i * crop_window_size
133
+ x0 = j * crop_window_size
134
+
135
+ # Extract crop with padding if needed
136
+ y_end = min(y0 + base_size[0], image.shape[0])
137
+ x_end = min(x0 + base_size[1], image.shape[1])
138
+
139
+ crop_region = image[y0:y_end, x0:x_end]
140
+ crops[
141
+ 1 + i * tiling[1] + j, : crop_region.shape[0], : crop_region.shape[1]
142
+ ] = crop_region
143
+
144
+ return {"crops": crops, "tiling": tiling}
145
+
146
+
147
+ def reconstruct_from_crops(
148
+ crops: torch.Tensor,
149
+ tiling: tuple[int, int],
150
+ overlap_margin: int,
151
+ patch_size: int = 14,
152
+ ) -> torch.Tensor:
153
+ """
154
+ Reconstruct the original image from overlapping crops into a single seamless image.
155
+
156
+ Takes a list of overlapping image crops along with their positional metadata and
157
+ reconstructs them into a single coherent image by carefully stitching together
158
+ non-overlapping regions. Handles both numpy arrays and PyTorch tensors.
159
+
160
+ Args:
161
+ crops: List of image crops as numpy arrays or PyTorch tensors with shape
162
+ (H,W,C)
163
+ tiling: Tuple of (height,width) indicating crop grid layout
164
+ patch_size: Size in pixels of each patch, default 14
165
+ overlap_margin: Number of overlapping patches on each edge, default 4
166
+
167
+ Returns:
168
+ Reconstructed image as numpy array or PyTorch tensor matching input type,
169
+ with shape (H,W,C) where H,W are the original image dimensions
170
+ """
171
+ tiling_h, tiling_w = tiling
172
+ crop_height, crop_width = crops[0].shape[:2]
173
+ margin_pixels = overlap_margin * patch_size
174
+
175
+ # Calculate output size (only adding margins once)
176
+ output_h = (crop_height - 2 * margin_pixels) * tiling_h + 2 * margin_pixels
177
+ output_w = (crop_width - 2 * margin_pixels) * tiling_w + 2 * margin_pixels
178
+
179
+ reconstructed = torch.zeros(
180
+ (output_h, output_w, crops[0].shape[2]),
181
+ device=crops[0].device,
182
+ dtype=crops[0].dtype,
183
+ )
184
+
185
+ for i, crop in enumerate(crops):
186
+ tile_y = i // tiling_w
187
+ tile_x = i % tiling_w
188
+
189
+ # For each tile, determine which part to keep
190
+ # Keep left margin only for first column
191
+ x_start = 0 if tile_x == 0 else margin_pixels
192
+ # Keep right margin only for last column
193
+ x_end = crop_width if tile_x == tiling_w - 1 else crop_width - margin_pixels
194
+ # Keep top margin only for first row
195
+ y_start = 0 if tile_y == 0 else margin_pixels
196
+ # Keep bottom margin only for last row
197
+ y_end = crop_height if tile_y == tiling_h - 1 else crop_height - margin_pixels
198
+
199
+ # Calculate where this piece belongs in the output
200
+ out_x = tile_x * (crop_width - 2 * margin_pixels)
201
+ out_y = tile_y * (crop_height - 2 * margin_pixels)
202
+
203
+ # Place the piece
204
+ reconstructed[
205
+ out_y + y_start : out_y + y_end, out_x + x_start : out_x + x_end
206
+ ] = crop[y_start:y_end, x_start:x_end]
207
+
208
+ return reconstructed
layers.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Literal
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def gelu_approx(x):
9
+ return F.gelu(x, approximate="tanh")
10
+
11
+
12
+ @dataclass
13
+ class LinearWeights:
14
+ weight: torch.Tensor
15
+ bias: torch.Tensor
16
+
17
+
18
+ def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
19
+ return F.linear(x, w.weight, w.bias)
20
+
21
+
22
+ @dataclass
23
+ class LayerNormWeights:
24
+ weight: torch.Tensor
25
+ bias: torch.Tensor
26
+
27
+
28
+ def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor:
29
+ return F.layer_norm(x, w.bias.shape, w.weight, w.bias)
30
+
31
+
32
+ @dataclass
33
+ class MLPWeights:
34
+ fc1: LinearWeights
35
+ fc2: LinearWeights
36
+ act: Literal["gelu_approx"] = "gelu_approx"
37
+
38
+
39
+ def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
40
+ x = linear(x, w.fc1)
41
+ x = gelu_approx(x)
42
+ x = linear(x, w.fc2)
43
+ return x
44
+
45
+
46
+ @dataclass
47
+ class AttentionWeights:
48
+ qkv: LinearWeights
49
+ proj: LinearWeights
50
+
51
+
52
+ def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor:
53
+ bsz, q_len, d_model = x.shape
54
+ head_dim = d_model // n_heads
55
+
56
+ q, k, v = [
57
+ t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
58
+ for t in linear(x, w.qkv).chunk(3, dim=-1)
59
+ ]
60
+ out = F.scaled_dot_product_attention(q, k, v)
61
+ out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
62
+ out = linear(out, w.proj)
63
+ return out
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4bf7aed8ba4325d23fa7cd348d795a27f3b272682536f08aca4cdd62cde79293
3
- size 3736040266
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23e2e6498a058d12832e119dc97a1d2f14936b4ccf77b8492bc0fefba49ea8bb
3
+ size 3854538376
moondream.py CHANGED
@@ -1,230 +1,586 @@
1
  import torch
 
 
2
 
3
- from typing import List, Union, Literal, Optional
4
- from transformers import PreTrainedModel
5
  from PIL import Image
 
 
6
 
7
- from .configuration_moondream import PhiConfig
8
- from .configuration_moondream import MoondreamConfig
9
- from .vision_encoder import VisionEncoder
10
- from .region_model import RegionModel
11
- from .modeling_phi import PhiForCausalLM
 
12
 
13
- class Moondream(PreTrainedModel):
14
- config_class = MoondreamConfig
15
- _supports_flash_attn_2 = True
16
 
17
- def __init__(self, config):
18
- super().__init__(config)
19
- self.vision_encoder = VisionEncoder(
20
- use_flash_attn=config._attn_implementation == "flash_attention_2"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  )
22
- self.region_model = RegionModel()
 
23
 
24
- if type(config.text_config) == dict:
25
- phi_config = PhiConfig(
26
- **config.text_config, attn_implementation=config._attn_implementation
27
- )
28
- else:
29
- phi_config = config.text_config
30
- self.text_model = PhiForCausalLM(phi_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  @property
33
  def device(self):
34
- return self.text_model.device
35
 
36
- def encode_image(self, image):
37
- with torch.no_grad():
38
- return self.vision_encoder(image)
 
 
 
 
 
 
 
 
 
39
 
40
- def input_embeds(self, prompt, image_embeds, tokenizer):
41
- def _tokenize(txt):
42
- return tokenizer(
43
- txt, return_tensors="pt", add_special_tokens=False
44
- ).input_ids.to(self.device)
45
 
46
- text_emb = self.text_model.get_input_embeddings()
47
 
48
- # Add BOS token
49
- embeds = []
50
- embeds.append(
51
- text_emb((torch.tensor([[tokenizer.bos_token_id]], device=self.device)))
 
 
52
  )
53
 
54
- if "<image>" not in prompt:
55
- embeds.append(text_emb(_tokenize(prompt)))
56
- else:
57
- assert prompt.count("<image>") == 1
58
- before, after = prompt.split("<image>")
59
- if len(before) > 0:
60
- embeds.append(text_emb(_tokenize(before)))
61
- embeds.append(image_embeds.to(self.device))
62
- if len(after) > 0:
63
- embeds.append(text_emb(_tokenize(after)))
64
 
65
- return torch.cat(embeds, dim=1)
 
 
66
 
67
- def get_input_embeddings(self):
68
- return self.text_model.get_input_embeddings()
 
 
 
69
 
70
- def generate(
71
- self,
72
- image_embeds,
73
- prompt,
74
- tokenizer,
75
- max_new_tokens=128,
76
- **kwargs,
77
- ):
78
- generate_config = {
79
- "eos_token_id": tokenizer.eos_token_id,
80
- "bos_token_id": tokenizer.bos_token_id,
81
- "pad_token_id": tokenizer.bos_token_id,
82
- "max_new_tokens": max_new_tokens,
83
- **kwargs,
84
- }
 
 
 
 
 
 
85
 
 
 
 
86
  with torch.no_grad():
87
- inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
88
- attention_mask = torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), device=self.device)
89
- output_ids = self.text_model.generate(
90
- inputs_embeds=inputs_embeds,
91
- attention_mask=attention_mask,
92
- **generate_config,
93
  )
 
 
 
 
94
 
95
- return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
 
 
 
 
 
 
 
 
96
 
97
- # Note: Not ready for use yet, intended for September release.
98
- def caption(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  self,
100
- images: List[Image.Image],
101
- tokenizer,
102
- length: Optional[Literal["short"]] = None,
103
- **kwargs,
104
  ):
105
- image_embeds = self.encode_image(images)
106
-
107
- templated_prompts = [
108
- f"<image>\n\n{'Short caption' if length == 'short' else 'Caption'}:" for _ in images
109
- ]
110
- inputs_embeds = torch.stack([
111
- self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0]
112
- for prompt, image_embed in zip(templated_prompts, image_embeds)
113
- ])
114
- attention_mask = torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), device=self.device)
115
-
116
- generate_config = {
117
- "eos_token_id": tokenizer.eos_token_id,
118
- "bos_token_id": tokenizer.bos_token_id,
119
- "pad_token_id": tokenizer.bos_token_id,
120
- "repetition_penalty": 1.2,
121
- "max_new_tokens": 512,
122
- **kwargs,
123
- }
124
 
125
- with torch.no_grad():
126
- output_ids = self.text_model.generate(
127
- inputs_embeds=inputs_embeds,
128
- attention_mask=attention_mask,
129
- **generate_config,
130
- )
 
 
 
131
 
132
- return [
133
- x.strip()
134
- for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
135
- ]
136
 
137
- def answer_question(
 
 
 
 
 
 
 
 
 
 
 
138
  self,
139
- image_embeds,
140
- question,
141
- tokenizer,
142
- chat_history="",
143
- result_queue=None,
144
- max_new_tokens=256,
145
- **kwargs,
146
  ):
147
- prompt = f"<image>\n\n{chat_history}Question: {question}\n\nAnswer:"
148
- answer = self.generate(
149
- image_embeds,
150
- prompt,
151
- tokenizer=tokenizer,
152
- max_new_tokens=max_new_tokens,
153
- **kwargs,
154
- )[0]
155
- cleaned_answer = answer.strip()
156
-
157
- # Use the result_queue to pass the result if it is provided
158
- if result_queue:
159
- result_queue.put(cleaned_answer)
 
 
 
 
 
 
 
 
 
160
  else:
161
- return cleaned_answer
162
 
163
- def batch_answer(
164
  self,
165
- images,
166
- prompts,
167
- tokenizer,
168
- **kwargs,
 
 
169
  ):
170
- image_embeds = self.encode_image(images)
171
 
172
- templated_prompts = [
173
- f"<image>\n\nQuestion: {prompt}\n\nAnswer:" for prompt in prompts
174
- ]
175
- prompt_embs = [
176
- self.input_embeds(prompt, image_embed.unsqueeze(0), tokenizer)[0]
177
- for prompt, image_embed in zip(templated_prompts, image_embeds)
178
- ]
 
 
 
179
 
180
- bos_emb = prompt_embs[0][0]
181
- max_len = max([p.shape[0] for p in prompt_embs])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
- inputs_embeds = torch.cat(
 
 
 
 
 
 
 
 
 
 
 
 
184
  [
185
- torch.cat([bos_emb.repeat(max_len - p.shape[0], 1), p]).unsqueeze(0)
186
- for p in prompt_embs
 
187
  ],
188
- dim=0,
189
  )
190
- attention_mask = torch.cat(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  [
192
- torch.cat(
193
- [
194
- torch.zeros(
195
- 1,
196
- max_len - p.shape[0],
197
- device=self.device,
198
- dtype=torch.long,
199
- ),
200
- torch.ones(1, p.shape[0], device=self.device, dtype=torch.long),
201
- ],
202
- dim=1,
203
- )
204
- for p in prompt_embs
205
  ],
206
- dim=0,
207
  )
208
 
209
- generate_config = {
210
- "eos_token_id": tokenizer.eos_token_id,
211
- "bos_token_id": tokenizer.bos_token_id,
212
- "pad_token_id": tokenizer.bos_token_id,
213
- "max_new_tokens": 512,
214
- **kwargs,
215
- }
216
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  with torch.no_grad():
218
- output_ids = self.text_model.generate(
219
- inputs_embeds=inputs_embeds,
220
- attention_mask=attention_mask,
221
- **generate_config,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  )
223
 
224
- return [
225
- x.strip()
226
- for x in tokenizer.batch_decode(output_ids, skip_special_tokens=True)
227
- ]
 
 
 
 
228
 
229
- def detect(self, image: Image.Image, query: str, tokenizer):
230
- pass
 
1
  import torch
2
+ import torch.nn as nn
3
+ import random
4
 
5
+ from typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional
 
6
  from PIL import Image
7
+ from dataclasses import dataclass
8
+ from tokenizers import Tokenizer
9
 
10
+ from .config import MoondreamConfig
11
+ from .image_crops import reconstruct_from_crops
12
+ from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
13
+ from .text import build_text_model, prefill, text_encoder, lm_head, decode_one_token
14
+ from .region import decode_coordinate, encode_coordinate, decode_size, encode_size
15
+ from .utils import remove_outlier_points
16
 
 
 
 
17
 
18
+ SamplingSettings = TypedDict(
19
+ "SamplingSettings",
20
+ {"max_tokens": int},
21
+ total=False,
22
+ )
23
+
24
+ DEFAULT_MAX_TOKENS = 512
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class EncodedImage:
29
+ pos: int
30
+ kv_cache: torch.Tensor
31
+
32
+
33
+ def _min_p_sampler(
34
+ logits: torch.Tensor,
35
+ min_p: float = 0.1,
36
+ filter_value: float = 0,
37
+ min_tokens_to_keep: int = 1,
38
+ temp=0.5,
39
+ ) -> torch.Tensor:
40
+ """
41
+ Min-p sampler adapted from https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
42
+ https://arxiv.org/pdf/2407.01082
43
+ """
44
+ logits = logits / temp
45
+ probs = torch.softmax(logits, dim=-1)
46
+ top_probs, _ = probs.max(dim=-1, keepdim=True)
47
+ scaled_min_p = min_p * top_probs
48
+ tokens_to_remove = probs < scaled_min_p
49
+ sorted_indices = torch.argsort(logits, descending=True, dim=-1)
50
+ sorted_indices_to_remove = torch.gather(
51
+ tokens_to_remove, dim=-1, index=sorted_indices
52
+ )
53
+ if min_tokens_to_keep > 1:
54
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = False
55
+
56
+ indices_to_remove = sorted_indices_to_remove.scatter(
57
+ 1, sorted_indices, sorted_indices_to_remove
58
+ )
59
+ logits = logits.masked_fill(indices_to_remove, filter_value)
60
+ token = torch.multinomial(logits, num_samples=1)
61
+ return token.squeeze(0)
62
+
63
+
64
+ class MoondreamModel(nn.Module):
65
+ def __init__(self, config: MoondreamConfig, dtype=torch.float16):
66
+ super().__init__()
67
+ self.config = config
68
+
69
+ self.tokenizer = Tokenizer.from_pretrained(
70
+ "vikhyatk/moondream2", revision="2024-08-26"
71
  )
72
+ self.vision = build_vision_model(config.vision, dtype)
73
+ self.text = build_text_model(config.text, dtype)
74
 
75
+ # Region Model
76
+ self.region = nn.ModuleDict(
77
+ {
78
+ "coord_encoder": nn.Linear(
79
+ config.region.coord_feat_dim, config.region.dim, dtype=dtype
80
+ ),
81
+ "coord_decoder": nn.ModuleDict(
82
+ {
83
+ "fc1": nn.Linear(
84
+ config.region.dim, config.region.inner_dim, dtype=dtype
85
+ ),
86
+ "fc2": nn.Linear(
87
+ config.region.inner_dim,
88
+ config.region.coord_out_dim,
89
+ dtype=dtype,
90
+ ),
91
+ }
92
+ ),
93
+ "size_encoder": nn.Linear(
94
+ config.region.size_feat_dim, config.region.dim, dtype=dtype
95
+ ),
96
+ "size_decoder": nn.ModuleDict(
97
+ {
98
+ "fc1": nn.Linear(
99
+ config.region.dim, config.region.inner_dim, dtype=dtype
100
+ ),
101
+ "fc2": nn.Linear(
102
+ config.region.inner_dim,
103
+ config.region.size_out_dim,
104
+ dtype=dtype,
105
+ ),
106
+ }
107
+ ),
108
+ }
109
+ )
110
+ self.region.coord_features = nn.Parameter(
111
+ torch.empty(config.region.coord_feat_dim // 2, 1, dtype=dtype).T
112
+ )
113
+ self.region.size_features = nn.Parameter(
114
+ torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
115
+ )
116
+
117
+ self.ops = {
118
+ "vision_encoder": vision_encoder,
119
+ "vision_projection": vision_projection,
120
+ "prefill": prefill,
121
+ "decode_one_token": decode_one_token,
122
+ }
123
 
124
  @property
125
  def device(self):
126
+ return self.vision.pos_emb.device
127
 
128
+ def compile(self):
129
+ self.ops["vision_encoder"] = torch.compile(
130
+ self.ops["vision_encoder"], fullgraph=True
131
+ )
132
+ # Need to figure out how to mark the 'reconstructed' input shape as dynamic
133
+ # self.ops["vision_projection"] = torch.compile(
134
+ # self.ops["vision_projection"], fullgraph=True
135
+ # )
136
+ self.ops["prefill"] = torch.compile(self.ops["prefill"], fullgraph=True)
137
+ self.ops["decode_one_token"] = torch.compile(
138
+ self.ops["decode_one_token"], fullgraph=True
139
+ )
140
 
141
+ def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
142
+ all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
143
+ torch._dynamo.mark_dynamic(all_crops, 0)
 
 
144
 
145
+ outputs = self.ops["vision_encoder"](all_crops, self.vision, self.config.vision)
146
 
147
+ global_features = outputs[0]
148
+ local_features = outputs[1:].view(
149
+ -1,
150
+ self.config.vision.enc_n_layers,
151
+ self.config.vision.enc_n_layers,
152
+ self.config.vision.enc_dim,
153
  )
154
 
155
+ reconstructed = reconstruct_from_crops(
156
+ local_features,
157
+ tiling,
158
+ patch_size=1,
159
+ overlap_margin=self.config.vision.overlap_margin,
160
+ )
 
 
 
 
161
 
162
+ return self.ops["vision_projection"](
163
+ global_features, reconstructed, self.vision, self.config.vision
164
+ )
165
 
166
+ def encode_image(self, image: Union[Image.Image, EncodedImage]) -> EncodedImage:
167
+ if isinstance(image, EncodedImage):
168
+ return image
169
+ elif not isinstance(image, Image.Image):
170
+ raise ValueError("image must be a PIL Image or EncodedImage")
171
 
172
+ # Run through text model in addition to the vision encoder, to minimize
173
+ # re-computation if multiple queries are performed on this image.
174
+ kv_cache = torch.zeros(
175
+ self.config.text.n_layers,
176
+ 2, # k, v
177
+ 1, # batch size
178
+ self.config.text.n_heads,
179
+ self.config.text.max_context, # static cache
180
+ self.config.text.dim // self.config.text.n_heads, # head dim
181
+ device=self.device,
182
+ dtype=torch.float16,
183
+ )
184
+ with torch.no_grad():
185
+ img_emb = self._run_vision_encoder(image)
186
+ bos_emb = text_encoder(
187
+ torch.tensor([[self.config.tokenizer.bos_id]], device=self.device),
188
+ self.text,
189
+ )
190
+ inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
191
+ self.ops["prefill"](inputs_embeds, kv_cache, 0, self.text, self.config.text)
192
+ return EncodedImage(pos=inputs_embeds.size(1), kv_cache=kv_cache)
193
 
194
+ def _prefill_prompt(
195
+ self, kv_cache: torch.Tensor, prompt_tokens: torch.Tensor, pos: int
196
+ ):
197
  with torch.no_grad():
198
+ prompt_emb = text_encoder(prompt_tokens, self.text)
199
+ hidden = self.ops["prefill"](
200
+ prompt_emb, kv_cache, pos, self.text, self.config.text
 
 
 
201
  )
202
+ logits = lm_head(hidden, self.text)
203
+ next_token = torch.argmax(logits, dim=-1)
204
+ pos = pos + prompt_emb.size(1)
205
+ return logits, hidden, next_token, pos
206
 
207
+ def _generate_text(
208
+ self,
209
+ prompt_tokens: torch.Tensor,
210
+ kv_cache: torch.Tensor,
211
+ pos: int,
212
+ max_tokens: int,
213
+ ):
214
+ kv_cache = kv_cache.clone()
215
+ _, _, next_token, pos = self._prefill_prompt(kv_cache, prompt_tokens, pos)
216
 
217
+ def generator(next_token, pos):
218
+ generated_tokens = 0
219
+
220
+ while (
221
+ next_token_id := next_token.item()
222
+ ) != self.config.tokenizer.eos_id and generated_tokens < max_tokens:
223
+ yield self.tokenizer.decode([next_token_id])
224
+
225
+ with torch.no_grad():
226
+ next_emb = text_encoder(next_token, self.text)
227
+ logits, _, kv_cache_update = self.ops["decode_one_token"](
228
+ next_emb, kv_cache, pos, self.text, self.config.text
229
+ )
230
+ kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = (
231
+ kv_cache_update
232
+ )
233
+ pos += 1
234
+ next_token = torch.argmax(logits, dim=-1)
235
+ generated_tokens += 1
236
+
237
+ return generator(next_token, pos)
238
+
239
+ def query(
240
  self,
241
+ image: Union[Image.Image, EncodedImage],
242
+ question: str,
243
+ stream: bool = False,
244
+ settings: Optional[SamplingSettings] = None,
245
  ):
246
+ if self.config.tokenizer.templates["query"] is None:
247
+ raise NotImplementedError("Model does not support querying.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
+ image = self.encode_image(image)
250
+ prompt_tokens = torch.tensor(
251
+ [
252
+ self.config.tokenizer.templates["query"]["prefix"]
253
+ + self.tokenizer.encode(question).ids
254
+ + self.config.tokenizer.templates["query"]["suffix"]
255
+ ],
256
+ device=self.device,
257
+ )
258
 
259
+ max_tokens = DEFAULT_MAX_TOKENS
260
+ if settings:
261
+ max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS)
 
262
 
263
+ def generator():
264
+ for token in self._generate_text(
265
+ prompt_tokens, image.kv_cache, image.pos, max_tokens
266
+ ):
267
+ yield token
268
+
269
+ if stream:
270
+ return {"answer": generator()}
271
+ else:
272
+ return {"answer": "".join(list(generator()))}
273
+
274
+ def caption(
275
  self,
276
+ image: Union[Image.Image, EncodedImage],
277
+ length: Literal["normal", "short"] = "normal",
278
+ stream: bool = False,
279
+ settings: Optional[SamplingSettings] = None,
 
 
 
280
  ):
281
+ if self.config.tokenizer.templates["caption"] is None:
282
+ raise NotImplementedError("Model does not support captioning.")
283
+ if length not in self.config.tokenizer.templates["caption"]:
284
+ raise ValueError(f"Model does not support caption length '{length}'.")
285
+
286
+ image = self.encode_image(image)
287
+ prompt_tokens = torch.tensor(
288
+ [self.config.tokenizer.templates["caption"][length]], device=self.device
289
+ )
290
+
291
+ max_tokens = DEFAULT_MAX_TOKENS
292
+ if settings:
293
+ max_tokens = settings.get("max_tokens", DEFAULT_MAX_TOKENS)
294
+
295
+ def generator():
296
+ for token in self._generate_text(
297
+ prompt_tokens, image.kv_cache, image.pos, max_tokens
298
+ ):
299
+ yield token
300
+
301
+ if stream:
302
+ return {"caption": generator()}
303
  else:
304
+ return {"caption": "".join(list(generator()))}
305
 
306
+ def _generate_points(
307
  self,
308
+ hidden: torch.Tensor,
309
+ kv_cache: torch.Tensor,
310
+ next_token: torch.Tensor,
311
+ pos: int,
312
+ include_size: bool = True,
313
+ max_points: int = 50,
314
  ):
315
+ out = []
316
 
317
+ with torch.no_grad():
318
+ while (
319
+ next_token.item() != self.config.tokenizer.eos_id
320
+ and len(out) < max_points
321
+ ):
322
+ x_logits = decode_coordinate(hidden, self.region)
323
+ x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1)
324
+ next_emb = encode_coordinate(
325
+ x_center.to(dtype=x_logits.dtype), self.region
326
+ )
327
 
328
+ # Decode y-coordinate
329
+ _, hidden, kv_cache_update = self.ops["decode_one_token"](
330
+ next_emb, kv_cache, pos, self.text, self.config.text
331
+ )
332
+ kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = (
333
+ kv_cache_update
334
+ )
335
+ pos += 1
336
+ y_logits = decode_coordinate(hidden, self.region)
337
+ y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
338
+ next_emb = encode_coordinate(
339
+ y_center.to(dtype=y_logits.dtype), self.region
340
+ )
341
+
342
+ # Decode size
343
+ if include_size:
344
+ logits, hidden, kv_cache_update = self.ops["decode_one_token"](
345
+ next_emb, kv_cache, pos, self.text, self.config.text
346
+ )
347
+ kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = (
348
+ kv_cache_update
349
+ )
350
+ pos += 1
351
+ size_logits = decode_size(hidden, self.region)
352
+ w = torch.argmax(size_logits[0], dim=-1) / size_logits.size(-1)
353
+ h = torch.argmax(size_logits[1], dim=-1) / size_logits.size(-1)
354
+ next_emb = encode_size(
355
+ torch.tensor(
356
+ [w, h], device=self.device, dtype=size_logits.dtype
357
+ ),
358
+ self.region,
359
+ )[None]
360
+
361
+ # Add object
362
+ out.append(
363
+ {
364
+ "x_min": x_center.item() - w.item() / 2,
365
+ "y_min": y_center.item() - h.item() / 2,
366
+ "x_max": x_center.item() + w.item() / 2,
367
+ "y_max": y_center.item() + h.item() / 2,
368
+ }
369
+ )
370
+ else:
371
+ out.append({"x": x_center.item(), "y": y_center.item()})
372
+
373
+ # Decode next token (x-coordinate, or eos)
374
+ logits, hidden, kv_cache_update = self.ops["decode_one_token"](
375
+ next_emb, kv_cache, pos, self.text, self.config.text
376
+ )
377
+ kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = (
378
+ kv_cache_update
379
+ )
380
+ pos += 1
381
+ next_token = torch.argmax(logits, dim=-1)
382
 
383
+ return out
384
+
385
+ def detect(
386
+ self,
387
+ image: Union[Image.Image, EncodedImage],
388
+ object: str,
389
+ settings: Optional[SamplingSettings] = None,
390
+ ):
391
+ if self.config.tokenizer.templates["detect"] is None:
392
+ raise NotImplementedError("Model does not support object detection.")
393
+
394
+ image = self.encode_image(image)
395
+ prompt_tokens = torch.tensor(
396
  [
397
+ self.config.tokenizer.templates["detect"]["prefix"]
398
+ + self.tokenizer.encode(object).ids
399
+ + self.config.tokenizer.templates["detect"]["suffix"]
400
  ],
401
+ device=self.device,
402
  )
403
+
404
+ kv_cache = image.kv_cache.clone()
405
+ _, hidden, next_token, pos = self._prefill_prompt(
406
+ kv_cache, prompt_tokens, image.pos
407
+ )
408
+ hidden = hidden[:, -1:, :]
409
+
410
+ objects = self._generate_points(
411
+ hidden, kv_cache, next_token, pos, include_size=True, max_points=50
412
+ )
413
+
414
+ return {"objects": objects}
415
+
416
+ def point(
417
+ self,
418
+ image: Union[Image.Image, EncodedImage],
419
+ object: str,
420
+ settings: Optional[SamplingSettings] = None,
421
+ ):
422
+ if self.config.tokenizer.templates["point"] is None:
423
+ raise NotImplementedError("Model does not support pointing.")
424
+
425
+ image = self.encode_image(image)
426
+ prompt_tokens = torch.tensor(
427
  [
428
+ self.config.tokenizer.templates["point"]["prefix"]
429
+ + self.tokenizer.encode(object).ids
430
+ + self.config.tokenizer.templates["point"]["suffix"]
 
 
 
 
 
 
 
 
 
 
431
  ],
432
+ device=self.device,
433
  )
434
 
435
+ kv_cache = image.kv_cache.clone()
436
+ _, hidden, next_token, pos = self._prefill_prompt(
437
+ kv_cache, prompt_tokens, image.pos
438
+ )
439
+ hidden = hidden[:, -1:, :]
 
 
440
 
441
+ objects = self._generate_points(
442
+ hidden, kv_cache, next_token, pos, include_size=False, max_points=50
443
+ )
444
+
445
+ return {"points": objects}
446
+
447
+ def _detect_gaze(
448
+ self,
449
+ image: EncodedImage,
450
+ source: Tuple[float, float],
451
+ force_detect: bool = False,
452
+ ):
453
  with torch.no_grad():
454
+ before_emb = text_encoder(
455
+ torch.tensor(
456
+ [self.tokenizer.encode("\n\nPoint:").ids], device=self.device
457
+ ),
458
+ self.text,
459
+ )
460
+ after_emb = text_encoder(
461
+ torch.tensor(
462
+ [self.tokenizer.encode(" gaze\n\n").ids], device=self.device
463
+ ),
464
+ self.text,
465
+ )
466
+ x_emb = encode_coordinate(
467
+ torch.tensor([[[source[0]]]], device=self.device, dtype=torch.float16),
468
+ self.region,
469
+ )
470
+ y_emb = encode_coordinate(
471
+ torch.tensor([[[source[1]]]], device=self.device, dtype=torch.float16),
472
+ self.region,
473
+ )
474
+
475
+ prompt_emb = torch.cat([before_emb, x_emb, y_emb, after_emb], dim=1)
476
+
477
+ kv_cache = image.kv_cache.clone()
478
+ hidden = self.ops["prefill"](
479
+ prompt_emb, kv_cache, image.pos, self.text, self.config.text
480
+ )
481
+ logits = lm_head(hidden, self.text)
482
+ next_token = torch.argmax(logits, dim=-1)
483
+ pos = image.pos + prompt_emb.size(1)
484
+ hidden = hidden[:, -1:, :]
485
+
486
+ if force_detect:
487
+ next_token = torch.tensor([[0]], device=self.device)
488
+
489
+ if next_token.item() == self.config.tokenizer.eos_id:
490
+ return None
491
+
492
+ gaze = self._generate_points(
493
+ hidden, kv_cache, next_token, pos, include_size=False, max_points=1
494
+ )
495
+ return gaze[0]
496
+
497
+ def detect_gaze(
498
+ self,
499
+ image: Union[Image.Image, EncodedImage],
500
+ eye: Optional[Tuple[float, float]] = None,
501
+ face: Optional[Dict[str, float]] = None,
502
+ unstable_settings: Dict[str, Any] = {},
503
+ ):
504
+ if "force_detect" in unstable_settings:
505
+ force_detect = unstable_settings["force_detect"]
506
+ else:
507
+ force_detect = False
508
+
509
+ if "prioritize_accuracy" in unstable_settings:
510
+ prioritize_accuracy = unstable_settings["prioritize_accuracy"]
511
+ else:
512
+ prioritize_accuracy = False
513
+
514
+ if not prioritize_accuracy:
515
+ if eye is None:
516
+ raise ValueError("eye must be provided when prioritize_accuracy=False")
517
+ image = self.encode_image(image)
518
+ return {"gaze": self._detect_gaze(image, eye, force_detect=force_detect)}
519
+ else:
520
+ if (
521
+ not isinstance(image, Image.Image)
522
+ and "flip_enc_img" not in unstable_settings
523
+ ):
524
+ raise ValueError(
525
+ "image must be a PIL Image when prioritize_accuracy=True, "
526
+ "or flip_enc_img must be provided"
527
+ )
528
+ if face is None:
529
+ raise ValueError("face must be provided when prioritize_accuracy=True")
530
+
531
+ encoded_image = self.encode_image(image)
532
+ if (
533
+ isinstance(image, Image.Image)
534
+ and "flip_enc_img" not in unstable_settings
535
+ ):
536
+ flipped_pil = image.copy()
537
+ flipped_pil = flipped_pil.transpose(method=Image.FLIP_LEFT_RIGHT)
538
+ encoded_flipped_image = self.encode_image(flipped_pil)
539
+ else:
540
+ encoded_flipped_image = unstable_settings["flip_enc_img"]
541
+
542
+ N = 10
543
+
544
+ detections = [
545
+ self._detect_gaze(
546
+ encoded_image,
547
+ (
548
+ random.uniform(face["x_min"], face["x_max"]),
549
+ random.uniform(face["y_min"], face["y_max"]),
550
+ ),
551
+ force_detect=force_detect,
552
+ )
553
+ for _ in range(N)
554
+ ]
555
+ detections = [
556
+ (gaze["x"], gaze["y"]) for gaze in detections if gaze is not None
557
+ ]
558
+ flipped_detections = [
559
+ self._detect_gaze(
560
+ encoded_flipped_image,
561
+ (
562
+ 1 - random.uniform(face["x_min"], face["x_max"]),
563
+ random.uniform(face["y_min"], face["y_max"]),
564
+ ),
565
+ force_detect=force_detect,
566
+ )
567
+ for _ in range(N)
568
+ ]
569
+ detections.extend(
570
+ [
571
+ (1 - gaze["x"], gaze["y"])
572
+ for gaze in flipped_detections
573
+ if gaze is not None
574
+ ]
575
  )
576
 
577
+ if len(detections) < N:
578
+ return {"gaze": None}
579
+
580
+ detections = remove_outlier_points(detections)
581
+ mean_gaze = (
582
+ sum(gaze[0] for gaze in detections) / len(detections),
583
+ sum(gaze[1] for gaze in detections) / len(detections),
584
+ )
585
 
586
+ return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}}
 
region.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ from .weights import RegionModel
5
+ from .layers import linear, mlp
6
+
7
+
8
+ def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
9
+ """
10
+ Applies Fourier feature mapping to input tensor x using frequency matrix w. This
11
+ projects inputs through sinusoidal functions to create higher dimensional features
12
+ that help mitigate spectral bias - the tendency of neural networks to learn
13
+ low-frequency functions more easily than high-frequency ones. By explicitly
14
+ mapping inputs to higher frequencies through sin/cos transformations, we enable
15
+ better learning of fine details and higher frequency patterns.
16
+
17
+ Args:
18
+ x: Input tensor to transform
19
+ w: Matrix of frequencies for the Fourier features transformation
20
+
21
+ Returns:
22
+ Concatenated cosine and sine transformed features as a tensor
23
+ """
24
+ f = 2 * math.pi * x @ w
25
+ return torch.cat([f.cos(), f.sin()], dim=-1)
26
+
27
+
28
+ def encode_coordinate(coord: torch.Tensor, w: RegionModel) -> torch.Tensor:
29
+ """
30
+ Takes as input a tensor containing a single float coordinate value (x or y)
31
+ and encodes it into hidden states for input to the text model.
32
+
33
+ Args:
34
+ coord: Tensor with single float coordinate value
35
+
36
+ Returns:
37
+ Encoded hidden states tensor for input to text model
38
+ """
39
+ return linear(fourier_features(coord, w.coord_features), w.coord_encoder)
40
+
41
+
42
+ def decode_coordinate(hidden_state: torch.Tensor, w: RegionModel) -> torch.Tensor:
43
+ """
44
+ Takes as input the last hidden state from the text model and outputs a single logit
45
+ representing either an x or y coordinate prediction.
46
+
47
+ Args:
48
+ hidden_state: The final hidden state tensor from the text model.
49
+
50
+ Returns:
51
+ A single logit representing the predicted coordinate value (x or y)
52
+ """
53
+ return mlp(hidden_state, w.coord_decoder)
54
+
55
+
56
+ def encode_size(size: torch.Tensor, w: RegionModel) -> torch.Tensor:
57
+ """
58
+ Takes a tensor containing normalized width and height values in range [0,1]
59
+ and encodes them into hidden states for input to the text model.
60
+
61
+ Args:
62
+ size: Tensor with two floats for width and height in range [0,1]
63
+
64
+ Returns:
65
+ Encoded hidden states tensor for input to text model
66
+ """
67
+ return linear(fourier_features(size, w.size_features), w.size_encoder)
68
+
69
+
70
+ def decode_size(hidden_state: torch.Tensor, w: RegionModel) -> torch.Tensor:
71
+ """
72
+ Takes as input the last hidden state from the text model and outputs two logits
73
+ for width and height respectively.
74
+
75
+ Args:
76
+ hidden_state: The final hidden state tensor from the text model.
77
+
78
+ Returns:
79
+ A tensor containing two logits - one for predicted width and one for
80
+ predicted height.
81
+ """
82
+ return mlp(hidden_state, w.size_decoder).view(2, -1)
rope.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ethically sourced from https://github.com/xjdr-alt/entropix
2
+
3
+ import torch
4
+
5
+
6
+ def precompute_freqs_cis(
7
+ dim: int,
8
+ end: int,
9
+ theta: float = 10000.0,
10
+ use_scaled: bool = False,
11
+ dtype: torch.dtype = torch.float32,
12
+ ) -> torch.Tensor:
13
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim))
14
+ t = torch.arange(end, dtype=dtype).unsqueeze(1)
15
+ freqs = t * freqs.unsqueeze(0)
16
+ freqs = torch.exp(1j * freqs)
17
+ return torch.stack([freqs.real, freqs.imag], dim=-1)
18
+
19
+
20
+ def apply_rotary_emb(
21
+ x: torch.Tensor,
22
+ freqs_cis: torch.Tensor,
23
+ position_ids: torch.Tensor,
24
+ num_heads: int,
25
+ rot_dim: int = 32,
26
+ interleave: bool = False,
27
+ ) -> torch.Tensor:
28
+ assert rot_dim == freqs_cis.shape[-2] * 2
29
+ assert num_heads == x.shape[1]
30
+
31
+ x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:]
32
+
33
+ if interleave:
34
+ xq_r = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0]
35
+ xq_i = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1]
36
+ else:
37
+ d_q = x_rot.shape[-1] // 2
38
+ xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:]
39
+
40
+ freqs_cos = freqs_cis[..., 0][position_ids, :].unsqueeze(0).unsqueeze(0)
41
+ freqs_sin = freqs_cis[..., 1][position_ids, :].unsqueeze(0).unsqueeze(0)
42
+
43
+ # Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
44
+ xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
45
+ xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
46
+ xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2)
47
+
48
+ return torch.cat([xq_out.to(x.dtype), x_pass], dim=-1)
text.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from .layers import layer_norm, linear, mlp
6
+ from .rope import apply_rotary_emb, precompute_freqs_cis
7
+ from .weights import AttentionWeights
8
+ from .config import TextConfig
9
+
10
+
11
+ def text_encoder(input_ids: torch.Tensor, w: nn.Module):
12
+ return F.embedding(input_ids, w.wte)
13
+
14
+
15
+ def attn(
16
+ x: torch.Tensor,
17
+ w: AttentionWeights,
18
+ freqs_cis: torch.Tensor,
19
+ layer_kv_cache: torch.Tensor,
20
+ attn_mask: torch.Tensor,
21
+ n_heads: int,
22
+ pos: int,
23
+ ):
24
+ bsz, q_len, d_model = x.shape
25
+ head_dim = d_model // n_heads
26
+
27
+ q, k, v = [
28
+ t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
29
+ for t in linear(x, w.qkv).chunk(3, dim=-1)
30
+ ]
31
+
32
+ position_ids = torch.arange(pos, pos + q_len, dtype=torch.long)
33
+ q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
34
+ k = apply_rotary_emb(k, freqs_cis, position_ids, n_heads)
35
+
36
+ k_, v_ = k, v
37
+ if layer_kv_cache is not None:
38
+ k = torch.cat([layer_kv_cache[0, :, :, :pos, :], k], dim=2)
39
+ v = torch.cat([layer_kv_cache[1, :, :, :pos, :], v], dim=2)
40
+
41
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask).to(
42
+ # This type conversion isn't needed when running in PyTorch directly, but the
43
+ # ONNX export runs attention in float32 because the attention mask is cast to
44
+ # float32.
45
+ x.dtype
46
+ )
47
+ out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
48
+ out = linear(out, w.proj)
49
+ return out, torch.stack([k_, v_])
50
+
51
+
52
+ def text_decoder(
53
+ inputs_embeds: torch.Tensor,
54
+ w: nn.Module,
55
+ kv_cache: torch.Tensor,
56
+ pos: int,
57
+ config: TextConfig,
58
+ ):
59
+ hidden_BTC = inputs_embeds
60
+ new_kv_cache = [torch.empty(0)] * len(w.blocks)
61
+
62
+ attn_mask = w.attn_mask[
63
+ :, :, pos : pos + hidden_BTC.size(1), : pos + hidden_BTC.size(1)
64
+ ]
65
+
66
+ for i, block in enumerate(w.blocks):
67
+ l_in = layer_norm(hidden_BTC, block.ln)
68
+ l_attn, new_kv_cache[i] = attn(
69
+ l_in,
70
+ block.attn,
71
+ freqs_cis=w.freqs_cis,
72
+ layer_kv_cache=kv_cache[i],
73
+ attn_mask=attn_mask,
74
+ n_heads=config.n_heads,
75
+ pos=pos,
76
+ )
77
+ l_mlp = mlp(l_in, block.mlp)
78
+ hidden_BTC = hidden_BTC + l_attn + l_mlp
79
+
80
+ return hidden_BTC, torch.stack(new_kv_cache)
81
+
82
+
83
+ def lm_head(hidden_BTC: torch.Tensor, w: nn.Module):
84
+ hidden_BC = hidden_BTC[:, -1, :]
85
+ hidden_BC = layer_norm(hidden_BC, w.post_ln)
86
+ logits = linear(hidden_BC, w.lm_head)
87
+ return logits
88
+
89
+
90
+ def prefill(
91
+ inputs_embeds: torch.Tensor,
92
+ kv_cache: torch.Tensor,
93
+ pos: int,
94
+ w: nn.Module,
95
+ config: TextConfig,
96
+ ):
97
+ # Updates kv_cache in-place
98
+ hidden, kv_cache[:, :, :, :, pos : pos + inputs_embeds.size(1), :] = text_decoder(
99
+ inputs_embeds, w, kv_cache, pos, config
100
+ )
101
+ return hidden
102
+
103
+
104
+ def decode_one_token(
105
+ token_emb: torch.Tensor,
106
+ kv_cache: torch.Tensor,
107
+ pos: int,
108
+ w: nn.Module,
109
+ config: TextConfig,
110
+ ):
111
+ hidden, kv_cache_update = text_decoder(token_emb[None], w, kv_cache, pos, config)
112
+ logits = lm_head(hidden, w)
113
+ return logits, hidden, kv_cache_update
114
+
115
+
116
+ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
117
+ text = nn.ModuleDict(
118
+ {
119
+ "blocks": nn.ModuleList(
120
+ [
121
+ nn.ModuleDict(
122
+ {
123
+ "ln": nn.LayerNorm(config.dim, dtype=dtype),
124
+ "attn": nn.ModuleDict(
125
+ {
126
+ "qkv": nn.Linear(
127
+ config.dim, 3 * config.dim, dtype=dtype
128
+ ),
129
+ "proj": nn.Linear(
130
+ config.dim, config.dim, dtype=dtype
131
+ ),
132
+ }
133
+ ),
134
+ "mlp": nn.ModuleDict(
135
+ {
136
+ "fc1": nn.Linear(
137
+ config.dim, 4 * config.dim, dtype=dtype
138
+ ),
139
+ "fc2": nn.Linear(
140
+ 4 * config.dim, config.dim, dtype=dtype
141
+ ),
142
+ }
143
+ ),
144
+ }
145
+ )
146
+ for _ in range(config.n_layers)
147
+ ]
148
+ ),
149
+ "post_ln": nn.LayerNorm(config.dim, dtype=dtype),
150
+ "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype),
151
+ }
152
+ )
153
+ text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
154
+ text.register_buffer(
155
+ "freqs_cis",
156
+ precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
157
+ persistent=False,
158
+ )
159
+
160
+ attn_mask = torch.tril(
161
+ torch.ones(1, 1, config.max_context, config.max_context, dtype=torch.bool)
162
+ )
163
+ if config.prefix_attn != 0:
164
+ attn_mask[..., : config.prefix_attn, : config.prefix_attn] = 1
165
+ text.register_buffer("attn_mask", attn_mask, persistent=False)
166
+
167
+ return text
utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def remove_outlier_points(points_tuples, k_nearest=2, threshold=2.0):
5
+ """
6
+ Robust outlier detection for list of (x,y) tuples.
7
+ Only requires numpy.
8
+
9
+ Args:
10
+ points_tuples: list of (x,y) tuples
11
+ k_nearest: number of neighbors to consider
12
+ threshold: multiplier for median distance
13
+
14
+ Returns:
15
+ list: filtered list of (x,y) tuples with outliers removed
16
+ list: list of booleans indicating which points were kept (True = kept)
17
+ """
18
+ points = np.array(points_tuples)
19
+ n_points = len(points)
20
+
21
+ # Calculate pairwise distances manually
22
+ dist_matrix = np.zeros((n_points, n_points))
23
+ for i in range(n_points):
24
+ for j in range(i + 1, n_points):
25
+ # Euclidean distance between points i and j
26
+ dist = np.sqrt(np.sum((points[i] - points[j]) ** 2))
27
+ dist_matrix[i, j] = dist
28
+ dist_matrix[j, i] = dist
29
+
30
+ # Get k nearest neighbors' distances
31
+ k = min(k_nearest, n_points - 1)
32
+ neighbor_distances = np.partition(dist_matrix, k, axis=1)[:, :k]
33
+ avg_neighbor_dist = np.mean(neighbor_distances, axis=1)
34
+
35
+ # Calculate mask using median distance
36
+ median_dist = np.median(avg_neighbor_dist)
37
+ mask = avg_neighbor_dist <= threshold * median_dist
38
+
39
+ # Return filtered tuples and mask
40
+ filtered_tuples = [t for t, m in zip(points_tuples, mask) if m]
41
+ return filtered_tuples
vision.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ from typing import Union, Tuple
7
+ from einops import rearrange
8
+ from PIL import Image
9
+
10
+ from .layers import attn, layer_norm, linear, mlp
11
+ from .image_crops import overlap_crop_image
12
+ from .config import VisionConfig
13
+
14
+ if torch.backends.mps.is_available():
15
+ # Non-divisible input sizes are not implemented on MPS device yet.
16
+ # https://github.com/pytorch/pytorch/issues/96056
17
+ def adaptive_avg_pool2d(input, output_size):
18
+ return F.adaptive_avg_pool2d(input.to("cpu"), output_size).to("mps")
19
+
20
+ else:
21
+ adaptive_avg_pool2d = F.adaptive_avg_pool2d
22
+
23
+ DeviceLike = Union[str, torch.device, int]
24
+
25
+
26
+ def prepare_crops(
27
+ image: Image.Image, config: VisionConfig, device: DeviceLike
28
+ ) -> Tuple[torch.Tensor, Tuple[int, int]]:
29
+ np_image = np.array(image.convert("RGB"))
30
+ overlap_crops = overlap_crop_image(
31
+ np_image, max_crops=config.max_crops, overlap_margin=config.overlap_margin
32
+ )
33
+ all_crops = overlap_crops["crops"]
34
+ all_crops = np.transpose(all_crops, (0, 3, 1, 2))
35
+ all_crops = (
36
+ torch.from_numpy(all_crops)
37
+ .to(device=device, dtype=torch.float16)
38
+ .div_(255.0)
39
+ .sub_(0.5)
40
+ .div_(0.5)
41
+ )
42
+ return all_crops, overlap_crops["tiling"]
43
+
44
+
45
+ def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
46
+ x = rearrange(
47
+ input_BCHW,
48
+ "b c (h p1) (w p2) -> b (h w) (c p1 p2)",
49
+ p1=config.enc_patch_size,
50
+ p2=config.enc_patch_size,
51
+ ) # B3HW -> B(HxW)(3xP1xP2), aka BTC
52
+
53
+ x = linear(x, w.patch_emb)
54
+ x = x + w.pos_emb
55
+ for block in w.blocks:
56
+ x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)
57
+ x = x + mlp(layer_norm(x, block.ln2), block.mlp)
58
+ x = layer_norm(x, w.post_ln)
59
+
60
+ return x
61
+
62
+
63
+ def vision_projection(
64
+ global_features: torch.Tensor,
65
+ reconstructed: torch.Tensor,
66
+ w: nn.Module,
67
+ config: VisionConfig,
68
+ ):
69
+ reconstructed = reconstructed.permute(2, 0, 1)
70
+ reconstructed = adaptive_avg_pool2d(
71
+ reconstructed, output_size=(config.enc_n_layers, config.enc_n_layers)
72
+ )
73
+ reconstructed = reconstructed.permute(1, 2, 0).view(729, config.enc_dim)
74
+ final_features = torch.cat([global_features, reconstructed], dim=-1)
75
+ return mlp(final_features, w.proj_mlp)
76
+
77
+
78
+ def build_vision_model(config: VisionConfig, dtype: torch.dtype):
79
+ patch_dim = config.enc_patch_size * config.enc_patch_size * config.in_channels
80
+ grid_size = config.crop_size // config.enc_patch_size
81
+ num_patches = grid_size * grid_size
82
+
83
+ vision = nn.ModuleDict(
84
+ {
85
+ "patch_emb": nn.Linear(patch_dim, config.enc_dim, dtype=dtype),
86
+ "blocks": nn.ModuleList(
87
+ [
88
+ nn.ModuleDict(
89
+ {
90
+ "ln1": nn.LayerNorm(config.enc_dim, dtype=dtype),
91
+ "attn": nn.ModuleDict(
92
+ {
93
+ "qkv": nn.Linear(
94
+ config.enc_dim, 3 * config.enc_dim, dtype=dtype
95
+ ),
96
+ "proj": nn.Linear(
97
+ config.enc_dim, config.enc_dim, dtype=dtype
98
+ ),
99
+ }
100
+ ),
101
+ "ln2": nn.LayerNorm(config.enc_dim, dtype=dtype),
102
+ "mlp": nn.ModuleDict(
103
+ {
104
+ "fc1": nn.Linear(
105
+ config.enc_dim, config.enc_ff_dim, dtype=dtype
106
+ ),
107
+ "fc2": nn.Linear(
108
+ config.enc_ff_dim, config.enc_dim, dtype=dtype
109
+ ),
110
+ }
111
+ ),
112
+ }
113
+ )
114
+ for _ in range(config.enc_n_layers)
115
+ ]
116
+ ),
117
+ "post_ln": nn.LayerNorm(config.enc_dim, dtype=dtype),
118
+ "proj_mlp": nn.ModuleDict(
119
+ {
120
+ "fc1": nn.Linear(
121
+ config.enc_dim * 2, config.proj_inner_dim, dtype=dtype
122
+ ),
123
+ "fc2": nn.Linear(
124
+ config.proj_inner_dim, config.proj_out_dim, dtype=dtype
125
+ ),
126
+ }
127
+ ),
128
+ }
129
+ )
130
+ vision.pos_emb = nn.Parameter(
131
+ torch.zeros(1, num_patches, config.enc_dim, dtype=dtype)
132
+ )
133
+ return vision
weights.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import safetensors
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from contextlib import contextmanager
6
+ from dataclasses import dataclass
7
+ from typing import Callable, List
8
+
9
+ from .layers import AttentionWeights, LayerNormWeights, LinearWeights, MLPWeights
10
+
11
+
12
+ @dataclass
13
+ class VisionBlock:
14
+ ln1: LayerNormWeights
15
+ attn: AttentionWeights
16
+ ln2: LayerNormWeights
17
+ mlp: MLPWeights
18
+
19
+
20
+ @dataclass
21
+ class VisionModel:
22
+ patch_emb: LinearWeights
23
+ pos_emb: torch.Tensor
24
+ blocks: List[VisionBlock]
25
+ post_ln: LayerNormWeights
26
+ proj_mlp: MLPWeights
27
+
28
+
29
+ @dataclass
30
+ class TextBlock:
31
+ ln: LayerNormWeights
32
+ attn: AttentionWeights
33
+ mlp: MLPWeights
34
+
35
+
36
+ @dataclass
37
+ class TextModel:
38
+ wte: torch.Tensor
39
+ blocks: List[TextBlock]
40
+ post_ln: LayerNormWeights
41
+ lm_head: LinearWeights
42
+
43
+
44
+ @dataclass
45
+ class RegionModel:
46
+ coord_features: torch.Tensor
47
+ coord_encoder: LinearWeights
48
+ coord_decoder: MLPWeights
49
+ size_features: torch.Tensor
50
+ size_encoder: LinearWeights
51
+ size_decoder: MLPWeights
52
+
53
+
54
+ @dataclass
55
+ class MoondreamModel:
56
+ vision: VisionModel
57
+ text: TextModel
58
+ region: RegionModel
59
+
60
+
61
+ @contextmanager
62
+ def safetensors_open(safetensors_file: str):
63
+ """
64
+ Simplify interfacing with safetensors files. Eliminates the need to ignore
65
+ type errors when using the `safe_open` function.
66
+ """
67
+ with safetensors.safe_open(
68
+ safetensors_file, framework="pt"
69
+ ) as st: # pyright: ignore
70
+
71
+ def get_tensor(name: str) -> torch.Tensor:
72
+ return st.get_tensor(name)
73
+
74
+ def get_keys() -> List[str]:
75
+ return st.keys()
76
+
77
+ get_tensor.keys = get_keys
78
+
79
+ yield get_tensor
80
+
81
+
82
+ def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None:
83
+ """Internal function to load weights using a tensor getter function."""
84
+ model = model.to(dtype=torch.float16)
85
+
86
+ # Vision Model
87
+ model.vision["patch_emb"].weight.data.copy_(
88
+ get_tensor("vision_encoder.encoder.model.visual.patch_embed.linear.weight")
89
+ )
90
+ model.vision["patch_emb"].bias.data.copy_(
91
+ get_tensor("vision_encoder.encoder.model.visual.patch_embed.linear.bias")
92
+ )
93
+ model.vision.pos_emb.data.copy_(
94
+ get_tensor("vision_encoder.encoder.model.visual.pos_embed")
95
+ )
96
+
97
+ for i in range(len(model.vision["blocks"])):
98
+ prefix = f"vision_encoder.encoder.model.visual.blocks.{i}"
99
+
100
+ # Layer norms
101
+ model.vision["blocks"][i]["ln1"].weight.data.copy_(
102
+ get_tensor(f"{prefix}.norm1.weight")
103
+ )
104
+ model.vision["blocks"][i]["ln1"].bias.data.copy_(
105
+ get_tensor(f"{prefix}.norm1.bias")
106
+ )
107
+ model.vision["blocks"][i]["ln2"].weight.data.copy_(
108
+ get_tensor(f"{prefix}.norm2.weight")
109
+ )
110
+ model.vision["blocks"][i]["ln2"].bias.data.copy_(
111
+ get_tensor(f"{prefix}.norm2.bias")
112
+ )
113
+
114
+ # Attention
115
+ model.vision["blocks"][i]["attn"]["qkv"].weight.data.copy_(
116
+ get_tensor(f"{prefix}.attn.qkv.weight")
117
+ )
118
+ model.vision["blocks"][i]["attn"]["qkv"].bias.data.copy_(
119
+ get_tensor(f"{prefix}.attn.qkv.bias")
120
+ )
121
+ model.vision["blocks"][i]["attn"]["proj"].weight.data.copy_(
122
+ get_tensor(f"{prefix}.attn.proj.weight")
123
+ )
124
+ model.vision["blocks"][i]["attn"]["proj"].bias.data.copy_(
125
+ get_tensor(f"{prefix}.attn.proj.bias")
126
+ )
127
+
128
+ # MLP
129
+ model.vision["blocks"][i]["mlp"]["fc1"].weight.data.copy_(
130
+ get_tensor(f"{prefix}.mlp.fc1.weight")
131
+ )
132
+ model.vision["blocks"][i]["mlp"]["fc1"].bias.data.copy_(
133
+ get_tensor(f"{prefix}.mlp.fc1.bias")
134
+ )
135
+ model.vision["blocks"][i]["mlp"]["fc2"].weight.data.copy_(
136
+ get_tensor(f"{prefix}.mlp.fc2.weight")
137
+ )
138
+ model.vision["blocks"][i]["mlp"]["fc2"].bias.data.copy_(
139
+ get_tensor(f"{prefix}.mlp.fc2.bias")
140
+ )
141
+
142
+ model.vision["post_ln"].weight.data.copy_(
143
+ get_tensor("vision_encoder.encoder.model.visual.norm.weight")
144
+ )
145
+ model.vision["post_ln"].bias.data.copy_(
146
+ get_tensor("vision_encoder.encoder.model.visual.norm.bias")
147
+ )
148
+
149
+ model.vision["proj_mlp"]["fc1"].weight.data.copy_(
150
+ get_tensor("vision_encoder.projection.mlp.fc1.weight")
151
+ )
152
+ model.vision["proj_mlp"]["fc1"].bias.data.copy_(
153
+ get_tensor("vision_encoder.projection.mlp.fc1.bias")
154
+ )
155
+ model.vision["proj_mlp"]["fc2"].weight.data.copy_(
156
+ get_tensor("vision_encoder.projection.mlp.fc2.weight")
157
+ )
158
+ model.vision["proj_mlp"]["fc2"].bias.data.copy_(
159
+ get_tensor("vision_encoder.projection.mlp.fc2.bias")
160
+ )
161
+
162
+ # Text Model
163
+ model.text.wte.data.copy_(get_tensor("text_model.transformer.embd.wte.weight"))
164
+
165
+ for i in range(len(model.text["blocks"])):
166
+ prefix = f"text_model.transformer.h.{i}"
167
+
168
+ # Layer norm
169
+ model.text["blocks"][i]["ln"].weight.data.copy_(
170
+ get_tensor(f"{prefix}.ln.weight")
171
+ )
172
+ model.text["blocks"][i]["ln"].bias.data.copy_(get_tensor(f"{prefix}.ln.bias"))
173
+
174
+ # Attention
175
+ model.text["blocks"][i]["attn"]["qkv"].weight.data.copy_(
176
+ get_tensor(f"{prefix}.mixer.Wqkv.weight")
177
+ )
178
+ model.text["blocks"][i]["attn"]["qkv"].bias.data.copy_(
179
+ get_tensor(f"{prefix}.mixer.Wqkv.bias")
180
+ )
181
+ model.text["blocks"][i]["attn"]["proj"].weight.data.copy_(
182
+ get_tensor(f"{prefix}.mixer.out_proj.weight")
183
+ )
184
+ model.text["blocks"][i]["attn"]["proj"].bias.data.copy_(
185
+ get_tensor(f"{prefix}.mixer.out_proj.bias")
186
+ )
187
+
188
+ # MLP
189
+ model.text["blocks"][i]["mlp"]["fc1"].weight.data.copy_(
190
+ get_tensor(f"{prefix}.mlp.fc1.weight")
191
+ )
192
+ model.text["blocks"][i]["mlp"]["fc1"].bias.data.copy_(
193
+ get_tensor(f"{prefix}.mlp.fc1.bias")
194
+ )
195
+ model.text["blocks"][i]["mlp"]["fc2"].weight.data.copy_(
196
+ get_tensor(f"{prefix}.mlp.fc2.weight")
197
+ )
198
+ model.text["blocks"][i]["mlp"]["fc2"].bias.data.copy_(
199
+ get_tensor(f"{prefix}.mlp.fc2.bias")
200
+ )
201
+
202
+ model.text["post_ln"].weight.data.copy_(get_tensor("text_model.lm_head.ln.weight"))
203
+ model.text["post_ln"].bias.data.copy_(get_tensor("text_model.lm_head.ln.bias"))
204
+
205
+ model.text["lm_head"].weight.data.copy_(
206
+ get_tensor("text_model.lm_head.linear.weight")
207
+ )
208
+ model.text["lm_head"].bias.data.copy_(get_tensor("text_model.lm_head.linear.bias"))
209
+
210
+ # Region Model
211
+ model.region.coord_features.data.copy_(
212
+ get_tensor("region_model.coordinate_features.weight").T
213
+ )
214
+ model.region["coord_encoder"].weight.data.copy_(
215
+ get_tensor("region_model.coordinate_encoder.weight")
216
+ )
217
+ model.region["coord_encoder"].bias.data.copy_(
218
+ get_tensor("region_model.coordinate_encoder.bias")
219
+ )
220
+
221
+ model.region["coord_decoder"]["fc1"].weight.data.copy_(
222
+ get_tensor("region_model.coordinate_decoder.fc1.weight")
223
+ )
224
+ model.region["coord_decoder"]["fc1"].bias.data.copy_(
225
+ get_tensor("region_model.coordinate_decoder.fc1.bias")
226
+ )
227
+ model.region["coord_decoder"]["fc2"].weight.data.copy_(
228
+ get_tensor("region_model.coordinate_decoder.fc2.weight")
229
+ )
230
+ model.region["coord_decoder"]["fc2"].bias.data.copy_(
231
+ get_tensor("region_model.coordinate_decoder.fc2.bias")
232
+ )
233
+
234
+ model.region.size_features.data.copy_(
235
+ get_tensor("region_model.size_features.weight").T
236
+ )
237
+ model.region["size_encoder"].weight.data.copy_(
238
+ get_tensor("region_model.size_encoder.weight")
239
+ )
240
+ model.region["size_encoder"].bias.data.copy_(
241
+ get_tensor("region_model.size_encoder.bias")
242
+ )
243
+
244
+ model.region["size_decoder"]["fc1"].weight.data.copy_(
245
+ get_tensor("region_model.size_decoder.fc1.weight")
246
+ )
247
+ model.region["size_decoder"]["fc1"].bias.data.copy_(
248
+ get_tensor("region_model.size_decoder.fc1.bias")
249
+ )
250
+ model.region["size_decoder"]["fc2"].weight.data.copy_(
251
+ get_tensor("region_model.size_decoder.fc2.weight")
252
+ )
253
+ model.region["size_decoder"]["fc2"].bias.data.copy_(
254
+ get_tensor("region_model.size_decoder.fc2.bias")
255
+ )
256
+
257
+
258
+ def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
259
+ """Load weights from a safetensors file into a MoondreamModel instance."""
260
+ with safetensors_open(weights_file) as get_tensor:
261
+ # Wrap the get_tensor function to handle key normalization
262
+ name_map = {k.replace("._orig_mod", ""): k for k in get_tensor.keys()}
263
+ _load_weights(lambda x: get_tensor(name_map[x]).to(dtype=torch.float16), model)
264
+
265
+
266
+ def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
267
+ """Load weights from a PyTorch file into a MoondreamModel instance."""
268
+ device = str(torch.empty(0).device)
269
+ tensors = torch.load(weights_file, map_location=device, weights_only=True)
270
+ tensors = {
271
+ k.replace("._orig_mod", ""): v.to(dtype=torch.float16)
272
+ for k, v in tensors.items()
273
+ }
274
+ _load_weights(lambda x: tensors[x], model)
275
+
276
+
277
+ def load_weights_into_model(weights_file: str, model: nn.Module) -> None:
278
+ """
279
+ Load weights from either a safetensors or PyTorch file directly into a MoondreamModel instance.
280
+
281
+ Args:
282
+ weights_file: Path to weights file (either .safetensors or .pt)
283
+ model: MoondreamModel instance to load weights into
284
+ """
285
+ if weights_file.endswith(".safetensors"):
286
+ load_weights_from_safetensors(weights_file, model)
287
+ else:
288
+ load_weights_from_pt(weights_file, model)
289
+
290
+ # Make all parameters contiguous
291
+ for param in model.parameters():
292
+ param.data = param.data.contiguous()