aria-dev commited on
Commit
fb5de81
1 Parent(s): 6a21e23

sync changes from github

Browse files
Files changed (5) hide show
  1. configuration_aria.py +17 -5
  2. modeling_aria.py +53 -264
  3. moe_lm.py +1 -1
  4. processing_aria.py +23 -1
  5. vision_processor.py +1 -1
configuration_aria.py CHANGED
@@ -17,11 +17,15 @@
17
  # specific language governing permissions and limitations
18
  # under the License.
19
 
 
 
20
  from transformers.configuration_utils import PretrainedConfig
21
 
22
  from .moe_lm import AriaMoELMConfig
23
  from .vision_encoder import AriaVisionConfig
24
 
 
 
25
 
26
  # adapted from transformers.models.llava.configuration_llava.LlavaConfig
27
  class AriaConfig(PretrainedConfig):
@@ -69,6 +73,7 @@ class AriaConfig(PretrainedConfig):
69
  self.image_token_index = image_token_index
70
 
71
  attn_implementation = kwargs.pop("attn_implementation", None)
 
72
 
73
  # Convert the keys and values of projector_patch_to_query_dict to integers
74
  # This ensures consistency even if they were provided as strings
@@ -78,11 +83,15 @@ class AriaConfig(PretrainedConfig):
78
 
79
  if isinstance(vision_config, dict) and "model_type" in vision_config:
80
  vision_config = AriaVisionConfig(**vision_config)
81
- vision_attn_implementation = (
82
- "flash_attention_2"
83
- if attn_implementation is None
84
- else attn_implementation
85
- )
 
 
 
 
86
  vision_config._attn_implementation = vision_attn_implementation
87
 
88
  self.vision_config = vision_config
@@ -95,3 +104,6 @@ class AriaConfig(PretrainedConfig):
95
  text_config._attn_implementation = text_attn_implementation
96
 
97
  self.text_config = text_config
 
 
 
 
17
  # specific language governing permissions and limitations
18
  # under the License.
19
 
20
+ import logging
21
+
22
  from transformers.configuration_utils import PretrainedConfig
23
 
24
  from .moe_lm import AriaMoELMConfig
25
  from .vision_encoder import AriaVisionConfig
26
 
27
+ logger = logging.getLogger(__name__)
28
+
29
 
30
  # adapted from transformers.models.llava.configuration_llava.LlavaConfig
31
  class AriaConfig(PretrainedConfig):
 
73
  self.image_token_index = image_token_index
74
 
75
  attn_implementation = kwargs.pop("attn_implementation", None)
76
+ self._attn_implementation = attn_implementation
77
 
78
  # Convert the keys and values of projector_patch_to_query_dict to integers
79
  # This ensures consistency even if they were provided as strings
 
83
 
84
  if isinstance(vision_config, dict) and "model_type" in vision_config:
85
  vision_config = AriaVisionConfig(**vision_config)
86
+ if attn_implementation is None:
87
+ vision_attn_implementation = "flash_attention_2"
88
+ elif attn_implementation == "sdpa":
89
+ logger.warning(
90
+ "SDPA is not supported for vit, using flash_attention_2 instead"
91
+ )
92
+ vision_attn_implementation = "flash_attention_2"
93
+ else:
94
+ vision_attn_implementation = attn_implementation
95
  vision_config._attn_implementation = vision_attn_implementation
96
 
97
  self.vision_config = vision_config
 
104
  text_config._attn_implementation = text_attn_implementation
105
 
106
  self.text_config = text_config
107
+
108
+ # This is needed for the static kv cache
109
+ self.num_hidden_layers = self.text_config.num_hidden_layers
modeling_aria.py CHANGED
@@ -24,7 +24,6 @@ import torch
24
  import torch.nn as nn
25
  from torch import nn
26
  from transformers import PreTrainedModel
27
- from transformers.cache_utils import Cache
28
  from transformers.modeling_outputs import ModelOutput
29
  from transformers.utils import logging
30
 
