support infer in cpu device (#4)
Browse files- make vit work with cpu device (73cd0c73d98c27f82cc525c1cc5f48118ea3c686)
- fix import (9d09fe5da66ed112c3a7e066bce23e8aacee3fe3)
- modeling_kimi_vl.py +28 -60
modeling_kimi_vl.py
CHANGED
@@ -44,7 +44,6 @@ import math
|
|
44 |
import warnings
|
45 |
from typing import List, Optional, Tuple, Union
|
46 |
from copy import deepcopy
|
47 |
-
from functools import cached_property
|
48 |
from typing import Union, Tuple, Sequence, Optional, List
|
49 |
|
50 |
import numpy as np
|
@@ -66,10 +65,7 @@ from transformers.modeling_outputs import (
|
|
66 |
BaseModelOutputWithPast,
|
67 |
CausalLMOutputWithPast,
|
68 |
)
|
69 |
-
from transformers.pytorch_utils import
|
70 |
-
ALL_LAYERNORM_LAYERS,
|
71 |
-
is_torch_greater_or_equal_than_1_13,
|
72 |
-
)
|
73 |
from transformers.utils import (
|
74 |
add_start_docstrings,
|
75 |
add_start_docstrings_to_model_forward,
|
@@ -280,18 +276,18 @@ class MoonVisionPatchEmbed(nn.Module):
|
|
280 |
height=pos_emb_height, width=pos_emb_width, dim=out_dim
|
281 |
)
|
282 |
|
283 |
-
def forward(self, x: torch.Tensor,
|
284 |
"""
|
285 |
Args:
|
286 |
x (L, Channels): input tensor
|
287 |
-
|
288 |
|
289 |
Returns:
|
290 |
(L, Cout) tensor
|
291 |
"""
|
292 |
x = self.proj(x).view(x.size(0), -1)
|
293 |
# apply positional embedding
|
294 |
-
x = self.pos_emb(x,
|
295 |
return x
|
296 |
|
297 |
|
@@ -317,22 +313,20 @@ class Rope2DPosEmb(nn.Module):
|
|
317 |
device (str): the device to store the precomputed cis
|
318 |
"""
|
319 |
|
320 |
-
def __init__(
|
321 |
-
self, dim: int, max_height: int, max_width: int, theta_base=10000, device="cuda"
|
322 |
-
):
|
323 |
super().__init__()
|
324 |
self.dim = dim
|
325 |
assert self.dim % 4 == 0, "dim must be divisible by 4"
|
326 |
self.max_height = max_height
|
327 |
self.max_width = max_width
|
328 |
self.theta_base = theta_base
|
329 |
-
|
|
|
330 |
|
331 |
def extra_repr(self):
|
332 |
return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
|
333 |
|
334 |
-
|
335 |
-
def precomputed_freqs_cis(self) -> torch.Tensor:
|
336 |
"""Calculate the cis(freqs) for each position in the 2D grid.
|
337 |
|
338 |
Return: complex tensor of shape (max_height, max_width, dim//2) and value:
|
@@ -341,11 +335,11 @@ class Rope2DPosEmb(nn.Module):
|
|
341 |
note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
|
342 |
"""
|
343 |
N = self.max_height * self.max_width
|
344 |
-
flat_pos = torch.arange(0, N).float().to(
|
345 |
x_pos = flat_pos % self.max_width
|
346 |
y_pos = flat_pos // self.max_width
|
347 |
dim_range = (
|
348 |
-
torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(
|
349 |
) # C/4
|
350 |
freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
|
351 |
x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
|
@@ -360,13 +354,17 @@ class Rope2DPosEmb(nn.Module):
|
|
360 |
freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
|
361 |
return freqs_cis
|
362 |
|
363 |
-
def
|
364 |
"""
|
365 |
Args:
|
366 |
-
grid_hws (torch.Tensor):
|
|
|
367 |
Returns:
|
368 |
freqs_cis: tensor of shape (sum(t * height * width), dim//2)
|
369 |
"""
|
|
|
|
|
|
|
370 |
shapes = grid_hws.tolist()
|
371 |
assert all(
|
372 |
1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes
|
@@ -376,41 +374,11 @@ class Rope2DPosEmb(nn.Module):
|
|
376 |
self.max_width,
|
377 |
)
|
378 |
freqs_cis = torch.cat(
|
379 |
-
[
|
380 |
-
self.precomputed_freqs_cis[:h, :w].reshape(-1, self.dim // 2)
|
381 |
-
for h, w in shapes
|
382 |
-
],
|
383 |
dim=0,
|
384 |
)
|
385 |
return freqs_cis
|
386 |
|
387 |
-
def get_freqs_cis_by_idx(
|
388 |
-
self, pos_idx: torch.Tensor, pos_idx_mask: torch.Tensor
|
389 |
-
) -> torch.Tensor:
|
390 |
-
"""
|
391 |
-
Args:
|
392 |
-
pos_idx: tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token.
|
393 |
-
pos_idx_mask: a mask of shape (...), the leading dimensions should be the same as pos_idx.
|
394 |
-
Rope will only be applied to the tokens with True mask. `freqs_cis` for the tokens with False mask with be ones.
|
395 |
-
Return:
|
396 |
-
freqs_cis: tensor of shape (..., dim//2)
|
397 |
-
"""
|
398 |
-
assert (
|
399 |
-
pos_idx.shape[:-1] == pos_idx_mask.shape
|
400 |
-
and pos_idx.shape[-1] == 2
|
401 |
-
and pos_idx.ndim == pos_idx_mask.ndim + 1
|
402 |
-
), (pos_idx.shape, pos_idx_mask.shape)
|
403 |
-
assert pos_idx_mask.dtype == torch.bool, pos_idx_mask.dtype
|
404 |
-
|
405 |
-
shp = pos_idx_mask.shape + (self.dim // 2,) # ..., head_dim/2
|
406 |
-
freqs_cis = torch.ones(
|
407 |
-
shp, dtype=torch.complex64, device=self.device
|
408 |
-
) # ..., head_dim/2
|
409 |
-
freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[
|
410 |
-
pos_idx[..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask]
|
411 |
-
]
|
412 |
-
return freqs_cis
|
413 |
-
|
414 |
|
415 |
class MLP2(nn.Module):
|
416 |
"""
|
@@ -537,14 +505,14 @@ class MoonVitEncoder(nn.Module):
|
|
537 |
self.final_layernorm = nn.LayerNorm(hidden_dim)
|
538 |
|
539 |
def forward(
|
540 |
-
self, hidden_states: torch.Tensor,
|
541 |
) -> torch.Tensor:
|
542 |
-
rope_freqs_cis = self.rope_2d.
|
543 |
|
544 |
lengths = torch.cat(
|
545 |
(
|
546 |
-
torch.zeros(1, device=hidden_states.device, dtype=
|
547 |
-
|
548 |
)
|
549 |
)
|
550 |
cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
|
@@ -561,14 +529,14 @@ class MoonVitEncoder(nn.Module):
|
|
561 |
|
562 |
def patch_merger(
|
563 |
x: torch.Tensor,
|
564 |
-
|
565 |
merge_kernel_size: list[int, int] = (2, 2),
|
566 |
) -> List[torch.Tensor]:
|
567 |
d_model = x.size(-1)
|
568 |
|
569 |
outputs = []
|
570 |
pre_sum = 0
|
571 |
-
for x_shape in
|
572 |
height, width = x_shape[0], x_shape[1]
|
573 |
# Get the current sequence
|
574 |
seq = x[pre_sum : pre_sum + height * width]
|
@@ -2290,20 +2258,20 @@ class MoonVitPretrainedModel(PreTrainedModel):
|
|
2290 |
)
|
2291 |
|
2292 |
def forward(
|
2293 |
-
self, pixel_values: torch.Tensor,
|
2294 |
) -> torch.Tensor:
|
2295 |
"""
|
2296 |
Args:
|
2297 |
pixel_values (torch.Tensor): The input pixel values.
|
2298 |
-
|
2299 |
|
2300 |
Returns:
|
2301 |
torch.Tensor: The output tokens.
|
2302 |
"""
|
2303 |
-
hidden_states = self.patch_embed(pixel_values,
|
2304 |
-
hidden_states = self.encoder(hidden_states,
|
2305 |
hidden_states = patch_merger(
|
2306 |
-
hidden_states,
|
2307 |
)
|
2308 |
return hidden_states
|
2309 |
|
|
|
44 |
import warnings
|
45 |
from typing import List, Optional, Tuple, Union
|
46 |
from copy import deepcopy
|
|
|
47 |
from typing import Union, Tuple, Sequence, Optional, List
|
48 |
|
49 |
import numpy as np
|
|
|
65 |
BaseModelOutputWithPast,
|
66 |
CausalLMOutputWithPast,
|
67 |
)
|
68 |
+
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
|
|
|
|
|
|
|
69 |
from transformers.utils import (
|
70 |
add_start_docstrings,
|
71 |
add_start_docstrings_to_model_forward,
|
|
|
276 |
height=pos_emb_height, width=pos_emb_width, dim=out_dim
|
277 |
)
|
278 |
|
279 |
+
def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
|
280 |
"""
|
281 |
Args:
|
282 |
x (L, Channels): input tensor
|
283 |
+
grid_hws (N, 2): grid height and width
|
284 |
|
285 |
Returns:
|
286 |
(L, Cout) tensor
|
287 |
"""
|
288 |
x = self.proj(x).view(x.size(0), -1)
|
289 |
# apply positional embedding
|
290 |
+
x = self.pos_emb(x, grid_hws)
|
291 |
return x
|
292 |
|
293 |
|
|
|
313 |
device (str): the device to store the precomputed cis
|
314 |
"""
|
315 |
|
316 |
+
def __init__(self, dim: int, max_height: int, max_width: int, theta_base=10000):
|
|
|
|
|
317 |
super().__init__()
|
318 |
self.dim = dim
|
319 |
assert self.dim % 4 == 0, "dim must be divisible by 4"
|
320 |
self.max_height = max_height
|
321 |
self.max_width = max_width
|
322 |
self.theta_base = theta_base
|
323 |
+
|
324 |
+
self.freqs_cis = None
|
325 |
|
326 |
def extra_repr(self):
|
327 |
return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
|
328 |
|
329 |
+
def _precompute_freqs_cis(self, device: torch.device) -> torch.Tensor:
|
|
|
330 |
"""Calculate the cis(freqs) for each position in the 2D grid.
|
331 |
|
332 |
Return: complex tensor of shape (max_height, max_width, dim//2) and value:
|
|
|
335 |
note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
|
336 |
"""
|
337 |
N = self.max_height * self.max_width
|
338 |
+
flat_pos = torch.arange(0, N).float().to(device)
|
339 |
x_pos = flat_pos % self.max_width
|
340 |
y_pos = flat_pos // self.max_width
|
341 |
dim_range = (
|
342 |
+
torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(device)
|
343 |
) # C/4
|
344 |
freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
|
345 |
x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
|
|
|
354 |
freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
|
355 |
return freqs_cis
|
356 |
|
357 |
+
def get_freqs_cis(self, grid_hws: torch.Tensor) -> torch.Tensor:
|
358 |
"""
|
359 |
Args:
|
360 |
+
grid_hws (torch.Tensor): grid height and width
|
361 |
+
|
362 |
Returns:
|
363 |
freqs_cis: tensor of shape (sum(t * height * width), dim//2)
|
364 |
"""
|
365 |
+
if self.freqs_cis is None:
|
366 |
+
self.freqs_cis = self._precompute_freqs_cis(grid_hws.device)
|
367 |
+
|
368 |
shapes = grid_hws.tolist()
|
369 |
assert all(
|
370 |
1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes
|
|
|
374 |
self.max_width,
|
375 |
)
|
376 |
freqs_cis = torch.cat(
|
377 |
+
[self.freqs_cis[:h, :w].reshape(-1, self.dim // 2) for h, w in shapes],
|
|
|
|
|
|
|
378 |
dim=0,
|
379 |
)
|
380 |
return freqs_cis
|
381 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
|
383 |
class MLP2(nn.Module):
|
384 |
"""
|
|
|
505 |
self.final_layernorm = nn.LayerNorm(hidden_dim)
|
506 |
|
507 |
def forward(
|
508 |
+
self, hidden_states: torch.Tensor, grid_hws: torch.Tensor
|
509 |
) -> torch.Tensor:
|
510 |
+
rope_freqs_cis = self.rope_2d.get_freqs_cis(grid_hws=grid_hws)
|
511 |
|
512 |
lengths = torch.cat(
|
513 |
(
|
514 |
+
torch.zeros(1, device=hidden_states.device, dtype=grid_hws.dtype),
|
515 |
+
grid_hws[:, 0] * grid_hws[:, 1],
|
516 |
)
|
517 |
)
|
518 |
cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
|
|
|
529 |
|
530 |
def patch_merger(
|
531 |
x: torch.Tensor,
|
532 |
+
grid_hws: torch.Tensor,
|
533 |
merge_kernel_size: list[int, int] = (2, 2),
|
534 |
) -> List[torch.Tensor]:
|
535 |
d_model = x.size(-1)
|
536 |
|
537 |
outputs = []
|
538 |
pre_sum = 0
|
539 |
+
for x_shape in grid_hws.tolist():
|
540 |
height, width = x_shape[0], x_shape[1]
|
541 |
# Get the current sequence
|
542 |
seq = x[pre_sum : pre_sum + height * width]
|
|
|
2258 |
)
|
2259 |
|
2260 |
def forward(
|
2261 |
+
self, pixel_values: torch.Tensor, grid_hws: torch.Tensor
|
2262 |
) -> torch.Tensor:
|
2263 |
"""
|
2264 |
Args:
|
2265 |
pixel_values (torch.Tensor): The input pixel values.
|
2266 |
+
grid_hws (torch.Tensor): The grid height and width.
|
2267 |
|
2268 |
Returns:
|
2269 |
torch.Tensor: The output tokens.
|
2270 |
"""
|
2271 |
+
hidden_states = self.patch_embed(pixel_values, grid_hws)
|
2272 |
+
hidden_states = self.encoder(hidden_states, grid_hws)
|
2273 |
hidden_states = patch_merger(
|
2274 |
+
hidden_states, grid_hws, merge_kernel_size=self.merge_kernel_size
|
2275 |
)
|
2276 |
return hidden_states
|
2277 |
|