izhx commited on
Commit
facd32f
·
verified ·
1 Parent(s): e93daf4

revert gme_inference.py

Browse files
Files changed (1) hide show
  1. gme_inference.py +108 -121
gme_inference.py CHANGED
@@ -1,70 +1,44 @@
1
  from __future__ import annotations
2
 
3
- import base64
4
  import logging
5
  import math
6
  import os
7
- from io import BytesIO
8
- from typing import Any, Dict, List, Optional, Union
9
 
10
- import requests
11
  import torch
12
  from PIL import Image
13
  from torch.utils.data import DataLoader
14
  from tqdm.autonotebook import tqdm
15
- from transformers import (
16
- AutoConfig,
17
- AutoModel,
18
- AutoModelForVision2Seq,
19
- AutoProcessor,
20
- PreTrainedModel,
21
- Qwen2VLConfig,
22
- Qwen2VLForConditionalGeneration,
23
- )
24
- import os
25
- from collections.abc import Iterable
26
 
27
 
28
- class GmeQwen2VLConfig(Qwen2VLConfig):
29
- model_type: str = "gme_qwen2_vl"
30
-
31
  def __init__(
32
  self,
33
- min_image_tokens: int = 256,
34
- max_image_tokens: int = 1280,
35
- max_length: int = 1800,
36
  device: str = "cuda" if torch.cuda.is_available() else "cpu",
37
- **kwargs: Any,
 
 
 
38
  ) -> None:
39
- super().__init__(**kwargs)
40
- self.min_image_tokens = min_image_tokens
41
- self.max_image_tokens = max_image_tokens
42
- self.max_length = max_length
43
-
44
-
45
- class GmeQwen2VLForVision2Seq(PreTrainedModel):
46
- config_class = GmeQwen2VLConfig
47
- base_model_prefix: str = "base"
48
-
49
- def __init__(self, config: GmeQwen2VLConfig, **kwargs: Any) -> None:
50
- super().__init__(config)
51
- model_name: str = getattr(
52
- config, "_name_or_path", "Alibaba-NLP/gme-Qwen2-VL-2B-Instruct"
53
  )
54
-
55
- self.base = Qwen2VLForConditionalGeneration(config)
56
- self.normalize: bool = True
57
-
58
- min_pixels: int = config.min_image_tokens * 28 * 28
59
- max_pixels: int = config.max_image_tokens * 28 * 28
60
-
61
- self.max_length: int = config.max_length
62
  self.processor = AutoProcessor.from_pretrained(
63
  model_name, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
64
  )
65
- self.processor.tokenizer.padding_side = "right"
66
- self.defualt_instruction: str = "You are a helpful assistant."
67
- self.sep: str = " "
68
 
69
  def forward(
70
  self,
@@ -78,15 +52,13 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
78
  image_grid_thw: Optional[torch.LongTensor] = None,
79
  # video_grid_thw: Optional[torch.LongTensor] = None,
80
  pooling_mask: Optional[torch.LongTensor] = None,
81
- **kwargs,
82
  ) -> torch.Tensor:
83
  if inputs_embeds is None:
84
  inputs_embeds = self.base.model.embed_tokens(input_ids)
85
  if pixel_values is not None:
86
  pixel_values = pixel_values.type(self.base.visual.get_dtype())
87
- image_embeds = self.base.visual(
88
- pixel_values, grid_thw=image_grid_thw
89
- ).to(inputs_embeds.device)
90
  image_mask = input_ids == self.base.config.image_token_id
91
  inputs_embeds[image_mask] = image_embeds
92
  # if pixel_values_videos is not None:
@@ -106,44 +78,36 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
106
  )
107
 
108
  pooling_mask = attention_mask if pooling_mask is None else pooling_mask
109
- left_padding = pooling_mask[:, -1].sum() == pooling_mask.shape[0] # TODO
110
  if left_padding:
111
  embeddings = outputs.last_hidden_state[:, -1]
112
  else:
113
  sequence_lengths = pooling_mask.sum(dim=1) - 1
114
  batch_size = outputs.last_hidden_state.shape[0]
