support infer in cpu device (#4)
Browse files- make vit work with cpu device (73cd0c73d98c27f82cc525c1cc5f48118ea3c686)
- fix import (9d09fe5da66ed112c3a7e066bce23e8aacee3fe3)
- modeling_kimi_vl.py +28 -60
modeling_kimi_vl.py
CHANGED
|
@@ -44,7 +44,6 @@ import math
|
|
| 44 |
import warnings
|
| 45 |
from typing import List, Optional, Tuple, Union
|
| 46 |
from copy import deepcopy
|
| 47 |
-
from functools import cached_property
|
| 48 |
from typing import Union, Tuple, Sequence, Optional, List
|
| 49 |
|
| 50 |
import numpy as np
|
|
@@ -66,10 +65,7 @@ from transformers.modeling_outputs import (
|
|
| 66 |
BaseModelOutputWithPast,
|
| 67 |
CausalLMOutputWithPast,
|
| 68 |
)
|
| 69 |
-
from transformers.pytorch_utils import
|
| 70 |
-
ALL_LAYERNORM_LAYERS,
|
| 71 |
-
is_torch_greater_or_equal_than_1_13,
|
| 72 |
-
)
|
| 73 |
from transformers.utils import (
|
| 74 |
add_start_docstrings,
|
| 75 |
add_start_docstrings_to_model_forward,
|
|
@@ -280,18 +276,18 @@ class MoonVisionPatchEmbed(nn.Module):
|
|
| 280 |
height=pos_emb_height, width=pos_emb_width, dim=out_dim
|
| 281 |
)
|
| 282 |
|
| 283 |
-
def forward(self, x: torch.Tensor,
|
| 284 |
"""
|
| 285 |
Args:
|
| 286 |
x (L, Channels): input tensor
|
| 287 |
-
|
| 288 |
|
| 289 |
Returns:
|
| 290 |
(L, Cout) tensor
|
| 291 |
"""
|
| 292 |
x = self.proj(x).view(x.size(0), -1)
|
| 293 |
# apply positional embedding
|
| 294 |
-
x = self.pos_emb(x,
|
| 295 |
return x
|
| 296 |
|
| 297 |
|
|
@@ -317,22 +313,20 @@ class Rope2DPosEmb(nn.Module):
|
|
| 317 |
device (str): the device to store the precomputed cis
|
| 318 |
"""
|
| 319 |
|
| 320 |
-
def __init__(
|
| 321 |
-
self, dim: int, max_height: int, max_width: int, theta_base=10000, device="cuda"
|
| 322 |
-
):
|
| 323 |
super().__init__()
|
| 324 |
self.dim = dim
|
| 325 |
assert self.dim % 4 == 0, "dim must be divisible by 4"
|
| 326 |
self.max_height = max_height
|
| 327 |
self.max_width = max_width
|
| 328 |
self.theta_base = theta_base
|
| 329 |
-
|
|
|
|
| 330 |
|
| 331 |
def extra_repr(self):
|
| 332 |
return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
|
| 333 |
|
| 334 |
-
|
| 335 |
-
def precomputed_freqs_cis(self) -> torch.Tensor:
|
| 336 |
"""Calculate the cis(freqs) for each position in the 2D grid.
|
| 337 |
|
| 338 |
Return: complex tensor of shape (max_height, max_width, dim//2) and value:
|
|
@@ -341,11 +335,11 @@ class Rope2DPosEmb(nn.Module):
|
|
| 341 |
note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
|
| 342 |
"""
|
| 343 |
N = self.max_height * self.max_width
|
| 344 |
-
flat_pos = torch.arange(0, N).float().to(
|
| 345 |
x_pos = flat_pos % self.max_width
|
| 346 |
y_pos = flat_pos // self.max_width
|
| 347 |
dim_range = (
|
| 348 |
-
torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(
|
| 349 |
) # C/4
|
| 350 |
freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
|
| 351 |
x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
|
|
@@ -360,13 +354,17 @@ class Rope2DPosEmb(nn.Module):
|
|
| 360 |
freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
|
| 361 |
return freqs_cis
|
| 362 |
|
| 363 |
-
def
|
| 364 |
"""
|
| 365 |
Args:
|
| 366 |
-
grid_hws (torch.Tensor):
|
|
|
|
| 367 |
Returns:
|
| 368 |
freqs_cis: tensor of shape (sum(t * height * width), dim//2)
|
| 369 |
"""
|
|
|
|
|
|
|
|
|
|
| 370 |
shapes = grid_hws.tolist()
|
| 371 |
assert all(
|
| 372 |
1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes
|
|
@@ -376,41 +374,11 @@ class Rope2DPosEmb(nn.Module):
|
|
| 376 |
self.max_width,
|
| 377 |
)
|
| 378 |
freqs_cis = torch.cat(
|
| 379 |
-
[
|
| 380 |
-
self.precomputed_freqs_cis[:h, :w].reshape(-1, self.dim // 2)
|
| 381 |
-
for h, w in shapes
|
| 382 |
-
],
|
| 383 |
dim=0,
|
| 384 |
)
|
| 385 |
return freqs_cis
|
| 386 |
|
| 387 |
-
def get_freqs_cis_by_idx(
|
| 388 |
-
self, pos_idx: torch.Tensor, pos_idx_mask: torch.Tensor
|
| 389 |
-
) -> torch.Tensor:
|
| 390 |
-
"""
|
| 391 |
-
Args:
|
| 392 |
-
pos_idx: tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token.
|
| 393 |
-
pos_idx_mask: a mask of shape (...), the leading dimensions should be the same as pos_idx.
|
| 394 |
-
Rope will only be applied to the tokens with True mask. `freqs_cis` for the tokens with False mask with be ones.
|
| 395 |
-
Return:
|
| 396 |
-
freqs_cis: tensor of shape (..., dim//2)
|
| 397 |
-
"""
|
| 398 |
-
assert (
|
| 399 |
-
pos_idx.shape[:-1] == pos_idx_mask.shape
|
| 400 |
-
and pos_idx.shape[-1] == 2
|
| 401 |
-
and pos_idx.ndim == pos_idx_mask.ndim + 1
|
| 402 |
-
), (pos_idx.shape, pos_idx_mask.shape)
|
| 403 |
-
assert pos_idx_mask.dtype == torch.bool, pos_idx_mask.dtype
|
| 404 |
-
|
| 405 |
-
shp = pos_idx_mask.shape + (self.dim // 2,) # ..., head_dim/2
|
| 406 |
-
freqs_cis = torch.ones(
|
| 407 |
-
shp, dtype=torch.complex64, device=self.device
|
| 408 |
-
) # ..., head_dim/2
|
| 409 |
-
freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[
|
| 410 |
-
pos_idx[..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask]
|
| 411 |
-
]
|
| 412 |
-
return freqs_cis
|
| 413 |
-
|
| 414 |
|
| 415 |
class MLP2(nn.Module):
|
| 416 |
"""
|
|
@@ -537,14 +505,14 @@ class MoonVitEncoder(nn.Module):
|
|
| 537 |
self.final_layernorm = nn.LayerNorm(hidden_dim)
|
| 538 |
|
| 539 |
def forward(
|
| 540 |
-
self, hidden_states: torch.Tensor,
|
| 541 |
) -> torch.Tensor:
|
| 542 |
-
rope_freqs_cis = self.rope_2d.
|
| 543 |
|
| 544 |
lengths = torch.cat(
|
| 545 |
(
|
| 546 |
-
torch.zeros(1, device=hidden_states.device, dtype=
|
| 547 |
-
|
| 548 |
)
|
| 549 |
)
|
| 550 |
cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
|
|
@@ -561,14 +529,14 @@ class MoonVitEncoder(nn.Module):
|
|
| 561 |
|
| 562 |
def patch_merger(
|
| 563 |
x: torch.Tensor,
|
| 564 |
-
|
| 565 |
merge_kernel_size: list[int, int] = (2, 2),
|
| 566 |
) -> List[torch.Tensor]:
|
| 567 |
d_model = x.size(-1)
|
| 568 |
|
| 569 |
outputs = []
|
| 570 |
pre_sum = 0
|
| 571 |
-
for x_shape in
|
| 572 |
height, width = x_shape[0], x_shape[1]
|
| 573 |
# Get the current sequence
|
| 574 |
seq = x[pre_sum : pre_sum + height * width]
|
|
@@ -2290,20 +2258,20 @@ class MoonVitPretrainedModel(PreTrainedModel):
|
|
| 2290 |
)
|
| 2291 |
|
| 2292 |
def forward(
|
| 2293 |
-
self, pixel_values: torch.Tensor,
|
| 2294 |
) -> torch.Tensor:
|
| 2295 |
"""
|
| 2296 |
Args:
|
| 2297 |
pixel_values (torch.Tensor): The input pixel values.
|
| 2298 |
-
|
| 2299 |
|
| 2300 |
Returns:
|
| 2301 |
torch.Tensor: The output tokens.
|
| 2302 |
"""
|
| 2303 |
-
hidden_states = self.patch_embed(pixel_values,
|
| 2304 |
-
hidden_states = self.encoder(hidden_states,
|
| 2305 |
hidden_states = patch_merger(
|
| 2306 |
-
hidden_states,
|
| 2307 |
)
|
| 2308 |
return hidden_states
|
| 2309 |
|
|
|
|
| 44 |
import warnings
|
| 45 |
from typing import List, Optional, Tuple, Union
|
| 46 |
from copy import deepcopy
|
|
|
|
| 47 |
from typing import Union, Tuple, Sequence, Optional, List
|
| 48 |
|
| 49 |
import numpy as np
|
|
|
|
| 65 |
BaseModelOutputWithPast,
|
| 66 |
CausalLMOutputWithPast,
|
| 67 |
)
|
| 68 |
+
from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
|
|
|
|
|
|
|
|
|
|
| 69 |
from transformers.utils import (
|
| 70 |
add_start_docstrings,
|
| 71 |
add_start_docstrings_to_model_forward,
|
|
|
|
| 276 |
height=pos_emb_height, width=pos_emb_width, dim=out_dim
|
| 277 |
)
|
| 278 |
|
| 279 |
+
def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
|
| 280 |
"""
|
| 281 |
Args:
|
| 282 |
x (L, Channels): input tensor
|
| 283 |
+
grid_hws (N, 2): grid height and width
|
| 284 |
|
| 285 |
Returns:
|
| 286 |
(L, Cout) tensor
|
| 287 |
"""
|
| 288 |
x = self.proj(x).view(x.size(0), -1)
|
| 289 |
# apply positional embedding
|
| 290 |
+
x = self.pos_emb(x, grid_hws)
|
| 291 |
return x
|
| 292 |
|
| 293 |
|
|
|
|
| 313 |
device (str): the device to store the precomputed cis
|
| 314 |
"""
|
| 315 |
|
| 316 |
+
def __init__(self, dim: int, max_height: int, max_width: int, theta_base=10000):
|
|
|
|
|
|
|
| 317 |
super().__init__()
|
| 318 |
self.dim = dim
|
| 319 |
assert self.dim % 4 == 0, "dim must be divisible by 4"
|
| 320 |
self.max_height = max_height
|
| 321 |
self.max_width = max_width
|
| 322 |
self.theta_base = theta_base
|
| 323 |
+
|
| 324 |
+
self.freqs_cis = None
|
| 325 |
|
| 326 |
def extra_repr(self):
|
| 327 |
return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
|
| 328 |
|
| 329 |
+
def _precompute_freqs_cis(self, device: torch.device) -> torch.Tensor:
|
|
|
|
| 330 |
"""Calculate the cis(freqs) for each position in the 2D grid.
|
| 331 |
|
| 332 |
Return: complex tensor of shape (max_height, max_width, dim//2) and value:
|
|
|
|
| 335 |
note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
|
| 336 |
"""
|
| 337 |
N = self.max_height * self.max_width
|
| 338 |
+
flat_pos = torch.arange(0, N).float().to(device)
|
| 339 |
x_pos = flat_pos % self.max_width
|
| 340 |
y_pos = flat_pos // self.max_width
|
| 341 |
dim_range = (
|
| 342 |
+
torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(device)
|
| 343 |
) # C/4
|
| 344 |
freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
|
| 345 |
x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
|
|
|
|
| 354 |
freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
|
| 355 |
return freqs_cis
|
| 356 |
|
| 357 |
+
def get_freqs_cis(self, grid_hws: torch.Tensor) -> torch.Tensor:
|
| 358 |
"""
|
| 359 |
Args:
|
| 360 |
+
grid_hws (torch.Tensor): grid height and width
|
| 361 |
+
|
| 362 |
Returns:
|
| 363 |
freqs_cis: tensor of shape (sum(t * height * width), dim//2)
|
| 364 |
"""
|
| 365 |
+
if self.freqs_cis is None:
|
| 366 |
+
self.freqs_cis = self._precompute_freqs_cis(grid_hws.device)
|
| 367 |
+
|
| 368 |
shapes = grid_hws.tolist()
|
| 369 |
assert all(
|
| 370 |
1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes
|
|
|
|
| 374 |
self.max_width,
|
| 375 |
)
|
| 376 |
freqs_cis = torch.cat(
|
| 377 |
+
[self.freqs_cis[:h, :w].reshape(-1, self.dim // 2) for h, w in shapes],
|
|
|
|
|
|
|
|
|
|
| 378 |
dim=0,
|
| 379 |
)
|
| 380 |
return freqs_cis
|
| 381 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
|
| 383 |
class MLP2(nn.Module):
|
| 384 |
"""
|
|
|
|
| 505 |
self.final_layernorm = nn.LayerNorm(hidden_dim)
|
| 506 |
|
| 507 |
def forward(
|
| 508 |
+
self, hidden_states: torch.Tensor, grid_hws: torch.Tensor
|
| 509 |
) -> torch.Tensor:
|
| 510 |
+
rope_freqs_cis = self.rope_2d.get_freqs_cis(grid_hws=grid_hws)
|
| 511 |
|
| 512 |
lengths = torch.cat(
|
| 513 |
(
|
| 514 |
+
torch.zeros(1, device=hidden_states.device, dtype=grid_hws.dtype),
|
| 515 |
+
grid_hws[:, 0] * grid_hws[:, 1],
|
| 516 |
)
|
| 517 |
)
|
| 518 |
cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
|
|
|
|
| 529 |
|
| 530 |
def patch_merger(
|
| 531 |
x: torch.Tensor,
|
| 532 |
+
grid_hws: torch.Tensor,
|
| 533 |
merge_kernel_size: list[int, int] = (2, 2),
|
| 534 |
) -> List[torch.Tensor]:
|
| 535 |
d_model = x.size(-1)
|
| 536 |
|
| 537 |
outputs = []
|
| 538 |
pre_sum = 0
|
| 539 |
+
for x_shape in grid_hws.tolist():
|
| 540 |
height, width = x_shape[0], x_shape[1]
|
| 541 |
# Get the current sequence
|
| 542 |
seq = x[pre_sum : pre_sum + height * width]
|
|
|
|
| 2258 |
)
|
| 2259 |
|
| 2260 |
def forward(
|
| 2261 |
+
self, pixel_values: torch.Tensor, grid_hws: torch.Tensor
|
| 2262 |
) -> torch.Tensor:
|
| 2263 |
"""
|
| 2264 |
Args:
|
| 2265 |
pixel_values (torch.Tensor): The input pixel values.
|
| 2266 |
+
grid_hws (torch.Tensor): The grid height and width.
|
| 2267 |
|
| 2268 |
Returns:
|
| 2269 |
torch.Tensor: The output tokens.
|
| 2270 |
"""
|
| 2271 |
+
hidden_states = self.patch_embed(pixel_values, grid_hws)
|
| 2272 |
+
hidden_states = self.encoder(hidden_states, grid_hws)
|
| 2273 |
hidden_states = patch_merger(
|
| 2274 |
+
hidden_states, grid_hws, merge_kernel_size=self.merge_kernel_size
|
| 2275 |
)
|
| 2276 |
return hidden_states
|
| 2277 |
|