molmo-7B-O-bnb-4bit / image_preprocessing_molmo.py
ctranslate2-4you's picture
Upload 18 files
2aa3aa4 verified
raw
history blame
24.2 kB
"""Image processor class for Molmo"""
"""NOTE: This is a modified version of the original image_preprocessing_molmo.py script that removes all tensorflow-related dependencies."""
from typing import List, Optional, Union, Mapping
import numpy as np
import einops
import torch
import torchvision.transforms
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import convert_image_dtype
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ImageInput,
is_valid_image,
)
from transformers.processing_utils import ImagesKwargs
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.utils import TensorType, is_vision_available, logging
logger = logging.get_logger(__name__)
def make_batched_images(images) -> List[List[ImageInput]]:
"""
Accepts images in list or nested list format, and makes a list of images for preprocessing.
Args:
images (`Union[List[List[ImageInput]], List[ImageInput], ImageInput]`):
The input image.
Returns:
list: A list of images.
"""
if isinstance(images, (list, tuple)) and isinstance(images[0], (list, tuple)) and is_valid_image(images[0][0]):
return [img for img_list in images for img in img_list]
elif isinstance(images, (list, tuple)) and is_valid_image(images[0]):
return images
elif is_valid_image(images):
return [images]
raise ValueError(f"Could not make batched images from {images}")
def pad_to_bounding_box(
image, offset_height, offset_width, target_height,
target_width, value=0
):
height, width = image.shape[:2]
after_padding_width = target_width - offset_width - width
after_padding_height = target_height - offset_height - height
return np.pad(image, [
[offset_height, after_padding_height],
[offset_width, after_padding_width],
[0, 0]
], constant_values=value)
def normalize_image(image, offset, scale):
image -= np.array(offset, dtype=np.float32)[None, None, :]
image /= np.array(scale, dtype=np.float32)[None, None, :]
return image
def resize_and_pad(
image: np.ndarray,
desired_output_size: List[int],
resize_method: str = "bilinear",
pad_value: float = 0,
normalize: bool = True,
image_mean: Optional[List[float]] = OPENAI_CLIP_MEAN,
image_std: Optional[List[float]] = OPENAI_CLIP_STD,
) -> (np.ndarray, np.ndarray):
"""
Resize and pad the image to the desired output size.
Args:
image (np.ndarray): Input image as a NumPy array.
desired_output_size (List[int]): Desired output size as [height, width].
resize_method (str, optional): Resize interpolation method. Defaults to "bilinear".
pad_value (float, optional): Padding value. Defaults to 0.
normalize (bool, optional): Whether to normalize the image. Defaults to True.
image_mean (Optional[List[float]], optional): Mean for normalization. Defaults to OPENAI_CLIP_MEAN.
image_std (Optional[List[float]], optional): Standard deviation for normalization. Defaults to OPENAI_CLIP_STD.
Returns:
Tuple[np.ndarray, np.ndarray]: Resized and padded image, and image mask.
"""
desired_height, desired_width = desired_output_size
height, width = image.shape[:2]
# Calculate scaling factors and determine the scaling factor to maintain aspect ratio
scale_y = desired_height / height
scale_x = desired_width / width
scale = min(scale_x, scale_y)
scaled_height = int(height * scale)
scaled_width = int(width * scale)
# Convert the image to a PyTorch tensor and normalize to [0, 1]
image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
# Define the interpolation mode
if resize_method.lower() == "bilinear":
interpolation = InterpolationMode.BILINEAR
elif resize_method.lower() == "nearest":
interpolation = InterpolationMode.NEAREST
elif resize_method.lower() == "bicubic":
interpolation = InterpolationMode.BICUBIC
elif resize_method.lower() == "lanczos":
interpolation = InterpolationMode.LANCZOS
else:
raise ValueError(f"Unsupported resize method: {resize_method}")
# Resize the image
resized_image = torchvision.transforms.Resize(
[scaled_height, scaled_width],
interpolation=interpolation,
antialias=True
)(image_tensor)
# Clip the image to ensure values are within [0, 1]
resized_image = torch.clamp(resized_image, 0.0, 1.0)
# Convert back to NumPy
resized_image_np = resized_image.permute(1, 2, 0).numpy()
# Calculate padding
top_pad = (desired_height - scaled_height) // 2
bottom_pad = desired_height - scaled_height - top_pad
left_pad = (desired_width - scaled_width) // 2
right_pad = desired_width - scaled_width - left_pad
# Pad the image using NumPy
padded_image = np.pad(
resized_image_np,
pad_width=((top_pad, bottom_pad), (left_pad, right_pad), (0, 0)),
mode='constant',
constant_values=pad_value
)
# Create the image mask
image_mask = np.pad(
np.ones((scaled_height, scaled_width), dtype=bool),
pad_width=((top_pad, bottom_pad), (left_pad, right_pad)),
mode='constant',
constant_values=False
)
if normalize:
padded_image = normalize_image(padded_image, offset=image_mean, scale=image_std)
return padded_image, image_mask
def select_tiling(h, w, patch_size, max_num_patches):
"""Decide how best to divide in image of size [w, h] in up to max_num_patches of size patch_size"""
original_size = np.stack([h, w]) # [1, 2]
original_res = h * w
tilings = []
for i in range(1, max_num_patches+1):
for j in range(1, max_num_patches+1):
if i*j <= max_num_patches:
tilings.append((i, j))
# sort so argmin and argmax favour smaller tilings in the event of a tie
tilings.sort(key=lambda x: (x[0]*x[1], x[0]))
candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
# How much we would need to scale the image to fit exactly in each tiling
original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
required_scale_d = candidate_resolutions.astype(np.float32) / original_size
required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
if np.all(required_scale < 1):
# We are forced to downscale, so try to minimize the amount of downscaling
ix = np.argmax(required_scale)
else:
# Pick the resolution that required the least upscaling so that it most closely fits the image
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
ix = np.argmin(required_scale)
return candidate_tilings[ix]
class MolmoImagesKwargs(ImagesKwargs, total=False):
max_crops: Optional[int]
overlap_margins: Optional[List[int]]
base_image_input_size: Optional[List[int]]
image_token_length_w: Optional[int]
image_token_length_h: Optional[int]
image_patch_size: Optional[int]
image_padding_mask: Optional[bool]
class MolmoImageProcessor(BaseImageProcessor):
"""Preprocess images and multi-model inputs"""
def __init__(
self,
max_crops: int = 12,
overlap_margins: List[int] = (4, 4),
base_image_input_size: List[int] = (336, 336),
image_token_length_w: int = 12,
image_token_length_h: int = 12,
image_patch_size: int = 14,
image_padding_mask: bool = True,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
**kwargs,
):
super().__init__(**kwargs)
self.max_crops = max_crops
self.overlap_margins = overlap_margins
self.base_image_input_size = base_image_input_size
self.image_token_length_w = image_token_length_w
self.image_token_length_h = image_token_length_h
self.image_patch_size = image_patch_size
self.image_padding_mask = image_padding_mask
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
def image_to_patches_and_tokens(
self,
image: ImageInput,
image_patch_token_id: int,
image_col_token_id: int,
image_start_token_id: int,
image_end_token_id: int,
max_crops: Optional[int] = None,
overlap_margins: Optional[List[int]] = None,
base_image_input_size: Optional[Union[int, List[int]]] = None,
image_token_length_w: Optional[int] = None,
image_token_length_h: Optional[int] = None,
image_patch_size: Optional[int] = None,
):
"""Preprocesses an image
Returns:
crops: (n_crops, n_patches, patch_dim) individual crops, `n_crops` might
change between images but the other dimension are fixed
tokens: (n_tokens,) int32 tokens, pad tokens indicating where to insert the
patch features, might include other special tokens as well
patch_ordering: (n_crops, n_tokens_per_crop) order image features should be inserted
into the `tokens`, negative values indicates patches features to exclude
padding_mask: (n_crops, n_patches) what percent of each crop is padding, be None
if the image mask is not being used.
"""
if isinstance(base_image_input_size, int):
base_image_input_size = (base_image_input_size, base_image_input_size)
base_image_input_d = image_patch_size
tokens_per_image = image_token_length_w * image_token_length_h
image_base_patch_w = base_image_input_size[1] // base_image_input_d
image_base_patch_h = base_image_input_size[0] // base_image_input_d
original_image_h, original_image_w = image.shape[:2]
crop_size = base_image_input_size[0]
# Discard this many patches from the (left/top, right/bottom) of crops
left_margin, right_margin = overlap_margins
# left_margin, right_margin = 2, 2
assert left_margin % 2 == 0 # Required for compatibility with 2x2 pooling
total_margin_pixels = base_image_input_d*(right_margin + left_margin) # pixels removed per dim
crop_patches = base_image_input_size[0] // base_image_input_d # patches per crop dim
crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
crop_window_size = crop_window_patches * base_image_input_d
tiling = select_tiling(
original_image_h - total_margin_pixels,
original_image_w - total_margin_pixels,
crop_window_size,
max_crops
)
src, img_mask = resize_and_pad(
image,
[tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels]
)
# Now we have to split the image into crops, while keeping track of how each patch in the
# each crop should be ordered in the global image, this require a lot of tricky booking
n_crops = tiling[0] * tiling[1]
patches_arr = []
mask_arr = []
patch_ordering_arr = []
# We assume 2x2 pooling, but can allow padding the right/bottom with extra
# patches if the number of patches per side is not even
assert (crop_patches+1)//2 == image_token_length_h
assert (crop_patches+1)//2 == image_token_length_w
on = 0
on_patch = 0
for i in range(tiling[0]):
y0 = i*crop_window_size
if i == 0:
crop_y0 = 0
else:
crop_y0 = left_margin // 2
crop_h = image_base_patch_h - (right_margin + left_margin)
if i == 0:
crop_h += left_margin
if i == (tiling[0]-1):
crop_h += right_margin
for j in range(tiling[1]):
x0 = j*crop_window_size
if j == 0:
crop_x0 = 0
else:
crop_x0 = left_margin // 2
crop_w = image_base_patch_w - (right_margin + left_margin)
if j == 0:
crop_w += left_margin
if j == (tiling[1]-1):
crop_w += right_margin
pooled_w = (crop_w + 1) // 2
pooled_h = (crop_h + 1) // 2
patch_ordering_arr.append(
pad_to_bounding_box(
np.reshape(np.arange(on, on+pooled_h*pooled_w, dtype=np.int32), (pooled_h, pooled_w, 1)),
crop_y0, crop_x0, image_token_length_h, image_token_length_w, value=-1
)[:, :, 0]
)
patches_arr.append(src[y0:y0+crop_size, x0:x0+crop_size])
mask_arr.append(img_mask[y0:y0+crop_size, x0:x0+crop_size])
on += pooled_h*pooled_w
on_patch += 1
patches = np.stack(patches_arr)
patch_ordering = np.stack(patch_ordering_arr)
img_mask = np.stack(mask_arr)
# Switch to [n_crops, n_patches, pixels_per_patch] format
image_layout_impatch_w, image_layout_impatch_h = tiling[0], tiling[1]
patches = einops.rearrange(
patches, 'p (h dh) (w dw) c -> p (h w) (dh dw c)',
dh=base_image_input_d,
dw=base_image_input_d,
h=image_base_patch_h,
w=image_base_patch_w
)
img_mask = einops.rearrange(
img_mask, 'p (h dh) (w dw) -> p (h w) (dh dw)',
dh=base_image_input_d,
dw=base_image_input_d,
h=image_base_patch_h,
w=image_base_patch_w
)
img_mask = img_mask.astype(np.float32).mean(axis=-1)
patch_ordering = np.reshape(patch_ordering, [-1])
valid = patch_ordering >= 0
# Transpose order, to get left-to-right order instead of crop-by-crop order
patch_ordering_rh = np.reshape(
patch_ordering,
[tiling[0], tiling[1], image_token_length_h, image_token_length_w]
)
patch_ordering_rh = np.transpose(patch_ordering_rh, [0, 2, 1, 3])
patch_ordering_rh = np.reshape(patch_ordering_rh, [-1])
# The transpose will screw up which patches are masked, project the
# new order into sparse structure of `patch_ordering` to fix this
patch_ordering[valid] = patch_ordering_rh[patch_ordering_rh >= 0]
# Now build the output tokens
h = tiling[0] * crop_window_patches + (right_margin+left_margin)
w = tiling[1] * crop_window_patches + (right_margin+left_margin)
per_row = np.full(
((w+1)//2,),
image_patch_token_id,
)
per_row = np.concatenate([per_row, [image_col_token_id]], 0)
joint = np.tile(per_row, [(h+1)//2])
joint = [
[image_start_token_id],
joint,
[image_end_token_id]
]
# Finally do the same for the global image
resized, _ = resize_and_pad(image, base_image_input_size)
resized = einops.rearrange(
resized, '(h dh) (w dw) c -> (h w) (dh dw c)',
dh=base_image_input_d,
dw=base_image_input_d,
h=image_base_patch_h,
w=image_base_patch_w
)
patches = np.concatenate([np.expand_dims(resized, 0), patches], 0)
# Global image goes first, so the order of patches in previous crops gets increased
patch_ordering = np.where(
patch_ordering >= 0,
patch_ordering + tokens_per_image,
-1
)
patch_ordering = np.concatenate([np.arange(0, tokens_per_image), patch_ordering], 0)
per_row = np.full(
(image_token_length_w,),
image_patch_token_id,
)
per_row = np.concatenate([per_row, [image_col_token_id]], 0)
extra_tokens = np.tile(per_row, [image_token_length_h])
joint = [
[image_start_token_id],
extra_tokens,
[image_end_token_id],
] + joint
joint = np.concatenate(joint, 0)
img_mask = np.pad(img_mask, [[0, 1], [0, 0]], constant_values=-1)
return patches, joint, patch_ordering, img_mask
def build_image_input_idx(
self,
image_tokens: np.ndarray,
patch_order: np.ndarray,
image_patch_token_id: int,
no_image: Optional[bool] = None,
image_token_length_w: Optional[int] = None,
image_token_length_h: Optional[int] = None,
):
"""Converts `patch_order` into a mapping of token_id -> patch_id"""
tokens_per_image = image_token_length_w * image_token_length_h
if no_image is not None and no_image:
return np.zeros((0, tokens_per_image), np.int32)
# Indices to insert the patches
image_input_idx = image_tokens == image_patch_token_id
image_input_idx = np.nonzero(image_input_idx)[0].astype(np.int32)
if patch_order is not None:
n_tokens = image_input_idx.shape[0]
patch_order = np.reshape(patch_order, [-1])
n_patches = patch_order.shape[0]
valid = patch_order >= 0
n_valid_patches = valid.sum()
assert len(image_input_idx) == n_valid_patches
sorted_patch_ixs = np.zeros([n_tokens], np.int32)
sorted_patch_ixs[patch_order[valid]] = np.arange(n_valid_patches, dtype=np.int32)
# Project the inverted mapping into same sparse structure
sorted_patch_ixs_ex = np.full(np.shape(patch_order), -1)
sorted_patch_ixs_ex[valid] = sorted_patch_ixs
# Do the gather and then re-masked outputs that were masked in `sorted_patch_ixs`
valid = (sorted_patch_ixs_ex >= 0).astype(np.int32)
image_input_idx = image_input_idx[sorted_patch_ixs_ex*valid]
image_input_idx = image_input_idx*valid - 100*(1 - valid)
image_input_idx = np.reshape(image_input_idx, [-1, tokens_per_image])
return image_input_idx
def preprocess(
self,
image: np.ndarray,
image_patch_token_id: int,
image_col_token_id: int,
image_start_token_id: int,
image_end_token_id: int,
max_crops: Optional[int] = None,
overlap_margins: Optional[List[int]] = None,
base_image_input_size: Optional[Union[int, List[int]]] = None,
image_token_length_w: Optional[int] = None,
image_token_length_h: Optional[int] = None,
image_patch_size: Optional[int] = None,
**kwargs,
):
"""Preprocesses a single image"""
max_crops = max_crops or self.max_crops
overlap_margins = overlap_margins or self.overlap_margins
base_image_input_size = base_image_input_size or self.base_image_input_size
image_token_length_w = image_token_length_w or self.image_token_length_w
image_token_length_h = image_token_length_h or self.image_token_length_h
image_patch_size = image_patch_size or self.image_patch_size
crops, image_tokens, patch_ordering, img_mask = self.image_to_patches_and_tokens(
image,
image_patch_token_id,
image_col_token_id,
image_start_token_id,
image_end_token_id,
max_crops,
overlap_margins,
base_image_input_size,
image_token_length_w,
image_token_length_h,
image_patch_size,
)
patch_idx = self.build_image_input_idx(
image_tokens,
patch_ordering,
image_patch_token_id,
image_token_length_w=image_token_length_w,
image_token_length_h=image_token_length_h,
)
return crops, image_tokens, patch_idx, img_mask
def multimodal_preprocess(
self,
images: np.ndarray,
tokens: List[int],
image_idx: np.ndarray,
sequence_length: int,
image_patch_token_id: int,
image_col_token_id: int,
image_start_token_id: int,
image_end_token_id: int,
**kwargs,
):
"""Merge images and text tokens into multi-modal features for the model
:param images: images to use as input
:param tokens: input text tokens
:param image_idx: where to insert the images into `tokens`
:params image_patch_token_id: id to use of tokens that will contain image features
:params image_col_token_id: token id for image column special tokens
:params image_start_token_id: token id for image start special tokens
:params image_end_token_id: token id for image end special tokens
:params kwargs: override preprocessor default args
"""
max_total_crops = kwargs.get("max_crops") or self.max_crops
image_token_length_w = kwargs.get("image_token_length_w") or self.image_token_length_w
image_token_length_h = kwargs.get("image_token_length_h") or self.image_token_length_h
image_patch_size = kwargs.get("image_patch_size") or self.image_patch_size
base_image_input_size = kwargs.get("base_image_input_size") or self.base_image_input_size
image_num_patch = (
base_image_input_size[0] // image_patch_size,
base_image_input_size[1] // image_patch_size,
)
image_padding_mask = kwargs.get("image_padding_mask") or self.image_padding_mask
tokens_per_image = image_token_length_w * image_token_length_h
n_pixels = image_patch_size * image_patch_size * 3
n_patches = image_num_patch[0] * image_num_patch[1]
if images is None:
return {
"input_ids": tokens,
"images": None,
"image_input_idx": None
}
else:
n = len(images)
all_crops = []
all_image_idx = []
out_tokens = []
all_crop_masks = []
for ix in range(n):
token_ix = image_idx[ix]
crops, image_tokens, patch_idx, img_mask = self.preprocess(
images[ix],
image_patch_token_id,
image_col_token_id,
image_start_token_id,
image_end_token_id,
**kwargs,
)
if token_ix == -1: # -1 is an image inserted at the very start
start = 0
token_ix = 0
end = 0
else:
start = 0 if ix == 0 else image_idx[ix-1] + 1
end = token_ix + 1
all_image_idx.append(patch_idx + token_ix)
all_crops.append(crops)
out_tokens.append(tokens[start:token_ix])
out_tokens.append(image_tokens)
if ix == (n - 1):
out_tokens.append(tokens[end:])
if image_padding_mask:
all_crop_masks.append(img_mask)
input_ids = np.concatenate(out_tokens, 0)
images = np.concatenate(all_crops, 0)
image_input_idx = np.concatenate(all_image_idx, 0)
if image_padding_mask:
image_masks = np.concatenate(all_crop_masks, 0)
else:
image_masks = None
out = {
"input_ids": input_ids,
"images": images,
"image_input_idx": image_input_idx
}
if image_masks is not None:
out["image_masks"] = image_masks
return out
MolmoImageProcessor.register_for_auto_class()