sync changes from github
Browse files- configuration_aria.py +17 -5
- modeling_aria.py +53 -264
- moe_lm.py +1 -1
- processing_aria.py +23 -1
- vision_processor.py +1 -1
configuration_aria.py
CHANGED
@@ -17,11 +17,15 @@
|
|
17 |
# specific language governing permissions and limitations
|
18 |
# under the License.
|
19 |
|
|
|
|
|
20 |
from transformers.configuration_utils import PretrainedConfig
|
21 |
|
22 |
from .moe_lm import AriaMoELMConfig
|
23 |
from .vision_encoder import AriaVisionConfig
|
24 |
|
|
|
|
|
25 |
|
26 |
# adapted from transformers.models.llava.configuration_llava.LlavaConfig
|
27 |
class AriaConfig(PretrainedConfig):
|
@@ -69,6 +73,7 @@ class AriaConfig(PretrainedConfig):
|
|
69 |
self.image_token_index = image_token_index
|
70 |
|
71 |
attn_implementation = kwargs.pop("attn_implementation", None)
|
|
|
72 |
|
73 |
# Convert the keys and values of projector_patch_to_query_dict to integers
|
74 |
# This ensures consistency even if they were provided as strings
|
@@ -78,11 +83,15 @@ class AriaConfig(PretrainedConfig):
|
|
78 |
|
79 |
if isinstance(vision_config, dict) and "model_type" in vision_config:
|
80 |
vision_config = AriaVisionConfig(**vision_config)
|
81 |
-
|
82 |
-
"flash_attention_2"
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
86 |
vision_config._attn_implementation = vision_attn_implementation
|
87 |
|
88 |
self.vision_config = vision_config
|
@@ -95,3 +104,6 @@ class AriaConfig(PretrainedConfig):
|
|
95 |
text_config._attn_implementation = text_attn_implementation
|
96 |
|
97 |
self.text_config = text_config
|
|
|
|
|
|
|
|
17 |
# specific language governing permissions and limitations
|
18 |
# under the License.
|
19 |
|
20 |
+
import logging
|
21 |
+
|
22 |
from transformers.configuration_utils import PretrainedConfig
|
23 |
|
24 |
from .moe_lm import AriaMoELMConfig
|
25 |
from .vision_encoder import AriaVisionConfig
|
26 |
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
|
30 |
# adapted from transformers.models.llava.configuration_llava.LlavaConfig
|
31 |
class AriaConfig(PretrainedConfig):
|
|
|
73 |
self.image_token_index = image_token_index
|
74 |
|
75 |
attn_implementation = kwargs.pop("attn_implementation", None)
|
76 |
+
self._attn_implementation = attn_implementation
|
77 |
|
78 |
# Convert the keys and values of projector_patch_to_query_dict to integers
|
79 |
# This ensures consistency even if they were provided as strings
|
|
|
83 |
|
84 |
if isinstance(vision_config, dict) and "model_type" in vision_config:
|
85 |
vision_config = AriaVisionConfig(**vision_config)
|
86 |
+
if attn_implementation is None:
|
87 |
+
vision_attn_implementation = "flash_attention_2"
|
88 |
+
elif attn_implementation == "sdpa":
|
89 |
+
logger.warning(
|
90 |
+
"SDPA is not supported for vit, using flash_attention_2 instead"
|
91 |
+
)
|
92 |
+
vision_attn_implementation = "flash_attention_2"
|
93 |
+
else:
|
94 |
+
vision_attn_implementation = attn_implementation
|
95 |
vision_config._attn_implementation = vision_attn_implementation
|
96 |
|
97 |
self.vision_config = vision_config
|
|
|
104 |
text_config._attn_implementation = text_attn_implementation
|
105 |
|
106 |
self.text_config = text_config
|
107 |
+
|
108 |
+
# This is needed for the static kv cache
|
109 |
+
self.num_hidden_layers = self.text_config.num_hidden_layers
|
modeling_aria.py
CHANGED
@@ -24,7 +24,6 @@ import torch
|
|
24 |
import torch.nn as nn
|
25 |
from torch import nn
|
26 |
from transformers import PreTrainedModel
|
27 |
-
from transformers.cache_utils import Cache
|
28 |
from transformers.modeling_outputs import ModelOutput
|
29 |
from transformers.utils import logging
|
30 |
|
@@ -48,6 +47,7 @@ class AriaPretrainedModel(PreTrainedModel):
|
|
48 |
_skip_keys_device_placement = "past_key_values"
|
49 |
_supports_flash_attn_2 = True
|
50 |
_supports_cache_class = True
|
|
|
51 |
|
52 |
@property
|
53 |
def _supports_sdpa(self):
|
@@ -183,138 +183,6 @@ class AriaForConditionalGeneration(AriaPretrainedModel):
|
|
183 |
"""
|
184 |
self.language_model.set_aux_loss_coeff(value)
|
185 |
|
186 |
-
# copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration
|
187 |
-
def _merge_input_ids_with_image_features(
|
188 |
-
self, image_features, inputs_embeds, input_ids, attention_mask, labels
|
189 |
-
):
|
190 |
-
"""
|
191 |
-
Merge input IDs with image features to create a combined input representation.
|
192 |
-
|
193 |
-
This method handles the complex logic of interleaving text and image tokens,
|
194 |
-
adjusting attention masks and labels accordingly.
|
195 |
-
|
196 |
-
Args:
|
197 |
-
image_features (torch.Tensor): Processed image features.
|
198 |
-
inputs_embeds (torch.Tensor): Text input embeddings.
|
199 |
-
input_ids (torch.Tensor): Input token IDs.
|
200 |
-
attention_mask (torch.Tensor): Attention mask for input tokens.
|
201 |
-
labels (torch.Tensor, optional): Labels for language modeling.
|
202 |
-
|
203 |
-
Returns:
|
204 |
-
tuple: Contains the merged embeddings, updated attention mask,
|
205 |
-
updated labels, and position IDs.
|
206 |
-
"""
|
207 |
-
num_images, num_image_patches, embed_dim = image_features.shape
|
208 |
-
batch_size, sequence_length = input_ids.shape
|
209 |
-
left_padding = not torch.sum(
|
210 |
-
input_ids[:, -1] == torch.tensor(self.pad_token_id)
|
211 |
-
)
|
212 |
-
# 1. Create a mask to know where special image tokens are
|
213 |
-
special_image_token_mask = input_ids == self.config.image_token_index
|
214 |
-
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
|
215 |
-
# Compute the maximum embed dimension
|
216 |
-
max_embed_dim = (
|
217 |
-
num_special_image_tokens.max() * (num_image_patches - 1)
|
218 |
-
) + sequence_length
|
219 |
-
batch_indices, non_image_indices = torch.where(
|
220 |
-
input_ids != self.config.image_token_index
|
221 |
-
)
|
222 |
-
|
223 |
-
# 2. Compute the positions where text should be written
|
224 |
-
# Calculate new positions for text tokens in merged image-text sequence.
|
225 |
-
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
|
226 |
-
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
|
227 |
-
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
228 |
-
new_token_positions = (
|
229 |
-
torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1)
|
230 |
-
- 1
|
231 |
-
)
|
232 |
-
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
233 |
-
if left_padding:
|
234 |
-
new_token_positions += nb_image_pad[:, None] # offset for left padding
|
235 |
-
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
|
236 |
-
|
237 |
-
# 3. Create the full embedding, already padded to the maximum position
|
238 |
-
final_embedding = torch.zeros(
|
239 |
-
batch_size,
|
240 |
-
max_embed_dim,
|
241 |
-
embed_dim,
|
242 |
-
dtype=inputs_embeds.dtype,
|
243 |
-
device=inputs_embeds.device,
|
244 |
-
)
|
245 |
-
final_attention_mask = torch.zeros(
|
246 |
-
batch_size,
|
247 |
-
max_embed_dim,
|
248 |
-
dtype=attention_mask.dtype,
|
249 |
-
device=inputs_embeds.device,
|
250 |
-
)
|
251 |
-
if labels is not None:
|
252 |
-
final_labels = torch.full(
|
253 |
-
(batch_size, max_embed_dim),
|
254 |
-
self.config.ignore_index,
|
255 |
-
dtype=input_ids.dtype,
|
256 |
-
device=input_ids.device,
|
257 |
-
)
|
258 |
-
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
259 |
-
# set the corresponding tensors into their correct target device.
|
260 |
-
target_device = inputs_embeds.device
|
261 |
-
batch_indices, non_image_indices, text_to_overwrite = (
|
262 |
-
batch_indices.to(target_device),
|
263 |
-
non_image_indices.to(target_device),
|
264 |
-
text_to_overwrite.to(target_device),
|
265 |
-
)
|
266 |
-
attention_mask = attention_mask.to(target_device)
|
267 |
-
|
268 |
-
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
|
269 |
-
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
|
270 |
-
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
|
271 |
-
batch_indices, non_image_indices
|
272 |
-
]
|
273 |
-
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
|
274 |
-
batch_indices, non_image_indices
|
275 |
-
]
|
276 |
-
if labels is not None:
|
277 |
-
final_labels[batch_indices, text_to_overwrite] = labels[
|
278 |
-
batch_indices, non_image_indices
|
279 |
-
]
|
280 |
-
|
281 |
-
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
|
282 |
-
image_to_overwrite = torch.full(
|
283 |
-
(batch_size, max_embed_dim),
|
284 |
-
True,
|
285 |
-
dtype=torch.bool,
|
286 |
-
device=inputs_embeds.device,
|
287 |
-
)
|
288 |
-
image_to_overwrite[batch_indices, text_to_overwrite] = False
|
289 |
-
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[
|
290 |
-
:, None
|
291 |
-
].to(target_device)
|
292 |
-
|
293 |
-
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
|
294 |
-
raise ValueError(
|
295 |
-
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
|
296 |
-
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
|
297 |
-
)
|
298 |
-
|
299 |
-
final_embedding[image_to_overwrite] = (
|
300 |
-
image_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
301 |
-
)
|
302 |
-
final_attention_mask |= image_to_overwrite
|
303 |
-
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_(
|
304 |
-
(final_attention_mask == 0), 1
|
305 |
-
)
|
306 |
-
|
307 |
-
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
308 |
-
batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
|
309 |
-
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
310 |
-
|
311 |
-
final_embedding[batch_indices, indices_to_mask] = 0
|
312 |
-
|
313 |
-
if labels is None:
|
314 |
-
final_labels = None
|
315 |
-
|
316 |
-
return final_embedding, final_attention_mask, final_labels, position_ids
|
317 |
-
|
318 |
def forward(
|
319 |
self,
|
320 |
input_ids: torch.LongTensor = None,
|
@@ -329,6 +197,8 @@ class AriaForConditionalGeneration(AriaPretrainedModel):
|
|
329 |
output_attentions: Optional[bool] = None,
|
330 |
output_hidden_states: Optional[bool] = None,
|
331 |
return_dict: Optional[bool] = None,
|
|
|
|
|
332 |
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
|
333 |
"""
|
334 |
Forward pass of the AriaForConditionalGeneration model.
|
@@ -371,69 +241,38 @@ class AriaForConditionalGeneration(AriaPretrainedModel):
|
|
371 |
# 1. Extra the input embeddings
|
372 |
inputs_embeds = self.get_input_embeddings()(input_ids)
|
373 |
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
selected_image_feature = image_outputs.last_hidden_state
|
381 |
-
|
382 |
-
image_features = self.multi_modal_projector(
|
383 |
-
selected_image_feature, attn_mask=image_attn_mask
|
384 |
-
)
|
385 |
-
|
386 |
-
inputs_embeds = inputs_embeds.to(image_features.dtype)
|
387 |
-
(
|
388 |
-
inputs_embeds,
|
389 |
-
attention_mask,
|
390 |
-
labels,
|
391 |
-
position_ids,
|
392 |
-
) = self._merge_input_ids_with_image_features(
|
393 |
-
image_features, inputs_embeds, input_ids, attention_mask, labels
|
394 |
-
)
|
395 |
-
|
396 |
-
# In case input_ids.shape[1] == 1 & pixel_values != None & past_key_values != None, we are in the case of
|
397 |
-
# generation with cache
|
398 |
-
elif (
|
399 |
-
past_key_values is not None
|
400 |
-
and pixel_values is not None
|
401 |
-
and input_ids.shape[1] == 1
|
402 |
-
):
|
403 |
-
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
404 |
-
# that are set to 0
|
405 |
-
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
406 |
-
|
407 |
-
# Sum all dimensions of head_dim (-2) to avoid random errors
|
408 |
-
# such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
409 |
-
batch_index, non_attended_tokens = torch.where(
|
410 |
-
first_layer_past_key_value.float().sum(-2) == 0
|
411 |
-
)
|
412 |
-
|
413 |
-
# Get the target length
|
414 |
-
target_length = input_ids.shape[1]
|
415 |
-
past_length = first_layer_past_key_value.shape[-1]
|
416 |
-
|
417 |
-
extended_attention_mask = torch.ones(
|
418 |
-
(attention_mask.shape[0], past_length),
|
419 |
-
dtype=attention_mask.dtype,
|
420 |
-
device=attention_mask.device,
|
421 |
-
)
|
422 |
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
new_batch_index = batch_index[valid_indices]
|
428 |
-
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
429 |
|
430 |
-
|
431 |
-
|
|
|
432 |
|
433 |
-
|
434 |
-
|
|
|
435 |
)
|
436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
437 |
|
438 |
outputs = self.language_model(
|
439 |
attention_mask=attention_mask,
|
@@ -444,6 +283,8 @@ class AriaForConditionalGeneration(AriaPretrainedModel):
|
|
444 |
output_attentions=output_attentions,
|
445 |
output_hidden_states=output_hidden_states,
|
446 |
return_dict=return_dict,
|
|
|
|
|
447 |
)
|
448 |
|
449 |
logits = outputs[0]
|
@@ -452,7 +293,11 @@ class AriaForConditionalGeneration(AriaPretrainedModel):
|
|
452 |
if labels is not None:
|
453 |
# Shift so that tokens < n predict n
|
454 |
if attention_mask is not None:
|
455 |
-
|
|
|
|
|
|
|
|
|
456 |
shift_logits = logits[..., :-1, :][
|
457 |
shift_attention_mask.to(logits.device) != 0
|
458 |
].contiguous()
|
@@ -487,80 +332,24 @@ class AriaForConditionalGeneration(AriaPretrainedModel):
|
|
487 |
past_key_values=None,
|
488 |
inputs_embeds=None,
|
489 |
pixel_values=None,
|
490 |
-
pixel_mask=None,
|
491 |
attention_mask=None,
|
|
|
|
|
492 |
**kwargs,
|
493 |
):
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
|
|
|
|
|
|
|
|
499 |
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
pixel_values (torch.FloatTensor, optional): Pixel values of the images.
|
505 |
-
pixel_mask (torch.LongTensor, optional): Mask for the pixel values.
|
506 |
-
attention_mask (torch.Tensor, optional): Attention mask.
|
507 |
-
**kwargs: Additional keyword arguments.
|
508 |
|
509 |
-
Returns:
|
510 |
-
dict: A dictionary containing the prepared inputs for the generation step.
|
511 |
-
"""
|
512 |
-
if past_key_values is not None:
|
513 |
-
if isinstance(past_key_values, Cache):
|
514 |
-
cache_length = past_key_values.get_seq_length()
|
515 |
-
past_length = past_key_values.seen_tokens
|
516 |
-
else:
|
517 |
-
cache_length = past_length = past_key_values[0][0].shape[2]
|
518 |
-
|
519 |
-
# Keep only the unprocessed tokens:
|
520 |
-
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
521 |
-
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
522 |
-
# input)
|
523 |
-
if (
|
524 |
-
attention_mask is not None
|
525 |
-
and attention_mask.shape[1] > input_ids.shape[1]
|
526 |
-
):
|
527 |
-
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
528 |
-
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
529 |
-
# input_ids based on the past_length.
|
530 |
-
elif past_length < input_ids.shape[1]:
|
531 |
-
input_ids = input_ids[:, past_length:]
|
532 |
-
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
533 |
-
elif self.config.image_token_index in input_ids:
|
534 |
-
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
|
535 |
-
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
|
536 |
-
# older attention values, as their corresponding values are not part of the input.
|
537 |
-
if cache_length < past_length and attention_mask is not None:
|
538 |
-
attention_mask = attention_mask[
|
539 |
-
:, -(cache_length + input_ids.shape[1]) :
|
540 |
-
]
|
541 |
-
|
542 |
-
position_ids = kwargs.get("position_ids", None)
|
543 |
-
if attention_mask is not None and position_ids is None:
|
544 |
-
# create position_ids on the fly for batch generation
|
545 |
-
position_ids = attention_mask.long().cumsum(-1) - 1
|
546 |
-
position_ids.masked_fill_(attention_mask == 0, 1)
|
547 |
-
if past_key_values:
|
548 |
-
position_ids = position_ids[:, -input_ids.shape[1] :]
|
549 |
-
|
550 |
-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
551 |
-
if inputs_embeds is not None and past_key_values is None:
|
552 |
-
model_inputs = {"inputs_embeds": inputs_embeds}
|
553 |
-
else:
|
554 |
-
model_inputs = {"input_ids": input_ids}
|
555 |
-
|
556 |
-
model_inputs.update(
|
557 |
-
{
|
558 |
-
"position_ids": position_ids,
|
559 |
-
"past_key_values": past_key_values,
|
560 |
-
"use_cache": kwargs.get("use_cache"),
|
561 |
-
"attention_mask": attention_mask,
|
562 |
-
"pixel_values": pixel_values,
|
563 |
-
"pixel_mask": pixel_mask,
|
564 |
-
}
|
565 |
-
)
|
566 |
return model_inputs
|
|
|
24 |
import torch.nn as nn
|
25 |
from torch import nn
|
26 |
from transformers import PreTrainedModel
|
|
|
27 |
from transformers.modeling_outputs import ModelOutput
|
28 |
from transformers.utils import logging
|
29 |
|
|
|
47 |
_skip_keys_device_placement = "past_key_values"
|
48 |
_supports_flash_attn_2 = True
|
49 |
_supports_cache_class = True
|
50 |
+
_supports_static_cache = True
|
51 |
|
52 |
@property
|
53 |
def _supports_sdpa(self):
|
|
|
183 |
"""
|
184 |
self.language_model.set_aux_loss_coeff(value)
|
185 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
def forward(
|
187 |
self,
|
188 |
input_ids: torch.LongTensor = None,
|
|
|
197 |
output_attentions: Optional[bool] = None,
|
198 |
output_hidden_states: Optional[bool] = None,
|
199 |
return_dict: Optional[bool] = None,
|
200 |
+
cache_position: Optional[torch.LongTensor] = None,
|
201 |
+
num_logits_to_keep: int = 0,
|
202 |
) -> Union[Tuple, AriaCausalLMOutputWithPast]:
|
203 |
"""
|
204 |
Forward pass of the AriaForConditionalGeneration model.
|
|
|
241 |
# 1. Extra the input embeddings
|
242 |
inputs_embeds = self.get_input_embeddings()(input_ids)
|
243 |
|
244 |
+
image_features = None
|
245 |
+
if pixel_values is not None:
|
246 |
+
image_outputs, image_attn_mask = self.vision_tower(
|
247 |
+
pixel_values,
|
248 |
+
pixel_mask=pixel_mask,
|
249 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
|
251 |
+
selected_image_feature = image_outputs.last_hidden_state
|
252 |
+
image_features = self.multi_modal_projector(
|
253 |
+
selected_image_feature, attn_mask=image_attn_mask
|
254 |
+
)
|
|
|
|
|
255 |
|
256 |
+
if image_features is not None:
|
257 |
+
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
258 |
+
n_image_features = image_features.shape[0] * image_features.shape[1]
|
259 |
|
260 |
+
if n_image_tokens != n_image_features:
|
261 |
+
raise ValueError(
|
262 |
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
263 |
)
|
264 |
+
special_image_mask = (
|
265 |
+
(input_ids == self.config.image_token_index)
|
266 |
+
.unsqueeze(-1)
|
267 |
+
.expand_as(inputs_embeds)
|
268 |
+
.to(inputs_embeds.device)
|
269 |
+
)
|
270 |
+
image_features = image_features.to(
|
271 |
+
inputs_embeds.device, inputs_embeds.dtype
|
272 |
+
)
|
273 |
+
inputs_embeds = inputs_embeds.masked_scatter(
|
274 |
+
special_image_mask, image_features
|
275 |
+
)
|
276 |
|
277 |
outputs = self.language_model(
|
278 |
attention_mask=attention_mask,
|
|
|
283 |
output_attentions=output_attentions,
|
284 |
output_hidden_states=output_hidden_states,
|
285 |
return_dict=return_dict,
|
286 |
+
cache_position=cache_position,
|
287 |
+
num_logits_to_keep=num_logits_to_keep,
|
288 |
)
|
289 |
|
290 |
logits = outputs[0]
|
|
|
293 |
if labels is not None:
|
294 |
# Shift so that tokens < n predict n
|
295 |
if attention_mask is not None:
|
296 |
+
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
297 |
+
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
298 |
+
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
|
299 |
+
logits.device
|
300 |
+
)
|
301 |
shift_logits = logits[..., :-1, :][
|
302 |
shift_attention_mask.to(logits.device) != 0
|
303 |
].contiguous()
|
|
|
332 |
past_key_values=None,
|
333 |
inputs_embeds=None,
|
334 |
pixel_values=None,
|
|
|
335 |
attention_mask=None,
|
336 |
+
cache_position=None,
|
337 |
+
num_logits_to_keep=None,
|
338 |
**kwargs,
|
339 |
):
|
340 |
+
model_inputs = self.language_model.prepare_inputs_for_generation(
|
341 |
+
input_ids,
|
342 |
+
past_key_values=past_key_values,
|
343 |
+
inputs_embeds=inputs_embeds,
|
344 |
+
attention_mask=attention_mask,
|
345 |
+
cache_position=cache_position,
|
346 |
+
num_logits_to_keep=num_logits_to_keep,
|
347 |
+
**kwargs,
|
348 |
+
)
|
349 |
|
350 |
+
if cache_position[0] == 0:
|
351 |
+
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
352 |
+
# Otherwise we need pixel values to be passed to model
|
353 |
+
model_inputs["pixel_values"] = pixel_values
|
|
|
|
|
|
|
|
|
354 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
return model_inputs
|
moe_lm.py
CHANGED
@@ -146,7 +146,7 @@ def switch_load_balancing_loss_func(
|
|
146 |
topk: int,
|
147 |
moe_aux_loss_coeff: float,
|
148 |
):
|
149 |
-
"""Calculate the auxiliary loss for better load
|
150 |
Please refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details.
|
151 |
|
152 |
Args:
|
|
|
146 |
topk: int,
|
147 |
moe_aux_loss_coeff: float,
|
148 |
):
|
149 |
+
"""Calculate the auxiliary loss for better load balancing.
|
150 |
Please refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details.
|
151 |
|
152 |
Args:
|
processing_aria.py
CHANGED
@@ -94,6 +94,7 @@ class AriaProcessor(ProcessorMixin):
|
|
94 |
max_image_size: Optional[int] = 980,
|
95 |
split_image: Optional[bool] = False,
|
96 |
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
|
|
97 |
) -> BatchFeature:
|
98 |
"""
|
99 |
Main method to prepare for the model one or several sequences(s) and image(s). Please refer to the doctsring
|
@@ -168,6 +169,24 @@ class AriaProcessor(ProcessorMixin):
|
|
168 |
)
|
169 |
)
|
170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
else:
|
172 |
image_inputs = {}
|
173 |
prompt_strings = text
|
@@ -180,7 +199,10 @@ class AriaProcessor(ProcessorMixin):
|
|
180 |
max_length=max_length,
|
181 |
)
|
182 |
|
183 |
-
|
|
|
|
|
|
|
184 |
|
185 |
@staticmethod
|
186 |
def _extract_kwargs(func: callable, **kwargs) -> dict:
|
|
|
94 |
max_image_size: Optional[int] = 980,
|
95 |
split_image: Optional[bool] = False,
|
96 |
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
|
97 |
+
return_final_prompts: Optional[bool] = False,
|
98 |
) -> BatchFeature:
|
99 |
"""
|
100 |
Main method to prepare for the model one or several sequences(s) and image(s). Please refer to the doctsring
|
|
|
169 |
)
|
170 |
)
|
171 |
|
172 |
+
max_image_size = (
|
173 |
+
max_image_size
|
174 |
+
if max_image_size is not None
|
175 |
+
else self.image_processor.max_image_size
|
176 |
+
)
|
177 |
+
if max_image_size == 490:
|
178 |
+
num_image_tokens = 128
|
179 |
+
elif max_image_size == 980:
|
180 |
+
num_image_tokens = 256
|
181 |
+
else:
|
182 |
+
raise ValueError(
|
183 |
+
f"max_image_size must be either 490 or 980, got {max_image_size}"
|
184 |
+
)
|
185 |
+
prompt_strings = [
|
186 |
+
sample.replace(self.image_token, self.image_token * num_image_tokens)
|
187 |
+
for sample in prompt_strings
|
188 |
+
]
|
189 |
+
|
190 |
else:
|
191 |
image_inputs = {}
|
192 |
prompt_strings = text
|
|
|
199 |
max_length=max_length,
|
200 |
)
|
201 |
|
202 |
+
if return_final_prompts:
|
203 |
+
return BatchFeature(data={**text_inputs, **image_inputs}), prompt_strings
|
204 |
+
else:
|
205 |
+
return BatchFeature(data={**text_inputs, **image_inputs})
|
206 |
|
207 |
@staticmethod
|
208 |
def _extract_kwargs(func: callable, **kwargs) -> dict:
|
vision_processor.py
CHANGED
@@ -45,7 +45,7 @@ def _select_best_resolution(
|
|
45 |
aspect_ratio = img_width / img_height
|
46 |
best_ratio_diff = float("inf")
|
47 |
best_ratio_w, best_ratio_h = 1, 1
|
48 |
-
area = np.int32(
|
49 |
for ratio in target_ratios:
|
50 |
target_aspect_ratio = ratio[0] / ratio[1]
|
51 |
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
|
|
45 |
aspect_ratio = img_width / img_height
|
46 |
best_ratio_diff = float("inf")
|
47 |
best_ratio_w, best_ratio_h = 1, 1
|
48 |
+
area = np.int32(img_width) * np.int32(img_height)
|
49 |
for ratio in target_ratios:
|
50 |
target_aspect_ratio = ratio[0] / ratio[1]
|
51 |
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|