zhouzaida
		
	commited on
		
		
					Commit 
							
							·
						
						73cd0c7
	
1
								Parent(s):
							
							704b5c8
								
make vit work with cpu device
Browse files- modeling_kimi_vl.py +29 -61
    	
        modeling_kimi_vl.py
    CHANGED
    
    | @@ -44,7 +44,6 @@ import math | |
| 44 | 
             
            import warnings
         | 
| 45 | 
             
            from typing import List, Optional, Tuple, Union
         | 
| 46 | 
             
            from copy import deepcopy
         | 
| 47 | 
            -
            from functools import cached_property
         | 
| 48 | 
             
            from typing import Union, Tuple, Sequence, Optional, List
         | 
| 49 |  | 
| 50 | 
             
            import numpy as np
         | 
| @@ -66,10 +65,7 @@ from transformers.modeling_outputs import ( | |
| 66 | 
             
                BaseModelOutputWithPast,
         | 
| 67 | 
             
                CausalLMOutputWithPast,
         | 
| 68 | 
             
            )
         | 
| 69 | 
            -
            from transformers.pytorch_utils import  | 
| 70 | 
            -
                ALL_LAYERNORM_LAYERS,
         | 
| 71 | 
            -
                is_torch_greater_or_equal_than_1_13,
         | 
| 72 | 
            -
            )
         | 
| 73 | 
             
            from transformers.utils import (
         | 
| 74 | 
             
                add_start_docstrings,
         | 
| 75 | 
             
                add_start_docstrings_to_model_forward,
         | 
| @@ -80,7 +76,7 @@ from transformers.utils import ( | |
| 80 | 
             
            )
         | 
| 81 | 
             
            from transformers.utils.import_utils import is_torch_fx_available
         | 
| 82 |  | 
| 83 | 
            -
            from  | 
| 84 |  | 
| 85 |  | 
| 86 | 
             
            if is_flash_attn_2_available():
         | 
| @@ -280,18 +276,18 @@ class MoonVisionPatchEmbed(nn.Module): | |
| 280 | 
             
                        height=pos_emb_height, width=pos_emb_width, dim=out_dim
         | 
| 281 | 
             
                    )
         | 