@@ -48,6 +47,7 @@ class AriaPretrainedModel(PreTrainedModel):
48
  _skip_keys_device_placement = "past_key_values"
49
  _supports_flash_attn_2 = True
50
  _supports_cache_class = True
 
51
 
52
  @property
53
  def _supports_sdpa(self):
@@ -183,138 +183,6 @@ class AriaForConditionalGeneration(AriaPretrainedModel):
183
  """
184
  self.language_model.set_aux_loss_coeff(value)
185
 
186
- # copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration
187
- def _merge_input_ids_with_image_features(
188
- self, image_features, inputs_embeds, input_ids, attention_mask, labels
189
- ):
190
- """
191
- Merge input IDs with image features to create a combined input representation.
192
-
193
- This method handles the complex logic of interleaving text and image tokens,
194
- adjusting attention masks and labels accordingly.
195
-
196
- Args:
197
- image_features (torch.Tensor): Processed image features.
198
- inputs_embeds (torch.Tensor): Text input embeddings.
199
- input_ids (torch.Tensor): Input token IDs.
200
- attention_mask (torch.Tensor): Attention mask for input tokens.
201
- labels (torch.Tensor, optional): Labels for language modeling.
202
-
203
- Returns:
204
- tuple: Contains the merged embeddings, updated attention mask,
205
- updated labels, and position IDs.
206
- """
207
- num_images, num_image_patches, embed_dim = image_features.shape
208
- batch_size, sequence_length = input_ids.shape
209
- left_padding = not torch.sum(
210
- input_ids[:, -1] == torch.tensor(self.pad_token_id)
211
- )
212
- # 1. Create a mask to know where special image tokens are
213
- special_image_token_mask = input_ids == self.config.image_token_index
214
- num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
215
- # Compute the maximum embed dimension
216
- max_embed_dim = (
217
- num_special_image_tokens.max() * (num_image_patches - 1)
218
- ) + sequence_length
219
- batch_indices, non_image_indices = torch.where(
220
- input_ids != self.config.image_token_index
221
- )
222
-
223
- # 2. Compute the positions where text should be written
224
- # Calculate new positions for text tokens in merged image-text sequence.
225
- # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
226
- # `torch.cumsum` computes how each image token shifts subsequent text token positions.
227
- # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
228
- new_token_positions = (
229
- torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1)
230
- - 1
231
- )
232
- nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
233
- if left_padding:
234
- new_token_positions += nb_image_pad[:, None] # offset for left padding
235
- text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
236
-
237
- # 3. Create the full embedding, already padded to the maximum position
238
- final_embedding = torch.zeros(
239
- batch_size,
240
- max_embed_dim,
241
- embed_dim,
242
- dtype=inputs_embeds.dtype,
243
- device=inputs_embeds.device,
244
- )
245
- final_attention_mask = torch.zeros(
246
- batch_size,
247
- max_embed_dim,
248
- dtype=attention_mask.dtype,
249
- device=inputs_embeds.device,
250
- )
251
- if labels is not None:
252
- final_labels = torch.full(
253
- (batch_size, max_embed_dim),
254
- self.config.ignore_index,
255
- dtype=input_ids.dtype,
256
- device=input_ids.device,
257
- )
258
- # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
259
- # set the corresponding tensors into their correct target device.
260
- target_device = inputs_embeds.device
261
- batch_indices, non_image_indices, text_to_overwrite = (
262
- batch_indices.to(target_device),
263
- non_image_indices.to(target_device),
264
- text_to_overwrite.to(target_device),
265
- )
266
- attention_mask = attention_mask.to(target_device)
267
-
268
- # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
269
- # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
270
- final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
271
- batch_indices, non_image_indices
272
- ]
273
- final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
274
- batch_indices, non_image_indices
275
- ]
276
- if labels is not None:
277
- final_labels[batch_indices, text_to_overwrite] = labels[
278
- batch_indices, non_image_indices
279
- ]
280
-
281
- # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
282
- image_to_overwrite = torch.full(
283
- (batch_size, max_embed_dim),
284
- True,
285
- dtype=torch.bool,
286
- device=inputs_embeds.device,
287
- )
288
- image_to_overwrite[batch_indices, text_to_overwrite] = False
289
- image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[
290
- :, None
291
- ].to(target_device)
292
-
293
- if image_to_overwrite.sum() != image_features.shape[:-1].numel():
294
- raise ValueError(
295
- f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
296
- f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
297
- )
298
-
299
- final_embedding[image_to_overwrite] = (
300
- image_features.contiguous().reshape(-1, embed_dim).to(target_device)
301
- )
302
- final_attention_mask |= image_to_overwrite
303
- position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
304
- (final_attention_mask == 0), 1
305
- )
306
-
307
- # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
308
- batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
309
- indices_to_mask = new_token_positions[batch_indices, pad_indices]
310
-
311
- final_embedding[batch_indices, indices_to_mask] = 0
312
-
313
- if labels is None:
314
- final_labels = None
315
-
316
- return final_embedding, final_attention_mask, final_labels, position_ids
317
-
318
  def forward(
319
  self,
320
  input_ids: torch.LongTensor = None,
@@ -329,6 +197,8 @@ class AriaForConditionalGeneration(AriaPretrainedModel):
329
  output_attentions: Optional[bool] = None,
330
  output_hidden_states: Optional[bool] = None,
331
  return_dict: Optional[bool] = None,
 
 
332
  ) -> Union[Tuple, AriaCausalLMOutputWithPast]:
333
  """
