merge
Browse files- config.json +1 -1
- model.safetensors +0 -3
- modeling_aria.py +25 -135
config.json
CHANGED
@@ -10,7 +10,7 @@
|
|
10 |
"model_type": "aria",
|
11 |
"num_attention_heads": 24,
|
12 |
"num_hidden_layers": 16,
|
13 |
-
"torch_dtype": "
|
14 |
"transformers_version": "4.45.0",
|
15 |
"use_cache": true,
|
16 |
"vocab_size": 17727,
|
|
|
10 |
"model_type": "aria",
|
11 |
"num_attention_heads": 24,
|
12 |
"num_hidden_layers": 16,
|
13 |
+
"torch_dtype": "float32",
|
14 |
"transformers_version": "4.45.0",
|
15 |
"use_cache": true,
|
16 |
"vocab_size": 17727,
|
model.safetensors
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:9057480d90c91e0b9000f365ceafcbd7e21cd1940dc4bb25f1bd328cbe26c28f
|
3 |
-
size 2634170640
|
|
|
|
|
|
|
|
modeling_aria.py
CHANGED
@@ -180,13 +180,11 @@ class TransformerBlock(nn.Module):
|
|
180 |
xk, xv, self.layer_idx, cache_kwargs
|
181 |
)
|
182 |
|
183 |
-
# scaled_dot_product_attention expects: (b_sz, n_head, s_len, d_head)
|
184 |
att = F.scaled_dot_product_attention(
|
185 |
query=xq,
|
186 |
key=xk,
|
187 |
value=xv,
|
188 |
-
attn_mask=attention_mask,
|
189 |
-
# is_causal=True,
|
190 |
)
|
191 |
|
192 |
# Reshape for out: (b_sz, s_len, n_head, d_head)
|
@@ -215,6 +213,7 @@ class AriaModel(AriaPreTrainedModel):
|
|
215 |
super().__init__(model_config)
|
216 |
self.model_config = model_config
|
217 |
self.freqs_cis = None
|
|
|
218 |
|
219 |
self.tok_embeddings = nn.Embedding(
|
220 |
num_embeddings=model_config.vocab_size,
|
@@ -341,13 +340,10 @@ class AriaModel(AriaPreTrainedModel):
|
|
341 |
position_ids = cache_position.unsqueeze(0)
|
342 |
hidden_states = inputs_embeds
|
343 |
|
344 |
-
causal_mask
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
past_key_values,
|
349 |
-
output_attentions,
|
350 |
-
)
|
351 |
|
352 |
if self.freqs_cis is None:
|
353 |
self.freqs_cis = precompute_freqs_cis(
|
@@ -360,6 +356,19 @@ class AriaModel(AriaPreTrainedModel):
|
|
360 |
|
361 |
freqs_cis = self.freqs_cis[cache_position]
|
362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
kwargs = {
|
364 |
"position_ids": position_ids,
|
365 |
"past_key_values": past_key_values,
|
@@ -432,130 +441,6 @@ class AriaModel(AriaPreTrainedModel):
|
|
432 |
attentions=all_attentions,
|
433 |
)
|
434 |
|
435 |
-
def _update_causal_mask(
|
436 |
-
self,
|
437 |
-
attention_mask: torch.Tensor,
|
438 |
-
input_tensor: torch.Tensor,
|
439 |
-
cache_position: torch.Tensor,
|
440 |
-
past_key_values: Cache,
|
441 |
-
output_attentions: bool,
|
442 |
-
):
|
443 |
-
if self.model_config._attn_implementation == "flash_attention_2":
|
444 |
-
if attention_mask is not None and (attention_mask == 0.0).any():
|
445 |
-
return attention_mask
|
446 |
-
return None
|
447 |
-
|
448 |
-
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
449 |
-
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
450 |
-
# to infer the attention mask.
|
451 |
-
past_seen_tokens = (
|
452 |
-
past_key_values.get_seq_length()
|
453 |
-
if past_key_values is not None
|
454 |
-
else 0
|
455 |
-
)
|
456 |
-
using_static_cache = isinstance(past_key_values, StaticCache)
|
457 |
-
|
458 |
-
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
459 |
-
if (
|
460 |
-
self.model_config._attn_implementation == "sdpa"
|
461 |
-
and not using_static_cache
|
462 |
-
and not output_attentions
|
463 |
-
):
|
464 |
-
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
465 |
-
attention_mask,
|
466 |
-
inputs_embeds=input_tensor,
|
467 |
-
past_key_values_length=past_seen_tokens,
|
468 |
-
is_training=self.training,
|
469 |
-
):
|
470 |
-
return None
|
471 |
-
|
472 |
-
dtype, device = input_tensor.dtype, input_tensor.device
|
473 |
-
sequence_length = input_tensor.shape[1]
|
474 |
-
if using_static_cache:
|
475 |
-
target_length = past_key_values.get_max_cache_shape()
|
476 |
-
else:
|
477 |
-
target_length = (
|
478 |
-
attention_mask.shape[-1]
|
479 |
-
if isinstance(attention_mask, torch.Tensor)
|
480 |
-
else past_seen_tokens + sequence_length + 1
|
481 |
-
)
|
482 |
-
|
483 |
-
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
484 |
-
causal_mask = (
|
485 |
-
self._prepare_4d_causal_attention_mask_with_cache_position(
|
486 |
-
attention_mask,
|
487 |
-
sequence_length=sequence_length,
|
488 |
-
target_length=target_length,
|
489 |
-
dtype=dtype,
|
490 |
-
device=device,
|
491 |
-
cache_position=cache_position,
|
492 |
-
batch_size=input_tensor.shape[0],
|
493 |
-
)
|
494 |
-
)
|
495 |
-
|
496 |
-
if (
|
497 |
-
self.model_config._attn_implementation == "sdpa"
|
498 |
-
and attention_mask is not None
|
499 |
-
and attention_mask.device.type == "cuda"
|
500 |
-
and not output_attentions
|
501 |
-
):
|
502 |
-
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
503 |
-
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
504 |
-
# Details: https://github.com/pytorch/pytorch/issues/110213
|
505 |
-
min_dtype = torch.finfo(dtype).min
|
506 |
-
causal_mask = AttentionMaskConverter._unmask_unattended(
|
507 |
-
causal_mask, min_dtype
|
508 |
-
)
|
509 |
-
|
510 |
-
return causal_mask
|
511 |
-
|
512 |
-
@staticmethod
|
513 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
|
514 |
-
def _prepare_4d_causal_attention_mask_with_cache_position(
|
515 |
-
attention_mask: torch.Tensor,
|
516 |
-
sequence_length: int,
|
517 |
-
target_length: int,
|
518 |
-
dtype: torch.dtype,
|
519 |
-
device: torch.device,
|
520 |
-
cache_position: torch.Tensor,
|
521 |
-
batch_size: int,
|
522 |
-
**kwargs,
|
523 |
-
):
|
524 |
-
if attention_mask is not None and attention_mask.dim() == 4:
|
525 |
-
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
526 |
-
causal_mask = attention_mask
|
527 |
-
else:
|
528 |
-
min_dtype = torch.finfo(dtype).min
|
529 |
-
causal_mask = torch.full(
|
530 |
-
(sequence_length, target_length),
|
531 |
-
fill_value=min_dtype,
|
532 |
-
dtype=dtype,
|
533 |
-
device=device,
|
534 |
-
)
|
535 |
-
if sequence_length != 1:
|
536 |
-
causal_mask = torch.triu(causal_mask, diagonal=1)
|
537 |
-
causal_mask *= torch.arange(
|
538 |
-
target_length, device=device
|
539 |
-
) > cache_position.reshape(-1, 1)
|
540 |
-
causal_mask = causal_mask[None, None, :, :].expand(
|
541 |
-
batch_size, 1, -1, -1
|
542 |
-
)
|
543 |
-
if attention_mask is not None:
|
544 |
-
causal_mask = (
|
545 |
-
causal_mask.clone()
|
546 |
-
) # copy to contiguous memory for in-place edit
|
547 |
-
mask_length = attention_mask.shape[-1]
|
548 |
-
padding_mask = (
|
549 |
-
causal_mask[:, :, :, :mask_length]
|
550 |
-
+ attention_mask[:, None, None, :]
|
551 |
-
)
|
552 |
-
padding_mask = padding_mask == 0
|
553 |
-
causal_mask[:, :, :, :mask_length] = causal_mask[
|
554 |
-
:, :, :, :mask_length
|
555 |
-
].masked_fill(padding_mask, min_dtype)
|
556 |
-
|
557 |
-
return causal_mask
|
558 |
-
|
559 |
|
560 |
class AriaForCausalLM(AriaPreTrainedModel, GenerationMixin):
|
561 |
"""Transformer decoder with head for language modelling.
|
@@ -732,6 +617,12 @@ class AriaForSequenceEmbedding(AriaPreTrainedModel):
|
|
732 |
)
|
733 |
|
734 |
|
|
|
|
|
|
|
|
|
|
|
|
|
735 |
def precompute_freqs_cis(
|
736 |
seq_len: int,
|
737 |
n_elem: int,
|
@@ -749,7 +640,6 @@ def precompute_freqs_cis(
|
|
749 |
return cache.to(dtype=dtype)
|
750 |
|
751 |
|
752 |
-
@torch.jit.script
|
753 |
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
754 |
"""
|
755 |
In-place RoPE. Credits to Katherine Crowson:
|
|
|
180 |
xk, xv, self.layer_idx, cache_kwargs
|
181 |
)
|
182 |
|
|
|
183 |
att = F.scaled_dot_product_attention(
|
184 |
query=xq,
|
185 |
key=xk,
|
186 |
value=xv,
|
187 |
+
attn_mask=attention_mask[..., : xk.shape[2]],
|
|
|
188 |
)
|
189 |
|
190 |
# Reshape for out: (b_sz, s_len, n_head, d_head)
|
|
|
213 |
super().__init__(model_config)
|
214 |
self.model_config = model_config
|
215 |
self.freqs_cis = None
|
216 |
+
self.causal_mask = None
|
217 |
|
218 |
self.tok_embeddings = nn.Embedding(
|
219 |
num_embeddings=model_config.vocab_size,
|
|
|
340 |
position_ids = cache_position.unsqueeze(0)
|
341 |
hidden_states = inputs_embeds
|
342 |
|
343 |
+
if self.causal_mask is None:
|
344 |
+
self.causal_mask = precompute_causal_mask(
|
345 |
+
max_seq_len=self.model_config.max_seq_len,
|
346 |
+
).to(input_ids.device)
|
|
|
|
|
|
|
347 |
|
348 |
if self.freqs_cis is None:
|
349 |
self.freqs_cis = precompute_freqs_cis(
|
|
|
356 |
|
357 |
freqs_cis = self.freqs_cis[cache_position]
|
358 |
|
359 |
+
if use_cache is True:
|
360 |
+
causal_mask = self.causal_mask[None, None, cache_position]
|
361 |
+
else:
|
362 |
+
causal_mask = self.causal_mask[None, None, :seq_length, :seq_length]
|
363 |
+
|
364 |
+
if attention_mask is not None:
|
365 |
+
pad_len = causal_mask.shape[3] - attention_mask.shape[1]
|
366 |
+
padded_attention_mask = F.pad(attention_mask, (0, pad_len), value=1)
|
367 |
+
padded_attention_mask = padded_attention_mask[:, None, None, :]
|
368 |
+
padded_attention_mask = padded_attention_mask.bool()
|
369 |
+
|
370 |
+
causal_mask = causal_mask & padded_attention_mask
|
371 |
+
|
372 |
kwargs = {
|
373 |
"position_ids": position_ids,
|
374 |
"past_key_values": past_key_values,
|
|
|
441 |
attentions=all_attentions,
|
442 |
)
|
443 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
|
445 |
class AriaForCausalLM(AriaPreTrainedModel, GenerationMixin):
|
446 |
"""Transformer decoder with head for language modelling.
|
|
|
617 |
)
|
618 |
|
619 |
|
620 |
+
def precompute_causal_mask(max_seq_len: int):
|
621 |
+
return torch.tril(
|
622 |
+
torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
|
623 |
+
).cuda()
|
624 |
+
|
625 |
+
|
626 |
def precompute_freqs_cis(
|
627 |
seq_len: int,
|
628 |
n_elem: int,
|
|
|
640 |
return cache.to(dtype=dtype)
|
641 |
|
642 |
|
|
|
643 |
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
644 |
"""
|
645 |
In-place RoPE. Credits to Katherine Crowson:
|