115
- embeddings = outputs.last_hidden_state[
116
- torch.arange(batch_size, device=outputs.last_hidden_state.device),
117
- sequence_lengths,
118
- ]
119
  if self.normalize:
120
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
121
  return embeddings.contiguous()
122
 
123
- def embed(
124
- self,
125
- texts: list[str],
126
- images: list[Image.Image],
127
- is_query=True,
128
- instruction=None,
129
- **kwargs,
130
- ):
131
  self.base.to(self.device)
132
  # Inputs must be batched
133
  input_texts, input_images = list(), list()
134
  for t, i in zip(texts, images):
135
  if not is_query or instruction is None:
136
- instruction = self.defualt_instruction
137
- input_str = ""
138
  if i is None:
139
  input_images = None # All examples in the same batch are consistent
140
  else:
141
- input_str += "<|vision_start|><|image_pad|><|vision_end|>"
142
  i = fetch_image(i)
143
  input_images.append(i)
144
  if t is not None:
145
  input_str += t
146
- msg = f"<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>"
147
  input_texts.append(msg)
148
 
149
  inputs = self.processor(
@@ -152,7 +116,7 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
152
  padding=True,
153
  truncation=True,
154
  max_length=self.max_length,
155
- return_tensors="pt",
156
  )
157
  inputs = {k: v.to(self.device) for k, v in inputs.items()} # TODO
158
  with torch.no_grad():
@@ -160,9 +124,7 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
160
  return embeddings
161
 
162
  def encode(self, sentences: list[str], *, prompt_name=None, **kwargs):
163
- return self.get_fused_embeddings(
164
- texts=sentences, prompt_name=prompt_name, **kwargs
165
- )
166
 
167
  def encode_queries(self, queries: List[str], **kwargs):
168
  embeddings = self.encode(queries, **kwargs)
@@ -178,9 +140,7 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
178
  ]
179
  else:
180
  sentences = [
181
- (doc["title"] + self.sep + doc["text"]).strip()
182
- if "title" in doc
183
- else doc["text"].strip()
184
  for doc in corpus
185
  ]
186
  embeddings = self.encode(sentences, is_query=False, **kwargs)
@@ -192,18 +152,13 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
192
  def get_text_embeddings(self, texts: list[str], **kwargs):
193
  return self.get_fused_embeddings(texts=texts, **kwargs)
194
 
195
- def get_fused_embeddings(
196
- self,
197
- texts: list[str] = None,
198
- images: list[Image.Image] | DataLoader = None,
199
- **kwargs,
200
- ):
201
  if isinstance(images, DataLoader):
202
  image_loader = images
203
  batch_size = image_loader.batch_size
204
  image_loader.dataset.transform = None
205
  else:
206
- batch_size = kwargs.pop("batch_size", 32)
207
  if images is None:
208
  image_loader = None
209
  else:
@@ -224,18 +179,10 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
224
 
225
  all_embeddings = list()
226
  none_batch = [None] * batch_size
227
- show_progress_bar = kwargs.pop("show_progress_bar", True)
228
- pbar = tqdm(
229
- total=n_batch,
230
- disable=not show_progress_bar,
231
- mininterval=1,
232
- miniters=10,
233
- desc="encode",
234
- )
235
- for n, img_batch in zip(
236
- range(0, n_batch * batch_size, batch_size), image_loader
237
- ):
238
- text_batch = none_batch if texts is None else texts[n : n + batch_size]
239
  img_batch = none_batch if img_batch is None else img_batch
240
  embeddings = self.embed(texts=text_batch, images=img_batch, **kwargs)
241
  pbar.update(1)
@@ -249,11 +196,15 @@ def custom_collate_fn(batch):
249
  return batch
250
 
251
 
252
- # Utility functions (copied from your vision processing code)
253
- IMAGE_FACTOR: int = 28
254
- MIN_PIXELS: int = 4 * 28 * 28
255
- MAX_PIXELS: int = 16384 * 28 * 28
256
- MAX_RATIO: int = 200
 
 
 
 
257
 
258
 
259
  def round_by_factor(number: int, factor: int) -> int:
@@ -272,17 +223,16 @@ def floor_by_factor(number: int, factor: int) -> int:
272
 
273
 
