Hyggge commited on
Commit
7e9d312
·
1 Parent(s): 4817fd0

feat: add modeling code

Browse files
config.json CHANGED
@@ -3,6 +3,11 @@
3
  "architectures": [
4
  "ValleyQwen2ForCausalLM"
5
  ],
 
 
 
 
 
6
  "attention_dropout": 0.0,
7
  "bos_token_id": 151643,
8
  "eagle_vision_tower": "Qwen/Qwen2-VL-7B-Instruct",
@@ -32,7 +37,6 @@
32
  "mm_vision_select_layer": -2,
33
  "mm_vision_siglip_select_layer": -1,
34
  "mm_vision_tower": "google/siglip-so400m-patch14-384",
35
- "model_class": "valley-product",
36
  "model_type": "valley",
37
  "num_attention_heads": 28,
38
  "num_hidden_layers": 28,
 
3
  "architectures": [
4
  "ValleyQwen2ForCausalLM"
5
  ],
6
+ "auto_map": {
7
+ "AutoConfig": "modeling_valley.ValleyConfig",
8
+ "AutoModel": "modeling_valley.ValleyQwen2ForCausalLM",
9
+ "AutoModelForCausalLM": "modeling_valley.ValleyQwen2ForCausalLM"
10
+ },
11
  "attention_dropout": 0.0,
12
  "bos_token_id": 151643,
13
  "eagle_vision_tower": "Qwen/Qwen2-VL-7B-Instruct",
 
37
  "mm_vision_select_layer": -2,
38
  "mm_vision_siglip_select_layer": -1,
39
  "mm_vision_tower": "google/siglip-so400m-patch14-384",
 
40
  "model_type": "valley",
41
  "num_attention_heads": 28,
42
  "num_hidden_layers": 28,
