Bingsu commited on
Commit
b4f79b3
·
verified ·
1 Parent(s): e8c3ba1

Delete preprocessor.py

Browse files
Files changed (1) hide show
  1. preprocessor.py +0 -1583
preprocessor.py DELETED
@@ -1,1583 +0,0 @@
1
- import base64
2
- import copy
3
- import io
4
- import math
5
- import os
6
- import uuid
7
- from typing import Dict, List, Optional, Union
8
- from urllib.parse import urlparse
9
-
10
- import av
11
- import cv2
12
- import numpy as np
13
- import requests
14
- import torch
15
- from decord import VideoReader, cpu
16
- from PIL import Image, UnidentifiedImageError
17
- from transformers.image_processing_utils import (
18
- BaseImageProcessor,
19
- BatchFeature,
20
- get_size_dict,
21
- )
22
- from transformers.image_transforms import (
23
- convert_to_rgb,
24
- get_resize_output_image_size,
25
- resize,
26
- to_channel_dimension_format,
27
- )
28
- from transformers.image_utils import (
29
- OPENAI_CLIP_MEAN,
30
- OPENAI_CLIP_STD,
31
- ChannelDimension,
32
- ImageInput,
33
- PILImageResampling,
34
- get_image_size,
35
- infer_channel_dimension_format,
36
- is_scaled_image,
37
- make_list_of_images,
38
- to_numpy_array,
39
- valid_images,
40
- )
41
- from transformers.utils import TensorType, logging
42
-
43
- logger = logging.get_logger(__name__)
44
-
45
-
46
- def determine_possible_resolutions(anyres: bool, max_num_grids: int, grid_size: int, use_1x1_grid: bool = False):
47
- """
48
- Finds and returns possible resolution combinations with a total number of grids less than or equal to max_num_grids.
49
-
50
- For example, if max_num_grids is 4, the possible grid combinations are:
51
- [1x1, 1x2, 1x3, 1x4, 2x1, 2x2, 3x1, 4x1], and the resolutions are calculated accordingly.
52
-
53
- Example:
54
- >>> possible_resolutions = determine_possible_resolutions(anyres=True, max_num_grids=4, grid_size=336)
55
- >>> print(possible_resolutions)
56
- [[336, 336], [336, 672], [336, 1008], [336, 1344], [672, 336], [672, 672], [1008, 336], [1344, 336]]
57
-
58
- Args:
59
- anyres (bool): Whether to allow any resolution combinations up to the maximum grid count.
60
- max_num_grids (int): The maximum number of grids allowed (height x width must be ≤ this value).
61
- grid_size (int): The size of each grid in pixels (e.g., 336).
62
- use_1x1_grid (bool, optional): Whether to include the 1x1 grid as a valid resolution. Defaults to False.
63
-
64
- Returns:
65
- List[List[int]]: A list of possible [height, width] resolution pairs.
66
- """
67
- possible_resolutions = []
68
- if anyres:
69
- assert max_num_grids > 0
70
- for i in range(1, max_num_grids + 1):
71
- for j in range(1, max_num_grids + 1):
72
- if i == 1 and j == 1 and not use_1x1_grid:
73
- continue
74
- if i * j <= max_num_grids:
75
- possible_resolutions.append([i, j])
76
-
77
- possible_resolutions = [[ys * grid_size, xs * grid_size] for ys, xs in possible_resolutions]
78
-
79
- return possible_resolutions
80
-
81
-
82
- def divide_to_grids(image: np.array, grid_size: int, input_data_format=None) -> List[np.array]:
83
- """
84
- Divides a local image into grids of size (grid_size x grid_size).
85
-
86
- Args:
87
- image (np.array): Input image as a NumPy array.
88
- grid_size (int): The size (in pixels) of each square grid.
89
- input_data_format (optional): Optional format specifier (e.g., "channels_first" or "channels_last").
90
-
91
- Returns:
92
- List[np.array]: A list of image patches, each of size (grid_size x grid_size).
93
- """
94
- grids = []
95
- height, width = get_image_size(image, channel_dim=input_data_format)
96
- for i in range(0, height, grid_size):
97
- for j in range(0, width, grid_size):
98
- if input_data_format == ChannelDimension.LAST:
99
- grid = image[i : i + grid_size, j : j + grid_size]
100
- else:
101
- grid = image[:, i : i + grid_size, j : j + grid_size]
102
- grids.append(grid)
103
-
104
- return grids
105
-
106
-
107
- def pad(
108
- image: np.array,
109
- target_size: tuple,
110
- background_color=(127, 127, 127),
111
- input_data_format=None,
112
- ) -> np.array:
113
- """
114
- Pads the input image on the sides (top/bottom and left/right) to match the target height and width.
115
-
116
- Args:
117
- image (np.array): Input image as a NumPy array.
118
- target_size (tuple): Target size as (target_height, target_width).
119
- background_color (tuple, optional): RGB color value used for padding. Defaults to (127, 127, 127).
120
- input_data_format (optional): Optional format specifier (e.g., "channels_first" or "channels_last").
121
-
122
- Returns:
123
- np.array: The padded image with the specified target size.
124
- """
125
- target_height, target_width = target_size
126
- height, width = get_image_size(image, channel_dim=input_data_format)
127
-
128
- # result = np.ones((target_height, target_width, image.shape[2]), dtype=image.dtype) * background_color
129
- result = np.empty((target_height, target_width, image.shape[2]), dtype=image.dtype)
130
- for i in range(image.shape[2]):
131
- result[..., i].fill(background_color[i])
132
-
133
- paste_x = (target_width - width) // 2
134
- paste_y = (target_height - height) // 2
135
-
136
- result[paste_y : paste_y + height, paste_x : paste_x + width, :] = image
137
-
138
- return result
139
-
140
-
141
- def expand2square(
142
- image: np.array,
143
- bboxes_dict=None,
144
- background_color=(127, 127, 127),
145
- input_data_format=None,
146
- ) -> np.array:
147
- """
148
- Expands the input image to a square shape by placing it at the center of a new square canvas,
149
- with padding added to the shorter side (either top/bottom or left/right).
150
-
151
- The image is always centered on the new canvas, and padding is applied symmetrically.
152
-
153
- Args:
154
- image (np.array): Input image as a NumPy array.
155
- bboxes_dict (dict, optional): A dictionary of bounding boxes, where each value is an NDArray of shape (N, 4, 2)
156
- with box coordinates in the format [[xtl, ytl], [xtr, ytr], [xbr, ybr], [xbl, ybl]].
157
- Supports multiple categories (e.g., "ocr", "html") simultaneously.
158
- background_color (tuple, optional): RGB color to fill the padding area. Defaults to (127, 127, 127).
159
- input_data_format (optional): Optional format specifier for image data (e.g., "channels_first" or "channels_last").
160
-
161
- Returns:
162
- np.array: A square-shaped image with the original image centered and padded as needed.
163
-
164
- Example:
165
- >>> _img = np.ones((80, 100), dtype=np.uint8) * 100
166
- >>> _bboxes_dict = {"words": np.array([[[10, 10], [20, 10], [20, 20], [10, 20]],
167
- ... [[30, 30], [40, 30], [40, 40], [30, 40]]])}
168
- >>> _img, _bboxes_dict = expand2square(_img, _bboxes_dict, (255, 255, 255))
169
- >>> _img.shape
170
- (100, 100)
171
- >>> guessed_ocr_bboxes = np.array([[[20, 10], [30, 10], [30, 20], [20, 20]],
172
- ... [[40, 30], [50, 30], [50, 40], [40, 40]]])
173
- >>> np.testing.assert_array_almost_equal(_bboxes_dict["words"], guessed_ocr_bboxes) is None
174
- True
175
- """
176
- height, width = get_image_size(image, channel_dim=input_data_format)
177
- if width == height:
178
- return image, bboxes_dict
179
- elif width > height:
180
- # result = np.ones((width, width, image.shape[2]), dtype=image.dtype) * background_color
181
- result = np.empty((width, width, image.shape[2]), dtype=image.dtype)
182
- for i in range(image.shape[2]):
183
- result[..., i].fill(background_color[i])
184
-
185
- result[(width - height) // 2 : (width - height) // 2 + height, :] = image
186
- if bboxes_dict is not None:
187
- for key in bboxes_dict:
188
- bboxes_dict[key][:, :, 1] += (width - height) // 2
189
- return result, bboxes_dict
190
- else:
191
- # result = np.ones((height, height, image.shape[2]), dtype=image.dtype) * background_color
192
- result = np.empty((height, height, image.shape[2]), dtype=image.dtype)
193
- for i in range(image.shape[2]):
194
- result[..., i].fill(background_color[i])
195
-
196
- result[:, (height - width) // 2 : (height - width) // 2 + width] = image
197
- if bboxes_dict is not None:
198
- for key in bboxes_dict:
199
- bboxes_dict[key][:, :, 0] += (height - width) // 2
200
- return result, bboxes_dict
201
-
202
-
203
- def resize_longside(
204
- image: np.array,
205
- size: int,
206
- resample: PILImageResampling = PILImageResampling.BICUBIC,
207
- data_format: Optional[Union[str, ChannelDimension]] = None,
208
- input_data_format: Optional[Union[str, ChannelDimension]] = None,
209
- ):
210
- """
211
- Resizes the image so that its longer side matches the specified size, maintaining the original aspect ratio.
212
-
213
- Args:
214
- image (np.array): Input image as a NumPy array.
215
- size (int): Target size for the longer side of the image.
216
- resample (PILImageResampling, optional): Resampling method to use during resizing. Defaults to BICUBIC.
217
- data_format (str or ChannelDimension, optional): Output data format (e.g., "channels_first" or "channels_last").
218
- input_data_format (str or ChannelDimension, optional): Input data format of the image.
219
-
220
- Returns:
221
- np.array: The resized image with its aspect ratio preserved.
222
- """
223
- height, width = get_image_size(image, channel_dim=input_data_format)
224
-
225
- if width == height:
226
- target_height, target_width = size, size
227
- elif width > height:
228
- target_width = size
229
- target_height = math.ceil(height / width * size)
230
- else:
231
- target_width = math.ceil(width / height * size)
232
- target_height = size
233
-
234
- return resize(
235
- image,
236
- size=(target_height, target_width),
237
- resample=resample,
238
- data_format=data_format,
239
- input_data_format=input_data_format,
240
- )
241
-
242
-
243
- def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
244
- """
245
- Selects the best-fit resolution from a list of possible resolutions based on the original image size.
246
-
247
- This function, adapted from LLaVA-Next
248
- (https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/llava_next/image_processing_llava_next.py),
249
- evaluates each resolution by computing its effective and wasted area compared to the original size.
250
- The optimal resolution is the one that maximizes the effective area while minimizing unused (wasted) space.
251
-
252
- Args:
253
- original_size (tuple): The original image size in the format (height, width).
254
- possible_resolutions (list): A list of candidate resolutions in the format [(height1, width1), (height2, width2), ...].
255
-
256
- Returns:
257
- tuple: The best-fit resolution in the format (height, width).
258
- """
259
- original_height, original_width = original_size
260
- best_fit = None
261
- max_effective_resolution = 0
262
- min_wasted_resolution = float("inf")
263
-
264
- for height, width in possible_resolutions:
265
- scale = min(width / original_width, height / original_height)
266
- downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
267
- effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
268
- wasted_resolution = (width * height) - effective_resolution
269
-
270
- if effective_resolution > max_effective_resolution or (
271
- effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
272
- ):
273
- max_effective_resolution = effective_resolution
274
- min_wasted_resolution = wasted_resolution
275
- best_fit = (height, width)
276
-
277
- return best_fit
278
-
279
-
280
- def _get_local_grids_output_size(image: np.array, target_resolution: tuple, input_data_format=None):
281
- """
282
- Computes the number of local grids (patches) along the height and width when resizing an image
283
- to the target resolution.
284
-
285
- Args:
286
- image (np.array): Input image as a NumPy array.
287
- target_resolution (tuple): Target resolution in the format (target_height, target_width).
288
- input_data_format (optional): Optional format specifier (e.g., "channels_first" or "channels_last").
289
-
290
- Returns:
291
- tuple: A tuple (grid_h, grid_w) representing the number of grids along the height and width.
292
- """
293
- original_height, original_width = get_image_size(image, channel_dim=input_data_format)
294
- target_height, target_width = target_resolution
295
-
296
- scale_w = target_width / original_width
297
- scale_h = target_height / original_height
298
-
299
- if scale_w < scale_h:
300
- new_width = target_width
301
- new_height = min(math.ceil(original_height * scale_w), target_height)
302
- else:
303
- new_height = target_height
304
- new_width = min(math.ceil(original_width * scale_h), target_width)
305
-
306
- return new_height, new_width
307
-
308
-
309
- def determine_anyres_num_vision_patches(
310
- num_grids,
311
- image_size,
312
- grid_size,
313
- patch_size,
314
- possible_resolutions,
315
- anyres=False,
316
- unpad=True,
317
- num_queries_vis_abstractor=0,
318
- num_queries_vis_abstractor_slow=0,
319
- is_video=False,
320
- first_last_frames_slow=False, # sample-wise option
321
- is_first_or_last_frames=False, # grid-wise option
322
- ):
323
- """
324
- Computes the number of visual tokens (patches) based on image resolution, grid configuration, and patch size.
325
-
326
- This function supports both fixed-size and any-resolution settings, as well as video-specific configurations
327
- such as handling slow frames and frame position flags.
328
-
329
- Args:
330
- num_grids (int): Number of grids per image (e.g., 1 for 1x1, 4 for 2x2, etc.).
331
- image_size (tuple): The original image size as (height, width).
332
- grid_size (int): Size of each grid in pixels (e.g., 336).
333
- patch_size (int): Size of each vision patch (e.g., 14 for ViT models).
334
- possible_resolutions (list): List of possible resolution tuples [(h1, w1), (h2, w2), ...].
335
- anyres (bool, optional): Whether to use any-resolution mode. Defaults to False.
336
- unpad (bool, optional): Whether to unpad the image before computing patches. Defaults to True.
337
- num_queries_vis_abstractor (int, optional): Number of query tokens for vision abstractor (fast path).
338
- num_queries_vis_abstractor_slow (int, optional): Number of query tokens for vision abstractor (slow path).
339
- is_video (bool, optional): Whether the input is a video. Defaults to False.
340
- first_last_frames_slow (bool, optional): Whether to treat first/last video frames as "slow". Defaults to False.
341
- is_first_or_last_frames (bool, optional): Whether current grid corresponds to first/last frame. Defaults to False.
342
-
343
- Returns:
344
- int: Total number of visual tokens (patches) after processing.
345
- """
346
- if not anyres:
347
- return num_queries_vis_abstractor if num_queries_vis_abstractor > 0 else (grid_size // patch_size) ** 2
348
-
349
- if num_queries_vis_abstractor > 0:
350
- num_patch_per_grid = int(num_queries_vis_abstractor**0.5)
351
- else:
352
- num_patch_per_grid = grid_size // patch_size
353
-
354
- num_global_per_grid = num_patch_per_grid
355
-
356
- # In anyres mode, a global image is included, so there are always at least 2 grids.
357
- # However, for video inputs, there is no global image, so it's possible to have only 1 grid.
358
- # Therefore, the assertion below is commented out:
359
- # assert num_grids > 1
360
-
361
- # Compute the number of vision patches.
362
- height, width = select_best_resolution(image_size, possible_resolutions)
363
-
364
- num_patch_height = (height // grid_size) * num_patch_per_grid
365
- num_patch_width = (width // grid_size) * num_patch_per_grid
366
-
367
- # local images
368
- if unpad:
369
- original_height, original_width = image_size
370
-
371
- original_aspect_ratio = original_width / original_height
372
- current_aspect_ratio = num_patch_width / num_patch_height
373
-
374
- if original_aspect_ratio > current_aspect_ratio:
375
- scale_factor = num_patch_width / original_width
376
- new_height = int(original_height * scale_factor)
377
- padding = (num_patch_height - new_height) // 2
378
- num_patch_height = num_patch_height - padding * 2
379
- else:
380
- scale_factor = num_patch_height / original_height
381
- new_width = int(original_width * scale_factor)
382
- padding = (num_patch_width - new_width) // 2
383
- num_patch_width = num_patch_width - padding * 2
384
-
385
- num_patches = num_patch_width * num_patch_height + num_patch_height
386
- else:
387
- num_patches = num_patch_width * num_patch_height
388
-
389
- # In the "slow" strategy, when applying to first and last frames only, it is applied exclusively to those two frames.
390
- if num_queries_vis_abstractor_slow > 0:
391
- if first_last_frames_slow:
392
- if is_first_or_last_frames:
393
- num_patches += num_queries_vis_abstractor_slow - num_queries_vis_abstractor
394
- else:
395
- num_patches += num_queries_vis_abstractor_slow - num_queries_vis_abstractor
396
- # The slowfast feature is only applicable when unpad is set to False.
397
- assert unpad is False
398
-
399
- # Global image is not included for video inputs.
400
- if not is_video:
401
- num_patches += num_global_per_grid**2
402
-
403
- return num_patches
404
-
405
-
406
- class HCXVisionProcessor(BaseImageProcessor):
407
- r"""
408
- Constructs a VLM image processor.
409
-
410
- This processor is based on [`CLIPImageProcessor`] and incorporates additional techniques
411
- for handling high-resolution images, such as flexible resolution support (`anyres`), unpadding,
412
- square padding, and multi-grid patching strategies.
413
-
414
- Args:
415
- do_resize (bool): Whether to resize the image.
416
- size (Dict[str, int], optional): Target size for resizing, typically with keys `"height"` and `"width"`.
417
- anyres (bool): Whether to enable the any-resolution (`anyres`) feature, which allows flexible resolution handling via grid division.
418
- unpad (bool): When `anyres` is enabled, whether to remove visual tokens corresponding to pure padding regions.
419
- max_num_grids (int): Maximum number of grids allowed per image.
420
- max_image_cnt (int): Maximum number of images that can be processed at once (used for batching).
421
- num_queries_vis_abstractor (int): Number of visual query tokens per grid when using a visual resampler (e.g., Perceiver).
422
- num_queries_vis_abstractor_video_fast (int): Number of visual queries for fast-path video frames.
423
- num_queries_vis_abstractor_video_slow (int): Number of visual queries for slow-path video frames (e.g., first/last).
424
- possible_resolutions (List): List of allowed resolution pairs when `anyres` is enabled. Example: [[336, 336], [336, 672], [672, 336]].
425
- patch_size (int): Patch size for the Vision Transformer (ViT).
426
- pad_to_square (bool): Whether to pad images to a square shape. If `False`, a center crop is applied to fit ViT input.
427
- resample (PILImageResampling): Resampling method to use for resizing. Default is `BICUBIC`.
428
- do_center_crop (bool): Whether to apply center cropping.
429
- crop_size (Dict[str, int], optional): Size for center cropping.
430
- do_rescale (bool): Whether to rescale pixel values.
431
- rescale_factor (float or int): Factor to use for rescaling pixel values (typically `1/255`).
432
- do_normalize (bool): Whether to normalize pixel values using `image_mean` and `image_std`.
433
- image_mean (float or List[float], optional): Mean values for normalization. Can be a single float or list of floats per channel.
434
- image_std (float or List[float], optional): Standard deviation values for normalization. Can be a single float or list of floats per channel.
435
- do_convert_rgb (bool): Whether to convert the input image to RGB.
436
- first_last_frames_slow (bool): Whether to treat the first and last frames of a video as “slow path” (processed differently).
437
-
438
- Attributes:
439
- model_input_names (List[str]): Names of the expected model inputs. Defaults to `["pixel_values"]`.
440
- """
441
-
442
- model_input_names = ["pixel_values"]
443
-
444
- def __init__(
445
- self,
446
- do_resize: bool = True,
447
- size: Dict[str, int] = None,
448
- anyres: bool = False,
449
- unpad: bool = False,
450
- max_num_grids: int = 9,
451
- max_image_cnt: int = 12,
452
- num_queries_vis_abstractor: int = 0,
453
- num_queries_vis_abstractor_video_fast: int = 0,
454
- num_queries_vis_abstractor_video_slow: int = 0,
455
- possible_resolutions: List = [],
456
- patch_size: int = 14,
457
- pad_to_square: bool = True,
458
- resample: PILImageResampling = PILImageResampling.BICUBIC,
459
- do_center_crop: bool = True,
460
- crop_size: Dict[str, int] = None,
461
- do_rescale: bool = True,
462
- rescale_factor: Union[int, float] = 1 / 255,
463
- do_normalize: bool = True,
464
- image_mean: Optional[Union[float, List[float]]] = None,
465
- image_std: Optional[Union[float, List[float]]] = None,
466
- do_convert_rgb: bool = True,
467
- first_last_frames_slow: bool = False,
468
- **kwargs,
469
- ) -> None:
470
- super().__init__(**kwargs)
471
- size = size if size is not None else {"shortest_edge": 512}
472
- size = get_size_dict(size, default_to_square=False)
473
- crop_size = crop_size if crop_size is not None else {"height": 512, "width": 512}
474
- crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
475
-
476
- self.do_resize = do_resize
477
- self.size = size
478
- self.anyres = anyres
479
- self.unpad = unpad
480
- self.max_num_grids = max_num_grids
481
- self.max_image_cnt = max_image_cnt
482
- self.num_queries_vis_abstractor = num_queries_vis_abstractor
483
- self.num_queries_vis_abstractor_video_fast = num_queries_vis_abstractor_video_fast
484
- self.num_queries_vis_abstractor_video_slow = num_queries_vis_abstractor_video_slow
485
- self.possible_resolutions = [_resolution for _resolution in possible_resolutions]
486
- self.patch_size = patch_size
487
- self.pad_to_square = pad_to_square
488
- self.resample = resample
489
- self.do_center_crop = do_center_crop
490
- self.crop_size = crop_size
491
- self.do_rescale = do_rescale
492
- self.rescale_factor = rescale_factor
493
- self.do_normalize = do_normalize
494
- self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
495
- self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
496
- self.do_convert_rgb = do_convert_rgb
497
- self.first_last_frames_slow = first_last_frames_slow
498
-
499
- assert self.crop_size["height"] == self.crop_size["width"]
500
-
501
- def resize(
502
- self,
503
- image: np.ndarray,
504
- size: Dict[str, int],
505
- resample: PILImageResampling = PILImageResampling.BICUBIC,
506
- data_format: Optional[Union[str, ChannelDimension]] = None,
507
- input_data_format: Optional[Union[str, ChannelDimension]] = None,
508
- **kwargs,
509
- ) -> np.ndarray:
510
- """
511
- Resizes the input image to the specified target size.
512
-
513
- Args:
514
- image (np.ndarray): The input image to resize.
515
- size (Dict[str, int]): A dictionary specifying the target size with keys `"height"` and `"width"`.
516
- resample (PILImageResampling, optional): The resampling filter to use. Defaults to `BICUBIC`.
517
- data_format (str or ChannelDimension, optional): The desired output data format (e.g., "channels_last").
518
- input_data_format (str or ChannelDimension, optional): The input data format of the image.
519
- **kwargs: Additional keyword arguments, if any.
520
-
521
- Returns:
522
- np.ndarray: The resized image as a NumPy array.
523
- """
524
- default_to_square = True
525
- if "shortest_edge" in size:
526
- size = size["shortest_edge"]
527
- default_to_square = False
528
- elif "height" in size and "width" in size:
529
- size = (size["height"], size["width"])
530
- else:
531
- raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
532
-
533
- output_size = get_resize_output_image_size(
534
- image,
535
- size=size,
536
- default_to_square=default_to_square,
537
- input_data_format=input_data_format,
538
- )
539
-
540
- return resize(
541
- image,
542
- size=output_size,
543
- resample=resample,
544
- data_format=data_format,
545
- input_data_format=input_data_format,
546
- **kwargs,
547
- )
548
-
549
- def _preprocess(
550
- self,
551
- images: ImageInput,
552
- do_resize: bool = None,
553
- size: Dict[str, int] = None,
554
- resample: PILImageResampling = None,
555
- do_center_crop: bool = None,
556
- crop_size: int = None,
557
- do_rescale: bool = None,
558
- rescale_factor: float = None,
559
- do_normalize: bool = None,
560
- image_mean: Optional[Union[float, List[float]]] = None,
561
- image_std: Optional[Union[float, List[float]]] = None,
562
- data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
563
- input_data_format: Optional[Union[str, ChannelDimension]] = None,
564
- ) -> Image.Image:
565
- """
566
- Applies a sequence of preprocessing operations to the input image(s), including resizing, cropping, rescaling,
567
- normalization, and format conversion.
568
-
569
- This method is typically used internally to prepare images for model input.
570
-
571
- Args:
572
- images (ImageInput): A single image or a batch of images to preprocess.
573
- do_resize (bool, optional): Whether to resize the image(s).
574
- size (Dict[str, int], optional): Target size for resizing, with keys `"height"` and `"width"`.
575
- resample (PILImageResampling, optional): Resampling method to use for resizing.
576
- do_center_crop (bool, optional): Whether to apply center cropping.
577
- crop_size (int, optional): Size of the center crop (applied to both height and width).
578
- do_rescale (bool, optional): Whether to rescale the image pixel values.
579
- rescale_factor (float, optional): Factor to use when rescaling pixel values (e.g., 1/255).
580
- do_normalize (bool, optional): Whether to normalize the image using `image_mean` and `image_std`.
581
- image_mean (float or List[float], optional): Mean value(s) used for normalization.
582
- image_std (float or List[float], optional): Standard deviation value(s) used for normalization.
583
- data_format (ChannelDimension, optional): The desired output data format (e.g., `ChannelDimension.FIRST`).
584
- input_data_format (str or ChannelDimension, optional): The format of the input image(s).
585
-
586
- Returns:
587
- Image.Image: The preprocessed image or batch of images, ready for model input.
588
- """
589
- images = make_list_of_images(images)
590
-
591
- if do_resize:
592
- images = [
593
- self.resize(
594
- image=image,
595
- size=size,
596
- resample=resample,
597
- input_data_format=input_data_format,
598
- )
599
- for image in images
600
- ]
601
-
602
- if do_center_crop:
603
- images = [
604
- self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images
605
- ]
606
-
607
- if do_rescale:
608
- images = [
609
- self.rescale(
610
- image=image,
611
- scale=rescale_factor,
612
- input_data_format=input_data_format,
613
- )
614
- for image in images
615
- ]
616
-
617
- if do_normalize:
618
- images = [
619
- self.normalize(
620
- image=image,
621
- mean=image_mean,
622
- std=image_std,
623
- input_data_format=input_data_format,
624
- )
625
- for image in images
626
- ]
627
-
628
- images = [
629
- to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
630
- ]
631
-
632
- return images
633
-
634
- def _resize_for_local_grids(
635
- self,
636
- image: np.array,
637
- target_resolution: tuple,
638
- resample,
639
- input_data_format: ChannelDimension,
640
- ) -> np.array:
641
- """
642
- Resizes the image to the given target resolution for use in local grid processing.
643
-
644
- This function ensures that the image is properly resized to match the (height, width) specified
645
- in `target_resolution`, using the provided resampling method. It supports channel-first and
646
- channel-last formats based on `input_data_format`.
647
-
648
- Args:
649
- image (np.array): Input image as a NumPy array.
650
- target_resolution (tuple): Target resolution as (height, width) for resizing.
651
- resample: Resampling method to use (e.g., `PILImageResampling.BICUBIC`).
652
- input_data_format (ChannelDimension): Format of the input image (e.g., `ChannelDimension.FIRST` or `LAST`).
653
-
654
- Returns:
655
- np.array: The resized image in NumPy array format.
656
- """
657
- new_height, new_width = _get_local_grids_output_size(image, target_resolution, input_data_format)
658
-
659
- # Resize the image
660
- resized_image = resize(
661
- image,
662
- (new_height, new_width),
663
- resample=resample,
664
- input_data_format=input_data_format,
665
- )
666
-
667
- return resized_image
668
-
669
- def _pad_for_patching(
670
- self,
671
- image: np.array,
672
- target_resolution: tuple,
673
- input_data_format: ChannelDimension,
674
- ) -> np.array:
675
- """
676
- Pads the image to match the target resolution, ensuring compatibility with patch-based models.
677
-
678
- This is typically used to make sure the image dimensions are divisible by the patch size or to
679
- meet specific model input requirements. Padding is applied symmetrically where needed.
680
-
681
- Args:
682
- image (np.array): Input image as a NumPy array.
683
- target_resolution (tuple): The desired resolution after padding, in the format (height, width).
684
- input_data_format (ChannelDimension): Format of the input image (e.g., `ChannelDimension.FIRST` or `LAST`).
685
-
686
- Returns:
687
- np.array: The padded image as a NumPy array.
688
- """
689
- target_height, target_width = target_resolution
690
-
691
- background_color = tuple(int(x * 255) for x in self.image_mean)
692
- padded_image = pad(
693
- image,
694
- target_size=(target_height, target_width),
695
- background_color=background_color,
696
- input_data_format=input_data_format,
697
- )
698
-
699
- return padded_image
700
-
701
- def get_image_grids(
702
- self,
703
- image: np.array,
704
- possible_resolutions,
705
- grid_size: int,
706
- resample: PILImageResampling,
707
- data_format: ChannelDimension,
708
- input_data_format: ChannelDimension,
709
- ) -> List[np.array]:
710
- """
711
- Splits the input image into multiple local grids based on possible resolutions and grid size.
712
-
713
- The function selects the best resolution from the provided list, resizes the image accordingly,
714
- and divides it into non-overlapping grid patches of size (grid_size x grid_size). It is commonly
715
- used for any-resolution (anyres) visual processing.
716
-
717
- Args:
718
- image (np.array): Input image as a NumPy array.
719
- possible_resolutions (List[Tuple[int, int]]): List of allowed resolutions to choose from.
720
- grid_size (int): The size of each grid patch (e.g., 336 pixels).
721
- resample (PILImageResampling): Resampling method used during resizing.
722
- data_format (ChannelDimension): Output data format (e.g., `ChannelDimension.FIRST`).
723
- input_data_format (ChannelDimension): Input data format of the image.
724
-
725
- Returns:
726
- List[np.array]: A list of grid image patches as NumPy arrays.
727
- """
728
- if not isinstance(possible_resolutions, list):
729
- raise ValueError("possible_resolutions must be a list of possible resolutions.")
730
-
731
- image_size = get_image_size(image, channel_dim=input_data_format)
732
- best_resolution = select_best_resolution(image_size, possible_resolutions)
733
- resized_image = self._resize_for_local_grids(
734
- image,
735
- best_resolution,
736
- resample=resample,
737
- input_data_format=input_data_format,
738
- )
739
- padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format)
740
- local_grids = divide_to_grids(padded_image, grid_size=grid_size, input_data_format=input_data_format)
741
-
742
- # make sure that all patches are in the input data format
743
- local_grids = [
744
- to_channel_dimension_format(grid, channel_dim=data_format, input_channel_dim=input_data_format)
745
- for grid in local_grids
746
- ]
747
-
748
- return local_grids
749
-
750
- def preprocess(
751
- self,
752
- images: ImageInput,
753
- do_resize: bool = None,
754
- size: Dict[str, int] = None,
755
- anyres: bool = None,
756
- unpad: bool = None,
757
- is_video_list: List[bool] = None,
758
- possible_resolutions: List = None,
759
- patch_size: int = None,
760
- pad_to_square: bool = None,
761
- resample: PILImageResampling = None,
762
- do_center_crop: bool = None,
763
- crop_size: int = None,
764
- do_rescale: bool = None,
765
- rescale_factor: float = None,
766
- do_normalize: bool = None,
767
- image_mean: Optional[Union[float, List[float]]] = None,
768
- image_std: Optional[Union[float, List[float]]] = None,
769
- do_convert_rgb: bool = None,
770
- return_tensors: Optional[Union[str, TensorType]] = None,
771
- data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
772
- input_data_format: Optional[Union[str, ChannelDimension]] = None,
773
- is_first_or_last_frames: List[bool] = False,
774
- ):
775
- """
776
- Preprocesses images using HCXVisionProcessor.
777
-
778
- This method prepares images for visual language models by applying resizing, padding, cropping,
779
- normalization, and tokenization into visual patches. In video mode, each frame is converted to
780
- a 1D sequence of patches. The `unpad` option is disabled when processing videos.
781
-
782
- Args:
783
- images (ImageInput): A single image or a batch of images (PIL, NumPy, or tensor format).
784
- do_resize (bool, optional): Whether to resize the image(s).
785
- size (Dict[str, int], optional): Resize target with keys `"height"` and `"width"`.
786
- anyres (bool, optional): Whether to use any-resolution processing with grid splitting.
787
- unpad (bool, optional): Whether to remove visual tokens that belong to padding areas (only in non-video mode).
788
- is_video_list (List[bool], optional): A list indicating which inputs are video frames.
789
- possible_resolutions (List, optional): List of resolution pairs allowed in `anyres` mode.
790
- patch_size (int, optional): Patch size for the Vision Transformer (ViT).
791
- pad_to_square (bool, optional): Whether to pad the image to a square.
792
- resample (PILImageResampling, optional): Resampling method to use for resizing.
793
- do_center_crop (bool, optional): Whether to apply center cropping.
794
- crop_size (int, optional): Target crop size for center cropping.
795
- do_rescale (bool, optional): Whether to rescale image pixel values.
796
- rescale_factor (float, optional): Factor for pixel rescaling, e.g., `1/255`.
797
- do_normalize (bool, optional): Whether to normalize using mean and std.
798
- image_mean (float or List[float], optional): Mean value(s) for normalization.
799
- image_std (float or List[float], optional): Standard deviation(s) for normalization.
800
- do_convert_rgb (bool, optional): Whether to convert the image to RGB.
801
- return_tensors (str or TensorType, optional): Desired output tensor type (e.g., "pt" for PyTorch).
802
- data_format (ChannelDimension, optional): Output data format (e.g., `ChannelDimension.FIRST`).
803
- input_data_format (str or ChannelDimension, optional): Format of the input image.
804
- is_first_or_last_frames (List[bool], optional): Flags indicating whether each image is a first/last video frame.
805
-
806
- Returns:
807
- Tuple:
808
- pixel_values (List[torch.Tensor]): A list of 4D image tensors ready for model input.
809
- image_sizes (List[List[int]]): A list of list containing the original width and height [width, height]
810
- of each image, e.g., `[[width, height], ...]`.
811
- vision_query_lengths (List[int]): A list of integers representing the number of visual tokens
812
- each image contributes to the LLM input.
813
- """
814
- do_resize = do_resize if do_resize is not None else self.do_resize
815
- size = size if size is not None else self.size
816
- size = get_size_dict(size, param_name="size", default_to_square=False)
817
- anyres = anyres if anyres is not None else self.anyres
818
- unpad = unpad if unpad is not None else self.unpad
819
- possible_resolutions = possible_resolutions if possible_resolutions is not None else self.possible_resolutions
820
- patch_size = patch_size if patch_size is not None else self.patch_size
821
- pad_to_square = pad_to_square if pad_to_square is not None else self.pad_to_square
822
- resample = resample if resample is not None else self.resample
823
- do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
824
- crop_size = crop_size if crop_size is not None else self.crop_size
825
- crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
826
- do_rescale = do_rescale if do_rescale is not None else self.do_rescale
827
- rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
828
- do_normalize = do_normalize if do_normalize is not None else self.do_normalize
829
- image_mean = image_mean if image_mean is not None else self.image_mean
830
- image_std = image_std if image_std is not None else self.image_std
831
- do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
832
-
833
- images = make_list_of_images(images)
834
-
835
- if not valid_images(images):
836
- raise ValueError(
837
- "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
838
- "torch.Tensor, tf.Tensor or jax.ndarray."
839
- )
840
-
841
- if do_convert_rgb:
842
- images = [convert_to_rgb(image) for image in images]
843
-
844
- # All transformations expect numpy arrays.
845
- images = [to_numpy_array(image) for image in images]
846
-
847
- if is_scaled_image(images[0]) and do_rescale:
848
- logger.warning_once(
849
- "It looks like you are trying to rescale already rescaled images. If the input"
850
- " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
851
- )
852
-
853
- if input_data_format is None:
854
- # We assume that all images have the same channel dimension format.
855
- input_data_format = infer_channel_dimension_format(images[0])
856
-
857
- new_images = []
858
- image_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images]
859
- vision_query_lengths = []
860
-
861
- assert crop_size["height"] == crop_size["width"]
862
-
863
- # Padding operations for the global image can become a bottleneck when the original image width or height is large.
864
- # To mitigate this, the image is first resized such that the longest side is scaled proportionally based on size["shortest_edge"],
865
- # and then padding is applied to reach the target dimensions.
866
- if anyres:
867
- anyres_global_images = copy.deepcopy(images)
868
- if pad_to_square:
869
- background_color = tuple(int(x * 255) for x in self.image_mean)
870
- anyres_global_images = [
871
- resize_longside(
872
- copy.deepcopy(image),
873
- size["shortest_edge"],
874
- resample,
875
- input_data_format,
876
- )
877
- for image in anyres_global_images
878
- ]
879
- anyres_global_images = [
880
- expand2square(
881
- image,
882
- background_color=background_color,
883
- input_data_format=input_data_format,
884
- )[0]
885
- for image in anyres_global_images
886
- ]
887
- else:
888
- anyres_global_images = [
889
- self.resize(
890
- image=image,
891
- size={
892
- "height": size["shortest_edge"],
893
- "width": size["shortest_edge"],
894
- },
895
- resample=resample,
896
- input_data_format=input_data_format,
897
- )
898
- for image in anyres_global_images
899
- ]
900
- else:
901
- anyres_global_images = [None for _ in range(len(images))]
902
- if pad_to_square:
903
- background_color = tuple(int(x * 255) for x in self.image_mean)
904
- images = [
905
- resize_longside(image, size["shortest_edge"], resample, input_data_format) for image in images
906
- ]
907
- images = [
908
- expand2square(
909
- image,
910
- background_color=background_color,
911
- input_data_format=input_data_format,
912
- )[0]
913
- for image in images
914
- ]
915
-
916
- num_queries_vis_abstractors = []
917
- num_queries_vis_abstractors_slow = []
918
- first_last_frames_slows = []
919
-
920
- for image, is_video, anyres_global_image, image_size in zip(
921
- images, is_video_list, anyres_global_images, image_sizes
922
- ):
923
- if is_video:
924
- num_queries_vis_abstractor = self.num_queries_vis_abstractor_video_fast
925
- num_queries_vis_abstractor_slow = self.num_queries_vis_abstractor_video_slow
926
- else:
927
- num_queries_vis_abstractor = self.num_queries_vis_abstractor
928
- num_queries_vis_abstractor_slow = 0
929
-
930
- num_queries_vis_abstractors.append(num_queries_vis_abstractor)
931
- num_queries_vis_abstractors_slow.append(num_queries_vis_abstractor_slow)
932
- first_last_frames_slows.append(self.first_last_frames_slow)
933
-
934
- if anyres:
935
- # convert image into a list of grids
936
- # we intentially use the same data format as the input data format
937
- image_grids = self.get_image_grids(
938
- image,
939
- possible_resolutions,
940
- grid_size=crop_size["height"],
941
- resample=resample,
942
- data_format=input_data_format,
943
- input_data_format=input_data_format,
944
- )
945
- # Global image (thumbnail) is not used for video inputs.
946
- if not is_video:
947
- image_grids = [anyres_global_image] + image_grids
948
- else:
949
- image_grids = [image]
950
-
951
- pixel_values = self._preprocess(
952
- image_grids,
953
- do_resize=do_resize,
954
- size=size,
955
- resample=resample,
956
- do_center_crop=do_center_crop,
957
- crop_size=crop_size,
958
- do_rescale=do_rescale,
959
- rescale_factor=rescale_factor,
960
- do_normalize=do_normalize,
961
- image_mean=image_mean,
962
- image_std=image_std,
963
- data_format=data_format,
964
- input_data_format=input_data_format,
965
- )
966
-
967
- pixel_values = np.array(pixel_values)
968
- new_images.append(pixel_values)
969
-
970
- num_grids = pixel_values.shape[0]
971
-
972
- vision_query_length = determine_anyres_num_vision_patches(
973
- num_grids=num_grids,
974
- image_size=image_size,
975
- grid_size=crop_size["height"],
976
- patch_size=patch_size,
977
- possible_resolutions=possible_resolutions,
978
- anyres=anyres,
979
- unpad=False if is_video else unpad,
980
- num_queries_vis_abstractor=num_queries_vis_abstractor,
981
- num_queries_vis_abstractor_slow=num_queries_vis_abstractor_slow,
982
- is_video=is_video,
983
- first_last_frames_slow=self.first_last_frames_slow,
984
- is_first_or_last_frames=self.first_last_frames_slow,
985
- )
986
-
987
- vision_query_lengths.append(vision_query_length)
988
-
989
- data = {
990
- "pixel_values": [[torch.tensor(new_image) for new_image in new_images]],
991
- "image_sizes": [[[image_size[1], image_size[0]] for image_size in image_sizes]],
992
- "vision_query_lengths": [vision_query_lengths],
993
- "is_videos": [is_video_list],
994
- "num_queries_vis_abstractors": [num_queries_vis_abstractors],
995
- "num_queries_vis_abstractors_slow": [num_queries_vis_abstractors_slow],
996
- "first_last_frames_slows": [first_last_frames_slows],
997
- }
998
-
999
- return BatchFeature(data=data)
1000
-
1001
- def load_images_videos(self, vlm_chat):
1002
- """
1003
- Loads and prepares images or video frames from a VLM chat input.
1004
-
1005
- This function parses the input `vlm_chat` object, extracts image or video sources,
1006
- and loads them into memory as PIL or NumPy images, ready for preprocessing.
1007
-
1008
- Args:
1009
- vlm_chat: A VLM chat input structure containing multimodal elements
1010
- (e.g., images, videos, URLs, or file paths). The format is typically a list of messages
1011
- with associated media fields.
1012
-
1013
- Returns:
1014
- List[Union[PIL.Image.Image, List[PIL.Image.Image]]]:
1015
- A list of loaded images. For video entries, a list of frames is returned instead of a single image.
1016
- """
1017
- vlm_chat = copy.deepcopy(vlm_chat)
1018
-
1019
- new_vlm_chat = []
1020
- all_images = [] # images + images_from_videos
1021
- is_video_list = []
1022
-
1023
- for line in vlm_chat:
1024
- if "content" in line:
1025
- content = line["content"]
1026
-
1027
- if "image" in content:
1028
- if "filename" not in content:
1029
- content["filename"] = f"{uuid.uuid4().hex}.jpg"
1030
- image_pil = load_image(content["image"])
1031
- all_images.append(image_pil)
1032
- is_video_list.append(False)
1033
- new_vlm_chat.append(line)
1034
-
1035
- elif "video" in content:
1036
- video_bytesio = load_video_to_bytesio(content["video"])
1037
- pil_img_frames, video_time_stamp = process_video(
1038
- video_bytesio, self.max_num_grids, self.max_image_cnt, self.crop_size["width"]
1039
- )
1040
- all_images.extend(pil_img_frames)
1041
- is_video_list.extend([True] * len(pil_img_frames))
1042
-
1043
- if "filename" not in content:
1044
- content["filename"] = f"{uuid.uuid4().hex}.mp4"
1045
-
1046
- for i, image_time_stamp in enumerate(video_time_stamp):
1047
- new_line = copy.deepcopy(line)
1048
- basename, ext = os.path.splitext(content["filename"])
1049
- new_line["content"]["filename"] = f"{basename}-{i}{ext}"
1050
- new_line["content"]["video_time_stamp"] = image_time_stamp
1051
-
1052
- if i == len(video_time_stamp) - 1:
1053
- new_line["content"]["is_final_grid"] = True
1054
-
1055
- for last_frame_target_key in ["lens_keywords", "lens_local_keywords", "speech_to_text"]:
1056
- if last_frame_target_key in content:
1057
- new_line["content"][last_frame_target_key] = content[last_frame_target_key]
1058
-
1059
- new_vlm_chat.append(new_line)
1060
- else:
1061
- new_vlm_chat.append(line)
1062
-
1063
- return new_vlm_chat, all_images, is_video_list
1064
-
1065
-
1066
- def process_video(video_bytesio, max_num_grids, max_image_cnt, vit_input_size):
1067
- """
1068
- Processes a video file and extracts frames suitable for vision transformer (ViT) input.
1069
-
1070
- The function reads video data from a BytesIO object, extracts a limited number of frames
1071
- based on `max_num_grids` and `max_image_cnt`, and resizes them to the appropriate ViT input size.
1072
-
1073
- Args:
1074
- video_bytesio (io.BytesIO): A BytesIO object containing the raw video file data.
1075
- max_num_grids (int): The maximum number of grids allowed (e.g., for tiling or patching).
1076
- max_image_cnt (int): The maximum number of frames to extract from the video.
1077
- vit_input_size (int): The desired input size (height and width) for the ViT model.
1078
-
1079
- Returns:
1080
- List[np.ndarray]: A list of processed video frames as NumPy arrays, each resized to (vit_input_size, vit_input_size).
1081
- """
1082
- frames, time_interval = video_decoder(
1083
- video_bytesio, max_num_grids=max_num_grids, max_image_cnt=max_image_cnt, default_interval=0.4
1084
- )
1085
- pil_img_frames, video_time_stamp = combine_frames_into_images(
1086
- frames, time_interval, max_grid_shape=(max_num_grids, 1), vit_input_size=vit_input_size
1087
- )
1088
-
1089
- return pil_img_frames, video_time_stamp
1090
-
1091
-
1092
- def load_image(image_src):
1093
- """
1094
- Loads an image from various sources (file path, URL, base64 string, or raw bytes)
1095
- and returns it as a PIL Image object.
1096
-
1097
- Args:
1098
- image_src (str or bytes): The image source. It can be:
1099
- - A local file path
1100
- - A URL
1101
- - A base64-encoded string
1102
- - Raw image bytes
1103
-
1104
- Returns:
1105
- PIL.Image.Image: The loaded image as a PIL Image object.
1106
-
1107
- Raises:
1108
- ValueError: If the image cannot be loaded or the format is unsupported.
1109
- TypeError: If the input is not of type str or bytes.
1110
- """
1111
- try:
1112
- # 1. If input is bytes type
1113
- if isinstance(image_src, bytes):
1114
- return Image.open(io.BytesIO(image_src))
1115
-
1116
- # 2. If input is str type (path, URL, base64)
1117
- if isinstance(image_src, str):
1118
- # 2a. Check if it's a Base64 data URI format ('data:image/...')
1119
- if image_src.startswith("data:image"):
1120
- try:
1121
- # Remove the 'data:image/...;base64,' part and decode
1122
- header, encoded = image_src.split(",", 1)
1123
- image_bytes = base64.b64decode(encoded)
1124
- return Image.open(io.BytesIO(image_bytes))
1125
- except (ValueError, base64.binascii.Error) as e:
1126
- raise ValueError(f"Invalid base64 data URI format: {e}") from e
1127
-
1128
- # 2b. Check if it's a URL format ('http://' or 'https://')
1129
- elif image_src.startswith("http://") or image_src.startswith("https://"):
1130
- try:
1131
- response = requests.get(image_src, stream=True, timeout=10)
1132
- response.raise_for_status() # Raise an exception for HTTP errors
1133
- image_bytes = response.content
1134
- return Image.open(io.BytesIO(image_bytes))
1135
- except requests.exceptions.RequestException as e:
1136
- raise ValueError(f"Error loading image from URL '{image_src}': {e}") from e
1137
-
1138
- # 2c. Assume it's a local file path
1139
- else:
1140
- return Image.open(image_src)
1141
-
1142
- else:
1143
- raise TypeError(f"Unsupported image_src type: {type(image_src)}")
1144
-
1145
- # Common exception handling
1146
- except FileNotFoundError:
1147
- raise ValueError(f"Image loading error: File not found '{image_src}'")
1148
- except UnidentifiedImageError:
1149
- raise ValueError("Image loading error: Cannot identify image file format.")
1150
- except IOError as e:
1151
- raise ValueError(f"Image loading error (I/O): {e}") from e
1152
- except Exception as e:
1153
- raise ValueError(f"Unexpected error during image loading: {e}") from e
1154
-
1155
-
1156
- def load_video_to_bytesio(video_src):
1157
- """
1158
- Loads video data from various sources (file path, URL, base64 string, or raw bytes)
1159
- and returns an `io.BytesIO` object containing the raw video content.
1160
-
1161
- Args:
1162
- video_src (str or bytes): The video source. Supported formats include:
1163
- - Local file path
1164
- - URL
1165
- - Base64-encoded data URI string
1166
- - Raw video bytes
1167
-
1168
- Returns:
1169
- io.BytesIO: A `BytesIO` object containing the loaded video data.
1170
-
1171
- Raises:
1172
- ValueError: If the video cannot be loaded due to issues such as an invalid path,
1173
- URL failure, malformed base64 string, or unsupported format.
1174
- TypeError: If the input is not a `str` or `bytes` object.
1175
- """
1176
- video_bytes = None
1177
- try:
1178
- # 1. If input is bytes type
1179
- if isinstance(video_src, bytes):
1180
- video_bytes = video_src
1181
-
1182
- # 2. If input is str type (path, URL, base64)
1183
- elif isinstance(video_src, str):
1184
- # 2a. Check if it's a Base64 data URI format ('data:video/...')
1185
- if video_src.startswith("data:video"):
1186
- try:
1187
- # Remove the 'data:video/...;base64,' part and decode
1188
- header, encoded = video_src.split(",", 1)
1189
- video_bytes = base64.b64decode(encoded)
1190
- except (ValueError, base64.binascii.Error) as e:
1191
- raise ValueError(f"Invalid base64 data URI format: {e}") from e
1192
-
1193
- # 2b. Check if it looks like a URL
1194
- elif urlparse(video_src).scheme in ("http", "https"):
1195
- try:
1196
- response = requests.get(
1197
- video_src, stream=True, timeout=30
1198
- ) # Increased timeout for potentially large videos
1199
- response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
1200
- # Read all content from the stream into bytes
1201
- video_bytes = response.content
1202
- except requests.exceptions.MissingSchema:
1203
- # If urlparse thinks it's a scheme but requests disagrees (e.g., "http:/example.com")
1204
- # Treat it as a potential file path below.
1205
- pass
1206
- except requests.exceptions.RequestException as e:
1207
- raise ValueError(f"Error loading video from URL '{video_src}': {e}") from e
1208
-
1209
- # 2c. Assume it's a local file path if not base64 or confirmed URL
1210
- if video_bytes is None: # Only attempt file read if not already loaded as base64 or URL failed gracefully
1211
- # Check if it could potentially be a file path
1212
- # Note: This check is basic. A string like "http:/path/file" might incorrectly be treated as a path here
1213
- # if the requests call failed due to MissingSchema. More robust path validation could be added.
1214
- if (
1215
- os.path.exists(video_src) or "/" in video_src or "\\" in video_src
1216
- ): # Basic check if it resembles a path
1217
- try:
1218
- with open(video_src, "rb") as f:
1219
- video_bytes = f.read()
1220
- except FileNotFoundError:
1221
- raise ValueError(f"Video loading error: File not found at path '{video_src}'")
1222
- except IsADirectoryError:
1223
- raise ValueError(f"Video loading error: Path '{video_src}' is a directory, not a file.")
1224
- except IOError as e:
1225
- raise ValueError(f"Video loading error (I/O) for path '{video_src}': {e}") from e
1226
- else:
1227
- # If it's not base64, not a valid downloadable URL, and doesn't look like a path/doesn't exist
1228
- raise ValueError(f"Unsupported string input format or resource not found: '{video_src}'")
1229
-
1230
- # 3. If the type is unsupported
1231
- else:
1232
- raise TypeError(f"Unsupported video_src type: {type(video_src)}")
1233
-
1234
- # Final check if video_bytes was successfully obtained
1235
- if video_bytes is None:
1236
- raise ValueError(f"Could not load video data from the provided source: {video_src}")
1237
-
1238
- # Return the bytes wrapped in BytesIO
1239
- return io.BytesIO(video_bytes)
1240
-
1241
- # Catch specific exceptions first for better error reporting
1242
- except FileNotFoundError as e: # Should be caught above, but as a safeguard
1243
- raise ValueError(f"Video loading error: File not found '{video_src}'") from e
1244
- except requests.exceptions.RequestException as e: # Already handled, but for clarity
1245
- raise ValueError(f"Video loading error (Network): {e}") from e
1246
- except (ValueError, TypeError) as e: # Re-raise ValueErrors/TypeErrors raised intentionally within the try block
1247
- raise e
1248
- except Exception as e:
1249
- # Catch any other unexpected errors during processing
1250
- raise ValueError(f"Unexpected error during video loading from source '{video_src}': {e}") from e
1251
-
1252
-
1253
- def video_decoder(video_bytesio, max_num_grids, max_image_cnt, default_interval=0.4):
1254
- """
1255
- Decodes video data from a BytesIO object and returns a list of extracted frames.
1256
-
1257
- Args:
1258
- video_bytesio (io.BytesIO): A BytesIO object containing the raw video data.
1259
- max_num_grids (int): Maximum number of grids allowed per image. Used to determine how many frames to extract.
1260
- max_image_cnt (int): Maximum number of frames to extract from the video.
1261
- default_interval (float, optional): Default time interval (in seconds) between frames. Used when frame rate info is unavailable. TODO: make configurable.
1262
-
1263
- Returns:
1264
- Tuple:
1265
- frames (List[PIL.Image.Image]): A list of extracted frames as PIL Images.
1266
- time_interval (float): Time interval (in seconds) between selected frames.
1267
- """
1268
- error_messages = []
1269
- frames = []
1270
-
1271
- # 1. Try decoding the video using Decord.
1272
- try:
1273
- vr = VideoReader(video_bytesio, ctx=cpu(0), num_threads=8)
1274
- fps = vr.get_avg_fps()
1275
- play_time = len(vr) / fps
1276
- total_frames = len(vr)
1277
- frame_indices, time_interval = extract_frame_indices(
1278
- play_time, total_frames, fps, max_num_grids, max_image_cnt, default_interval=default_interval
1279
- ) # Sample every 0.4 seconds; if the video is too long, apply uniform sampling instead.
1280
- if frame_indices is None:
1281
- frame_indices = range(len(vr)) # Convert all frames.
1282
- batch_frames = vr.get_batch(frame_indices).asnumpy()
1283
- frames = [Image.fromarray(frame).convert("RGB") for frame in batch_frames]
1284
- return frames, time_interval
1285
- except Exception as e:
1286
- print("error with decord")
1287
- error_messages.append(f"Decord 실패: {e}")
1288
-
1289
- # 2. Fallback: Try decoding the video using PyAV.
1290
- try:
1291
- container = av.open(video_bytesio)
1292
- fps = container.streams.video[0].average_rate
1293
- play_time = len(container) / fps
1294
- total_frames = len(container)
1295
- frame_indices, time_interval = extract_frame_indices(
1296
- play_time, total_frames, fps, max_num_grids, max_image_cnt, default_interval=default_interval
1297
- ) # Sample frames every 0.4 seconds. If the video is long, use uniform sampling to limit the number of frames.
1298
- # Even if frame_indices were assigned using Decord, reprocess them to be compatible with PyAV.
1299
- target_indices = None if frame_indices is None else set(frame_indices)
1300
- frames = []
1301
- for i, frame in enumerate(container.decode(video=0)):
1302
- if target_indices is not None and i not in target_indices:
1303
- continue # Skip frames that are not in the required indices.
1304
- pil_frame = Image.fromarray(frame.to_ndarray(format="rgb24")).convert("RGB")
1305
- frames.append(pil_frame)
1306
- if frames:
1307
- return frames, time_interval
1308
- else:
1309
- raise Exception("Decoding with PyAV succeeded, but no frames were extracted.")
1310
- except Exception as e:
1311
- error_messages.append(f"PyAV failed: {e}")
1312
-
1313
- # 3. Fallback: Try decoding the video using OpenCV.
1314
- try:
1315
- byte_data = np.frombuffer(video_bytesio.getvalue(), dtype=np.uint8)
1316
- video = cv2.imdecode(byte_data, cv2.IMREAD_UNCHANGED)
1317
-
1318
- cap = cv2.VideoCapture(video)
1319
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
1320
- fps = cap.get(cv2.CAP_PROP_FPS)
1321
- play_time = total_frames / fps
1322
- frame_indices, time_interval = extract_frame_indices(
1323
- play_time, total_frames, fps, max_num_grids, max_image_cnt, default_interval=default_interval
1324
- ) # Sample frames every 0.4 seconds; if the video is too long, apply uniform sampling to limit the total number of frames.
1325
- if frame_indices is None:
1326
- frame_indices = range(total_frames) # Convert all frames.
1327
-
1328
- index_set = set(frame_indices) # Convert to a set for faster lookup.
1329
- current_index = 0
1330
-
1331
- while cap.isOpened():
1332
- ret, frame = cap.read()
1333
- if not ret:
1334
- break
1335
- if current_index in index_set:
1336
- frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).convert("RGB"))
1337
- current_index += 1
1338
- if current_index > max(index_set): # Stop processing once all required indices have been handled.
1339
- break
1340
-
1341
- cap.release()
1342
- if frames:
1343
- return frames, time_interval
1344
- except Exception as e:
1345
- error_messages.append(f"OpenCV failed: {e}")
1346
-
1347
- if error_messages:
1348
- raise Exception(f"All decoding attempts have failed.: {error_messages}")
1349
-
1350
-
1351
- def convert_format_for_multi_image(img, json, convert_key_list=["words", "text", "objects", "entities"]):
1352
- """
1353
- Converts the format of image and annotation data from a single-image dataset to a multi-image dataset format.
1354
-
1355
- Single-image datasets typically return a single image and its associated annotation as individual objects.
1356
- This function wraps them in a dictionary format used by multi-image datasets.
1357
-
1358
- Args:
1359
- img: The input image (e.g., a PIL Image or NumPy array).
1360
- json: The annotation data associated with the image.
1361
- convert_key_list (List[str], optional): A list of keys to extract and convert from the original JSON.
1362
- Defaults to ["words", "text", "objects", "entities"].
1363
-
1364
- Returns:
1365
- Tuple[Dict, Dict]:
1366
- - A dictionary mapping image IDs to images (e.g., {"image_0": img}).
1367
- - A dictionary mapping image IDs to corresponding annotation JSONs (with filtered keys).
1368
- """
1369
- is_multi_image_dataset = isinstance(img, dict)
1370
- if not is_multi_image_dataset:
1371
- img = {"00": img}
1372
-
1373
- for convert_key in convert_key_list:
1374
- if convert_key in json:
1375
- json[convert_key] = {"00": json[convert_key]}
1376
-
1377
- for json_key in json:
1378
- if "region" in json_key:
1379
- json[json_key] = {"00": json[json_key]}
1380
-
1381
- return is_multi_image_dataset, img, json
1382
-
1383
-
1384
- def convert_tags_for_video(img, json):
1385
- """
1386
- Converts <video_00> tags to <image_xx> tags based on the number of video frames.
1387
-
1388
- In video datasets, annotations often use a generic <video_00> tag. This function replaces that tag
1389
- with frame-specific tags such as <image_00>, <image_01>, ..., <image_NN> based on the number of frames in `img`.
1390
-
1391
- Args:
1392
- img: A list of video frames (e.g., list of PIL Images or NumPy arrays).
1393
- json: The annotation data containing <video_00> tags to be replaced.
1394
-
1395
- Returns:
1396
- Dict: The updated annotation JSON with frame-specific <image_xx> tags.
1397
- """
1398
- image_tag = "".join([f"<image_{idx:02d}>" for idx in range(len(img))])
1399
- # image_tag = "<image_00>" # Use this format to construct and insert image-specific tags.
1400
- for json_key in json:
1401
- if "qa_pairs" in json_key:
1402
- new_qa_pairs = []
1403
- for qa_pair in json[json_key]:
1404
- question = qa_pair[0]
1405
- # Replace <video_00> tags with corresponding <image_xx> tags.
1406
- question = question.replace("<video_00>", image_tag)
1407
- new_qa_pairs.append([question, qa_pair[1]])
1408
- json[json_key] = new_qa_pairs
1409
-
1410
- return img, json
1411
-
1412
-
1413
- def split_list(input_list, split_value):
1414
- """
1415
- Splits a list into sublists using a specified delimiter value.
1416
-
1417
- Each time `split_value` is encountered in `input_list`, a new sublist is started.
1418
- The delimiter itself is not included in the output.
1419
-
1420
- Args:
1421
- input_list (List[Any]): The input list to split.
1422
- split_value (Any): The value used as the delimiter for splitting.
1423
-
1424
- Returns:
1425
- List[List[Any]]: A list of sublists, split by the specified delimiter.
1426
-
1427
- Example:
1428
- >>> split_list(["a", "b", "|", "c", "d", "|", "e"], "|")
1429
- [['a', 'b'], ['c', 'd'], ['e']]
1430
- """
1431
- temp_list = []
1432
- result = []
1433
-
1434
- for value in input_list:
1435
- if value == split_value:
1436
- result.append(temp_list)
1437
- temp_list = []
1438
- else:
1439
- temp_list.append(value)
1440
- result.append(temp_list)
1441
-
1442
- return result
1443
-
1444
-
1445
- def combine_frames_into_images(frames, time_interval, max_grid_shape=(3, 3), vit_input_size=378):
1446
- """
1447
- Combines a sequence of video frames into grid-based images and generates corresponding time range labels.
1448
-
1449
- Frames are grouped and arranged into a grid (e.g., 3x3) such that each combined image contains up to
1450
- `max_grid_shape[0] * max_grid_shape[1]` frames. Each combined image is resized to the given ViT input size.
1451
-
1452
- Args:
1453
- frames (List[PIL.Image.Image]): A list of frames extracted from a video.
1454
- time_interval (float): Time interval (in seconds) between consecutive frames.
1455
- max_grid_shape (Tuple[int, int], optional): The maximum grid shape as (rows, cols). Defaults to (3, 3).
1456
- vit_input_size (int, optional): The target size (height and width) for the Vision Transformer input. Defaults to 378.
1457
-
1458
- Returns:
1459
- Tuple:
1460
- image_list (List[PIL.Image.Image]): A list of grid-combined images.
1461
- image_time_stamps (List[str]): A list of time span labels for each combined image,
1462
- e.g., ["0.00s~1.50s", "1.50s~3.00s", ...].
1463
- """
1464
- # grid_size = int(np.sqrt(max_num_grids))
1465
- # assert grid_size**2 == max_num_grids, "max_num_grids must be a perfect square."
1466
- max_num_grids = max_grid_shape[0] * max_grid_shape[1]
1467
- assert (
1468
- max_grid_shape[1] == 1
1469
- ), f"For video processing, decided to concatenate frames horizontally into a wide image."
1470
-
1471
- # List to store the resulting combined images.
1472
- image_list = []
1473
-
1474
- # Calculate the number of canvases needed.
1475
- num_frames = len(frames)
1476
- num_canvases = num_frames // max_num_grids
1477
- leftover_frames = num_frames % max_num_grids
1478
-
1479
- time_stamp = 0 # second
1480
- image_time_stamps = []
1481
-
1482
- for canvas_idx in range(num_canvases):
1483
- # Initialize the current canvas.
1484
- combined_image = Image.new(
1485
- "RGB", (vit_input_size * max_grid_shape[0], vit_input_size * max_grid_shape[1]), color=(0, 0, 0)
1486
- )
1487
-
1488
- # Determine the frames to fill in the current canvas.
1489
- start_idx = canvas_idx * max_num_grids
1490
- end_idx = min(start_idx + max_num_grids, num_frames)
1491
-
1492
- for idx in range(start_idx, end_idx):
1493
- img = frames[idx]
1494
-
1495
- # Resize each frame to a square shape.
1496
- img_resized = img.resize((vit_input_size, vit_input_size))
1497
-
1498
- # Calculate the (row, column) position to place the frame within the grid layout.
1499
- local_idx = idx - start_idx
1500
- x_offset = (local_idx % max_grid_shape[0]) * vit_input_size
1501
- y_offset = (local_idx // max_grid_shape[0]) * vit_input_size
1502
-
1503
- # Calculate the position to place the frame in the grid.
1504
- combined_image.paste(img_resized, (x_offset, y_offset))
1505
-
1506
- # Append the current canvas to the result list.
1507
- image_list.append(combined_image)
1508
- frame_cnt = end_idx - start_idx
1509
- image_time_stamps.append(f"{time_stamp:.2f}s~{time_stamp + frame_cnt * time_interval:.2f}s")
1510
- time_stamp += frame_cnt * time_interval
1511
-
1512
- if leftover_frames > 0:
1513
- # canvas_idx might be undefined; default to 0 if not previously assigned to avoid "referenced before assignment" error.
1514
- canvas_idx = num_canvases
1515
- # Add the remaining frames to the final canvas.
1516
- combined_image = Image.new("RGB", (vit_input_size * leftover_frames, vit_input_size * 1), color=(0, 0, 0))
1517
-
1518
- for idx in range(leftover_frames):
1519
- img = frames[num_canvases * max_num_grids + idx]
1520
-
1521
- # Resize the frame to a square (equal width and height).
1522
- img_resized = img.resize((vit_input_size, vit_input_size))
1523
-
1524
- # Calculate the (row, column) position to place the frame within the grid layout.
1525
- x_offset = (idx % leftover_frames) * vit_input_size
1526
- y_offset = (idx // leftover_frames) * vit_input_size
1527
-
1528
- # Calculate the position to place the frame within the grid layout.
1529
- combined_image.paste(img_resized, (x_offset, y_offset))
1530
-
1531
- # Add the current canvas to the list of combined images.
1532
- image_list.append(combined_image)
1533
- frame_cnt = leftover_frames
1534
- image_time_stamps.append(f"{time_stamp:.2f}s~{time_stamp + frame_cnt * time_interval:.2f}s")
1535
- time_stamp += frame_cnt * time_interval
1536
-
1537
- return image_list, image_time_stamps
1538
-
1539
-
1540
- def extract_frame_indices(play_time, total_frames, fps, max_num_grids, max_image_cnt, default_interval=0.4):
1541
- """
1542
- Extracts specific frame indices from a video based on duration, frame count, and sampling strategy.
1543
-
1544
- The function determines which frames to extract given the video duration (`play_time`),
1545
- total frame count, and frame rate. It samples frames at regular intervals (default: 0.4s),
1546
- but if the number of frames exceeds the limit defined by `max_num_grids * max_image_cnt`,
1547
- it performs uniform sampling to stay within that limit.
1548
-
1549
- Args:
1550
- play_time (float): Total play time of the video in seconds.
1551
- total_frames (int): Total number of frames in the video.
1552
- fps (float): Frames per second of the video.
1553
- max_num_grids (int): Maximum number of grids to display.
1554
- max_image_cnt (int): Maximum number of images per grid.
1555
- default_interval (float, optional): Interval in seconds between frame samples. Defaults to 0.4.
1556
-
1557
- Returns:
1558
- Tuple:
1559
- frame_indices (List[int]): A list of selected frame indices.
1560
- time_interval (float): Time interval between selected frames (in seconds).
1561
- """
1562
-
1563
- # Calculate how many frames to extract with the default interval
1564
- default_frame_count = int(play_time / default_interval)
1565
-
1566
- # Maximum frames allowed based on max_num_grids and max_image_cnt
1567
- max_frames_allowed = max_num_grids * max_image_cnt
1568
-
1569
- # Determine whether we can use the default interval or need uniform sampling
1570
- if default_frame_count <= max_frames_allowed:
1571
- # Default interval is sufficient, extract frames every 0.4 seconds
1572
- frame_interval = int(total_frames / default_frame_count)
1573
- else:
1574
- # Use uniform sampling to fit within max_frames_allowed
1575
- frame_interval = int(total_frames / max_frames_allowed)
1576
-
1577
- # Extract frame indices at the calculated interval
1578
- selected_indices = list(range(0, total_frames, frame_interval))
1579
-
1580
- time_interval = frame_interval / fps
1581
-
1582
- # Ensure the number of selected indices does not exceed max_frames_allowed
1583
- return selected_indices[:max_frames_allowed], time_interval