334
  Forward pass of the AriaForConditionalGeneration model.
@@ -371,69 +241,38 @@ class AriaForConditionalGeneration(AriaPretrainedModel):
371
  # 1. Extra the input embeddings
372
  inputs_embeds = self.get_input_embeddings()(input_ids)
373
 
374
- # 2. Merge text and images
375
- if pixel_values is not None and input_ids.shape[1] != 1:
376
- image_outputs, image_attn_mask = self.vision_tower(
377
- pixel_values,
378
- pixel_mask=pixel_mask,
379
- )
380
- selected_image_feature = image_outputs.last_hidden_state
381
-
382
- image_features = self.multi_modal_projector(
383
- selected_image_feature, attn_mask=image_attn_mask
384
- )
385
-
386
- inputs_embeds = inputs_embeds.to(image_features.dtype)
387
- (
388
- inputs_embeds,
389
- attention_mask,
390
- labels,
391
- position_ids,
392
- ) = self._merge_input_ids_with_image_features(
393
- image_features, inputs_embeds, input_ids, attention_mask, labels
394
- )
395
-
396
- # In case input_ids.shape[1] == 1 & pixel_values != None & past_key_values != None, we are in the case of
397
- # generation with cache
398
- elif (
399
- past_key_values is not None
400
- and pixel_values is not None
401
- and input_ids.shape[1] == 1
402
- ):
403
- # Retrieve the first layer to inspect the logits and mask out the hidden states
404
- # that are set to 0
405
- first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
406
-
407
- # Sum all dimensions of head_dim (-2) to avoid random errors
408
- # such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
409
- batch_index, non_attended_tokens = torch.where(
410
- first_layer_past_key_value.float().sum(-2) == 0
411
- )
412
-
413
- # Get the target length
414
- target_length = input_ids.shape[1]
415
- past_length = first_layer_past_key_value.shape[-1]
416
-
417
- extended_attention_mask = torch.ones(
418
- (attention_mask.shape[0], past_length),
419
- dtype=attention_mask.dtype,
420
- device=attention_mask.device,
421
- )
422
 
423
- # Filter out only the tokens that can be un-attended, this can happen
424
- # if one uses Llava + Fused modules where the cache on the
425
- # first iteration is already big enough, or if one passes custom cache
426
- valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
427
- new_batch_index = batch_index[valid_indices]
428
- new_non_attended_tokens = non_attended_tokens[valid_indices]
429
 
430
- # Zero-out the places where we don't need to attend
431
- extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
 
432
 
433
- attention_mask = torch.cat(
434
- (extended_attention_mask, attention_mask[:, -target_length:]), dim=1
 
435
  )
436
- position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
 
 
 
 
 
 
 
 
 
 
 
437
 
438
  outputs = self.language_model(
439
  attention_mask=attention_mask,
@@ -444,6 +283,8 @@ class AriaForConditionalGeneration(AriaPretrainedModel):
444
  output_attentions=output_attentions,
445
  output_hidden_states=output_hidden_states,
446
  return_dict=return_dict,
 
 
447
  )
448
 
449
  logits = outputs[0]
@@ -452,7 +293,11 @@ class AriaForConditionalGeneration(AriaPretrainedModel):
452
  if labels is not None:
453
  # Shift so that tokens < n predict n
454
  if attention_mask is not None:
455
- shift_attention_mask = attention_mask[..., 1:]
 
 
 
 
456
  shift_logits = logits[..., :-1, :][
457
  shift_attention_mask.to(logits.device) != 0
458
  ].contiguous()
@@ -487,80 +332,24 @@ class AriaForConditionalGeneration(AriaPretrainedModel):
487
  past_key_values=None,
488
  inputs_embeds=None,
489
  pixel_values=None,
490
- pixel_mask=None,
491
  attention_mask=None,
 
 
492
  **kwargs,
493
  ):
494
- """
495
- Prepare inputs for generation step.
496
-
497
- This method prepares the inputs for the generation step, handling both
498
- text and image inputs, and managing the model's cache mechanism.
 
 
 
 
499
 
500
- Args:
501
- input_ids (torch.LongTensor): Input token ids.
502
- past_key_values (Cache or List[torch.FloatTensor], optional): Past key values for efficient processing.
503
- inputs_embeds (torch.FloatTensor, optional): Input embeddings.
504
- pixel_values (torch.FloatTensor, optional): Pixel values of the images.
505
- pixel_mask (torch.LongTensor, optional): Mask for the pixel values.
506
- attention_mask (torch.Tensor, optional): Attention mask.
507
- **kwargs: Additional keyword arguments.
508
 
509
- Returns:
510
- dict: A dictionary containing the prepared inputs for the generation step.
511
- """
512
- if past_key_values is not None:
513
- if isinstance(past_key_values, Cache):
514
- cache_length = past_key_values.get_seq_length()
515
- past_length = past_key_values.seen_tokens
516
- else:
517
- cache_length = past_length = past_key_values[0][0].shape[2]
518
-
519
- # Keep only the unprocessed tokens:
520
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
521
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
522
- # input)
523
- if (
524
- attention_mask is not None
525
- and attention_mask.shape[1] > input_ids.shape[1]
526
- ):
527
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
528
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
529
- # input_ids based on the past_length.
530
- elif past_length < input_ids.shape[1]:
531
- input_ids = input_ids[:, past_length:]
532
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
533
- elif self.config.image_token_index in input_ids:
534
- input_ids = input_ids[:, input_ids.shape[1] - 1 :]
535
- # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
536
- # older attention values, as their corresponding values are not part of the input.
537
- if cache_length < past_length and attention_mask is not None:
538
- attention_mask = attention_mask[
539
- :, -(cache_length + input_ids.shape[1]) :
540
- ]
541
-
542
- position_ids = kwargs.get("position_ids", None)
543
- if attention_mask is not None and position_ids is None:
544
- # create position_ids on the fly for batch generation
545
- position_ids = attention_mask.long().cumsum(-1) - 1
546
- position_ids.masked_fill_(attention_mask == 0, 1)
547
- if past_key_values:
548
- position_ids = position_ids[:, -input_ids.shape[1] :]
549
-
550
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
551
- if inputs_embeds is not None and past_key_values is None:
552
- model_inputs = {"inputs_embeds": inputs_embeds}
553
- else:
554
- model_inputs = {"input_ids": input_ids}
555
-
556
- model_inputs.update(
557
- {
558
- "position_ids": position_ids,
559
- "past_key_values": past_key_values,
560
- "use_cache": kwargs.get("use_cache"),
561
- "attention_mask": attention_mask,
562
- "pixel_values": pixel_values,
563
- "pixel_mask": pixel_mask,
564
- }
565
- )
566
  return model_inputs
 
