Samoed commited on
Commit
e0e7250
·
verified ·
1 Parent(s): 5fa49f4
Files changed (1) hide show
  1. gme_inference.py +70 -31
gme_inference.py CHANGED
@@ -24,9 +24,10 @@ from transformers import (
24
  import os
25
  from collections.abc import Iterable
26
 
 
27
  class GmeQwen2VLConfig(Qwen2VLConfig):
28
  model_type: str = "gme_qwen2_vl"
29
-
30
  def __init__(
31
  self,
32
  min_image_tokens: int = 256,
@@ -44,25 +45,27 @@ class GmeQwen2VLConfig(Qwen2VLConfig):
44
  class GmeQwen2VLForVision2Seq(PreTrainedModel):
45
  config_class = GmeQwen2VLConfig
46
  base_model_prefix: str = "base"
47
-
48
  def __init__(self, config: GmeQwen2VLConfig, **kwargs: Any) -> None:
49
  super().__init__(config)
50
- model_name: str = getattr(config, "_name_or_path", "Alibaba-NLP/gme-Qwen2-VL-2B-Instruct")
51
-
 
 
52
  self.base = Qwen2VLForConditionalGeneration(config)
53
  self.normalize: bool = True
54
-
55
  min_pixels: int = config.min_image_tokens * 28 * 28
56
  max_pixels: int = config.max_image_tokens * 28 * 28
57
-
58
  self.max_length: int = config.max_length
59
  self.processor = AutoProcessor.from_pretrained(
60
  model_name, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
61
  )
62
- self.processor.tokenizer.padding_side = 'right'
63
  self.defualt_instruction: str = "You are a helpful assistant."
64
  self.sep: str = " "
65
-
66
  def forward(
67
  self,
68
  input_ids: Optional[torch.LongTensor] = None,
@@ -75,13 +78,15 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
75
  image_grid_thw: Optional[torch.LongTensor] = None,
76
  # video_grid_thw: Optional[torch.LongTensor] = None,
77
  pooling_mask: Optional[torch.LongTensor] = None,
78
- **kwargs
79
  ) -> torch.Tensor:
80
  if inputs_embeds is None:
81
  inputs_embeds = self.base.model.embed_tokens(input_ids)
82
  if pixel_values is not None:
83
  pixel_values = pixel_values.type(self.base.visual.get_dtype())
84
- image_embeds = self.base.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
 
 
85
  image_mask = input_ids == self.base.config.image_token_id
86
  inputs_embeds[image_mask] = image_embeds
87
  # if pixel_values_videos is not None:
@@ -101,37 +106,44 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
101
  )
102
 
103
  pooling_mask = attention_mask if pooling_mask is None else pooling_mask
104
- left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) # TODO
105
  if left_padding:
106
  embeddings = outputs.last_hidden_state[:, -1]
107
  else:
108
  sequence_lengths = pooling_mask.sum(dim=1) - 1
109
  batch_size = outputs.last_hidden_state.shape[0]
110
- embeddings = outputs.last_hidden_state[torch.arange(
111
- batch_size, device=outputs.last_hidden_state.device
112
- ), sequence_lengths]
 
113
  if self.normalize:
114
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
115
  return embeddings.contiguous()
116
 
117
-
118
- def embed(self, texts: list[str], images: list[Image.Image], is_query=True, instruction=None, **kwargs):
 
 
 
 
 
 
119
  self.base.to(self.device)
120
  # Inputs must be batched
121
  input_texts, input_images = list(), list()
122
  for t, i in zip(texts, images):
123
  if not is_query or instruction is None:
124
  instruction = self.defualt_instruction
125
- input_str = ''
126
  if i is None:
127
  input_images = None # All examples in the same batch are consistent
128
  else:
129
- input_str += '<|vision_start|><|image_pad|><|vision_end|>'
130
  i = fetch_image(i)
131
  input_images.append(i)
132
  if t is not None:
133
  input_str += t
134
- msg = f'<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>'
135
  input_texts.append(msg)
136
 
137
  inputs = self.processor(
@@ -140,7 +152,7 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
140
  padding=True,
141
  truncation=True,
142
  max_length=self.max_length,
143
- return_tensors='pt'
144
  )
145
  inputs = {k: v.to(self.device) for k, v in inputs.items()} # TODO
146
  with torch.no_grad():
@@ -148,7 +160,9 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
148
  return embeddings
149
 
150
  def encode(self, sentences: list[str], *, prompt_name=None, **kwargs):
151
- return self.get_fused_embeddings(texts=sentences, prompt_name=prompt_name, **kwargs)
 
 
152
 
153
  def encode_queries(self, queries: List[str], **kwargs):
154
  embeddings = self.encode(queries, **kwargs)
@@ -164,7 +178,9 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
164
  ]
165
  else:
