aria-dev commited on
Commit
6be5a01
1 Parent(s): 6a3f5b6

fix: add pixel_mask support to generation

Browse files
Files changed (1) hide show
  1. modeling_aria.py +2 -0
modeling_aria.py CHANGED
@@ -332,6 +332,7 @@ class AriaForConditionalGeneration(AriaPretrainedModel, GenerationMixin):
332
  past_key_values=None,
333
  inputs_embeds=None,
334
  pixel_values=None,
 
335
  attention_mask=None,
336
  cache_position=None,
337
  num_logits_to_keep=None,
@@ -351,5 +352,6 @@ class AriaForConditionalGeneration(AriaPretrainedModel, GenerationMixin):
351
  # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
352
  # Otherwise we need pixel values to be passed to model
353
  model_inputs["pixel_values"] = pixel_values
 
354
 
355
  return model_inputs
 
332
  past_key_values=None,
333
  inputs_embeds=None,
334
  pixel_values=None,
335
+ pixel_mask=None,
336
  attention_mask=None,
337
  cache_position=None,
338
  num_logits_to_keep=None,
 
352
  # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
353
  # Otherwise we need pixel values to be passed to model
354
  model_inputs["pixel_values"] = pixel_values
355
+ model_inputs["pixel_mask"] = pixel_mask
356
 
357
  return model_inputs