| 282 |  | 
| 283 | 
            -
                def forward(self, x: torch.Tensor,  | 
| 284 | 
             
                    """
         | 
| 285 | 
             
                    Args:
         | 
| 286 | 
             
                        x (L, Channels): input tensor
         | 
| 287 | 
            -
                         | 
| 288 |  | 
| 289 | 
             
                    Returns:
         | 
| 290 | 
             
                        (L, Cout) tensor
         | 
| 291 | 
             
                    """
         | 
| 292 | 
             
                    x = self.proj(x).view(x.size(0), -1)
         | 
| 293 | 
             
                    # apply positional embedding
         | 
| 294 | 
            -
                    x = self.pos_emb(x,  | 
| 295 | 
             
                    return x
         | 
| 296 |  | 
| 297 |  | 
| @@ -317,22 +313,20 @@ class Rope2DPosEmb(nn.Module): | |
| 317 | 
             
                    device (str): the device to store the precomputed cis
         | 
| 318 | 
             
                """
         | 
| 319 |  | 
| 320 | 
            -
                def __init__(
         | 
| 321 | 
            -
                    self, dim: int, max_height: int, max_width: int, theta_base=10000, device="cuda"
         | 
| 322 | 
            -
                ):
         | 
| 323 | 
             
                    super().__init__()
         | 
| 324 | 
             
                    self.dim = dim
         | 
| 325 | 
             
                    assert self.dim % 4 == 0, "dim must be divisible by 4"
         | 
| 326 | 
             
                    self.max_height = max_height
         | 
| 327 | 
             
                    self.max_width = max_width
         | 
| 328 | 
             
                    self.theta_base = theta_base
         | 
| 329 | 
            -
             | 
|  | |
| 330 |  | 
| 331 | 
             
                def extra_repr(self):
         | 
| 332 | 
             
                    return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
         | 
| 333 |  | 
| 334 | 
            -
                 | 
| 335 | 
            -
                def precomputed_freqs_cis(self) -> torch.Tensor:
         | 
| 336 | 
             
                    """Calculate the cis(freqs) for each position in the 2D grid.
         | 
| 337 |  | 
| 338 | 
             
                    Return: complex tensor of shape (max_height, max_width, dim//2) and value:
         | 
| @@ -341,11 +335,11 @@ class Rope2DPosEmb(nn.Module): | |
| 341 | 
             
                        note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
         | 
| 342 | 
             
                    """
         | 
| 343 | 
             
                    N = self.max_height * self.max_width
         | 
| 344 | 
            -
                    flat_pos = torch.arange(0, N).float().to( | 
| 345 | 
             
                    x_pos = flat_pos % self.max_width
         | 
| 346 | 
             
                    y_pos = flat_pos // self.max_width
         | 
| 347 | 
             
                    dim_range = (
         | 
| 348 | 
            -
                        torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to( | 
| 349 | 
             
                    )  # C/4
         | 
| 350 | 
             
                    freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
         | 
| 351 | 
             
                    x_freqs = torch.outer(x_pos, freqs).float()  # N, C/4
         | 
| @@ -360,13 +354,17 @@ class Rope2DPosEmb(nn.Module): | |
| 360 | 
             
                    freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
         | 
| 361 | 
             
                    return freqs_cis
         | 
| 362 |  | 
| 363 | 
            -
                def  | 
| 364 | 
             
                    """
         | 
| 365 | 
             
                    Args:
         | 
| 366 | 
            -
                        grid_hws (torch.Tensor):  | 
|  | |
| 367 | 
             
                    Returns:
         | 
| 368 | 
             
                        freqs_cis: tensor of shape (sum(t * height * width), dim//2)
         | 
| 369 | 
             
                    """
         | 
|  | |
|  | |
|  | |
| 370 | 
             
                    shapes = grid_hws.tolist()
         | 
| 371 | 
             
                    assert all(
         | 
| 372 | 
             
                        1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes
         | 
| @@ -376,41 +374,11 @@ class Rope2DPosEmb(nn.Module): | |
| 376 | 
             
                        self.max_width,
         | 
| 377 | 
             
                    )
         | 
| 378 | 
             
                    freqs_cis = torch.cat(
         | 
| 379 | 
            -
                        [
         | 
| 380 | 
            -
                            self.precomputed_freqs_cis[:h, :w].reshape(-1, self.dim // 2)
         | 
| 381 | 
            -
                            for h, w in shapes
         | 
| 382 | 
            -
                        ],
         | 
| 383 | 
             
                        dim=0,
         | 
| 384 | 
             
                    )
         | 
| 385 | 
             
                    return freqs_cis
         | 
| 386 |  | 
| 387 | 
            -
                def get_freqs_cis_by_idx(
         | 
| 388 | 
            -
                    self, pos_idx: torch.Tensor, pos_idx_mask: torch.Tensor
         | 
| 389 | 
            -
                ) -> torch.Tensor:
         | 
| 390 | 
            -
                    """
         | 
| 391 | 
            -
                    Args:
         | 
| 392 | 
            -
                        pos_idx: tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token.
         | 
| 393 | 
            -
                        pos_idx_mask: a mask of shape (...), the leading dimensions should be the same as pos_idx.
         | 
| 394 | 
            -
                            Rope will only be applied to the tokens with True mask. `freqs_cis` for the tokens with False mask with be ones.
         | 
| 395 | 
            -
                    Return:
         | 
| 396 | 
            -
                        freqs_cis: tensor of shape (..., dim//2)
         | 
| 397 | 
            -
                    """
         | 
| 398 | 
            -
                    assert (
         | 
| 399 | 
            -
                        pos_idx.shape[:-1] == pos_idx_mask.shape
         | 
| 400 | 
            -
                        and pos_idx.shape[-1] == 2
         | 
| 401 | 
            -
                        and pos_idx.ndim == pos_idx_mask.ndim + 1
         | 
| 402 | 
            -
                    ), (pos_idx.shape, pos_idx_mask.shape)
         | 
| 403 | 
            -
                    assert pos_idx_mask.dtype == torch.bool, pos_idx_mask.dtype
         | 
| 404 | 
            -
             | 
| 405 | 
            -
                    shp = pos_idx_mask.shape + (self.dim // 2,)  # ..., head_dim/2
         | 
| 406 | 
            -
                    freqs_cis = torch.ones(
         | 
| 407 | 
            -
                        shp, dtype=torch.complex64, device=self.device
         | 
| 408 | 
            -
                    )  # ..., head_dim/2
         | 
| 409 | 
            -
                    freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[
         | 
| 410 | 
            -
                        pos_idx[..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask]
         | 
| 411 | 
            -
                    ]
         | 
| 412 | 
            -
                    return freqs_cis
         | 
| 413 | 
            -
             | 
| 414 |  | 
| 415 | 
             
            class MLP2(nn.Module):
         | 
| 416 | 
             
                """
         | 
| @@ -537,14 +505,14 @@ class MoonVitEncoder(nn.Module): | |
| 537 | 
             
                    self.final_layernorm = nn.LayerNorm(hidden_dim)
         | 
| 538 |  | 
| 539 | 
             
                def forward(
         | 
| 540 | 
            -
                    self, hidden_states: torch.Tensor,  | 
| 541 | 
             
                ) -> torch.Tensor:
         | 
| 542 | 
            -
                    rope_freqs_cis = self.rope_2d. | 
| 543 |  | 
| 544 | 
             
                    lengths = torch.cat(
         | 
| 545 | 
             
                        (
         | 
| 546 | 
            -
                            torch.zeros(1, device=hidden_states.device, dtype= | 
| 547 | 
            -
                             | 
| 548 | 
             
                        )
         | 
| 549 | 
             
                    )
         | 
| 550 | 
             
                    cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
         | 
| @@ -561,14 +529,14 @@ class MoonVitEncoder(nn.Module): | |
| 561 |  | 
| 562 | 
             
            def patch_merger(
         | 
| 563 | 
             
                x: torch.Tensor,
         | 
| 564 | 
            -
                 | 
| 565 | 
             
                merge_kernel_size: list[int, int] = (2, 2),
         | 
| 566 | 
             
            ) -> List[torch.Tensor]:
         | 
| 567 | 
             
                d_model = x.size(-1)
         | 
| 568 |  | 
| 569 | 
             
                outputs = []
         | 
| 570 | 
             
                pre_sum = 0
         | 
| 571 | 
            -
                for x_shape in  | 
| 572 | 
             
                    height, width = x_shape[0], x_shape[1]
         | 
| 573 | 
             
                    # Get the current sequence
         | 
| 574 | 
             
                    seq = x[pre_sum : pre_sum + height * width]
         | 
| @@ -2290,20 +2258,20 @@ class MoonVitPretrainedModel(PreTrainedModel): | |
| 2290 | 
             
                    )
         | 
| 2291 |  | 
| 2292 | 
             
                def forward(
         | 
| 2293 | 
            -
                    self, pixel_values: torch.Tensor,  | 
| 2294 | 
             
                ) -> torch.Tensor:
         | 
| 2295 | 
             
                    """
         | 
| 2296 | 
             
                    Args:
         | 
| 2297 | 
             
                        pixel_values (torch.Tensor): The input pixel values.
         | 
| 2298 | 
            -
                         | 
| 2299 |  | 
| 2300 | 
             
                    Returns:
         | 
| 2301 | 
             
                        torch.Tensor: The output tokens.
         | 
| 2302 | 
             
                    """
         | 
| 2303 | 
            -
                    hidden_states = self.patch_embed(pixel_values,  | 
| 2304 | 
            -
                    hidden_states = self.encoder(hidden_states,  | 
| 2305 | 
             
                    hidden_states = patch_merger(
         | 
| 2306 | 
            -
                        hidden_states,  | 
| 2307 | 
             
                    )
         | 
| 2308 | 
             
                    return hidden_states
         | 
| 2309 |  | 
|  | |
| 44 | 
             
            import warnings
         | 
| 45 | 
             
            from typing import List, Optional, Tuple, Union
         | 
| 46 | 
             
            from copy import deepcopy
         | 
|  | |
| 47 | 
             
            from typing import Union, Tuple, Sequence, Optional, List
         | 
| 48 |  | 
| 49 | 
             
            import numpy as np
         | 
|  | |
| 65 | 
             
                BaseModelOutputWithPast,
         | 
| 66 | 
             
                CausalLMOutputWithPast,
         | 
| 67 | 
             
            )
         | 
| 68 | 
            +
            from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13
         | 
|  | |
|  | |
|  | |
| 69 | 
             
            from transformers.utils import (
         | 
| 70 | 
             
                add_start_docstrings,
         | 
| 71 | 
             
                add_start_docstrings_to_model_forward,
         | 
|  | |
| 76 | 
             
            )
         | 
| 77 | 
             
            from transformers.utils.import_utils import is_torch_fx_available
         | 
| 78 |  | 
| 79 | 
            +
            from configuration_kimi_vl import MoonViTConfig, DeepseekV3Config, KimiVLConfig
         | 
| 80 |  | 
| 81 |  | 
| 82 | 
             
            if is_flash_attn_2_available():
         | 
|  | |
| 276 | 
             
                        height=pos_emb_height, width=pos_emb_width, dim=out_dim
         | 
| 277 | 
             
                    )
         | 
| 278 |  | 
| 279 | 
            +
                def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
         | 
| 280 | 
             
                    """
         | 
| 281 | 
             
                    Args:
         | 
| 282 | 
             
                        x (L, Channels): input tensor
         | 
| 283 | 
            +
                        grid_hws (N, 2): grid height and width
         | 
| 284 |  | 
| 285 | 
             
                    Returns:
         | 
| 286 | 
             
                        (L, Cout) tensor
         | 
| 287 | 
             
                    """
         | 
| 288 | 
             
                    x = self.proj(x).view(x.size(0), -1)
         | 
| 289 | 
             
                    # apply positional embedding
         | 
| 290 | 
            +
                    x = self.pos_emb(x, grid_hws)
         | 
| 291 | 
             
                    return x
         | 
| 292 |  | 
| 293 |  | 
|  | |
| 313 | 
             
                    device (str): the device to store the precomputed cis
         | 
| 314 | 
             
                """
         | 
| 315 |  | 
| 316 | 
            +
                def __init__(self, dim: int, max_height: int, max_width: int, theta_base=10000):
         | 
|  | |
|  | |
| 317 | 
             
                    super().__init__()
         | 
| 318 | 
             
                    self.dim = dim
         | 
| 319 | 
             
                    assert self.dim % 4 == 0, "dim must be divisible by 4"
         | 
| 320 | 
             
                    self.max_height = max_height
         | 
| 321 | 
             
                    self.max_width = max_width
         | 
| 322 | 
             
                    self.theta_base = theta_base
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                    self.freqs_cis = None
         | 
| 325 |  | 
| 326 | 
             
                def extra_repr(self):
         | 
| 327 | 
             
                    return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
         | 
| 328 |  | 
| 329 | 
            +
                def _precompute_freqs_cis(self, device: torch.device) -> torch.Tensor:
         | 
|  | |
| 330 | 
             
                    """Calculate the cis(freqs) for each position in the 2D grid.
         | 
| 331 |  | 
| 332 | 
             
                    Return: complex tensor of shape (max_height, max_width, dim//2) and value:
         | 
|  | |
| 335 | 
             
                        note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
         | 
| 336 | 
             
                    """
         | 
| 337 | 
             
                    N = self.max_height * self.max_width
         | 
| 338 | 
            +
                    flat_pos = torch.arange(0, N).float().to(device)
         | 
| 339 | 
             
                    x_pos = flat_pos % self.max_width
         | 
| 340 | 
             
                    y_pos = flat_pos // self.max_width
         | 
| 341 | 
             
                    dim_range = (
         | 
| 342 | 
            +
                        torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(device)
         | 
| 343 | 
             
                    )  # C/4
         | 
| 344 | 
             
                    freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
         | 
| 345 | 
             
                    x_freqs = torch.outer(x_pos, freqs).float()  # N, C/4
         | 
|  | |
| 354 | 
             
                    freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
         | 
| 355 | 
             
                    return freqs_cis
         | 
| 356 |  | 
| 357 | 
            +
                def get_freqs_cis(self, grid_hws: torch.Tensor) -> torch.Tensor:
         | 
| 358 | 
             
                    """
         | 
| 359 | 
             
                    Args:
         | 
| 360 | 
            +
                        grid_hws (torch.Tensor): grid height and width
         | 
| 361 | 
            +
             | 
| 362 | 
             
                    Returns:
         | 
| 363 | 
             
                        freqs_cis: tensor of shape (sum(t * height * width), dim//2)
         | 
| 364 | 
             
                    """
         | 
| 365 | 
            +
                    if self.freqs_cis is None:
         | 
| 366 | 
            +
                        self.freqs_cis = self._precompute_freqs_cis(grid_hws.device)
         | 
| 367 | 
            +
             | 
| 368 | 
             
                    shapes = grid_hws.tolist()
         | 
| 369 | 
             
                    assert all(
         | 
| 370 | 
             
                        1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes
         | 
|  | |
| 374 | 
             
                        self.max_width,
         | 
| 375 | 
             
                    )
         | 
| 376 | 
             
                    freqs_cis = torch.cat(
         | 
| 377 | 
            +
                        [self.freqs_cis[:h, :w].reshape(-1, self.dim // 2) for h, w in shapes],
         | 
|  | |
|  | |
|  | |
| 378 | 
             
                        dim=0,
         | 
| 379 | 
             
                    )
         | 
| 380 | 
             
                    return freqs_cis
         | 
| 381 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 382 |  | 
| 383 | 
             
            class MLP2(nn.Module):
         | 
| 384 | 
             
                """
         | 
|  | |
| 505 | 
             
                    self.final_layernorm = nn.LayerNorm(hidden_dim)
         | 
| 506 |  | 
| 507 | 
             
                def forward(
         | 
| 508 | 
            +
                    self, hidden_states: torch.Tensor, grid_hws: torch.Tensor
         | 
| 509 | 
             
                ) -> torch.Tensor:
         | 
| 510 | 
            +
                    rope_freqs_cis = self.rope_2d.get_freqs_cis(grid_hws=grid_hws)
         | 
| 511 |  | 
| 512 | 
             
                    lengths = torch.cat(
         | 
| 513 | 
             
                        (
         | 
| 514 | 
            +
                            torch.zeros(1, device=hidden_states.device, dtype=grid_hws.dtype),
         | 
| 515 | 
            +
                            grid_hws[:, 0] * grid_hws[:, 1],
         | 
| 516 | 
             
                        )
         | 
| 517 | 
             
                    )
         | 
| 518 | 
             
                    cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
         | 
|  | |
| 529 |  | 
| 530 | 
             
            def patch_merger(
         | 
| 531 | 
             
                x: torch.Tensor,
         | 
| 532 | 
            +
                grid_hws: torch.Tensor,
         | 
| 533 | 
             
                merge_kernel_size: list[int, int] = (2, 2),
         | 
| 534 | 
             
            ) -> List[torch.Tensor]:
         | 
| 535 | 
             
                d_model = x.size(-1)
         | 
| 536 |  | 
| 537 | 
             
                outputs = []
         | 
| 538 | 
             
                pre_sum = 0
         | 
| 539 | 
            +
                for x_shape in grid_hws.tolist():
         | 
| 540 | 
             
                    height, width = x_shape[0], x_shape[1]
         | 
| 541 | 
             
                    # Get the current sequence
         | 
| 542 | 
             
                    seq = x[pre_sum : pre_sum + height * width]
         | 
|  | |
| 2258 | 
             
                    )
         | 
| 2259 |  | 
| 2260 | 
             
                def forward(
         | 
| 2261 | 
            +
                    self, pixel_values: torch.Tensor, grid_hws: torch.Tensor
         | 
| 2262 | 
             
                ) -> torch.Tensor:
         | 
| 2263 | 
             
                    """
         | 
| 2264 | 
             
                    Args:
         | 
| 2265 | 
             
                        pixel_values (torch.Tensor): The input pixel values.
         | 
| 2266 | 
            +
                        grid_hws (torch.Tensor): The grid height and width.
         | 
| 2267 |  | 
| 2268 | 
             
                    Returns:
         | 
| 2269 | 
             
                        torch.Tensor: The output tokens.
         | 
| 2270 | 
             
                    """
         | 
| 2271 | 
            +
                    hidden_states = self.patch_embed(pixel_values, grid_hws)
         | 
| 2272 | 
            +
                    hidden_states = self.encoder(hidden_states, grid_hws)
         | 
| 2273 | 
             
                    hidden_states = patch_merger(
         | 
| 2274 | 
            +
                        hidden_states, grid_hws, merge_kernel_size=self.merge_kernel_size
         | 
| 2275 | 
             
                    )
         | 
| 2276 | 
             
                    return hidden_states
         | 
| 2277 |  | 
