Upload HfMoondream
Browse files- image_crops.py +36 -13
- moondream.py +2 -0
- text.py +0 -12
image_crops.py
CHANGED
@@ -1,10 +1,18 @@
|
|
1 |
import math
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
-
import pyvips
|
5 |
|
6 |
from typing import TypedDict
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
def select_tiling(
|
10 |
height: int, width: int, crop_size: int, max_crops: int
|
@@ -113,18 +121,33 @@ def overlap_crop_image(
|
|
113 |
tiling[1] * crop_window_size + total_margin_pixels,
|
114 |
)
|
115 |
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
for i in range(tiling[0]):
|
130 |
for j in range(tiling[1]):
|
|
|
1 |
import math
|
2 |
import numpy as np
|
3 |
import torch
|
|
|
4 |
|
5 |
from typing import TypedDict
|
6 |
|
7 |
+
try:
|
8 |
+
import pyvips
|
9 |
+
|
10 |
+
HAS_VIPS = True
|
11 |
+
except:
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
HAS_VIPS = False
|
15 |
+
|
16 |
|
17 |
def select_tiling(
|
18 |
height: int, width: int, crop_size: int, max_crops: int
|
|
|
121 |
tiling[1] * crop_window_size + total_margin_pixels,
|
122 |
)
|
123 |
|
124 |
+
if HAS_VIPS:
|
125 |
+
# Convert to vips for resizing
|
126 |
+
vips_image = pyvips.Image.new_from_array(image)
|
127 |
+
scale_x = target_size[1] / image.shape[1]
|
128 |
+
scale_y = target_size[0] / image.shape[0]
|
129 |
+
resized = vips_image.resize(scale_x, vscale=scale_y)
|
130 |
+
image = resized.numpy()
|
131 |
+
|
132 |
+
# Create global crop
|
133 |
+
scale_x = base_size[1] / vips_image.width
|
134 |
+
scale_y = base_size[0] / vips_image.height
|
135 |
+
global_vips = vips_image.resize(scale_x, vscale=scale_y)
|
136 |
+
crops[0] = global_vips.numpy()
|
137 |
+
else:
|
138 |
+
# Fallback to PIL
|
139 |
+
pil_img = Image.fromarray(image)
|
140 |
+
resized = pil_img.resize(
|
141 |
+
(int(target_size[1]), int(target_size[0])),
|
142 |
+
resample=Image.Resampling.LANCZOS,
|
143 |
+
)
|
144 |
+
image = np.asarray(resized)
|
145 |
+
|
146 |
+
# Create global crop
|
147 |
+
global_pil = pil_img.resize(
|
148 |
+
(int(base_size[1]), int(base_size[0])), resample=Image.Resampling.LANCZOS
|
149 |
+
)
|
150 |
+
crops[0] = np.asarray(global_pil)
|
151 |
|
152 |
for i in range(tiling[0]):
|
153 |
for j in range(tiling[1]):
|
moondream.py
CHANGED
@@ -182,6 +182,7 @@ class MoondreamModel(nn.Module):
|
|
182 |
|
183 |
def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
|
184 |
all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
|
|
|
185 |
torch._dynamo.mark_dynamic(all_crops, 0)
|
186 |
|
187 |
outputs = self._vis_enc(all_crops)
|
@@ -249,6 +250,7 @@ class MoondreamModel(nn.Module):
|
|
249 |
with torch.inference_mode():
|
250 |
prompt_emb = text_encoder(prompt_tokens, self.text)
|
251 |
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
|
|
252 |
mask = self.attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
|
253 |
pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.long)
|
254 |
hidden = self._prefill(prompt_emb, mask, pos_ids)
|
|
|
182 |
|
183 |
def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
|
184 |
all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
|
185 |
+
|
186 |
torch._dynamo.mark_dynamic(all_crops, 0)
|
187 |
|
188 |
outputs = self._vis_enc(all_crops)
|
|
|
250 |
with torch.inference_mode():
|
251 |
prompt_emb = text_encoder(prompt_tokens, self.text)
|
252 |
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
253 |
+
|
254 |
mask = self.attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
|
255 |
pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.long)
|
256 |
hidden = self._prefill(prompt_emb, mask, pos_ids)
|
text.py
CHANGED
@@ -35,18 +35,6 @@ def attn(
|
|
35 |
k = k.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
|
36 |
v = v.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
|
37 |
|
38 |
-
# q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
|
39 |
-
# k = (
|
40 |
-
# qkv_out[..., q_dim : q_dim + kv_dim]
|
41 |
-
# .view(bsz, q_len, n_kv_heads, head_dim)
|
42 |
-
# .transpose(1, 2)
|
43 |
-
# )
|
44 |
-
# v = (
|
45 |
-
# qkv_out[..., q_dim + kv_dim :]
|
46 |
-
# .view(bsz, q_len, n_kv_heads, head_dim)
|
47 |
-
# .transpose(1, 2)
|
48 |
-
# )
|
49 |
-
|
50 |
q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
|
51 |
k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
|
52 |
|
|
|
35 |
k = k.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
|
36 |
v = v.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
|
39 |
k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
|
40 |
|