Add task_prefix_attention_mask argument to _merge_input_ids_with_image_features for better padding handling

#66

This PR introduces a small change in the _merge_input_ids_with_image_features function by adding a task_prefix_attention_mask=None argument. This enhancement ensures that when doing batch processing with padding to the max length, the attention mask correctly ignores padding tokens.

Changes Made:

  1. Added task_prefix_attention_mask=None argument to _merge_input_ids_with_image_features function.
  2. Updated the function to incorporate the provided attention mask, allowing it to ignore padding tokens during batch processing.

Below is an example demonstrating the issue and the improvement:

prompts =["prompt", "longer prompt", "much much longer prompt"]

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"

image = Image.open(requests.get(url, stream=True).raw)
images = [image] * len(prompts)

inputs = processor(text=prompts, images=images, return_tensors="pt", padding=True).to("cuda", torch.float16)

inputs_embeds = model.get_input_embeddings()(inputs.input_ids)
image_features = model._encode_image(inputs.pixel_values)

print(inputs.input_ids)
# Output:
# tensor([[    0, 12501,  3320,     2,     1,     1],
#         [    0,  3479,   254, 14302,     2,     1],
#         [    0, 28431,   203,  1181, 14302,     2]], device='cuda:0')

# Before change
inputs_embeds, attention_mask = model._merge_input_ids_with_image_features(image_features, inputs_embeds)
print(attention_mask[:, -10:])
# Output:
# tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
#         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
#         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')

# After change
inputs_embeds, attention_mask = model._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=inputs.attention_mask)
print(attention_mask[:, -10:])
# Output:
# tensor([[1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
#         [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
#         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')
Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment