Samoed commited on
Commit
24370b0
·
verified ·
1 Parent(s): cfeb668

Try to integrate AutoModel

Browse files
Files changed (2) hide show
  1. config.json +8 -10
  2. gme_inference.py +161 -139
config.json CHANGED
@@ -1,8 +1,10 @@
1
  {
2
  "_name_or_path": "gme-Qwen2-VL-2B-Instruct",
3
- "architectures": [
4
- "Qwen2VLForConditionalGeneration"
5
- ],
 
 
6
  "attention_dropout": 0.0,
7
  "bos_token_id": 151643,
8
  "eos_token_id": 151645,
@@ -13,17 +15,13 @@
13
  "intermediate_size": 8960,
14
  "max_position_embeddings": 32768,
15
  "max_window_layers": 28,
16
- "model_type": "qwen2_vl",
17
  "num_attention_heads": 12,
18
  "num_hidden_layers": 28,
19
  "num_key_value_heads": 2,
20
- "rms_norm_eps": 1e-06,
21
  "rope_scaling": {
22
- "mrope_section": [
23
- 16,
24
- 24,
25
- 24
26
- ],
27
  "type": "mrope"
28
  },
29
  "rope_theta": 1000000.0,
 
1
  {
2
  "_name_or_path": "gme-Qwen2-VL-2B-Instruct",
3
+ "architectures": ["GmeQwen2VLForVision2Seq"],
4
+ "auto_map": {
5
+ "AutoModel": "gme_inference.GmeQwen2VLForVision2Seq",
6
+ "AutoConfig": "gme_inference.GmeQwen2VLConfig"
7
+ },
8
  "attention_dropout": 0.0,
9
  "bos_token_id": 151643,
10
  "eos_token_id": 151645,
 
15
  "intermediate_size": 8960,
16
  "max_position_embeddings": 32768,
17
  "max_window_layers": 28,
18
+ "model_type": "gme_qwen2_vl",
19
  "num_attention_heads": 12,
20
  "num_hidden_layers": 28,
21
  "num_key_value_heads": 2,
22
+ "rms_norm_eps": 1e-6,
23
  "rope_scaling": {
24
+ "mrope_section": [16, 24, 24],
 
 
 
 
25
  "type": "mrope"
26
  },
27
  "rope_theta": 1000000.0,
gme_inference.py CHANGED
@@ -1,45 +1,79 @@
1
  from __future__ import annotations
2
 
 
3
  import logging
4
  import math
5
  import os
6
- from typing import Dict, List, Optional
 
7
 
 
8
  import torch
9
  from PIL import Image
10
  from torch.utils.data import DataLoader
11
  from tqdm.autonotebook import tqdm
12
- from transformers import AutoModelForVision2Seq, AutoProcessor
13
-
 
 
 
 
 
 
 
 
14
 
15
- class GmeQwen2VL:
 
 
 
16
  def __init__(
17
  self,
18
- model_name: str = "Alibaba-NLP/gme-Qwen2-VL-2B-Instruct",
19
- model_path: Optional[str] = None,
 
20
  device: str = "cuda" if torch.cuda.is_available() else "cpu",
21
- min_image_tokens=256,
22
- max_image_tokens=1280,
23
- max_length=1800,
24
- **kwargs,
25
  ) -> None:
26
- model_name = model_path or model_name
27
- self.base = AutoModelForVision2Seq.from_pretrained(
28
- model_name, torch_dtype=torch.float16, **kwargs
29
- )
30
- self.base.eval()
31
- self.normalize = True
32
- self.device = device
33
- min_pixels = min_image_tokens * 28 * 28
34
- max_pixels = max_image_tokens * 28 * 28
35
  self.max_length = max_length
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  self.processor = AutoProcessor.from_pretrained(
37
  model_name, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
38
  )
39
- self.processor.tokenizer.padding_side = 'right'
40
- self.defualt_instruction = 'You are a helpful assistant.'
41
- self.sep = ' '
42
-
 
 
 
 
 
43
  def forward(
44
  self,
45
  input_ids: Optional[torch.LongTensor] = None,
@@ -48,11 +82,9 @@ class GmeQwen2VL:
48
  past_key_values: Optional[List[torch.FloatTensor]] = None,
49
  inputs_embeds: Optional[torch.FloatTensor] = None,
50
  pixel_values: Optional[torch.Tensor] = None,
51
- # pixel_values_videos: Optional[torch.FloatTensor] = None,
52
  image_grid_thw: Optional[torch.LongTensor] = None,
53
- # video_grid_thw: Optional[torch.LongTensor] = None,
54
  pooling_mask: Optional[torch.LongTensor] = None,
55
- **kwargs
56
  ) -> torch.Tensor:
57
  if inputs_embeds is None:
58
  inputs_embeds = self.base.model.embed_tokens(input_ids)
@@ -61,11 +93,6 @@ class GmeQwen2VL:
61
  image_embeds = self.base.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
62
  image_mask = input_ids == self.base.config.image_token_id
63
  inputs_embeds[image_mask] = image_embeds
64
- # if pixel_values_videos is not None:
65
- # pixel_values_videos = pixel_values_videos.type(self.base.visual.get_dtype())
66
- # video_embeds = self.base.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device)
67
- # video_mask = input_ids == self.base.config.video_token_id
68
- # inputs_embeds[video_mask] = video_embeds
69
  if attention_mask is not None:
70
  attention_mask = attention_mask.to(inputs_embeds.device)
71
 
@@ -78,36 +105,48 @@ class GmeQwen2VL:
78
  )
79
 
80
  pooling_mask = attention_mask if pooling_mask is None else pooling_mask
81
- left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) # TODO
82
  if left_padding:
83
  embeddings = outputs.last_hidden_state[:, -1]
84
  else:
85
  sequence_lengths = pooling_mask.sum(dim=1) - 1
86
  batch_size = outputs.last_hidden_state.shape[0]
87
- embeddings = outputs.last_hidden_state[torch.arange(
88
- batch_size, device=outputs.last_hidden_state.device
89
- ), sequence_lengths]
 
90
  if self.normalize:
91
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
92
  return embeddings.contiguous()
93
 
94
- def embed(self, texts: list[str], images: list[Image.Image], is_query=True, instruction=None, **kwargs):
 
 
 
 
 
 
 
95
  self.base.to(self.device)
96
- # Inputs must be batched
97
- input_texts, input_images = list(), list()
98
  for t, i in zip(texts, images):
99
  if not is_query or instruction is None:
100
  instruction = self.defualt_instruction
101
- input_str = ''
102
  if i is None:
103
  input_images = None # All examples in the same batch are consistent
104
  else:
105
- input_str += '<|vision_start|><|image_pad|><|vision_end|>'
106
  i = fetch_image(i)
107
  input_images.append(i)
108
  if t is not None:
109
  input_str += t
110
- msg = f'<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>'
 
 
 
 
111
  input_texts.append(msg)
112
 
113
  inputs = self.processor(
@@ -116,22 +155,22 @@ class GmeQwen2VL:
116
  padding=True,
117
  truncation=True,
118
  max_length=self.max_length,
119
- return_tensors='pt'
120
  )
121
- inputs = {k: v.to(self.device) for k, v in inputs.items()} # TODO
122
  with torch.no_grad():
123
  embeddings = self.forward(**inputs)
124
  return embeddings
125
 
126
- def encode(self, sentences: list[str], *, prompt_name=None, **kwargs):
127
- return self.get_fused_embeddings(texts=sentences, prompt_name=prompt_name, **kwargs)
 
128
 
129
- def encode_queries(self, queries: List[str], **kwargs):
130
- embeddings = self.encode(queries, **kwargs)
131
- return embeddings
132
 