274
  def smart_resize(
275
- height: int,
276
- width: int,
277
- factor: int = IMAGE_FACTOR,
278
- min_pixels: int = MIN_PIXELS,
279
- max_pixels: int = MAX_PIXELS,
280
  ) -> tuple[int, int]:
281
  """
282
- Rescales the image so that:
283
- 1. Both dimensions are divisible by 'factor'.
284
- 2. Total pixels fall between ['min_pixels', 'max_pixels'].
285
- 3. Aspect ratio is maintained as closely as possible.
 
 
 
286
  """
287
  h_bar = max(factor, round_by_factor(height, factor))
288
  w_bar = max(factor, round_by_factor(width, factor))
@@ -306,31 +256,35 @@ def smart_resize(
306
  return h_bar, w_bar
307
 
308
 
309
- def fetch_image(
310
- image: Union[str, Image.Image], size_factor: int = IMAGE_FACTOR
311
- ) -> Image.Image:
312
- image_obj: Optional[Image.Image] = None
313
  if isinstance(image, Image.Image):
314
  image_obj = image
315
- elif isinstance(image, str) and (
316
- image.startswith("http://") or image.startswith("https://")
317
- ):
318
  image_obj = Image.open(requests.get(image, stream=True).raw)
319
- elif isinstance(image, str) and image.startswith("file://"):
320
  image_obj = Image.open(image[7:])
321
- elif isinstance(image, str) and image.startswith("data:image"):
322
  if "base64," in image:
323
  _, base64_data = image.split("base64,", 1)
324
  data = base64.b64decode(base64_data)
325
  image_obj = Image.open(BytesIO(data))
326
- elif isinstance(image, str):
327
  image_obj = Image.open(image)
328
  if image_obj is None:
329
- raise ValueError(
330
- f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
331
- )
332
  image = image_obj.convert("RGB")
 
 
 
 
 
 
 
 
333
  width, height = image.size
 
 
334
  resized_height, resized_width = smart_resize(
335
  height,
336
  width,
@@ -339,4 +293,37 @@ def fetch_image(
339
  max_pixels=MAX_PIXELS,
340
  )
341
  image = image.resize((resized_width, resized_height))
 
342
  return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
 
3
  import logging
4
  import math
5
  import os
6
+ from typing import Dict, List, Optional
 
7
 
 
8
  import torch
9
  from PIL import Image
10
  from torch.utils.data import DataLoader
11
  from tqdm.autonotebook import tqdm
12
+ from transformers import AutoModelForVision2Seq, AutoProcessor
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
+ class GmeQwen2VL:
 
 
16
  def __init__(
17
  self,
18
+ model_name: str = "Alibaba-NLP/gme-Qwen2-VL-2B-Instruct",
19
+ model_path: Optional[str] = None,
 
20
  device: str = "cuda" if torch.cuda.is_available() else "cpu",
21
+ min_image_tokens=256,
22
+ max_image_tokens=1280,
23
+ max_length=1800,
24
+ **kwargs,
25
  ) -> None:
26
+ model_name = model_path or model_name
27
+ self.base = AutoModelForVision2Seq.from_pretrained(
28
+ model_name, torch_dtype=torch.float16, **kwargs
 
 
 
 
 
 
 
 
 
 
 
29
  )
30
+ self.base.eval()
31
+ self.normalize = True
32
+ self.device = device
33
+ min_pixels = min_image_tokens * 28 * 28
34
+ max_pixels = max_image_tokens * 28 * 28
35
+ self.max_length = max_length
 
 
36
  self.processor = AutoProcessor.from_pretrained(
37
  model_name, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
38
  )
39
+ self.processor.tokenizer.padding_side = 'right'
40
+ self.default_instruction = 'You are a helpful assistant.'
41
+ self.sep = ' '
42
 
43
  def forward(
44
  self,
 
52
  image_grid_thw: Optional[torch.LongTensor] = None,
53
  # video_grid_thw: Optional[torch.LongTensor] = None,
54
  pooling_mask: Optional[torch.LongTensor] = None,
55
+ **kwargs
56
  ) -> torch.Tensor:
57
  if inputs_embeds is None:
58
  inputs_embeds = self.base.model.embed_tokens(input_ids)
59
  if pixel_values is not None:
60
  pixel_values = pixel_values.type(self.base.visual.get_dtype())
61
+ image_embeds = self.base.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
 
 
62
  image_mask = input_ids == self.base.config.image_token_id
63
  inputs_embeds[image_mask] = image_embeds
64
  # if pixel_values_videos is not None:
 
78
  )
79
 
80
  pooling_mask = attention_mask if pooling_mask is None else pooling_mask
81
+ left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) # TODO
82
  if left_padding:
83
  embeddings = outputs.last_hidden_state[:, -1]
84
  else:
85
  sequence_lengths = pooling_mask.sum(dim=1) - 1
86
  batch_size = outputs.last_hidden_state.shape[0]
87
+ embeddings = outputs.last_hidden_state[torch.arange(
88
+ batch_size, device=outputs.last_hidden_state.device
89
+ ), sequence_lengths]
 
90
  if self.normalize:
91
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
92
  return embeddings.contiguous()
93
 
94
+ def embed(self, texts: list[str], images: list[Image.Image], is_query=True, instruction=None, **kwargs):
 
 
 
 
 
 
 
95
  self.base.to(self.device)
96
  # Inputs must be batched
97
  input_texts, input_images = list(), list()
98
  for t, i in zip(texts, images):
99
  if not is_query or instruction is None:
100
+ instruction = self.default_instruction
101
+ input_str = ''
102
  if i is None:
103
  input_images = None # All examples in the same batch are consistent
104
  else:
105
+ input_str += '<|vision_start|><|image_pad|><|vision_end|>'
106
  i = fetch_image(i)
107
  input_images.append(i)
108
  if t is not None:
109
  input_str += t
110
+ msg = f'<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>'
111
  input_texts.append(msg)
112
 
113
  inputs = self.processor(
 
116
  padding=True,
117
  truncation=True,
118
  max_length=self.max_length,
119
+ return_tensors='pt'
120
  )
121
  inputs = {k: v.to(self.device) for k, v in inputs.items()} # TODO
122
  with torch.no_grad():
 
124
  return embeddings
125
 
126
  def encode(self, sentences: list[str], *, prompt_name=None, **kwargs):
127
+ return self.get_fused_embeddings(texts=sentences, prompt_name=prompt_name, **kwargs)
 
 
128
 
129
  def encode_queries(self, queries: List[str], **kwargs):
130
  embeddings = self.encode(queries, **kwargs)
 
140
  ]
141
  else:
142
  sentences = [
143
+ (doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
 
 
144
  for doc in corpus
145
  ]
146
  embeddings = self.encode(sentences, is_query=False, **kwargs)
 
152
  def get_text_embeddings(self, texts: list[str], **kwargs):
153
  return self.get_fused_embeddings(texts=texts, **kwargs)
154
 
155
+ def get_fused_embeddings(self, texts: list[str] = None, images: list[Image.Image] | DataLoader = None, **kwargs):
 
 
 
 
 
156
  if isinstance(images, DataLoader):
157
  image_loader = images
158
  batch_size = image_loader.batch_size
159
  image_loader.dataset.transform = None
160
  else:
161
+ batch_size = kwargs.pop('batch_size', 32)
162
  if images is None:
163
  image_loader = None
164
  else:
 
179
 
180
  all_embeddings = list()
181
  none_batch = [None] * batch_size
182
+ show_progress_bar = kwargs.pop('show_progress_bar', True)
183
+ pbar = tqdm(total=n_batch, disable=not show_progress_bar, mininterval=1, miniters=10, desc='encode')
184
+ for n, img_batch in zip(range(0, n_batch * batch_size, batch_size), image_loader):
185
+ text_batch = none_batch if texts is None else texts[n: n+batch_size]
 
 
 
 
 
 
 
 
186
  img_batch = none_batch if img_batch is None else img_batch
187
  embeddings = self.embed(texts=text_batch, images=img_batch, **kwargs)
188
  pbar.update(1)
 
196
  return batch
197
 
198
 
199
+ ### Copied from qwen_vl_utils.vision_process.py
200
+ import base64
201
+ from io import BytesIO
202
+ import requests
203
+
204
+ IMAGE_FACTOR = 28
205
+ MIN_PIXELS = 4 * 28 * 28
206
+ MAX_PIXELS = 16384 * 28 * 28
207
+ MAX_RATIO = 200
208
 
209
 
210
  def round_by_factor(number: int, factor: int) -> int:
 
223
 
224
 
225
  def smart_resize(
226
+ height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
 
 
 
 
227
  ) -> tuple[int, int]:
228
  """
229
+ Rescales the image so that the following conditions are met:
230
+
231
+ 1. Both dimensions (height and width) are divisible by 'factor'.
232
+
233
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
234
+
235
+ 3. The aspect ratio of the image is maintained as closely as possible.
236
  """
237
  h_bar = max(factor, round_by_factor(height, factor))
238
  w_bar = max(factor, round_by_factor(width, factor))
 
256
  return h_bar, w_bar
257
 
258
 
259
+ def fetch_image(image: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> Image.Image:
260
+ image_obj = None
 
 
261
  if isinstance(image, Image.Image):
262
  image_obj = image
263
+ elif image.startswith("http://") or image.startswith("https://"):
 
 
264
  image_obj = Image.open(requests.get(image, stream=True).raw)
265
+ elif image.startswith("file://"):
266
  image_obj = Image.open(image[7:])
267
+ elif image.startswith("data:image"):
268
  if "base64," in image:
269
  _, base64_data = image.split("base64,", 1)
270
  data = base64.b64decode(base64_data)
271
  image_obj = Image.open(BytesIO(data))
272
+ else:
273
  image_obj = Image.open(image)
274
  if image_obj is None:
275
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
 
 
276
  image = image_obj.convert("RGB")
277
+ ## resize
278
+ # if "resized_height" in ele and "resized_width" in ele:
279
+ # resized_height, resized_width = smart_resize(
280
+ # ele["resized_height"],
281
+ # ele["resized_width"],
282
+ # factor=size_factor,
283
+ # )
284
+ # else:
285
  width, height = image.size
286
+ # min_pixels = ele.get("min_pixels", MIN_PIXELS)
287
+ # max_pixels = ele.get("max_pixels", MAX_PIXELS)
288
  resized_height, resized_width = smart_resize(
289
  height,
290
  width,
 
293
  max_pixels=MAX_PIXELS,
294
  )
295
  image = image.resize((resized_width, resized_height))
296
+
297
  return image
298
+ ###
299
+
300
+
301
+ if __name__ == '__main__':
302
+ texts = [
303
+ "What kind of car is this?",
304
+ "The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023."
305
+ ]
306
+ images = [
307
+ 'https://upload.wikimedia.org/wikipedia/commons/e/e9/Tesla_Cybertruck_damaged_window.jpg',
308
+ 'https://upload.wikimedia.org/wikipedia/commons/9/95/2024_Tesla_Cybertruck_Foundation_Series%2C_front_left_%28Greenwich%29.jpg',
309
+ ]
310
+
311
+ gme = GmeQwen2VL("Alibaba-NLP/gme-Qwen2-VL-2B-Instruct")
312
+
313
+ # Single-modal embedding
314
+ e_text = gme.get_text_embeddings(texts=texts)
315
+ e_image = gme.get_image_embeddings(images=images)
316
+ print((e_text * e_image).sum(-1))
317
+ ## tensor([0.2281, 0.6001], dtype=torch.float16)
318
+
319
+ # How to set embedding instruction
320
+ e_query = gme.get_text_embeddings(texts=texts, instruction='Find an image that matches the given text.')
321
+ # If is_query=False, we always use the default instruction.
322
+ e_corpus = gme.get_image_embeddings(images=images, is_query=False)
323
+ print((e_query * e_corpus).sum(-1))
324
+ ## tensor([0.2433, 0.7051], dtype=torch.float16)
325
+
326
+ # Fused-modal embedding
327
+ e_fused = gme.get_fused_embeddings(texts=texts, images=images)
328
+ print((e_fused[0] * e_fused[1]).sum())
329
+ ## tensor(0.6108, dtype=torch.float16)