24
  import torch.nn as nn
25
  from torch import nn
26
  from transformers import PreTrainedModel
 
27
  from transformers.modeling_outputs import ModelOutput
28
  from transformers.utils import logging
29
 
 
47
  _skip_keys_device_placement = "past_key_values"
48
  _supports_flash_attn_2 = True
49
  _supports_cache_class = True
50
+ _supports_static_cache = True
51
 
52
  @property
53
  def _supports_sdpa(self):
 
183
  """
184
  self.language_model.set_aux_loss_coeff(value)
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  def forward(
187
  self,
188
  input_ids: torch.LongTensor = None,
 
197
  output_attentions: Optional[bool] = None,
198
  output_hidden_states: Optional[bool] = None,
199
  return_dict: Optional[bool] = None,
200
+ cache_position: Optional[torch.LongTensor] = None,
201
+ num_logits_to_keep: int = 0,
202
  ) -> Union[Tuple, AriaCausalLMOutputWithPast]:
203
  """
204
  Forward pass of the AriaForConditionalGeneration model.
 
241
  # 1. Extra the input embeddings
242
  inputs_embeds = self.get_input_embeddings()(input_ids)
243
 
244
+ image_features = None
245
+ if pixel_values is not None:
246
+ image_outputs, image_attn_mask = self.vision_tower(
247
+ pixel_values,
248
+ pixel_mask=pixel_mask,
249
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
+ selected_image_feature = image_outputs.last_hidden_state
252
+ image_features = self.multi_modal_projector(
253
+ selected_image_feature, attn_mask=image_attn_mask
254
+ )
 
 
255
 
256
+ if image_features is not None:
257
+ n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
258
+ n_image_features = image_features.shape[0] * image_features.shape[1]
259
 
260
+ if n_image_tokens != n_image_features:
261
+ raise ValueError(
262
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
263
  )
264
+ special_image_mask = (
265
+ (input_ids == self.config.image_token_index)
266
+ .unsqueeze(-1)
267
+ .expand_as(inputs_embeds)
268
+ .to(inputs_embeds.device)
269
+ )
270
+ image_features = image_features.to(
271
+ inputs_embeds.device, inputs_embeds.dtype
272
+ )
273
+ inputs_embeds = inputs_embeds.masked_scatter(
274
+ special_image_mask, image_features
275
+ )
276
 
277
  outputs = self.language_model(
278
  attention_mask=attention_mask,
 
283
  output_attentions=output_attentions,
284
  output_hidden_states=output_hidden_states,
285
  return_dict=return_dict,
286
+ cache_position=cache_position,
287
+ num_logits_to_keep=num_logits_to_keep,
288
  )
289
 
290
  logits = outputs[0]
 
293
  if labels is not None:
294
  # Shift so that tokens < n predict n
295
  if attention_mask is not None:
296
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
297
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
298
+ shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
299
+ logits.device
300
+ )
301
  shift_logits = logits[..., :-1, :][
302
  shift_attention_mask.to(logits.device) != 0
303
  ].contiguous()
 
332
  past_key_values=None,
333
  inputs_embeds=None,
334
  pixel_values=None,
 
335
  attention_mask=None,
336
+ cache_position=None,
337
+ num_logits_to_keep=None,
338
  **kwargs,
339
  ):
340
+ model_inputs = self.language_model.prepare_inputs_for_generation(
341
+ input_ids,
342
+ past_key_values=past_key_values,
343
+ inputs_embeds=inputs_embeds,
344
+ attention_mask=attention_mask,
345
+ cache_position=cache_position,
346
+ num_logits_to_keep=num_logits_to_keep,
347
+ **kwargs,
348
+ )
349
 
