|
import math |
|
from typing import ClassVar, List, Optional, Tuple, Union |
|
|
|
import torch |
|
from PIL import Image, ImageOps |
|
from transformers import BatchFeature, LlavaNextProcessor |
|
|
|
|
|
def round_by_factor(number: float, factor: int) -> int: |
|
"""Returns the closest integer to 'number' that is divisible by 'factor'.""" |
|
return round(number / factor) * factor |
|
|
|
|
|
def ceil_by_factor(number: float, factor: int) -> int: |
|
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" |
|
return math.ceil(number / factor) * factor |
|
|
|
|
|
def floor_by_factor(number: float, factor: int) -> int: |
|
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" |
|
return math.floor(number / factor) * factor |
|
|
|
|
|
class GraniteVisionEmbProcessor(LlavaNextProcessor): |
|
""" |
|
Processor for GraniteVisionEmb. |
|
""" |
|
|
|
visual_prompt_prefix: ClassVar[str] = "<|user|>\n<image>\nDescribe the image.\n" |
|
system_message: ClassVar[ |
|
str] = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions." |
|
query_prefix: ClassVar[str] = "Query: " |
|
query_start: ClassVar[str] = "<|user|>\n" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.factor = 14 |
|
self.min_size = 384 |
|
self.max_size = 384 * 2 |
|
self.suffix_len = 10 |
|
self.patch_size = 14 |
|
|
|
@property |
|
def query_augmentation_token(self) -> str: |
|
""" |
|
Return the query augmentation token. |
|
Query augmentation buffers are used as reasoning buffers during inference. |
|
""" |
|
return self.tokenizer.pad_token |
|
|
|
@staticmethod |
|
def smart_resize_helper( |
|
width: int, |
|
height: int, |
|
factor: int, |
|
min_size: int, |
|
max_size: int |
|
) -> Tuple[int, int]: |
|
""" |
|
Returns the resized image dimensions such that: |
|
1. The smaller dimension is set to 'min_size'. |
|
2. The larger dimension is scaled proportionally to maintain aspect ratio. |
|
3. If the larger dimension exceeds 'max_size', it is clipped to 'max_size', |
|
and the smaller dimension is adjusted accordingly to maintain aspect ratio. |
|
4. Both dimensions are divisible by 'factor'. |
|
""" |
|
|
|
|
|
if height < width: |
|
scale_factor = min_size / height |
|
else: |
|
scale_factor = min_size / width |
|
|
|
new_width = round(width * scale_factor) |
|
new_height = round(height * scale_factor) |
|
|
|
|
|
if max(new_width, new_height) > max_size: |
|
clip_factor = max_size / max(new_width, new_height) |
|
new_width = round(new_width * clip_factor) |
|
new_height = round(new_height * clip_factor) |
|
|
|
|
|
|
|
|
|
|
|
return new_width, new_height |
|
|
|
@staticmethod |
|
def pad_image_center(image: Image.Image, |
|
target_width: int, |
|
target_height: int, |
|
fill_color=(0, 0, 0)) -> Image.Image: |
|
""" |
|
Pads the given image to be centered within the target dimensions. |
|
|
|
:param image: PIL Image to be padded. |
|
:param target_width: The desired width after padding. |
|
:param target_height: The desired height after padding. |
|
:param fill_color: Background color (default is black). |
|
:return: Padded image with centered content. |
|
""" |
|
|
|
|
|
img_width, img_height = image.size |
|
|
|
|
|
pad_left = (target_width - img_width) // 2 |
|
pad_top = (target_height - img_height) // 2 |
|
pad_right = target_width - img_width - pad_left |
|
pad_bottom = target_height - img_height - pad_top |
|
|
|
|
|
padded_image = ImageOps.expand(image, (pad_left, pad_top, pad_right, pad_bottom), fill_color).convert("RGB") |
|
|
|
return padded_image |
|
|
|
def smart_resize(self, image: Image.Image) -> Image.Image: |
|
""" |
|
Resize and convert the image to the required format. |
|
""" |
|
image_size = image.size |
|
resized_height, resized_width = self.smart_resize_helper( |
|
width=image_size[0], |
|
height=image_size[1], |
|
factor=self.factor, |
|
min_size=self.min_size, |
|
max_size=self.max_size |
|
) |
|
return image.convert("RGB").resize((resized_width, resized_height)) |
|
|
|
def smart_resize_and_pad(self, image: Image.Image) -> Image.Image: |
|
""" |
|
Resize and pad the image to the required format. |
|
""" |
|
return self.resize_and_pad_centered_to_long_side( |
|
image=image, |
|
factor=self.factor, |
|
min_size=self.min_size, |
|
max_size=self.max_size, |
|
fill_color=0 |
|
) |
|
|
|
def resize_and_pad_centered_to_long_side( |
|
self, |
|
image: Image.Image, |
|
factor: int, |
|
min_size: int, |
|
max_size: int, |
|
fill_color=0 |
|
) -> Image.Image: |
|
""" |
|
Resizes and pads an image such that: |
|
- The long side is set to `max_size`. |
|
- The short side is scaled proportionally but not below `min_size`. |
|
- The image is centered within the final padded area. |
|
|
|
:param image: PIL Image |
|
:param factor: Factor to make dimensions divisible by |
|
:param min_size: Minimum allowed size for the short side |
|
:param max_size: Target size for the long side |
|
:param fill_color: Background padding color (default black) |
|
:return: Resized and padded image |
|
""" |
|
|
|
|
|
width, height = image.size |
|
|
|
if min_size == -1 or max_size == -1: |
|
return image.convert("RGB") |
|
|
|
|
|
if width > height: |
|
scale_factor = max_size / width |
|
target_width = max_size |
|
max_scale_factor = max(min_size / height, scale_factor) |
|
target_height = round(height * max_scale_factor) |
|
else: |
|
scale_factor = max_size / height |
|
target_height = max_size |
|
max_scale_factor = max(min_size / width, scale_factor) |
|
target_width = round(width * max_scale_factor) |
|
|
|
|
|
resized_image = image.resize((target_width, target_height), Image.LANCZOS) |
|
final_image = resized_image.convert("RGB") |
|
|
|
return final_image |
|
|
|
def resize_and_pad_centered(self, |
|
image: Image.Image, |
|
factor: int, |
|
min_size: int, |
|
max_size: int, |
|
fill_color=0 |
|
) -> Image.Image: |
|
""" |
|
Resizes and pads an image such that: |
|
- The short side is set to `min_size`. |
|
- The long side is scaled proportionally but clipped to `max_size`. |
|
- The image is centered within the final padded area. |
|
|
|
:param image: PIL Image |
|
:param factor: Factor to make dimensions divisible by |
|
:param min_size: Minimum size for the short side |
|
:param max_size: Maximum allowed size for the long side |
|
:param fill_color: Background padding color (default black) |
|
:return: Resized and padded image |
|
""" |
|
|
|
|
|
width, height = image.size |
|
|
|
if min_size == -1 or max_size == -1: |
|
return image.convert("RGB") |
|
|
|
|
|
if width < height: |
|
scale_factor = min_size / width |
|
target_width = min_size |
|
max_scale_factor = min(max_size / height, scale_factor) |
|
target_height = round(height * max_scale_factor) |
|
else: |
|
scale_factor = min_size / height |
|
target_height = min_size |
|
max_scale_factor = min(max_size / width, scale_factor) |
|
target_width = round(width * max_scale_factor) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
resized_image = image.resize((target_width, target_height), Image.LANCZOS) |
|
|
|
|
|
if width < height: |
|
final_width, final_height = min_size, max_size |
|
else: |
|
final_width, final_height = max_size, min_size |
|
|
|
|
|
pad_left = (final_width - target_width) // 2 |
|
pad_top = (final_height - target_height) // 2 |
|
pad_right = final_width - target_width - pad_left |
|
pad_bottom = final_height - target_height - pad_top |
|
|
|
|
|
|
|
final_image = resized_image.convert("RGB") |
|
|
|
return final_image |
|
|
|
def format_data(self, question, image): |
|
return [ |
|
{ |
|
"role": "system", |
|
"content": [{"type": "text", "text": self.system_message}], |
|
}, |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "image", |
|
"image": image, |
|
}, |
|
{ |
|
"type": "text", |
|
"text": question, |
|
}, |
|
], |
|
} |
|
] |
|
|
|
def format_data_wo_role(self, question, image=None): |
|
return [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "image", |
|
"image": image, |
|
}, |
|
{ |
|
"type": "text", |
|
"text": question, |
|
}, |
|
], |
|
} |
|
] |
|
|
|
def process_images( |
|
self, |
|
images: List[Image.Image], |
|
) -> BatchFeature: |
|
""" |
|
Process images. |
|
""" |
|
|
|
texts_doc = [self.visual_prompt_prefix for _ in images] |
|
images = [self.smart_resize_and_pad(image) for image in images] |
|
|
|
batch_doc = self( |
|
text=texts_doc, |
|
images=images, |
|
return_tensors="pt", |
|
padding="longest", |
|
) |
|
return batch_doc |
|
|
|
def process_queries(self, queries, max_length=2048, suffix=None): |
|
if suffix is None: |
|
suffix = self.query_augmentation_token * self.suffix_len |
|
|
|
processed = [] |
|
for q in queries: |
|
q = self.query_start + self.query_prefix + q |
|
|
|
if len(q) + len(suffix) > max_length: |
|
q = q[: max_length - len(suffix) - 1] |
|
q += suffix + "\n" |
|
processed.append(q) |
|
|
|
return self( |
|
text=processed, |
|
images=None, |
|
return_tensors="pt", |
|
padding="longest", |
|
truncation=True, |
|
max_length=max_length, |
|
) |
|
|
|
def score( |
|
self, |
|
qs: List[torch.Tensor], |
|
ps: List[torch.Tensor], |
|
device: Optional[Union[str, torch.device]] = None, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
""" |
|
Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings. |
|
""" |
|
return self.score_multi_vector(qs, ps, device=device, **kwargs) |
|
|
|
def get_n_patches( |
|
self, |
|
image_size: Tuple[int, int], |
|
patch_size: int, |
|
) -> Tuple[int, int]: |
|
n_patches_x = self.image_processor.size["width"] // patch_size |
|
n_patches_y = self.image_processor.size["height"] // patch_size |
|
|
|
return n_patches_x, n_patches_y |
|
|
|
def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor: |
|
return batch_images.input_ids == self.image_token_id |
|
|
|
@staticmethod |
|
def score_single_vector( |
|
qs: List[torch.Tensor], |
|
ps: List[torch.Tensor], |
|
device: Optional[Union[str, torch.device]] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Compute the dot product score for the given single-vector query and passage embeddings. |
|
""" |
|
|
|
if len(qs) == 0: |
|
raise ValueError("No queries provided") |
|
if len(ps) == 0: |
|
raise ValueError("No passages provided") |
|
|
|
qs_stacked = torch.stack(qs).to(device) |
|
ps_stacked = torch.stack(ps).to(device) |
|
|
|
scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked) |
|
assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" |
|
|
|
scores = scores.to(torch.float32) |
|
return scores |
|
|
|
@staticmethod |
|
def score_multi_vector( |
|
qs: Union[torch.Tensor, List[torch.Tensor]], |
|
ps: Union[torch.Tensor, List[torch.Tensor]], |
|
batch_size: int = 128, |
|
device: Optional[Union[str, torch.device]] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector |
|
query embeddings (`qs`) and passage embeddings (`ps`). For us, a passage is the |
|
image of a document page. |
|
|
|
Because the embedding tensors are multi-vector and can thus have different shapes, they |
|
should be fed as: |
|
(1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim) |
|
(2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually |
|
obtained by padding the list of tensors. |
|
|
|
Args: |
|
qs (`Union[torch.Tensor, List[torch.Tensor]`): Query embeddings. |
|
ps (`Union[torch.Tensor, List[torch.Tensor]`): Passage embeddings. |
|
batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores. |
|
device (`Union[str, torch.device]`, *optional*): Device to use for computation. If not |
|
provided, uses `get_torch_device("auto")`. |
|
|
|
Returns: |
|
`torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score |
|
tensor is saved on the "cpu" device. |
|
""" |
|
|
|
if len(qs) == 0: |
|
raise ValueError("No queries provided") |
|
if len(ps) == 0: |
|
raise ValueError("No passages provided") |
|
|
|
scores_list: List[torch.Tensor] = [] |
|
|
|
for i in range(0, len(qs), batch_size): |
|
scores_batch = [] |
|
qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i: i + batch_size], batch_first=True, padding_value=0).to( |
|
device |
|
) |
|
for j in range(0, len(ps), batch_size): |
|
ps_batch = torch.nn.utils.rnn.pad_sequence( |
|
ps[j: j + batch_size], batch_first=True, padding_value=0 |
|
).to(device) |
|
scores_batch.append(torch.einsum("bnd,csd->bcns", qs_batch, ps_batch).max(dim=3)[0].sum(dim=2)) |
|
scores_batch = torch.cat(scores_batch, dim=1).cpu() |
|
scores_list.append(scores_batch) |
|
|
|
scores = torch.cat(scores_list, dim=0) |
|
assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}" |
|
|
|
scores = scores.to(torch.float32) |
|
return scores |
|
|