modeling_projector.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ def build_vision_projector(config, delay_load=False, **kwargs):
6
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
7
+
8
+ if projector_type == 'conv_adapter':
9
+ return ConvAdapter(config.mm_hidden_size, config.hidden_size, getattr(config, "mlp_hidden_dim", None))
10
+ elif projector_type == 'mlp_pixel_shuffle':
11
+ return MlpPixelShuffle(config.mm_hidden_size, config.hidden_size,
12
+ config.pixelshuffle_downsample_ratio, getattr(config, "mlp_hidden_dim", None))
13
+ elif projector_type == 'ovis_conv_adapter':
14
+ return OvisConvAdapter(config.mm_hidden_size, config.hidden_size, getattr(config, "mlp_hidden_dim", 32000),
15
+ getattr(config, "tokenize_function", "softmax"))
16
+ raise ValueError(f'Unknown projector type: {projector_type}')
17
+
18
+
19
+ class ConvAdapter(nn.Module):
20
+ def __init__(self, dim_in, dim_out, mlp_hidden_dim=None):
21
+ super().__init__()
22
+ self.mm_projector_type = 'conv_adapter'
23
+ if mlp_hidden_dim is None:
24
+ self.mlp = nn.Sequential(
25
+ nn.Linear(dim_in, dim_out),
26
+ nn.GELU(),
27
+ nn.Linear(dim_out, dim_out)
28
+ )
29
+ else:
30
+ self.mlp = nn.Sequential(
31
+ nn.Linear(dim_in, mlp_hidden_dim),
32
+ nn.GELU(),
33
+ nn.Linear(mlp_hidden_dim, dim_out)
34
+ )
35
+ self.conv = nn.Conv2d(dim_out, dim_out, kernel_size=(3, 3), stride=(2, 2), padding=1)
36
+
37
+ def forward(self, x):
38
+ """
39
+ Args:
40
+ x (torch.Tensor): image features
41
+ shape (F, v, D)
42
+ Returns:
43
+ shape (F, n, D) where n is token_num that has been reduced
44
+ """
45
+ x = self.mlp(x)
46
+
47
+ f, v, d = x.shape
48
+ s = int(math.sqrt(v - 1))
49
+ x = x[:, 1:, :] # remove cls_token
50
+ x = x.reshape(f, s, s, d).permute([0, 3, 1, 2])
51
+ x = self.conv(x)
52
+ x = x.permute([0, 2, 3, 1]).reshape(f, -1, d)
53
+ return x
54
+
55
+
56
+ class MlpPixelShuffle(nn.Module):
57
+ def __init__(self, dim_in, dim_out, pixelshuffle_downsample_ratio, mlp_hidden_dim=None):
58
+ super().__init__()
59
+ self.mm_projector_type = 'mlp_pixel_shuffle'
60
+ if mlp_hidden_dim is None:
61
+ self.mlp = nn.Sequential(
62
+ nn.Linear(int(dim_in * (pixelshuffle_downsample_ratio ** 2)), dim_out),
63
+ nn.GELU(),
64
+ nn.Linear(dim_out, dim_out)
65
+ )
66
+ else:
67
+ self.mlp = nn.Sequential(
68
+ nn.Linear(int(dim_in * (pixelshuffle_downsample_ratio ** 2)), mlp_hidden_dim),
69
+ nn.GELU(),
70
+ nn.Linear(mlp_hidden_dim, dim_out)
71
+ )
72
+ self.scale_factor = pixelshuffle_downsample_ratio
73
+
74
+ def pixel_shuffle(self, x, scale_factor=2):
75
+ # change scale_factor from float to int
76
+
77
+ n, w, h, c = x.size()
78
+ # N, W, H, C --> N, W, H / scale, C * scale
79
+ x = x.view(n, w, int(h / scale_factor), int(c * scale_factor))
80
+ # N, W, H / scale, C * scale --> N, H / scale, W, C * scale
81
+ x = x.permute(0, 2, 1, 3).contiguous()
82
+ # N, H / scale, W, C * scale --> N, H / scale, W / scale, C * (scale ** 2)
83
+ x = x.view(n, int(h / scale_factor), int(w / scale_factor),
84
+ int(c * (scale_factor * scale_factor)))
85
+
86
+ x = x.permute(0, 2, 1, 3).contiguous()
87
+
88
+ return x
89
+
90
+ def forward(self, x):
91
+ """
92
+ Args:
93
+ x (torch.Tensor): image features
94
+ shape (F, v, D)
95
+ Returns:
96
+ shape (F, n, D) where n is token_num that has been reduced
97
+ """
98
+ x = x[:, 1:, :] # remove cls_token
99
+ h = w = int(x.shape[1] ** 0.5)
100
+ x = x.view(x.shape[0], h, w, -1)
101
+ x = self.pixel_shuffle(x, self.scale_factor)
102
+ x = self.mlp(x)
103
+ x = x.view(x.shape[0],-1,x.shape[-1])
104
+ return x
105
+
106
+
107
+ class OvisConvAdapter(nn.Module):
108
+ def __init__(self, dim_in, dim_out, vocab_size, tokenize_function="softmax"):
109
+ super().__init__()
110
+ self.mm_projector_type = 'ovis_conv_adapter'
111
+ self.conv = nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), padding=1)
112
+ self.mlp = torch.nn.Sequential(
113
+ torch.nn.Linear(dim_in, vocab_size, bias=False),
114
+ torch.nn.LayerNorm(vocab_size)
115
+ )
116
+ self.embedding = torch.nn.Embedding(vocab_size, dim_out)
117
+ self.tokenize_function = tokenize_function
118
+
119
+ def tokenize(self, logits):
120
+ def st_argmax(y_soft, dim): # straight-through softmax
121
+ index = y_soft.max(dim, keepdim=True)[1]
122
+ y_hard = torch.zeros_like(y_soft, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
123
+ ret = y_hard - y_soft.detach() + y_soft
124
+ return ret
125
+
126
+ if self.tokenize_function == 'softmax':
127
+ tokens = torch.nn.functional.softmax(logits, dim=-1)
128
+ elif self.tokenize_function == 'gumbel_argmax':
129
+ tokens = torch.nn.functional.gumbel_softmax(logits, tau=self.config.tau, hard=True)
130
+ elif self.tokenize_function == 'st_argmax':
131
+ tokens = st_argmax(logits, dim=-1)
132
+ else:
133
+ raise ValueError(
134
+ 'Invalid `max_type`, expected softmax or gumbel_argmax or st_argmax,'
135
+ f' but got {self.config.tokenize_function}'
136
+ )
137
+ return tokens
138
+
139
+ def forward(self, x):
140
+ """
141
+ Args:
142
+ x (torch.Tensor): image features
143
+ shape (F, v, D)
144
+ Returns:
145
+ shape (F, n, D) where n is token_num that has been reduced
146
+ """
147
+ # conv
148
+ f, v, d = x.shape
149
+ s = int(math.sqrt(v - 1))
150
+ x = x[:, 1:, :] # remove cls_token
151
+ x = x.reshape(f, s, s, d).permute([0, 3, 1, 2])
152
+ x = self.conv(x)
153
+ x = x.permute([0, 2, 3, 1]).reshape(f, -1, d)
154
+
155
+ # tokenize
156
+ logits = self.mlp(x)
157
+ visual_tokens = self.tokenize(logits)
158
+
159
+ # get embeddings
160
+ out = torch.matmul(visual_tokens, self.embedding.weight)
161
+
162
+ return out
modeling_valley.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import numpy as np
17
+ from torch import nn
18
+ from torch.nn import CrossEntropyLoss
19
+ from abc import ABC, abstractmethod
20
+ from typing import List, Optional, Tuple, Union
21
+ from transformers.modeling_outputs import CausalLMOutputWithPast
22
+ from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config, Qwen2ForCausalLM, Qwen2Model
23
+
24
+ from .modeling_vision_tower import build_vision_tower
25
+ from .modeling_projector import build_vision_projector
26
+ from .utils import get_anyres_image_grid_shape, unpad_image, IGNORE_INDEX, IMAGE_TOKEN_INDEX
27
+
28
+
29
+ class ValleyConfig(Qwen2Config):
30
+ model_type = "valley"
31
+
32
+ class ValleyMetaModel:
33
+ def __init__(self, config):
34
+ super(ValleyMetaModel, self).__init__(config)
35
+ # Build vision tower
36
+ if hasattr(config, "mm_vision_tower"):
37
+ if getattr(config, "eagle_vision_tower", None) is not None:
38
+ self.vision_tower, self.qwen2vl_vision_tower = build_vision_tower(config, delay_load=False)
39
+ else:
40
+ self.vision_tower = build_vision_tower(config, delay_load=False)
41
+ # Build Projector
42
+ if hasattr(config, "mm_projector_type"):
43
+ self.mm_projector = build_vision_projector(config)
44
+
45
+ def get_vision_tower(self):
46
+ vision_tower = getattr(self, "vision_tower", None)
47
+ if getattr(self.config, "eagle_vision_tower", None) is not None:
48
+ qwen2vl_vision_tower = getattr(self, "qwen2vl_vision_tower", None)
49
+ return vision_tower, qwen2vl_vision_tower
50
+ else:
51
+ return vision_tower
52
+
53
+ class ValleyMetaForCausalLM(ABC):
54
+ @abstractmethod
55
+ def get_model(self):
56
+ pass
57
+
58
+ def get_vision_tower(self):
59
+ return self.get_model().get_vision_tower()
60
+
61
+ def split_by_instance(self, original_list, split_sizes):
62
+ start = 0
63
+ sub_lists = []
64
+ for size in split_sizes:
65
+ end = start + size
66
+ sub_list = original_list[start:end]
67
+ sub_lists.append([x.to(self.device) for x in sub_list])
68
+ start = end
69
+ return sub_lists
70
+
71
+ def encode_images_qwen2vl(self, pixel_values = None, grid_thw = None, split_sizes=None):
72
+ _, qwen2vl_vision_tower = self.get_model().get_vision_tower()
73
+ qwen2vl_image_features = qwen2vl_vision_tower(pixel_values, grid_thw)
74
+ qwen2vl_image_split_sizes = torch.prod(grid_thw[:, 1:3]//2, dim=1)
75
+ qwen2vl_image_features = torch.split(qwen2vl_image_features, qwen2vl_image_split_sizes.tolist(), dim=0)
76
+ qwen2vl_image_features = self.split_by_instance(qwen2vl_image_features, split_sizes)
77
+ return qwen2vl_image_features
78
+
79
+ def encode_images(self, images = None, split_sizes = None):
80
+ """
81
+ images: (if not anyres) images.shape = [n,3,336,336] , n = number of images + (number of video) * 8
82
+ images: (if anyres) images.shape = [n,3,336,336] , n = number of tiles * number of images
83
+ """
84
+ if getattr(self.config, "eagle_vision_tower", None) is not None:
85
+ siglip_vision_tower, _ = self.get_model().get_vision_tower()
86
+ image_features = siglip_vision_tower(images)
87
+ image_features = self.get_model().mm_projector(image_features)
88
+ else:
89
+ image_features = self.get_model().get_vision_tower()(images)
90
+ image_features = self.get_model().mm_projector(image_features)
91
+
92
+ if getattr(self.config,'anyres', False) and getattr(self.config, 'max_vision_token', None) is not None:
93
+ assert split_sizes is not None
94
+ image_features = list(torch.split(image_features, split_sizes, dim=0))
95
+ for i, image_feature in enumerate(image_features):
96
+ hidden_dim = image_feature.shape[-1]
97
+ image_tokens = image_feature.shape[0]*image_feature.shape[1]
98
+ if getattr(self.config, "eagle_vision_tower", None) is not None:
99
+ pass # the max_vision_token will be processed in the unpad image token part
100
+ else:
101
+ if image_tokens > self.config.max_vision_token:
102
+ intput_shape = int((image_feature.shape[1])**0.5)
103
+ output_shape = int((self.config.max_vision_token/image_feature.shape[0])**0.5)
104
+ image_feature = image_feature.view(image_feature.shape[0],intput_shape, intput_shape, -1).permute(0,3,1,2)
105
+ m = nn.AdaptiveAvgPool2d(output_shape) # different from roi pooling, but in square image, it seems the same
106
+ pooling_feature = m(image_feature).permute(0,2,3,1)
107
+ image_features[i] = pooling_feature.view(image_feature.shape[0], -1, hidden_dim)
108
+ split_sizes = None # have already split, set the flag
109
+
110
+ if getattr(self.config, 'mm_use_im_start_end', False):
111
+ raise ValueError('mm_use_im_start is not support')
112
+ if split_sizes is not None:
113
+ image_features = torch.split(image_features, split_sizes, dim=0)
114
+
115
+ return image_features
116
+
117
+
118
+ def prepare_inputs_labels_for_multimodal(
119
+ self, input_ids, position_ids, attention_mask, past_key_values, labels, images,
120
+ image_sizes, pixel_values, pixel_values_videos, image_grid_thw, video_grid_thw):
121
+
122
+ vision_tower = self.get_vision_tower()
123
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
124
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
125
+ target_shape = past_key_values[-1][-1].shape[-2] + 1
126
+ attention_mask = torch.cat((attention_mask, torch.ones(
127
+ (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
128
+ dtype=attention_mask.dtype,
129
+ device=attention_mask.device
130
+ )), dim=1)
131
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
132
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
133
+
134
+ # Step1: Get image embedings
135
+ if type(images) is list or images.ndim == 5:
136
+ # Without slicing the image
137
+ if not getattr(self.config,'anyres', False):
138
+ concat_images = torch.cat([image for image in images], dim=0) # to do batch compute
139
+ split_sizes = [image.shape[0] for image in images]
140
+
141
+ # Get vision tower feature, check whether only use navit firstly
142
+ if getattr(self.config, 'eagle_vision_tower', None) is not None and getattr(self.config, 'only_navit', False):
143
+ image_features = None
144
+ else:
145
+ image_features = self.encode_images(concat_images, split_sizes)
146
+ image_features = [x.to(self.device) for x in image_features]
147
+
148
+ # Get Eagle features
149
+ if getattr(self.config, 'eagle_vision_tower', None) is not None:
150
+ if pixel_values is not None:
151
+ qwen2vl_image_features = self.encode_images_qwen2vl(pixel_values, image_grid_thw, split_sizes)
152
+ elif pixel_values_videos is not None:
153
+ qwen2vl_image_features = self.encode_images_qwen2vl(pixel_values_videos, video_grid_thw, split_sizes)
154
+ else:
155
+ qwen2vl_image_features = None
156
+
157
+ # Slicing the image, each image contains some sub_images:
158
+ # images = [
159
+ # [image1_tiles(n1,3,336,336), image2_tiles(n2,3,336,336), ...],
160
+ # [image1_tiles(n1,3,336,336), image2_tiles(n2,3,336,336), ...], ...
161
+ # ]
162
+ else:
163
+ split_sizes = [len(image) for image in images]
164
+ # Get Eagle features
165
+ if getattr(self.config, "eagle_vision_tower", None) is not None:
166
+ if pixel_values is not None:
167
+ qwen2vl_image_features = self.encode_images_qwen2vl(pixel_values, image_grid_thw, split_sizes)
168
+ elif pixel_values_videos is not None:
169
+ qwen2vl_image_features = self.encode_images_qwen2vl(pixel_values_videos, video_grid_thw, split_sizes)
170
+ else:
171
+ qwen2vl_image_features = None
172
+
173
+ # Get vision tower feature, check whether only use navit firstly
174
+ if getattr(self.config, 'eagle_vision_tower', None) is not None and getattr(self.config, 'only_navit', False):
175
+ image_features = None
176
+ else:
177
+ image_features = []
178
+ all_concat_images = []
179
+ all_split_sizes = []
180
+ for batch_images in images:
181
+ concat_images = torch.cat([image for image in batch_images], dim=0) # to do batch compute
182
+ split_sizes = [image.shape[0] for image in batch_images]
183
+ all_concat_images.append(concat_images)
184
+ all_split_sizes.append(split_sizes)
185
+ all_image_features = self.encode_images(images=torch.cat(all_concat_images, dim=0), split_sizes=sum(all_split_sizes, []))
186
+
187
+ idx = 0
188
+ for split_sizes in all_split_sizes:
189
+ batch_image_features = all_image_features[idx:idx+len(split_sizes)]
190
+ idx += len(split_sizes)
191
+ if type(batch_image_features[0]) is list:
192
+ batch_image_features = [torch.cat(x).to(self.device) for x in batch_image_features]
193
+ else:
194
+ batch_image_features = [x.view(-1,x.shape[-1]).to(self.device) for x in batch_image_features] # tiles feature need to flatten in token dimention, [n_tiles, T, d] -> [n_tiles * T, d]
195
+ image_features.append(batch_image_features)
196
+
197
+ if getattr(self.config, "eagle_vision_tower", None) is not None and getattr(self.config, 'only_navit', False) == False:
198
+ # unpad image tokens
199
+ height = width = self.config.num_patches_per_side
200
+ new_image_features = []
201
+ for batch_image_features, batch_image_sizes in zip(image_features, image_sizes):
202
+ batch_image_features_list = []
203
+ for cur_image_feature, cur_image_size in zip(batch_image_features, batch_image_sizes):
204
+ base_image_feature = cur_image_feature[:width*height, :]
205
+ image_feature = cur_image_feature[width*height:, :]
206
+ if image_feature.shape[0] != 0:
207
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(
208
+ cur_image_size,
209
+ self.config.grid_pinpoints,
210
+ self.config.vit_crop_size
211
+ )
212
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) # (num_patch_H, num_patch_W, H, W, C)
213
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() # (C, num_patch_H, H, num_patch_W, W)
214
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3) # (C, num_token_H, num_token_W)
215
+ image_feature = unpad_image(image_feature, cur_image_size) # (C, num_token_H_unpad, num_token_W_unpad)
216
+ input_shape = (image_feature.shape[-2], image_feature.shape[-1])
217
+ subimage_tokens = np.prod(input_shape)
218
+
219
+ # adaptive avg 2d pool for reducing token num
220
+ max_subimage_tokens = self.config.max_vision_token-width*height
221
+ if subimage_tokens > max_subimage_tokens:
222
+ aspect_ratio = input_shape[0] / input_shape[1]
223
+ output_shape = (
224
+ int((max_subimage_tokens/aspect_ratio)**0.5*aspect_ratio),
225
+ int((max_subimage_tokens/aspect_ratio)**0.5)
226
+ )
227
+ m = nn.AdaptiveAvgPool2d(output_shape)
228
+ image_feature = m(image_feature)
229
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
230
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
231
+ else:
232
+ image_feature = cur_image_feature
233
+ batch_image_features_list.append(image_feature)
234
+ new_image_features.append(batch_image_features_list)
235
+
236
+ image_features = new_image_features
237
+
238
+ else:
239
+ image_features = self.encode_images(images).to(self.device)
240
+
241
+
242
+ # Step2: Iterate through each sample in the batch, insert image embedings into input_embeds
243
+ # and filling labels, attention mask at the same time. Finally, get `new_input_embed`,
244
+ # `new_labels`, new_attention_mask`.
245
+ _labels = labels
246
+ _position_ids = position_ids
247
+ _attention_mask = attention_mask
248
+ if attention_mask is None:
249
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
250
+ if position_ids is None:
251
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
252
+ if labels is None:
253
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
254
+
255
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask.bool())]
256
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask.bool())]
257
+ attention_mask = [cur_attention_mask[cur_attention_mask.bool()] for cur_attention_mask in attention_mask]
258
+ new_input_embeds = []
259
+ new_labels = []
260
+ new_attention_mask = []
261
+
262
+ for batch_idx, cur_input_ids in enumerate(input_ids):
263
+ cur_batch_image_idx = 0
264
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
265
+
266
+ # Step2-1: If this piece of data is pure text, then concat a dummy image to ensure the whole compute graph is same on all device
267
+ if num_images == 0:
268
+ if getattr(self.config, "eagle_vision_tower", None) is not None:
269
+ if getattr(self.config, 'only_navit', False):
270
+ cur_image_features = qwen2vl_image_features[batch_idx][cur_batch_image_idx]
271
+ else:
272
+ siglip_feat = image_features[batch_idx][cur_batch_image_idx]
273
+ try:
274
+ qwen2vl_feat = qwen2vl_image_features[batch_idx][cur_batch_image_idx]
275
+ cur_image_features = torch.cat((siglip_feat, qwen2vl_feat), dim=0)
276
+ except Exception as e:
277
+ print(e)
278
+ print("only siglip feature:", siglip_feat.shape)
279
+ cur_image_features = siglip_feat
280
+ else:
281
+ cur_image_features = image_features[batch_idx][cur_batch_image_idx]
282
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
283
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features.squeeze(0)[0:0]], dim=0)
284
+ new_input_embeds.append(cur_input_embeds)
285
+ new_labels.append(labels[batch_idx])
286
+ new_attention_mask.append(attention_mask[batch_idx])
287
+ cur_batch_image_idx += 1
288
+ continue
289
+
290
+ # Step2-2: Split input_ids, labels, attention_mask by IMAGE_TOKEN_INDEX
291
+ cur_input_ids_noim, cur_labels_noim, cur_attention_mask_noim = [], [], []
292
+ cur_labels = labels[batch_idx]
293
+ cur_attention_mask = attention_mask[batch_idx]
294
+ cur_img_attention_mask = [
295
+ attention_mask[batch_idx][i].item()
296
+ for i in torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
297
+ ]
298
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
299
+ for i in range(len(image_token_indices) - 1):
300
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
301
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
302
+ cur_attention_mask_noim.append(cur_attention_mask[image_token_indices[i]+1:image_token_indices[i+1]])
303
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
304
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
305
+ cur_input_embeds_no_im = list(torch.split(cur_input_embeds, split_sizes, dim=0))# get text features
306
+
307
+ # Step2-3: Insert image embedings
308
+ cur_new_input_embeds, cur_new_labels, cur_new_attention_mask = [], [], []
309
+ for i in range(num_images + 1): # to add multimodal feature internal the text feature
310
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
311
+ cur_new_labels.append(cur_labels_noim[i])
312
+ cur_new_attention_mask.append(cur_attention_mask_noim[i])
313
+ if i < num_images:
314
+ if getattr(self.config, "eagle_vision_tower", None) is not None:
315
+ if getattr(self.config, 'only_navit', False):
316
+ cur_image_features = qwen2vl_image_features[batch_idx][cur_batch_image_idx]
317
+ else:
318
+ siglip_feat = image_features[batch_idx][cur_batch_image_idx]
319
+ try:
320
+ qwen2vl_feat = qwen2vl_image_features[batch_idx][cur_batch_image_idx]
321
+ cur_image_features = torch.cat((siglip_feat, qwen2vl_feat), dim=0)
322
+ except Exception as e:
323
+ print(e)
324
+ print("only siglip feature:", siglip_feat.shape)
325
+ cur_image_features = siglip_feat
326
+ else:
327
+ cur_image_features = image_features[batch_idx][cur_batch_image_idx]
328
+ cur_batch_image_idx += 1
329
+ cur_new_input_embeds.append(cur_image_features)
330
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
331
+ cur_new_attention_mask.append(torch.full((cur_image_features.shape[0],), True, device=cur_attention_mask.device, dtype=cur_attention_mask.dtype))
332
+
333
+ # Step2-4: Concat image embedings and text embedings
334
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
335
+ cur_new_labels = torch.cat(cur_new_labels)
336
+ cur_new_attention_mask = torch.cat(cur_new_attention_mask)
337
+ new_input_embeds.append(cur_new_input_embeds)
338
+ new_labels.append(cur_new_labels)
339
+ new_attention_mask.append(cur_new_attention_mask)
340
+
341
+ # Step3: Truncate sequences to max length as image embeddings can make the sequence longer
342
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
343
+ if tokenizer_model_max_length is not None:
344
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
345
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
346
+ new_attention_mask = [x[:tokenizer_model_max_length] for x in new_attention_mask]
347
+
348
+ # Step4: Pad and stack input_embeds, labels, attention_mask
349
+ max_len = max(x.shape[0] for x in new_input_embeds)
350
+ batch_size = len(new_input_embeds)
351
+ new_input_embeds_padded = []
352
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
353
+ new_attention_mask_padded = torch.zeros((batch_size, max_len), dtype=new_attention_mask[0].dtype, device=new_attention_mask[0].device)
354
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
355
+
356
+ for i, (cur_new_embed, cur_new_labels, cur_attention_mask) in enumerate(zip(new_input_embeds, new_labels, new_attention_mask)):
357
+ cur_len = cur_new_embed.shape[0]
358
+ # Right padding when inferencing
359
+ if not self.training and not getattr(self, "right_padding", None):
360
+ new_input_embeds_padded.append(torch.cat((
361
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
362
+ cur_new_embed
363
+ ), dim=0))
364
+ if cur_len > 0:
365
+ new_labels_padded[i, -cur_len:] = cur_new_labels
366
+ new_attention_mask_padded[i, -cur_len:] = cur_attention_mask
367
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
368
+
369
+ # Left padding while training
370
+ else:
371
+ new_input_embeds_padded.append(torch.cat((
372
+ cur_new_embed,
373
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
374
+ ), dim=0))
375
+ if cur_len > 0:
376
+ new_labels_padded[i, :cur_len] = cur_new_labels
377
+ new_attention_mask_padded[i, :cur_len] = cur_attention_mask
378
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
379
+
380
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
381
+ new_labels = new_labels_padded if _labels is not None else None
382
+ new_attention_mask = new_attention_mask_padded if _attention_mask is not None else None
383
+ if _position_ids is None:
384
+ position_ids = None
385
+
386
+ return None, position_ids, new_attention_mask, past_key_values, new_input_embeds, new_labels
387
+
388
+
389
+ class ValleyQwen2Model(ValleyMetaModel, Qwen2Model):
390
+ config_class = ValleyConfig
391
+ def __init__(self, config: Qwen2Config):
392
+ super(ValleyQwen2Model, self).__init__(config)
393
+
394
+
395
+ class ValleyQwen2ForCausalLM(Qwen2ForCausalLM, ValleyMetaForCausalLM):
396
+ config_class = ValleyConfig
397
+
398
+ def __init__(self, config):
399
+ super(Qwen2ForCausalLM, self).__init__(config)
400
+ self.model = ValleyQwen2Model(config)
401
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
402
+ self.post_init()
403
+
404
+ def get_model(self):
405
+ return self.model
406
+
407
+ def forward(
408
+ self,
409
+ input_ids: torch.LongTensor = None,
410
+ attention_mask: Optional[torch.Tensor] = None,
411
+ position_ids: Optional[torch.LongTensor] = None,
412
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
413
+ inputs_embeds: Optional[torch.FloatTensor] = None,
414
+ labels: Optional[torch.LongTensor] = None,
415
+ use_cache: Optional[bool] = None,
416
+ output_attentions: Optional[bool] = None,
417
+ output_hidden_states: Optional[bool] = None,
418
+ images: Optional[torch.FloatTensor] = None,
419
+ return_dict: Optional[bool] = None,
420
+ image_sizes: Optional[List[List[int]]] = None,
421
+ pixel_values: Optional[torch.Tensor] = None,
422
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
423
+ image_grid_thw: Optional[torch.LongTensor] = None,
424
+ video_grid_thw: Optional[torch.LongTensor] = None,
425
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
426
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
427
+ output_hidden_states = (
428
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
429
+ )
430
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
431
+
432
+ if inputs_embeds is None:
433
+ (
434
+ input_ids,
435
+ position_ids,
436
+ attention_mask,
437
+ past_key_values,
438
+ inputs_embeds,
439
+ labels
440
+ ) = self.prepare_inputs_labels_for_multimodal(
441
+ input_ids,
442
+ position_ids,
443
+ attention_mask,
444
+ past_key_values,
445
+ labels,
446
+ images,
447
+ image_sizes,
448
+ pixel_values,
449
+ pixel_values_videos,
450
+ image_grid_thw,
451
+ video_grid_thw,
452
+ )
453
+
454
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
455
+ outputs = self.model(
456
+ input_ids=input_ids,
457
+ attention_mask=attention_mask,
458
+ position_ids=position_ids,
459
+ past_key_values=past_key_values,
460
+ inputs_embeds=inputs_embeds,
461
+ use_cache=use_cache,
462
+ output_attentions=output_attentions,
463
+ output_hidden_states=output_hidden_states,
464
+ return_dict=return_dict,
465
+ )
466
+
467
+ hidden_states = outputs[0]
468
+ logits = self.lm_head(hidden_states)
469
+
470
+ loss = None
471
+ if labels is not None:
472
+ # Shift so that tokens < n predict n
473
+ shift_logits = logits[..., :-1, :].contiguous()
474
+ shift_labels = labels[..., 1:].contiguous()
475
+ loss_fct = CrossEntropyLoss(reduction='mean')
476
+ bs = shift_labels.shape[0]
477
+ shift_labels = shift_labels.to(shift_logits.device)
478
+ loss = torch.stack([loss_fct(shift_logits[i], shift_labels[i]) for i in range(bs)])
479
+
480
+ if not return_dict:
481
+ output = (logits,) + outputs[1:]
482
+ return (loss,) + output if loss is not None else output
483
+
484
+ return CausalLMOutputWithPast(
485
+ loss=loss,
486
+ logits=logits,
487
+ past_key_values=outputs.past_key_values,
488
+ hidden_states=outputs.hidden_states,
489
+ attentions=outputs.attentions,
490
+ )
491
+
492
+ def prepare_inputs_for_generation(
493
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
494
+ ):
495
+ if past_key_values:
496
+ input_ids = input_ids[:, -1:]
497
+
498
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
499
+ if inputs_embeds is not None and past_key_values is None:
500
+ model_inputs = {"inputs_embeds": inputs_embeds}
501
+ else:
502
+ model_inputs = {"input_ids": input_ids}
503
+
504
+ model_inputs.update(
505
+ {
506
+ "past_key_values": past_key_values,
507
+ "use_cache": kwargs.get("use_cache"),
508
+ "attention_mask": attention_mask,
509
+ "images": kwargs.get("images", None),
510
+ "image_sizes": kwargs.get("image_sizes", None),
511
+ "pixel_values": kwargs.get("pixel_values", None),
512
+ "pixel_values_videos": kwargs.get("pixel_values_videos", None),
513
+ "image_grid_thw": kwargs.get("image_grid_thw", None),
514
+ "video_grid_thw": kwargs.get("video_grid_thw", None),
515
+ }
516
+ )
517
+ return model_inputs
518
+
519
+ AutoConfig.register("valley", ValleyConfig)
520
+ AutoModelForCausalLM.register(ValleyConfig, ValleyQwen2ForCausalLM)
modeling_vision_tower.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
4
+ from transformers import PretrainedConfig
5
+
6
+ siglip_config = PretrainedConfig.from_dict(
7
+ {
8
+ "attention_dropout": 0.0,
9
+ "hidden_act": "gelu_pytorch_tanh",
10
+ "hidden_size": 1152,
11
+ "image_size": 384,
12
+ "intermediate_size": 4304,
13
+ "layer_norm_eps": 1e-06,
14
+ "model_type": "siglip_vision_model",
15
+ "num_attention_heads": 16,
16
+ "num_channels": 3,
17
+ "num_hidden_layers": 27,
18
+ "patch_size": 14,
19
+ }
20
+ )
21
+
22
+ qwen2vl_vit_config = PretrainedConfig.from_dict(
23
+ {
24
+ "depth": 32,
25
+ "embed_dim": 1280,
26
+ "hidden_act": "quick_gelu",
27
+ "hidden_size": 3584,
28
+ "in_channels": 3,
29
+ "in_chans": 3,
30
+ "mlp_ratio": 4,
31
+ "model_type": "qwen2_vl",
32
+ "num_heads": 16,
33
+ "patch_size": 14,
34
+ "spatial_merge_size": 2,
35
+ "spatial_patch_size": 14,
36
+ "temporal_patch_size": 2,
37
+ "_attn_implementation": "flash_attention_2",
38
+ "_attn_implementation_internal": "flash_attention_2"
39
+ }
40
+ )
41
+
42
+ def build_vision_tower(vision_tower_cfg, **kwargs):
43
+ vision_tower = getattr(vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None))
44
+ if "siglip-so400m-patch14-384" in vision_tower:
45
+ # Eagle
46
+ if getattr(vision_tower_cfg, "eagle_vision_tower", None) is not None:
47
+ qwen2vl_vision_tower = Qwen2VisionTransformerPretrainedModel._from_config(qwen2vl_vit_config)
48
+
49
+ if getattr(vision_tower_cfg, "navit_merger_hidden_dim", None) is not None:
50
+ del qwen2vl_vision_tower.merger
51
+ qwen2vl_vision_tower.merger = CustomPatchMerger(
52
+ vision_tower_cfg.hidden_size,
53
+ context_dim=1280,
54
+ hidden_dim=getattr(vision_tower_cfg, "navit_merger_hidden_dim", None)
55
+ ) # random initialize
56
+ qwen2vl_vision_tower.requires_grad_(False)
57
+
58
+ # If only use navit, delete siglip_vision_tower
59
+ if getattr(vision_tower_cfg, "only_navit", False):
60
+ siglip_vision_tower = None
61
+ else:
62
+ siglip_vision_tower = SigLipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
63
+
64
+ return siglip_vision_tower, qwen2vl_vision_tower
65
+ # Non-Eagle
66
+ else:
67
+ siglip_vision_tower = SigLipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
68
+ return siglip_vision_tower
69
+ else:
70
+ raise ValueError(f"Unknown vision tower: {vision_tower}")
71
+
72
+ class SigLipVisionTower(nn.Module):
73
+ def __init__(self, vision_tower, args, delay_load=False, cache_dir="./cache_dir"):
74
+ super().__init__()
75
+ self.is_loaded = False
76
+ self.image_tower_name = vision_tower
77
+ self.select_layer = args.mm_vision_select_layer
78
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
79
+ self.cache_dir = cache_dir
80
+
81
+ if not delay_load:
82
+ self.load_model()
83
+ else:
84
+ from transformers import SiglipVisionModel
85
+ self.cfg_only = siglip_config
86
+ self.vision_tower = SiglipVisionModel._from_config(siglip_config) # dummy-load
87
+
88
+ def load_model(self):
89
+ from transformers import SiglipVisionModel
90
+ self.vision_tower = SiglipVisionModel._from_config(siglip_config)
91
+ self.vision_tower.requires_grad_(False)
92
+ self.is_loaded = True
93
+
94
+ def feature_select(self, image_forward_outs):
95
+ assert self.select_feature == "cls_patch"
96
+ image_features = torch.cat([image_forward_outs[:, :1, :], image_forward_outs], dim=1)
97
+ return image_features
98
+
99
+ def forward(self, images):
100
+ if type(images) is list:
101
+ image_features = []
102
+ for image in images:
103
+ image_forward_out = self.vision_tower(
104
+ image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
105
+ output_hidden_states=True,
106
+ return_dict=True,
107
+ )
108
+ image_feature = self.feature_select(image_forward_out.last_hidden_state).to(image.dtype)
109
+ image_features.append(image_feature)
110
+ else:
111
+ image_forward_outs = self.vision_tower(
112
+ images.to(device=self.device, dtype=self.dtype),
113
+ output_hidden_states=True,
114
+ return_dict=True,
115
+ )
116
+ image_features = self.feature_select(image_forward_outs.last_hidden_state).to(images.dtype)
117
+
118
+ return image_features
119
+
120
+ @property
121
+ def dummy_feature(self):
122
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
123
+
124
+ @property
125
+ def dtype(self):
126
+ return self.vision_tower.dtype
127
+
128
+ @property
129
+ def device(self):
130
+ return self.vision_tower.device
131
+
132
+ @property
133
+ def config(self):
134
+ if self.is_loaded:
135
+ return self.vision_tower.config
136
+ else:
137
+ return self.cfg_only
138
+
139
+ @property
140
+ def hidden_size(self):
141
+ return self.config.hidden_size
142
+
143
+ @property
144
+ def num_patches(self):
145
+ return (self.config.image_size // self.config.patch_size) ** 2
146
+
147
+
148
+ class CustomPatchMerger(nn.Module):
149
+ def __init__(self, dim: int, context_dim: int, hidden_dim: int, spatial_merge_size: int = 2) -> None:
150
+ super().__init__()
151
+ self.input_dim = context_dim * (spatial_merge_size**2)
152
+ self.ln_q = nn.LayerNorm(context_dim, eps=1e-6)
153
+ self.mlp = nn.Sequential(
154
+ nn.Linear(self.input_dim, hidden_dim),
155
+ nn.GELU(),
156
+ nn.Linear(hidden_dim, dim),
157
+ )
158
+
159
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
160
+ x = self.mlp(self.ln_q(x).view(-1, self.input_dim))
161
+ return x
preprocessor_config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "processor_class": "ValleyProcessor",
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_valley.ValleyProcessor"
5
+ },
6
+ "min_pixels": 1,
7
+ "qwen2vl_processor_config": {
8
+ "min_pixels": 3136,
9
+ "max_pixels": 12845056,
10
+ "patch_size": 14,
11
+ "temporal_patch_size": 2,
12
+ "merge_size": 2,
13
+ "image_mean": [
14
+ 0.48145466,
15
+ 0.4578275,
16
+ 0.40821073
17
+ ],
18
+ "image_std": [
19
+ 0.26862954,
20
+ 0.26130258,
21
+ 0.27577711
22
+ ],
23
+ "image_processor_type": "Qwen2VLImageProcessor",
24
+ "processor_class": "Qwen2VLProcessor"
25
+ }
26
+ }
processing_valley.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import types
3
+ import io
4
+ import torch
5
+ from PIL import Image
6
+ from qwen_vl_utils import fetch_image
7
+
8
+ from transformers import (
9
+ ProcessorMixin,
10
+ SiglipImageProcessor,
11
+ BatchFeature,
12
+ Qwen2VLImageProcessor,
13
+ PreTrainedTokenizer
14
+ )
15
+
16
+ from .utils import (
17
+ process_anyres_image,
18
+ BLACK_IMG_ENV,
19
+ DEFAULT_IM_END_TOKEN,
20
+ DEFAULT_IM_START_TOKEN,
21
+ DEFAULT_IMAGE_TOKEN,
22
+ DEFAULT_VI_END_TOKEN,
23
+ DEFAULT_VI_START_TOKEN,
24
+ DEFAULT_VIDEO_TOKEN,
25
+ IMAGE_TOKEN_INDEX,
26
+ SEQ_MAX_LEN,
27
+ )
28
+
29
+ siglip_processor_config = {
30
+ "do_normalize": True,
31
+ "do_rescale": True,
32
+ "do_resize": True,
33
+ "image_mean": [
34
+ 0.5,
35
+ 0.5,
36
+ 0.5
37
+ ],
38
+ "image_processor_type": "SiglipImageProcessor",
39
+ "image_std": [
40
+ 0.5,
41
+ 0.5,
42
+ 0.5
43
+ ],
44
+ "processor_class": "SiglipProcessor",
45
+ "resample": 3,
46
+ "rescale_factor": 0.00392156862745098,
47
+ "size": {
48
+ "height": 384,
49
+ "width": 384
50
+ }
51
+ }
52
+
53
+ qwen2vl_processor_config = {
54
+ "min_pixels": 3136,
55
+ "max_pixels": 12845056,
56
+ "patch_size": 14,
57
+ "temporal_patch_size": 2,
58
+ "merge_size": 2,
59
+ "image_mean": [
60
+ 0.48145466,
61
+ 0.4578275,
62
+ 0.40821073
63
+ ],
64
+ "image_std": [
65
+ 0.26862954,
66
+ 0.26130258,
67
+ 0.27577711
68
+ ],
69
+ "image_processor_type": "Qwen2VLImageProcessor",
70
+ "processor_class": "Qwen2VLProcessor"
71
+ }
72
+
73
+ class ValleyProcessor(ProcessorMixin):
74
+ attributes = ["tokenizer"]
75
+ optional_attributes = [
76
+ "max_pixels",
77
+ "min_pixels",
78
+ "anyres",
79
+ "only_crop_single_image",
80
+ "grid_pinpoints",
81
+ "use_special_start_end_token",
82
+ ]
83
+ tokenizer_class = "AutoTokenizer"
84
+
85
+ def __init__(self, tokenizer=None, **kwargs):
86
+ super().__init__(tokenizer, **kwargs)
87
+ self.black_img = BLACK_IMG_ENV
88
+ self.siglip_image_processor = SiglipImageProcessor.from_dict(siglip_processor_config)
89
+ self.qwen2vl_image_processor = Qwen2VLImageProcessor.from_dict(
90
+ qwen2vl_processor_config,
91
+ max_pixels=kwargs.get("max_pixels", 1280*28*28),
92
+ min_pixels=kwargs.get("min_pixels", 4*28*28)
93
+ )
94
+
95
+ self.anyres = kwargs.get("anyres", True)
96
+ self.grid_pinpoints = kwargs.get("grid_pinpoints", "(1x1),...,(3x3)")
97
+ self.only_crop_single_image = kwargs.get("only_crop_single_image", True)
98
+ self.use_special_start_end_token = kwargs.get("use_special_start_end_token", True)
99
+
100
+ def preprocess_images_siglip(self, images) -> torch.FloatTensor:
101
+ if isinstance(images[0], str):
102
+ images_pil = [Image.open(img).convert("RGB") for img in images]
103
+ elif isinstance(images[0], Image.Image):
104
+ images_pil = [img.convert("RGB") for img in images]
105
+ elif isinstance(images[0], bytes):
106
+ images_pil = [Image.open(io.BytesIO(img)).convert("RGB") for img in images]
107
+ else:
108
+ raise ValueError("unsupported type")
109
+
110
+ processed_images = []
111
+ have_multi_images = len(images_pil) > 1
112
+ for img in images_pil:
113
+ if self.anyres:
114
+ if not self.only_crop_single_image or not have_multi_images:
115
+ image = process_anyres_image(img, self.siglip_image_processor, self.grid_pinpoints)
116
+ else:
117
+ image = [self.siglip_image_processor(img, return_tensors="pt")["pixel_values"][0]]
118
+ else:
119
+ image = self.siglip_image_processor(img, return_tensors="pt")["pixel_values"][0]
120
+
121
+ processed_images.append(image)
122
+
123
+ if not self.anyres:
124
+ return torch.stack(processed_images, dim=0)
125
+ else:
126
+ return [torch.stack(img, dim=0) for img in processed_images]
127
+
128
+ def preprocess_images_qwen2vl(self, images) -> dict:
129
+ if isinstance(images[0], str):
130
+ images_pil = [Image.open(img).convert("RGB") for img in images]
131
+ elif isinstance(images[0], Image.Image):
132
+ images_pil = [img.convert("RGB") for img in images]
133
+ elif isinstance(images[0], bytes):
134
+ images_pil = [Image.open(io.BytesIO(img)).convert("RGB") for img in images]
135
+ else:
136
+ raise ValueError("unsupported type")
137
+
138
+ image_sizes = [[x.size for x in images_pil]]
139
+ data_dict_qwen2vl = self.qwen2vl_image_processor(
140
+ [fetch_image({"image": img}) for img in images_pil],
141
+ return_tensors="pt"
142
+ )
143
+
144
+ data_dict_qwen2vl["image_sizes"] = image_sizes
145
+
146
+ return data_dict_qwen2vl
147
+
148
+ def preprocess_multimodal(self, conversations, img_num):
149
+ for sentence in conversations:
150
+ if sentence["role"] == "system":
151
+ continue
152
+ if DEFAULT_VIDEO_TOKEN in sentence["content"]:
153
+ if self.use_special_start_end_token:
154
+ video_replace_token = (DEFAULT_VI_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_VI_END_TOKEN) * img_num
155
+ else:
156
+ video_replace_token = DEFAULT_IMAGE_TOKEN * img_num
157
+ sentence["content"] = sentence["content"].replace(DEFAULT_VIDEO_TOKEN, "").strip()
158
+ sentence["content"] = video_replace_token + "\n" + sentence["content"]
159
+ else:
160
+ segs = re.split(DEFAULT_IMAGE_TOKEN, sentence["content"])
161
+ if self.use_special_start_end_token:
162
+ sentence["content"] = (DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN).join(
163
+ segs[: img_num + 1]
164
+ ) + "".join(segs[img_num + 1 :])
165
+ else:
166
+ sentence["content"] = DEFAULT_IMAGE_TOKEN.join(segs[: img_num + 1]) + "".join(segs[img_num + 1 :])
167
+
168
+ return conversations
169
+
170
+ def preprocess_qwen2(
171
+ self,
172
+ conversations,
173
+ tokenizer: PreTrainedTokenizer,
174
+ has_image: bool = False,
175
+ inference: bool = False,
176
+ only_mask_system: bool = False,
177
+ ) -> dict:
178
+ conv = types.SimpleNamespace(
179
+ system="You are a helpful assistant.",
180
+ roles=("user", "assistant"),
181
+ version="qwen2",
182
+ offset=0,
183
+ sep="<|im_start|>",
184
+ sep2="<|im_end|>\n",
185
+ )
186
+
187
+ # Check system prompt
188
+ assert conversations[0]["role"] == "system"
189
+ if conversations[0]["content"] == None:
190
+ conversations[0]["content"] = conv.system # use default system prompt
191
+
192
+ # Check conversation sequence
193
+ for j, sentence in enumerate(conversations[1:]):
194
+ role = sentence["role"]
195
+ assert role == conv.roles[j % 2], "The conversation sequence is incorrect."
196
+
197
+ conversation_str = tokenizer.apply_chat_template(conversations, tokenize=False, add_generation_prompt=inference)
198
+
199
+ # Mask targets
200
+ rounds = conversation_str.split(conv.sep2)
201
+ input_ids_ = torch.tensor([], dtype=torch.int64)
202
+ targets_ = torch.tensor([], dtype=torch.int64)
203
+ for i, rou in enumerate(rounds):
204
+ if rou == "":
205
+ continue
206
+ if (not inference) or (i < (len(rounds) - 1)):
207
+ rou += conv.sep2
208
+ if has_image:
209
+ cur_input_ids_ = self.tokenizer_image_token(rou, tokenizer, return_tensors='pt')
210
+ input_ids_ = torch.cat([input_ids_, cur_input_ids_], dim=0)
211
+ if only_mask_system:
212
+ mask_len = len(self.tokenizer_image_token(re.sub(rf'{conv.roles[0]}\n[\s\S]*', f'{conv.roles[0]}:', rou),
213
+ tokenizer))
214
+ else:
215
+ mask_len = len(self.tokenizer_image_token(re.sub(rf'{conv.roles[1]}\n[\s\S]*', f'{conv.roles[1]}:', rou),
216
+ tokenizer))
217
+ targets_ = torch.cat([targets_, torch.tensor([-100] * mask_len), cur_input_ids_[mask_len:]], dim=0)
218
+ else:
219
+ cur_input_ids_ = tokenizer(rou, return_tensors='pt')["input_ids"][0, :]
220
+ input_ids_ = torch.cat([input_ids_, cur_input_ids_], dim=0)
221
+ mask_len = len(tokenizer(re.sub(rf'{conv.roles[1]}\n[\s\S]*', rf'{conv.roles[1]}:', rou))["input_ids"][:])
222
+ targets_ = torch.cat([targets_, torch.tensor([-100] * mask_len), cur_input_ids_[mask_len:]], dim=0)
223
+
224
+ return {"input_ids": input_ids_, "labels": targets_}
225
+
226
+
227
+ def tokenizer_image_token(
228
+ self,
229
+ prompt,
230
+ tokenizer,
231
+ image_token_index=IMAGE_TOKEN_INDEX,
232
+ return_tensors=None,
233
+ ):
234
+ def split_with_token(string, token):
235
+ result = string.split(token)
236
+ for i in range(len(result) - 1):
237
+ result.insert(i * 2 + 1, token)
238
+ return result
239
+
240
+ if len(prompt) > SEQ_MAX_LEN:
241
+ raise ValueError("sequence is too long !!!")
242
+
243
+ prompt_chunks = split_with_token(prompt, DEFAULT_IMAGE_TOKEN)
244
+ input_ids, offset = ([tokenizer.bos_token_id], 1) if getattr(tokenizer,'bos_token',None) else ([], 0)
245
+ token2index = {DEFAULT_IMAGE_TOKEN: image_token_index}
246
+ for chunk in prompt_chunks:
247
+ if chunk in token2index:
248
+ input_ids.append(token2index[chunk])
249
+ else:
250
+ chunk_ids = tokenizer(chunk).input_ids
251
+ if chunk_ids[0] != getattr(tokenizer,'bos_token_id', None):
252
+ offset = 0
253
+ input_ids.extend(chunk_ids[offset:])
254
+
255
+ if return_tensors is not None:
256
+ if return_tensors == "pt":
257
+ return torch.tensor(input_ids, dtype=torch.long)
258
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
259
+ return input_ids
260
+
261
+
262
+ def __call__(self, messages, inference=True) -> BatchFeature:
263
+ # Deal with images
264
+ if "images" not in messages or not messages["images"] or not messages["images"][0]:
265
+ images = [self.black_img]
266
+ elif type(messages["images"]) == str:
267
+ images = [messages["images"]]
268
+ else:
269
+ images = messages["images"][:16] # support 16 images
270
+
271
+ # Deal with conversations
272
+ conversations = messages["conversations"]
273
+ if conversations[0]["role"] != "system":
274
+ conversations = [{"role":"system", "content": None}] + conversations # dummy system prompt
275
+
276
+ # Insert special token `<image>`
277
+ assert conversations[1]["role"] == "user"
278
+ if images and "<image>" not in conversations[1]["content"]:
279
+ image_token = " ".join(["<image>"] * len(images))
280
+ conversations[1]["content"] = f"{image_token}\n{conversations[1]['content']}"
281
+
282
+ # The last message should be assistant if inference=True
283
+ if inference:
284
+ assert conversations[-1]["role"] == "user", "the last message should be assistant if inference=True"
285
+
286
+ # Image preprocess
287
+ precessed_images_siglip = self.preprocess_images_siglip(images)
288
+ processed_data_dict_qwen2vl = self.preprocess_images_qwen2vl(images)
289
+ source = self.preprocess_multimodal(conversations, len(precessed_images_siglip))
290
+ data_dict = self.preprocess_qwen2(source, self.tokenizer, has_image=True, only_mask_system=False, inference=inference)
291
+
292
+ # Construct batch data
293
+ data_dict["input_ids"] = data_dict["input_ids"].unsqueeze(0) # batch_size = 1
294
+ data_dict["labels"] = data_dict["labels"].unsqueeze(0)
295
+ data_dict["images"] = [precessed_images_siglip]
296
+
297
+ return BatchFeature(data={**data_dict, **processed_data_dict_qwen2vl})
298
+
299
+ def batch_decode(self, *args, **kwargs):
300
+ """
301
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
302
+ refer to the docstring of this method for more information.
303
+ """
304
+ return self.tokenizer.batch_decode(*args, **kwargs)
305
+
306
+
307
+ def decode(self, *args, **kwargs):
308
+ """
309
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
310
+ the docstring of this method for more information.
311
+ """
312
+ return self.tokenizer.decode(*args, **kwargs)
utils.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+ import math
5
+ import ast
6
+ import re
7
+ import torch
8
+ from transformers import StoppingCriteria
9
+
10
+ IGNORE_INDEX = -100
11
+ IMAGE_TOKEN_INDEX = -200
12
+ GANDALF_TOKEN_INDEX = -300
13
+ DEFAULT_PAD_TOKEN = "[PAD]"
14
+ DEFAULT_EOS_TOKEN = "</s>"
15
+ DEFAULT_BOS_TOKEN = "</s>"
16
+ DEFAULT_UNK_TOKEN = "<unk>"
17
+ DEFAULT_IMAGE_TOKEN = "<image>"
18
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
19
+ DEFAULT_IM_START_TOKEN = "<im_start>"
20
+ DEFAULT_IM_END_TOKEN = "<im_end>"
21
+ DEFAULT_VIDEO_TOKEN = "<video>"
22
+ DEFAULT_VIDEO_FRAME_TOKEN = "<vi_frame>"
23
+ DEFAULT_VI_START_TOKEN = "<vi_start>"
24
+ DEFAULT_VI_END_TOKEN = "<vi_end>"
25
+ DEFAULT_EOC_TOKEN = "<eoc>"
26
+ COR_START_TOKEN = "<cor>"
27
+ COR_END_TOKEN = "<\cor>"
28
+ SEQ_MAX_LEN = 50000
29
+ BLACK_IMG_ENV = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x03\x00\x00\x00\x03\x08\x02\x00\x00\x00\xd9J"\xe8\x00\x00\x00\x12IDAT\x08\x1dcd\x80\x01F\x06\x18`d\x80\x01\x00\x00Z\x00\x04we\x03N\x00\x00\x00\x00IEND\xaeB`\x82'
30
+
31
+
32
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
33
+ """
34
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
35
+
36
+ Args:
37
+ image_size (tuple): The size of the input image in the format (width, height).
38
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
39
+ patch_size (int): The size of each image patch.
40
+
41
+ Returns:
42
+ tuple: The shape of the image patch grid in the format (width, height).
43
+ """
44
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
45
+ assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
46
+ # Use regex to extract the range from the input string
47
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
48
+ range_start = tuple(map(int, matches[0]))
49
+ range_end = tuple(map(int, matches[-1]))
50
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
51
+ grid_pinpoints = [
52
+ (i, j)
53
+ for i in range(range_start[0], range_end[0] + 1)
54
+ for j in range(range_start[1], range_end[1] + 1)
55
+ ]
56
+ # Multiply all elements by patch_size
57
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
58
+ if type(grid_pinpoints) is list:
59
+ possible_resolutions = grid_pinpoints
60
+ else:
61
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
62
+ width, height = select_best_resolution(image_size, possible_resolutions)
63
+ return width // patch_size, height // patch_size
64
+
65
+ def select_best_resolution(original_size, possible_resolutions):
66
+ """
67
+ Selects the best resolution from a list of possible resolutions based on the original size.
68
+
69
+ Args:
70
+ original_size (tuple): The original size of the image in the format (width, height).
71
+ possible_resolutions (list): A list of possible resolutions in the format
72
+ [(width1, height1), (width2, height2), ...].
73
+
74
+ Returns:
75
+ tuple: The best fit resolution in the format (width, height).
76
+ """
77
+ original_width, original_height = original_size
78
+ best_fit = None
79
+ max_effective_resolution = 0
80
+ min_wasted_resolution = float("inf")
81
+
82
+ for width, height in possible_resolutions:
83
+ # Calculate the downscaled size to keep the aspect ratio
84
+ scale = min(width / original_width, height / original_height)
85
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
86
+
87
+ # Calculate effective and wasted resolutions
88
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
89
+ wasted_resolution = (width * height) - effective_resolution
90
+
91
+ if effective_resolution > max_effective_resolution or \
92
+ (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
93
+ max_effective_resolution = effective_resolution
94
+ min_wasted_resolution = wasted_resolution
95
+ best_fit = (width, height)
96
+
97
+ return best_fit
98
+
99
+
100
+ def unpad_image(tensor, original_size):
101
+ """
102
+ Unpads a PyTorch tensor of a padded and resized image.
103
+
104
+ Args:
105
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
106
+ original_size (tuple): The original size of the image (height, width).
107
+
108
+ Returns:
109
+ torch.Tensor: The unpadded image tensor.
110
+ """
111
+ original_width, original_height = original_size
112
+ current_height, current_width = tensor.shape[1:]
113
+
114
+ # Compute aspect ratios
115
+ original_aspect_ratio = original_width / original_height
116
+ current_aspect_ratio = current_width / current_height
117
+
118
+ # Determine padding size and direction
119
+ if original_aspect_ratio > current_aspect_ratio:
120
+ # Padding was added to the height
121
+ scale_factor = current_width / original_width
122
+ new_height = int(original_height * scale_factor)
123
+ padding = (current_height - new_height) // 2
124
+ unpadded_tensor = tensor[:, padding: current_height - padding, :]
125
+ else:
126
+ # Padding was added to the width
127
+ scale_factor = current_height / original_height
128
+ new_width = int(original_width * scale_factor)
129
+ padding = (current_width - new_width) // 2
130
+ unpadded_tensor = tensor[:, :, padding: current_width - padding]
131
+
132
+ return unpadded_tensor
133
+
134
+
135
+ def process_anyres_image(image, processor, grid_pinpoints):
136
+ """
137
+ Process an image with variable resolutions.
138
+
139
+ Args:
140
+ image (PIL.Image.Image): The input image to be processed.
141
+ processor: The image processor object.
142
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
143
+
144
+ Returns:
145
+ torch.Tensor: A tensor containing the processed image patches.
146
+ """
147
+ # Convert grid_pinpoints from string to list
148
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
149
+ try:
150
+ patch_size = processor.size["height"]
151
+ except Exception:
152
+ patch_size = processor.size["shortest_edge"]
153
+ assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
154
+ # Use regex to extract the range from the input string
155
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
156
+ range_start = tuple(map(int, matches[0]))
157
+ range_end = tuple(map(int, matches[-1]))
158
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
159
+ grid_pinpoints = [
160
+ (i, j)
161
+ for i in range(range_start[0], range_end[0] + 1)
162
+ for j in range(range_start[1], range_end[1] + 1)
163
+ ]
164
+ # Multiply all elements by patch_size
165
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
166
+
167
+ if type(grid_pinpoints) is list:
168
+ possible_resolutions = grid_pinpoints
169
+ else:
170
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
171
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
172
+ image_padded = resize_and_pad_image(image, best_resolution)
173
+
174
+ patches = divide_to_patches(image_padded, processor.size["height"])
175
+
176
+ # FIXME: this seems to be a bug that it resizes instead of pad.
177
+ # but to keep it consistent with previous, i will keep it as it is
178
+ # TODO: uncomment below to ablate with the padding
179
+ if isinstance(processor.size, dict):
180
+ shortest_edge = processor.size["height"]
181
+ else:
182
+ shortest_edge = min(processor.size)
183
+ image_original_resize = image.resize((shortest_edge, shortest_edge))
184
+ # image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
185
+
186
+ image_patches = [image_original_resize] + patches
187
+ image_patches = [
188
+ processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0]
189
+ for image_patch in image_patches
190
+ ]
191
+ # return torch.stack(image_patches, dim=0)
192
+ return image_patches
193
+
194
+ def resize_and_pad_image(image, target_resolution):
195
+ """
196
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
197
+
198
+ Args:
199
+ image (PIL.Image.Image): The input image.
200
+ target_resolution (tuple): The target resolution (width, height) of the image.
201
+
202
+ Returns:
203
+ PIL.Image.Image: The resized and padded image.
204
+ """
205
+ original_width, original_height = image.size
206
+ target_width, target_height = target_resolution
207
+
208
+ # Determine which dimension (width or height) to fill
209
+ scale_w = target_width / original_width
210
+ scale_h = target_height / original_height
211
+
212
+ if scale_w < scale_h:
213
+ # Width will be filled completely
214
+ new_width = target_width
215
+ new_height = min(math.ceil(original_height * scale_w), target_height)
216
+ else:
217
+ # Height will be filled completely
218
+ new_height = target_height
219
+ new_width = min(math.ceil(original_width * scale_h), target_width)
220
+
221
+ # Resize the image
222
+ resized_image = image.resize((new_width, new_height))
223
+
224
+ # Create a new image with the target size and paste the resized image onto it
225
+ new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
226
+ paste_x = (target_width - new_width) // 2
227
+ paste_y = (target_height - new_height) // 2
228
+ new_image.paste(resized_image, (paste_x, paste_y))
229
+
230
+ return new_image
231
+
232
+ def divide_to_patches(image, patch_size):
233
+ """
234
+ Divides an image into patches of a specified size.
235
+
236
+ Args:
237
+ image (PIL.Image.Image): The input image.
238
+ patch_size (int): The size of each patch.
239
+
240
+ Returns:
241
+ list: A list of PIL.Image.Image objects representing the patches.
242
+ """
243
+ patches = []
244
+ width, height = image.size
245
+ for i in range(0, height, patch_size):
246
+ for j in range(0, width, patch_size):
247
+ box = (j, i, j + patch_size, i + patch_size)
248
+ patch = image.crop(box)
249
+ patches.append(patch)
250
+
251
+ return patches