350
+ if cache_position[0] == 0:
351
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
352
+ # Otherwise we need pixel values to be passed to model
353
+ model_inputs["pixel_values"] = pixel_values
 
 
 
 
354
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  return model_inputs
moe_lm.py CHANGED
@@ -146,7 +146,7 @@ def switch_load_balancing_loss_func(
146
  topk: int,
147
  moe_aux_loss_coeff: float,
148
  ):
149
- """Calculate the auxiliary loss for better load balacing.
150
  Please refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details.
151
 
152
  Args:
 
146
  topk: int,
147
  moe_aux_loss_coeff: float,
148
  ):
149
+ """Calculate the auxiliary loss for better load balancing.
150
  Please refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details.
151
 
152
  Args:
processing_aria.py CHANGED
@@ -94,6 +94,7 @@ class AriaProcessor(ProcessorMixin):
94
  max_image_size: Optional[int] = 980,
95
  split_image: Optional[bool] = False,
96
  return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
 
97
  ) -> BatchFeature:
98
  """
99
  Main method to prepare for the model one or several sequences(s) and image(s). Please refer to the doctsring
@@ -168,6 +169,24 @@ class AriaProcessor(ProcessorMixin):
168
  )
169
  )
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  else:
172
  image_inputs = {}
173
  prompt_strings = text
@@ -180,7 +199,10 @@ class AriaProcessor(ProcessorMixin):
180
  max_length=max_length,
181
  )
182
 
183
- return BatchFeature(data={**text_inputs, **image_inputs})
 
 
 
184
 
185
  @staticmethod
186
  def _extract_kwargs(func: callable, **kwargs) -> dict:
 
94
  max_image_size: Optional[int] = 980,
95
  split_image: Optional[bool] = False,
96
  return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
97
+ return_final_prompts: Optional[bool] = False,
98
  ) -> BatchFeature:
99
  """
100
  Main method to prepare for the model one or several sequences(s) and image(s). Please refer to the doctsring
 
169
  )
170
  )
171
 
172
+ max_image_size = (
173
+ max_image_size
174
+ if max_image_size is not None
175
+ else self.image_processor.max_image_size
176
+ )
177
+ if max_image_size == 490:
178
+ num_image_tokens = 128
179
+ elif max_image_size == 980:
180
+ num_image_tokens = 256
181
+ else:
182
+ raise ValueError(
183
+ f"max_image_size must be either 490 or 980, got {max_image_size}"
184
+ )
185
+ prompt_strings = [
186
+ sample.replace(self.image_token, self.image_token * num_image_tokens)
187
+ for sample in prompt_strings
188
+ ]
189
+
190
  else:
191
  image_inputs = {}
192
  prompt_strings = text
 
199
  max_length=max_length,
200
  )
201
 
202
+ if return_final_prompts:
203
+ return BatchFeature(data={**text_inputs, **image_inputs}), prompt_strings
204
+ else:
205
+ return BatchFeature(data={**text_inputs, **image_inputs})
206
 
207
  @staticmethod
208
  def _extract_kwargs(func: callable, **kwargs) -> dict:
vision_processor.py CHANGED
@@ -45,7 +45,7 @@ def _select_best_resolution(
45
  aspect_ratio = img_width / img_height
46
  best_ratio_diff = float("inf")
47
  best_ratio_w, best_ratio_h = 1, 1
48
- area = np.int32(img_height) * np.int32(img_height)
49
  for ratio in target_ratios:
50
  target_aspect_ratio = ratio[0] / ratio[1]
51
  ratio_diff = abs(aspect_ratio - target_aspect_ratio)
 
45
  aspect_ratio = img_width / img_height
46
  best_ratio_diff = float("inf")
47
  best_ratio_w, best_ratio_h = 1, 1
48
+ area = np.int32(img_width) * np.int32(img_height)
49
  for ratio in target_ratios:
50
  target_aspect_ratio = ratio[0] / ratio[1]
51
  ratio_diff = abs(aspect_ratio - target_aspect_ratio)