zhouzaida commited on
Commit
9e0cbfa
·
verified ·
1 Parent(s): 704b5c8

support infer in cpu device (#4)

Browse files

- make vit work with cpu device (73cd0c73d98c27f82cc525c1cc5f48118ea3c686)
- fix import (9d09fe5da66ed112c3a7e066bce23e8aacee3fe3)

Files changed (1) hide show
  1. modeling_kimi_vl.py +28 -60
modeling_kimi_vl.py CHANGED
@@ -44,7 +44,6 @@ import math
44
  import warnings
45
  from typing import List, Optional, Tuple, Union
46
  from copy import deepcopy
47
- from functools import cached_property
48
  from typing import Union, Tuple, Sequence, Optional, List
49
 
50
  import numpy as np
@@ -66,10 +65,7 @@ from transformers.modeling_outputs import (
66
  BaseModelOutputWithPast,
67
  CausalLMOutputWithPast,
68
  )
69
- from transformers.pytorch_utils import (
70
- ALL_LAYERNORM_LAYERS,
71
- is_torch_greater_or_equal_than_1_13,
72
- )
73
  from transformers.utils import (
74
  add_start_docstrings,
75
  add_start_docstrings_to_model_forward,
@@ -280,18 +276,18 @@ class MoonVisionPatchEmbed(nn.Module):
280
  height=pos_emb_height, width=pos_emb_width, dim=out_dim
281
  )
282
 
283
- def forward(self, x: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor:
284
  """
285
  Args:
286
  x (L, Channels): input tensor
287
- grid_hw (N, 2): grid height and width
288
 
289
  Returns:
290
  (L, Cout) tensor
291
  """
292
  x = self.proj(x).view(x.size(0), -1)
293
  # apply positional embedding
294
- x = self.pos_emb(x, grid_hw)
295
  return x
296
 
297
 
@@ -317,22 +313,20 @@ class Rope2DPosEmb(nn.Module):
317
  device (str): the device to store the precomputed cis
318
  """
319
 
320
- def __init__(
321
- self, dim: int, max_height: int, max_width: int, theta_base=10000, device="cuda"
322
- ):
323
  super().__init__()
324
  self.dim = dim
325
  assert self.dim % 4 == 0, "dim must be divisible by 4"
326
  self.max_height = max_height
327
  self.max_width = max_width
328
  self.theta_base = theta_base
329
- self.device = device
 
330
 
331
  def extra_repr(self):
332
  return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
333
 
334
- @cached_property
335
- def precomputed_freqs_cis(self) -> torch.Tensor:
336
  """Calculate the cis(freqs) for each position in the 2D grid.
337
 
338
  Return: complex tensor of shape (max_height, max_width, dim//2) and value:
@@ -341,11 +335,11 @@ class Rope2DPosEmb(nn.Module):
341
  note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
342
  """
343
  N = self.max_height * self.max_width
344
- flat_pos = torch.arange(0, N).float().to(self.device)
345
  x_pos = flat_pos % self.max_width
346
  y_pos = flat_pos // self.max_width
347
  dim_range = (
348
- torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(self.device)
349
  ) # C/4
350
  freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
351
  x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
@@ -360,13 +354,17 @@ class Rope2DPosEmb(nn.Module):
360
  freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
361
  return freqs_cis
362
 
363
- def get_freqs_cis_by_seqlens(self, grid_hws: torch.Tensor) -> torch.Tensor:
364
  """
365
  Args:
366
- grid_hws (torch.Tensor): containing list of (height, width) or (t, height, width) tuples.
 
367
  Returns:
368
  freqs_cis: tensor of shape (sum(t * height * width), dim//2)
369
  """
 
 
 
370
  shapes = grid_hws.tolist()
371
  assert all(
372
  1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes
@@ -376,41 +374,11 @@ class Rope2DPosEmb(nn.Module):
376
  self.max_width,
377
  )
378
  freqs_cis = torch.cat(
379
- [
380
- self.precomputed_freqs_cis[:h, :w].reshape(-1, self.dim // 2)
381
- for h, w in shapes
382
- ],
383
  dim=0,
384
  )
385
  return freqs_cis
386
 
387
- def get_freqs_cis_by_idx(
388
- self, pos_idx: torch.Tensor, pos_idx_mask: torch.Tensor
389
- ) -> torch.Tensor:
390
- """
391
- Args:
392
- pos_idx: tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token.
393
- pos_idx_mask: a mask of shape (...), the leading dimensions should be the same as pos_idx.
394
- Rope will only be applied to the tokens with True mask. `freqs_cis` for the tokens with False mask with be ones.
395
- Return:
396
- freqs_cis: tensor of shape (..., dim//2)
397
- """
398
- assert (
399
- pos_idx.shape[:-1] == pos_idx_mask.shape
400
- and pos_idx.shape[-1] == 2
401
- and pos_idx.ndim == pos_idx_mask.ndim + 1
402
- ), (pos_idx.shape, pos_idx_mask.shape)
403
- assert pos_idx_mask.dtype == torch.bool, pos_idx_mask.dtype
404
-
405
- shp = pos_idx_mask.shape + (self.dim // 2,) # ..., head_dim/2
406
- freqs_cis = torch.ones(
407
- shp, dtype=torch.complex64, device=self.device
408
- ) # ..., head_dim/2
409
- freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[
410
- pos_idx[..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask]
411
- ]
412
- return freqs_cis
413
-
414
 
415
  class MLP2(nn.Module):
416
  """
@@ -537,14 +505,14 @@ class MoonVitEncoder(nn.Module):
537
  self.final_layernorm = nn.LayerNorm(hidden_dim)
538
 
539
  def forward(
540
- self, hidden_states: torch.Tensor, grid_hw: torch.Tensor
541
  ) -> torch.Tensor:
542
- rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens(grid_hws=grid_hw)
543
 
544
  lengths = torch.cat(
545
  (
546
- torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype),
547
- grid_hw[:, 0] * grid_hw[:, 1],
548
  )
549
  )
550
  cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
@@ -561,14 +529,14 @@ class MoonVitEncoder(nn.Module):
561
 
562
  def patch_merger(
563
  x: torch.Tensor,
564
- grid_hw: torch.Tensor,
565
  merge_kernel_size: list[int, int] = (2, 2),
566
  ) -> List[torch.Tensor]:
567
  d_model = x.size(-1)
568
 
569
  outputs = []
570
  pre_sum = 0
571
- for x_shape in grid_hw.tolist():
572
  height, width = x_shape[0], x_shape[1]
573
  # Get the current sequence
574
  seq = x[pre_sum : pre_sum + height * width]
@@ -2290,20 +2258,20 @@ class MoonVitPretrainedModel(PreTrainedModel):
2290
  )
2291
 
2292
  def forward(
2293
- self, pixel_values: torch.Tensor, grid_hw: torch.Tensor
2294
  ) -> torch.Tensor:
2295
  """
2296
  Args:
2297
  pixel_values (torch.Tensor): The input pixel values.
2298
- grid_hw (torch.Tensor): The grid height and width.
2299
 
2300
  Returns:
2301
  torch.Tensor: The output tokens.
2302
  """
2303
- hidden_states = self.patch_embed(pixel_values, grid_hw)
2304
- hidden_states = self.encoder(hidden_states, grid_hw)
2305
  hidden_states = patch_merger(
2306
- hidden_states, grid_hw, merge_kernel_size=self.merge_kernel_size
2307
  )
2308
  return hidden_states
2309
 
 
44
  import warnings
45
  from typing import List, Optional, Tuple, Union
46
  from copy import deepcopy
 
47
  from typing import Union, Tuple, Sequence, Optional, List
48
 
49
  import numpy as np
 
65
  BaseModelOutputWithPast,
66
  CausalLMOutputWithPast,
67
  )
68
+ from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
 
 
 
69
  from transformers.utils import (
70
  add_start_docstrings,
71
  add_start_docstrings_to_model_forward,
 
276
  height=pos_emb_height, width=pos_emb_width, dim=out_dim
277
  )
278
 
279
+ def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
280
  """
281
  Args:
282
  x (L, Channels): input tensor
283
+ grid_hws (N, 2): grid height and width
284
 
285
  Returns:
286
  (L, Cout) tensor
287
  """
288
  x = self.proj(x).view(x.size(0), -1)
289
  # apply positional embedding
290
+ x = self.pos_emb(x, grid_hws)
291
  return x
292
 
293
 
 
313
  device (str): the device to store the precomputed cis
314
  """
315
 
316
+ def __init__(self, dim: int, max_height: int, max_width: int, theta_base=10000):
 
 
317
  super().__init__()
318
  self.dim = dim
319
  assert self.dim % 4 == 0, "dim must be divisible by 4"
320
  self.max_height = max_height
321
  self.max_width = max_width
322
  self.theta_base = theta_base
323
+
324
+ self.freqs_cis = None
325
 
326
  def extra_repr(self):
327
  return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
328
 
329
+ def _precompute_freqs_cis(self, device: torch.device) -> torch.Tensor:
 
330
  """Calculate the cis(freqs) for each position in the 2D grid.
331
 
332
  Return: complex tensor of shape (max_height, max_width, dim//2) and value:
 
335
  note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
336
  """
337
  N = self.max_height * self.max_width
338
+ flat_pos = torch.arange(0, N).float().to(device)
339
  x_pos = flat_pos % self.max_width
340
  y_pos = flat_pos // self.max_width
341
  dim_range = (
342
+ torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(device)
343
  ) # C/4
344
  freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
345
  x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
 
354
  freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
355
  return freqs_cis
356
 
357
+ def get_freqs_cis(self, grid_hws: torch.Tensor) -> torch.Tensor:
358
  """
359
  Args:
360
+ grid_hws (torch.Tensor): grid height and width
361
+
362
  Returns:
363
  freqs_cis: tensor of shape (sum(t * height * width), dim//2)
364
  """
365
+ if self.freqs_cis is None:
366
+ self.freqs_cis = self._precompute_freqs_cis(grid_hws.device)
367
+
368
  shapes = grid_hws.tolist()
369
  assert all(
370
  1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes
 
374
  self.max_width,
375
  )
376
  freqs_cis = torch.cat(
377
+ [self.freqs_cis[:h, :w].reshape(-1, self.dim // 2) for h, w in shapes],
 
 
 
378
  dim=0,
379
  )
380
  return freqs_cis
381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
  class MLP2(nn.Module):
384
  """
 
505
  self.final_layernorm = nn.LayerNorm(hidden_dim)
506
 
507
  def forward(
508
+ self, hidden_states: torch.Tensor, grid_hws: torch.Tensor
509
  ) -> torch.Tensor:
510
+ rope_freqs_cis = self.rope_2d.get_freqs_cis(grid_hws=grid_hws)
511
 
512
  lengths = torch.cat(
513
  (
514
+ torch.zeros(1, device=hidden_states.device, dtype=grid_hws.dtype),
515
+ grid_hws[:, 0] * grid_hws[:, 1],
516
  )
517
  )
518
  cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
 
529
 
530
  def patch_merger(
531
  x: torch.Tensor,
532
+ grid_hws: torch.Tensor,
533
  merge_kernel_size: list[int, int] = (2, 2),
534
  ) -> List[torch.Tensor]:
535
  d_model = x.size(-1)
536
 
537
  outputs = []
538
  pre_sum = 0
539
+ for x_shape in grid_hws.tolist():
540
  height, width = x_shape[0], x_shape[1]
541
  # Get the current sequence
542
  seq = x[pre_sum : pre_sum + height * width]
 
2258
  )
2259
 
2260
  def forward(
2261
+ self, pixel_values: torch.Tensor, grid_hws: torch.Tensor
2262
  ) -> torch.Tensor:
2263
  """
2264
  Args:
2265
  pixel_values (torch.Tensor): The input pixel values.
2266
+ grid_hws (torch.Tensor): The grid height and width.
2267
 
2268
  Returns:
2269
  torch.Tensor: The output tokens.
2270
  """
2271
+ hidden_states = self.patch_embed(pixel_values, grid_hws)
2272
+ hidden_states = self.encoder(hidden_states, grid_hws)
2273
  hidden_states = patch_merger(
2274
+ hidden_states, grid_hws, merge_kernel_size=self.merge_kernel_size
2275
  )
2276
  return hidden_states
2277