vikhyatk commited on
Commit
48640e9
·
verified ·
1 Parent(s): 78c093c

Upload HfMoondream

Browse files
Files changed (3) hide show
  1. image_crops.py +36 -13
  2. moondream.py +2 -0
  3. text.py +0 -12
image_crops.py CHANGED
@@ -1,10 +1,18 @@
1
  import math
2
  import numpy as np
3
  import torch
4
- import pyvips
5
 
6
  from typing import TypedDict
7
 
 
 
 
 
 
 
 
 
 
8
 
9
  def select_tiling(
10
  height: int, width: int, crop_size: int, max_crops: int
@@ -113,18 +121,33 @@ def overlap_crop_image(
113
  tiling[1] * crop_window_size + total_margin_pixels,
114
  )
115
 
116
- # Convert to vips for resizing
117
- vips_image = pyvips.Image.new_from_array(image)
118
- scale_x = target_size[1] / image.shape[1]
119
- scale_y = target_size[0] / image.shape[0]
120
- resized = vips_image.resize(scale_x, vscale=scale_y)
121
- image = resized.numpy()
122
-
123
- # Create global crop
124
- scale_x = base_size[1] / vips_image.width
125
- scale_y = base_size[0] / vips_image.height
126
- global_vips = vips_image.resize(scale_x, vscale=scale_y)
127
- crops[0] = global_vips.numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  for i in range(tiling[0]):
130
  for j in range(tiling[1]):
 
1
  import math
2
  import numpy as np
3
  import torch
 
4
 
5
  from typing import TypedDict
6
 
7
+ try:
8
+ import pyvips
9
+
10
+ HAS_VIPS = True
11
+ except:
12
+ from PIL import Image
13
+
14
+ HAS_VIPS = False
15
+
16
 
17
  def select_tiling(
18
  height: int, width: int, crop_size: int, max_crops: int
 
121
  tiling[1] * crop_window_size + total_margin_pixels,
122
  )
123
 
124
+ if HAS_VIPS:
125
+ # Convert to vips for resizing
126
+ vips_image = pyvips.Image.new_from_array(image)
127
+ scale_x = target_size[1] / image.shape[1]
128
+ scale_y = target_size[0] / image.shape[0]
129
+ resized = vips_image.resize(scale_x, vscale=scale_y)
130
+ image = resized.numpy()
131
+
132
+ # Create global crop
133
+ scale_x = base_size[1] / vips_image.width
134
+ scale_y = base_size[0] / vips_image.height
135
+ global_vips = vips_image.resize(scale_x, vscale=scale_y)
136
+ crops[0] = global_vips.numpy()
137
+ else:
138
+ # Fallback to PIL
139
+ pil_img = Image.fromarray(image)
140
+ resized = pil_img.resize(
141
+ (int(target_size[1]), int(target_size[0])),
142
+ resample=Image.Resampling.LANCZOS,
143
+ )
144
+ image = np.asarray(resized)
145
+
146
+ # Create global crop
147
+ global_pil = pil_img.resize(
148
+ (int(base_size[1]), int(base_size[0])), resample=Image.Resampling.LANCZOS
149
+ )
150
+ crops[0] = np.asarray(global_pil)
151
 
152
  for i in range(tiling[0]):
153
  for j in range(tiling[1]):
moondream.py CHANGED
@@ -182,6 +182,7 @@ class MoondreamModel(nn.Module):
182
 
183
  def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
184
  all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
 
185
  torch._dynamo.mark_dynamic(all_crops, 0)
186
 
187
  outputs = self._vis_enc(all_crops)
@@ -249,6 +250,7 @@ class MoondreamModel(nn.Module):
249
  with torch.inference_mode():
250
  prompt_emb = text_encoder(prompt_tokens, self.text)
251
  torch._dynamo.mark_dynamic(prompt_emb, 1)
 
252
  mask = self.attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
253
  pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.long)
254
  hidden = self._prefill(prompt_emb, mask, pos_ids)
 
182
 
183
  def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
184
  all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)
185
+
186
  torch._dynamo.mark_dynamic(all_crops, 0)
187
 
188
  outputs = self._vis_enc(all_crops)
 
250
  with torch.inference_mode():
251
  prompt_emb = text_encoder(prompt_tokens, self.text)
252
  torch._dynamo.mark_dynamic(prompt_emb, 1)
253
+
254
  mask = self.attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
255
  pos_ids = torch.arange(pos, pos + prompt_emb.size(1), dtype=torch.long)
256
  hidden = self._prefill(prompt_emb, mask, pos_ids)
text.py CHANGED
@@ -35,18 +35,6 @@ def attn(
35
  k = k.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
36
  v = v.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
37
 
38
- # q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
39
- # k = (
40
- # qkv_out[..., q_dim : q_dim + kv_dim]
41
- # .view(bsz, q_len, n_kv_heads, head_dim)
42
- # .transpose(1, 2)
43
- # )
44
- # v = (
45
- # qkv_out[..., q_dim + kv_dim :]
46
- # .view(bsz, q_len, n_kv_heads, head_dim)
47
- # .transpose(1, 2)
48
- # )
49
-
50
  q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
51
  k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
52
 
 
35
  k = k.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
36
  v = v.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
37
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
39
  k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
40