133
- def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs):
134
- if type(corpus) is dict:
135
  sentences = [
136
  (corpus["title"][i] + self.sep + corpus["text"][i]).strip()
137
  if "title" in corpus
@@ -143,68 +182,55 @@ class GmeQwen2VL:
143
  (doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
144
  for doc in corpus
145
  ]
146
- embeddings = self.encode(sentences, is_query=False, **kwargs)
147
- return embeddings
148
 
149
- def get_image_embeddings(self, images: list[Image.Image] | DataLoader, **kwargs):
150
  return self.get_fused_embeddings(images=images, **kwargs)
151
 
152
- def get_text_embeddings(self, texts: list[str], **kwargs):
153
  return self.get_fused_embeddings(texts=texts, **kwargs)
154
 
155
- def get_fused_embeddings(self, texts: list[str] = None, images: list[Image.Image] | DataLoader = None, **kwargs):
 
 
 
 
 
 
156
  if isinstance(images, DataLoader):
157
  image_loader = images
158
  batch_size = image_loader.batch_size
159
  image_loader.dataset.transform = None
160
  else:
161
- batch_size = kwargs.pop('batch_size', 32)
162
  if images is None:
163
- image_loader = None
 
164
  else:
165
- image_loader = DataLoader(
166
- images,
167
- batch_size=batch_size,
168
- shuffle=False,
169
- collate_fn=custom_collate_fn,
170
- num_workers=min(math.floor(os.cpu_count() / 2), 8),
171
- )
172
-
173
- if texts is None:
174
- assert image_loader is not None
175
- n_batch = len(image_loader)
176
- else:
177
- n_batch = len(texts) // batch_size + int(len(texts) % batch_size > 0)
178
- image_loader = image_loader or [None] * n_batch
179
 
180
- all_embeddings = list()
 
181
  none_batch = [None] * batch_size
182
- show_progress_bar = kwargs.pop('show_progress_bar', True)
183
- pbar = tqdm(total=n_batch, disable=not show_progress_bar, mininterval=1, miniters=10, desc='encode')
184
  for n, img_batch in zip(range(0, n_batch * batch_size, batch_size), image_loader):
185
- text_batch = none_batch if texts is None else texts[n: n+batch_size]
186
  img_batch = none_batch if img_batch is None else img_batch
187
  embeddings = self.embed(texts=text_batch, images=img_batch, **kwargs)
188
  pbar.update(1)
189
  all_embeddings.append(embeddings.cpu())
190
  pbar.close()
191
- all_embeddings = torch.cat(all_embeddings, dim=0)
192
- return all_embeddings
193
 
 
 
194
 
195
- def custom_collate_fn(batch):
196
- return batch
197
-
198
-
199
- ### Copied from qwen_vl_utils.vision_process.py
200
- import base64
201
- from io import BytesIO
202
- import requests
203
-
204
- IMAGE_FACTOR = 28
205
- MIN_PIXELS = 4 * 28 * 28
206
- MAX_PIXELS = 16384 * 28 * 28
207
- MAX_RATIO = 200
208
 
209
 
210
  def round_by_factor(number: int, factor: int) -> int:
@@ -226,13 +252,10 @@ def smart_resize(
226
  height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
227
  ) -> tuple[int, int]:
228
  """
229
- Rescales the image so that the following conditions are met:
230
-
231
- 1. Both dimensions (height and width) are divisible by 'factor'.
232
-
233
- 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
234
-
235
- 3. The aspect ratio of the image is maintained as closely as possible.
236
  """
237
  h_bar = max(factor, round_by_factor(height, factor))
238
  w_bar = max(factor, round_by_factor(width, factor))
@@ -256,35 +279,27 @@ def smart_resize(
256
  return h_bar, w_bar
257
 
258
 
259
- def fetch_image(image: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> Image.Image:
260
- image_obj = None
261
  if isinstance(image, Image.Image):
262
  image_obj = image
263
- elif image.startswith("http://") or image.startswith("https://"):
264
  image_obj = Image.open(requests.get(image, stream=True).raw)
265
- elif image.startswith("file://"):
266
  image_obj = Image.open(image[7:])
267
- elif image.startswith("data:image"):
268
  if "base64," in image:
269
  _, base64_data = image.split("base64,", 1)
270
  data = base64.b64decode(base64_data)
271
  image_obj = Image.open(BytesIO(data))
272
- else:
273
  image_obj = Image.open(image)
274
  if image_obj is None:
275
- raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
 
 
276
  image = image_obj.convert("RGB")
277
- ## resize
278
- # if "resized_height" in ele and "resized_width" in ele:
279
- # resized_height, resized_width = smart_resize(
280
- # ele["resized_height"],
281
- # ele["resized_width"],
282
- # factor=size_factor,
283
- # )
284
- # else:
285
  width, height = image.size
286
- # min_pixels = ele.get("min_pixels", MIN_PIXELS)
287
- # max_pixels = ele.get("max_pixels", MAX_PIXELS)
288
  resized_height, resized_width = smart_resize(
289
  height,
290
  width,
@@ -293,37 +308,44 @@ def fetch_image(image: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> Im
293
  max_pixels=MAX_PIXELS,
294
  )
295
  image = image.resize((resized_width, resized_height))
296
-
297
  return image
298
- ###
299
 
300
 
301
- if __name__ == '__main__':
 
 
 
 
 
 
 
 
 
 
 
 
302
  texts = [
303
  "What kind of car is this?",
304
- "The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023."
305
  ]
306
  images = [
307
- 'https://en.wikipedia.org/wiki/File:Tesla_Cybertruck_damaged_window.jpg',
308
- 'https://en.wikipedia.org/wiki/File:2024_Tesla_Cybertruck_Foundation_Series,_front_left_(Greenwich).jpg',
309
  ]
310
 
311
- gme = GmeQwen2VL("Alibaba-NLP/gme-Qwen2-VL-2B-Instruct")
312
-
313
- # Single-modal embedding
314
- e_text = gme.get_text_embeddings(texts=texts)
315
- e_image = gme.get_image_embeddings(images=images)
316
- print((e_text * e_image).sum(-1))
317
- ## tensor([0.2281, 0.6001], dtype=torch.float16)
318
-
319
- # How to set embedding instruction
320
- e_query = gme.get_text_embeddings(texts=texts, instruction='Find an image that matches the given text.')
321
- # If is_query=False, we always use the default instruction.
322
- e_corpus = gme.get_image_embeddings(images=images, is_query=False)
323
- print((e_query * e_corpus).sum(-1))
324
- ## tensor([0.2433, 0.7051], dtype=torch.float16)
325
-
326
- # Fused-modal embedding
327
- e_fused = gme.get_fused_embeddings(texts=texts, images=images)
328
- print((e_fused[0] * e_fused[1]).sum())
329
- ## tensor(0.6108, dtype=torch.float16)
 
1
  from __future__ import annotations
2
 
3
+ import base64
4
  import logging
5
  import math
6
  import os
7
+ from io import BytesIO
8
+ from typing import Any, Dict, List, Optional, Union
9
 
10
+ import requests
11
  import torch
12
  from PIL import Image
13
  from torch.utils.data import DataLoader
14
  from tqdm.autonotebook import tqdm
15
+ from transformers import (
16
+ AutoConfig,
17
+ AutoModel,
18
+ AutoModelForVision2Seq,
19
+ AutoProcessor,
20
+ PreTrainedModel,
21
+ Qwen2VLConfig,
22
+ Qwen2VLModel,
23
+ )
24
+ import os
25
 
26
+ # Define a config class for our model.
27
+ class GmeQwen2VLConfig(Qwen2VLConfig):
28
+ model_type: str = "gme_qwen2_vl"
29
+
30
  def __init__(
31
  self,
32
+ min_image_tokens: int = 256,
33
+ max_image_tokens: int = 1280,
34
+ max_length: int = 1800,
35
  device: str = "cuda" if torch.cuda.is_available() else "cpu",
36
+ **kwargs: Any,
 
 
 
37
  ) -> None:
38
+ super().__init__(**kwargs)
39
+ self.min_image_tokens = min_image_tokens
40
+ self.max_image_tokens = max_image_tokens
 
 
 
 
 
 
41
  self.max_length = max_length
42
+ self.device = device
43
+ AutoConfig.register("gme_qwen2_vl", GmeQwen2VLConfig)
44
+
45
+
46
+ # Define the model class so that it can be loaded by AutoModel.from_pretrained.
47
+ class GmeQwen2VLForVision2Seq(PreTrainedModel):
48
+ config_class = GmeQwen2VLConfig
49
+ base_model_prefix: str = "base"
50
+
51
+ def __init__(self, config: GmeQwen2VLConfig, **kwargs: Any) -> None:
52
+ super().__init__(config)
53
+ model_name: str = getattr(config, "_name_or_path", "Alibaba-NLP/gme-Qwen2-VL-2B-Instruct")
54
+ # Load the underlying vision-to-sequence model.
55
+ self.base = Qwen2VLModel.from_pretrained(
56
+ model_name, trust_remote_code=True, **kwargs
57
+ )
58
+ self.normalize: bool = True
59
+ self.device: str = config.device
60
+
61
+ min_pixels: int = config.min_image_tokens * 28 * 28
62
+ max_pixels: int = config.max_image_tokens * 28 * 28
63
+ self.max_length: int = config.max_length
64
+
65
  self.processor = AutoProcessor.from_pretrained(
66
  model_name, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
67
  )
68
+ self.processor.tokenizer.padding_side = "right"
69
+ self.defualt_instruction: str = "You are a helpful assistant."
70
+ self.sep: str = " "
71
+
72
+ @classmethod
73
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs: Any) -> GmeQwen2VLForVision2Seq:
74
+ config = kwargs.pop("config", GmeQwen2VLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs))
75
+ return cls(config, **kwargs)
76
+
77
  def forward(
78
  self,
79
  input_ids: Optional[torch.LongTensor] = None,
 
82
  past_key_values: Optional[List[torch.FloatTensor]] = None,
83
  inputs_embeds: Optional[torch.FloatTensor] = None,
84
  pixel_values: Optional[torch.Tensor] = None,
 
85
  image_grid_thw: Optional[torch.LongTensor] = None,
 
86
  pooling_mask: Optional[torch.LongTensor] = None,
87
+ **kwargs: Any,
88
  ) -> torch.Tensor:
89
  if inputs_embeds is None:
90
  inputs_embeds = self.base.model.embed_tokens(input_ids)
 
93
  image_embeds = self.base.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
94
  image_mask = input_ids == self.base.config.image_token_id
95
  inputs_embeds[image_mask] = image_embeds
 
 
 
 
 
96
  if attention_mask is not None:
97
  attention_mask = attention_mask.to(inputs_embeds.device)
98
 
 
105
  )
106
 
107
  pooling_mask = attention_mask if pooling_mask is None else pooling_mask
108
+ left_padding: bool = (pooling_mask[:, -1].sum() == pooling_mask.shape[0])
109
  if left_padding:
110
  embeddings = outputs.last_hidden_state[:, -1]
111
  else:
112
  sequence_lengths = pooling_mask.sum(dim=1) - 1
113
  batch_size = outputs.last_hidden_state.shape[0]
114
+ embeddings = outputs.last_hidden_state[
115
+ torch.arange(batch_size, device=outputs.last_hidden_state.device),
116
+ sequence_lengths,
117
+ ]
118
  if self.normalize:
119
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
120
  return embeddings.contiguous()
121
 
122
+ def embed(
123
+ self,
124
+ texts: List[str],
125
+ images: List[Image.Image],
126
+ is_query: bool = True,
127
+ instruction: Optional[str] = None,
128
+ **kwargs: Any,
129
+ ) -> torch.Tensor:
130
  self.base.to(self.device)
131
+ input_texts: List[str] = []
132
+ input_images: List[Image.Image] = []
133
  for t, i in zip(texts, images):
134
  if not is_query or instruction is None:
135
  instruction = self.defualt_instruction
136
+ input_str: str = ""
137
  if i is None:
138
  input_images = None # All examples in the same batch are consistent
139
  else:
140
+ input_str += "<|vision_start|><|image_pad|><|vision_end|>"
141
  i = fetch_image(i)
142
  input_images.append(i)
143
  if t is not None:
144
  input_str += t
145
+ msg: str = (
146
+ f"<|im_start|>system\n{instruction}<|im_end|>\n"
147
+ f"<|im_start|>user\n{input_str}<|im_end|>\n"
148
+ f"<|im_start|>assistant\n<|endoftext|>"
149
+ )
150
  input_texts.append(msg)
151
 
152
  inputs = self.processor(
 
155
  padding=True,
156
  truncation=True,
157
  max_length=self.max_length,
158
+ return_tensors="pt",
159
  )
160
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
161
  with torch.no_grad():
162
  embeddings = self.forward(**inputs)
163
  return embeddings
164
 
165
+ def encode(self, sentences: List[str], **kwargs: Any) -> torch.Tensor:
166
+ # When no images are provided, we pass a list of Nones.
167
+ return self.embed(texts=sentences, images=[None] * len(sentences), **kwargs)
168
 
169
+ def encode_queries(self, queries: List[str], **kwargs: Any) -> torch.Tensor:
170
+ return self.encode(queries, **kwargs)
 
171
 
172
+ def encode_corpus(self, corpus: Union[Dict[str, List[str]], List[Dict[str, str]]], **kwargs: Any) -> torch.Tensor:
173
+ if isinstance(corpus, dict):
174
  sentences = [
175
  (corpus["title"][i] + self.sep + corpus["text"][i]).strip()
176
  if "title" in corpus
 
182
  (doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
183
  for doc in corpus
184
  ]
185
+ return self.encode(sentences, is_query=False, **kwargs)
 
186
 
187
+ def get_image_embeddings(self, images: Union[List[Image.Image], DataLoader], **kwargs: Any) -> torch.Tensor:
188
  return self.get_fused_embeddings(images=images, **kwargs)
189
 
190
+ def get_text_embeddings(self, texts: List[str], **kwargs: Any) -> torch.Tensor:
191
  return self.get_fused_embeddings(texts=texts, **kwargs)
192
 
193
+
194
+ def get_fused_embeddings(
195
+ self,
196
+ texts: Optional[List[str]] = None,
197
+ images: Optional[Union[List[Image.Image], DataLoader]] = None,
198
+ **kwargs: Any,
199
+ ) -> torch.Tensor:
200
  if isinstance(images, DataLoader):
201
  image_loader = images
202
  batch_size = image_loader.batch_size
203
  image_loader.dataset.transform = None
204
  else:
205
+ batch_size = kwargs.pop("batch_size", 32)
206
  if images is None:
207
+ # If texts are provided without images, create dummy image batches.
208
+ image_loader = [None] * ((len(texts) + batch_size - 1) // batch_size)
209
  else:
210
+ image_loader = images
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
+ n_batch: int = (len(texts) // batch_size + int(len(texts) % batch_size > 0)) if texts is not None else len(image_loader)
213
+ all_embeddings: List[torch.Tensor] = []
214
  none_batch = [None] * batch_size
215
+ show_progress_bar: bool = kwargs.pop("show_progress_bar", True)
216
+ pbar = tqdm(total=n_batch, disable=not show_progress_bar, mininterval=1, miniters=10, desc="encode")
217
  for n, img_batch in zip(range(0, n_batch * batch_size, batch_size), image_loader):
218
+ text_batch: List[Optional[str]] = none_batch if texts is None else texts[n: n + batch_size]
219
  img_batch = none_batch if img_batch is None else img_batch
220
  embeddings = self.embed(texts=text_batch, images=img_batch, **kwargs)
221
  pbar.update(1)
222
  all_embeddings.append(embeddings.cpu())
223
  pbar.close()
224
+ return torch.cat(all_embeddings, dim=0)
 
225
 
226
+ from transformers import AutoModelForVision2Seq
227
+ AutoModelForVision2Seq.register(GmeQwen2VLConfig, GmeQwen2VLForVision2Seq)
228
 
229
+ # Utility functions (copied from your vision processing code)
230
+ IMAGE_FACTOR: int = 28
231
+ MIN_PIXELS: int = 4 * 28 * 28
232
+ MAX_PIXELS: int = 16384 * 28 * 28
233
+ MAX_RATIO: int = 200
 
 
 
 
 
 
 
 
234
 
235
 
236
  def round_by_factor(number: int, factor: int) -> int:
 
252
  height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
253
  ) -> tuple[int, int]:
254
  """
255
+ Rescales the image so that:
256
+ 1. Both dimensions are divisible by 'factor'.
257
+ 2. Total pixels fall between ['min_pixels', 'max_pixels'].
258
+ 3. Aspect ratio is maintained as closely as possible.
 
 
 
259
  """
260
  h_bar = max(factor, round_by_factor(height, factor))
261
  w_bar = max(factor, round_by_factor(width, factor))
 
279
  return h_bar, w_bar
280
 
281
 
282
+ def fetch_image(image: Union[str, Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
283
+ image_obj: Optional[Image.Image] = None
284
  if isinstance(image, Image.Image):
285
  image_obj = image
286
+ elif isinstance(image, str) and (image.startswith("http://") or image.startswith("https://")):
287
  image_obj = Image.open(requests.get(image, stream=True).raw)
288
+ elif isinstance(image, str) and image.startswith("file://"):
289
  image_obj = Image.open(image[7:])
290
+ elif isinstance(image, str) and image.startswith("data:image"):
291
  if "base64," in image:
292
  _, base64_data = image.split("base64,", 1)
293
  data = base64.b64decode(base64_data)
294
  image_obj = Image.open(BytesIO(data))
295
+ elif isinstance(image, str):
296
  image_obj = Image.open(image)
297
  if image_obj is None:
298
+ raise ValueError(
299
+ f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
300
+ )
301
  image = image_obj.convert("RGB")
 
 
 
 
 
 
 
 
302
  width, height = image.size
 
 
303
  resized_height, resized_width = smart_resize(
304
  height,
305
  width,
 
308
  max_pixels=MAX_PIXELS,
309
  )
310
  image = image.resize((resized_width, resized_height))
 
311
  return image
 
312
 
313
 
314
+ # # For backward compatibility, you can add a from_pretrained classmethod.
315
+ # @classmethod
316
+ # def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs: Any) -> GmeQwen2VLForVision2Seq:
317
+ # config = GmeQwen2VLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
318
+ # return cls(config, **kwargs)
319
+
320
+
321
+ # # Monkey-patch the from_pretrained method to our class so that
322
+ # # one can load the model with AutoModel.from_pretrained.
323
+ # GmeQwen2VLForVision2Seq.from_pretrained = from_pretrained.__get__(GmeQwen2VLForVision2Seq)
324
+
325
+
326
+ if __name__ == "__main__":
327
  texts = [
328
  "What kind of car is this?",
329
+ "The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023.",
330
  ]
331
  images = [
332
+ "https://en.wikipedia.org/wiki/File:Tesla_Cybertruck_damaged_window.jpg",
333
+ "https://en.wikipedia.org/wiki/File:2024_Tesla_Cybertruck_Foundation_Series,_front_left_(Greenwich).jpg",
334
  ]
335
 
336
+ # You can now load your model with AutoModel as long as your repository's config JSON has the "architectures" field set.
337
+ model = AutoModel.from_pretrained("Alibaba-NLP/gme-Qwen2-VL-2B-Instruct")
338
+ # Alternatively, load it directly via our class:
339
+ # model = GmeQwen2VLForVision2Seq.from_pretrained("Alibaba-NLP/gme-Qwen2-VL-2B-Instruct")
340
+
341
+ # Single-modal embedding examples:
342
+ e_text = model.get_text_embeddings(texts=texts)
343
+ e_image = model.get_image_embeddings(images=images)
344
+ print("Text-Image similarity:", (e_text * e_image).sum(-1))
345
+ # Example with different instruction:
346
+ e_query = model.get_text_embeddings(texts=texts, instruction="Find an image that matches the given text.")
347
+ e_corpus = model.get_image_embeddings(images=images, is_query=False)
348
+ print("Query-Corpus similarity:", (e_query * e_corpus).sum(-1))
349
+ # Fused-modal embedding:
350
+ e_fused = model.get_fused_embeddings(texts=texts, images=images)
351
+ print("Fused-modal similarity:", (e_fused[0] * e_fused[1]).sum())