File size: 4,680 Bytes
704b5c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""Image processor class for KimiVL."""

import math
import numpy as np
from PIL import Image
from typing import Optional, Union

import torch
from torchvision.transforms import functional as TF
from transformers.image_utils import ImageInput, make_list_of_images, valid_images
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.utils import TensorType


OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)


class KimiVLImageProcessor(BaseImageProcessor):
    model_type = "kimi_vl"

    def __init__(
        self,
        patch_size: int = 14,
        pad_input: bool = False,
        image_mean: tuple[float, float, float] = OPENAI_DATASET_MEAN,
        image_std: tuple[float, float, float] = OPENAI_DATASET_STD,
        in_token_limit: int = 4096,
        merge_kernel_size: list[int, int] = [2, 2],
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.in_token_limit = in_token_limit
        self.patch_size = patch_size
        self.pad_input = pad_input
        self.image_mean = image_mean
        self.image_std = image_std
        self.merge_kernel_size = merge_kernel_size

    def rescale(
        self, image: Image.Image, merge_kernel_size: list[int, int] = [2, 2]
    ) -> Image.Image:
        w, h = image.size
        patch_size = self.patch_size

        if (w // patch_size) * (h // patch_size) > self.in_token_limit:
            scale = math.sqrt(self.in_token_limit / ((w // patch_size) * (h // patch_size)))
            new_w, new_h = int(w * scale), int(h * scale)
            image = image.resize((new_w, new_h), Image.Resampling.BICUBIC)
        if self.pad_input:
            new_w, new_h = image.size
            pad_size_h = merge_kernel_size[0] * patch_size
            pad_size_w = merge_kernel_size[1] * patch_size

            pad_h = (pad_size_h - new_h % pad_size_h) % pad_size_h
            pad_w = (pad_size_w - new_w % pad_size_w) % pad_size_w

            image = TF.pad(image, (0, 0, pad_w, pad_h))
        else:
            new_w, new_h = image.size
            new_w = new_w - new_w % patch_size
            new_h = new_h - new_h % patch_size
            image = TF.center_crop(image, (new_h, new_w))

        w, h = image.size
        if w // patch_size >= 512 or h // patch_size >= 512:
            raise ValueError("Exceed pos emb")

        return image

    def to_tensor(self, image: Image.Image) -> torch.Tensor:
        return TF.to_tensor(image.convert("RGB"))

    def normalize(self, image: torch.Tensor) -> torch.Tensor:
        return TF.normalize(image, self.image_mean, self.image_std)

    def patchify(self, image: torch.Tensor) -> tuple[torch.Tensor, list[int, int]]:
        patch_size = self.patch_size
        C, H, W = image.shape
        patches = image.reshape(C, H // patch_size, patch_size, W // patch_size, patch_size)
        patches = patches.permute(1, 3, 0, 2, 4)
        patches = patches.contiguous().view(-1, C, patch_size, patch_size)
        grid_hw = (H // patch_size, W // patch_size)
        return patches, grid_hw

    def _preprocess(self, image: ImageInput) -> tuple[torch.Tensor, list[int, int]]:
        """
        Preprocess image and patchify it.

        Args:
            image (`ImageInput`):
                Image to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.

        Returns:
            patches: torch.Tensor
            grid_hw: list[int, int]
        """
        image = self.rescale(image, self.merge_kernel_size)
        image = self.to_tensor(image)
        image = self.normalize(image)
        patches, grid_hw = self.patchify(image)
        return patches, grid_hw

    def preprocess(
        self,
        images: ImageInput,
        return_tensors: Optional[Union[str, TensorType]] = None,
    ) -> BatchFeature:
        images = make_list_of_images(images)

        if not valid_images(images):
            raise ValueError(
                "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
                "torch.Tensor, tf.Tensor or jax.ndarray."
            )

        pixel_values, image_grid_hws = [], []
        for image in images:
            patches, image_grid_hw = self._preprocess(image)
            pixel_values.append(patches)
            image_grid_hws.append(image_grid_hw)
        pixel_values = torch.concat(pixel_values, dim=0)
        image_grid_hws = np.array(image_grid_hws)
        data = {"pixel_values": pixel_values, "image_grid_hws": image_grid_hws}

        return BatchFeature(data=data, tensor_type=return_tensors)