166
  sentences = [
167
- (doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
 
 
168
  for doc in corpus
169
  ]
170
  embeddings = self.encode(sentences, is_query=False, **kwargs)
@@ -176,13 +192,18 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
176
  def get_text_embeddings(self, texts: list[str], **kwargs):
177
  return self.get_fused_embeddings(texts=texts, **kwargs)
178
 
179
- def get_fused_embeddings(self, texts: list[str] = None, images: list[Image.Image] | DataLoader = None, **kwargs):
 
 
 
 
 
180
  if isinstance(images, DataLoader):
181
  image_loader = images
182
  batch_size = image_loader.batch_size
183
  image_loader.dataset.transform = None
184
  else:
185
- batch_size = kwargs.pop('batch_size', 32)
186
  if images is None:
187
  image_loader = None
188
  else:
@@ -203,10 +224,18 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
203
 
204
  all_embeddings = list()
205
  none_batch = [None] * batch_size
206
- show_progress_bar = kwargs.pop('show_progress_bar', True)
207
- pbar = tqdm(total=n_batch, disable=not show_progress_bar, mininterval=1, miniters=10, desc='encode')
208
- for n, img_batch in zip(range(0, n_batch * batch_size, batch_size), image_loader):
209
- text_batch = none_batch if texts is None else texts[n: n+batch_size]
 
 
 
 
 
 
 
 
210
  img_batch = none_batch if img_batch is None else img_batch
211
  embeddings = self.embed(texts=text_batch, images=img_batch, **kwargs)
212
  pbar.update(1)
@@ -215,9 +244,11 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
215
  all_embeddings = torch.cat(all_embeddings, dim=0)
216
  return all_embeddings
217
 
 
218
  def custom_collate_fn(batch):
219
  return batch
220
 
 
221
  # Utility functions (copied from your vision processing code)
222
  IMAGE_FACTOR: int = 28
223
  MIN_PIXELS: int = 4 * 28 * 28
@@ -241,7 +272,11 @@ def floor_by_factor(number: int, factor: int) -> int:
241
 
242
 
243
  def smart_resize(
244
- height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
 
 
 
 
245
  ) -> tuple[int, int]:
246
  """
247
  Rescales the image so that:
@@ -271,11 +306,15 @@ def smart_resize(
271
  return h_bar, w_bar
272
 
273
 
274
- def fetch_image(image: Union[str, Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
 
 
275
  image_obj: Optional[Image.Image] = None
276
  if isinstance(image, Image.Image):
277
  image_obj = image
278
- elif isinstance(image, str) and (image.startswith("http://") or image.startswith("https://")):
 
 
279
  image_obj = Image.open(requests.get(image, stream=True).raw)
280
  elif isinstance(image, str) and image.startswith("file://"):
281
  image_obj = Image.open(image[7:])
 
24
  import os
25
  from collections.abc import Iterable
26
 
27
+
28
  class GmeQwen2VLConfig(Qwen2VLConfig):
29
  model_type: str = "gme_qwen2_vl"
30
+
31
  def __init__(
32
  self,
33
  min_image_tokens: int = 256,
 
45
  class GmeQwen2VLForVision2Seq(PreTrainedModel):
46
  config_class = GmeQwen2VLConfig
47
  base_model_prefix: str = "base"
48
+
49
  def __init__(self, config: GmeQwen2VLConfig, **kwargs: Any) -> None:
50
  super().__init__(config)
51
+ model_name: str = getattr(
52
+ config, "_name_or_path", "Alibaba-NLP/gme-Qwen2-VL-2B-Instruct"
53
+ )
54
+
55
  self.base = Qwen2VLForConditionalGeneration(config)
56
  self.normalize: bool = True
57
+
58
  min_pixels: int = config.min_image_tokens * 28 * 28
59
  max_pixels: int = config.max_image_tokens * 28 * 28
60
+
61
  self.max_length: int = config.max_length
62
  self.processor = AutoProcessor.from_pretrained(
63
  model_name, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
64
  )
65
+ self.processor.tokenizer.padding_side = "right"
66
  self.defualt_instruction: str = "You are a helpful assistant."
67
  self.sep: str = " "
68
+
69
  def forward(
70
  self,
71
  input_ids: Optional[torch.LongTensor] = None,
 
78
  image_grid_thw: Optional[torch.LongTensor] = None,
79
  # video_grid_thw: Optional[torch.LongTensor] = None,
80
  pooling_mask: Optional[torch.LongTensor] = None,
81
+ **kwargs,
82
  ) -> torch.Tensor:
83
  if inputs_embeds is None:
84
  inputs_embeds = self.base.model.embed_tokens(input_ids)
85
  if pixel_values is not None:
86
  pixel_values = pixel_values.type(self.base.visual.get_dtype())
87
+ image_embeds = self.base.visual(
88
+ pixel_values, grid_thw=image_grid_thw
89
+ ).to(inputs_embeds.device)
90
  image_mask = input_ids == self.base.config.image_token_id
91
  inputs_embeds[image_mask] = image_embeds
92
  # if pixel_values_videos is not None:
 
106
  )
107
 
108
  pooling_mask = attention_mask if pooling_mask is None else pooling_mask
109
+ left_padding = pooling_mask[:, -1].sum() == pooling_mask.shape[0] # TODO
110
  if left_padding:
111
  embeddings = outputs.last_hidden_state[:, -1]
112
  else:
113
  sequence_lengths = pooling_mask.sum(dim=1) - 1
114
  batch_size = outputs.last_hidden_state.shape[0]
115
+ embeddings = outputs.last_hidden_state[
116
+ torch.arange(batch_size, device=outputs.last_hidden_state.device),
117
+ sequence_lengths,
118
+ ]
119
  if self.normalize:
120
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
121
  return embeddings.contiguous()
122
 
123
+ def embed(
124
+ self,
125
+ texts: list[str],
126
+ images: list[Image.Image],
127
+ is_query=True,
128
+ instruction=None,
129
+ **kwargs,
130
+ ):
131
  self.base.to(self.device)
132
  # Inputs must be batched
133
  input_texts, input_images = list(), list()
134
  for t, i in zip(texts, images):
135
  if not is_query or instruction is None:
136
  instruction = self.defualt_instruction
137
+ input_str = ""
138
  if i is None:
139
  input_images = None # All examples in the same batch are consistent
140
  else:
141
+ input_str += "<|vision_start|><|image_pad|><|vision_end|>"
142
  i = fetch_image(i)
143
  input_images.append(i)
144
  if t is not None:
145
  input_str += t
146
+ msg = f"<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>"
147
  input_texts.append(msg)
148
 
149
  inputs = self.processor(
 
152
  padding=True,
153
  truncation=True,
154
  max_length=self.max_length,
155
+ return_tensors="pt",
156
  )
157
  inputs = {k: v.to(self.device) for k, v in inputs.items()} # TODO
158
  with torch.no_grad():
 
160
  return embeddings
161
 
162
  def encode(self, sentences: list[str], *, prompt_name=None, **kwargs):
163
+ return self.get_fused_embeddings(
164
+ texts=sentences, prompt_name=prompt_name, **kwargs
165
+ )
166
 
167
  def encode_queries(self, queries: List[str], **kwargs):
168
  embeddings = self.encode(queries, **kwargs)
 
178
  ]
179
  else:
180
  sentences = [
181
+ (doc["title"] + self.sep + doc["text"]).strip()
182
+ if "title" in doc
183
+ else doc["text"].strip()
184
  for doc in corpus
185
  ]
186
  embeddings = self.encode(sentences, is_query=False, **kwargs)
 
192
  def get_text_embeddings(self, texts: list[str], **kwargs):
193
  return self.get_fused_embeddings(texts=texts, **kwargs)
194
 
195
+ def get_fused_embeddings(
196
+ self,
197
+ texts: list[str] = None,
198
+ images: list[Image.Image] | DataLoader = None,
199
+ **kwargs,
200
+ ):
201
  if isinstance(images, DataLoader):
202
  image_loader = images
203
  batch_size = image_loader.batch_size
204
  image_loader.dataset.transform = None
205
  else:
206
+ batch_size = kwargs.pop("batch_size", 32)
207
  if images is None:
208
  image_loader = None
209
  else:
 
224
 
225
  all_embeddings = list()
226
  none_batch = [None] * batch_size
227
+ show_progress_bar = kwargs.pop("show_progress_bar", True)
228
+ pbar = tqdm(
229
+ total=n_batch,
230
+ disable=not show_progress_bar,
231
+ mininterval=1,
232
+ miniters=10,
233
+ desc="encode",
234
+ )
235
+ for n, img_batch in zip(
236
+ range(0, n_batch * batch_size, batch_size), image_loader
237
+ ):
238
+ text_batch = none_batch if texts is None else texts[n : n + batch_size]
239
  img_batch = none_batch if img_batch is None else img_batch
240
  embeddings = self.embed(texts=text_batch, images=img_batch, **kwargs)
241
  pbar.update(1)
 
244
  all_embeddings = torch.cat(all_embeddings, dim=0)
245
  return all_embeddings
246
 
247
+
248
  def custom_collate_fn(batch):
249
  return batch
250
 
251
+
252
  # Utility functions (copied from your vision processing code)
253
  IMAGE_FACTOR: int = 28
254
  MIN_PIXELS: int = 4 * 28 * 28
 
272
 
273
 
274
  def smart_resize(
275
+ height: int,
276
+ width: int,
277
+ factor: int = IMAGE_FACTOR,
278
+ min_pixels: int = MIN_PIXELS,
279
+ max_pixels: int = MAX_PIXELS,
280
  ) -> tuple[int, int]:
281
  """
282
  Rescales the image so that:
 
306
  return h_bar, w_bar
307
 
308
 
309
+ def fetch_image(
310
+ image: Union[str, Image.Image], size_factor: int = IMAGE_FACTOR
311
+ ) -> Image.Image:
312
  image_obj: Optional[Image.Image] = None
313
  if isinstance(image, Image.Image):
314
  image_obj = image
315
+ elif isinstance(image, str) and (
316
+ image.startswith("http://") or image.startswith("https://")
317
+ ):
318
  image_obj = Image.open(requests.get(image, stream=True).raw)
319
  elif isinstance(image, str) and image.startswith("file://"):
320
  image_obj = Image.open(image[7:])