loua19 commited on
Commit
0c2eef3
·
1 Parent(s): 42d10df
Files changed (3) hide show
  1. config.json +1 -1
  2. model.safetensors +0 -3
  3. modeling_aria.py +25 -135
config.json CHANGED
@@ -10,7 +10,7 @@
10
  "model_type": "aria",
11
  "num_attention_heads": 24,
12
  "num_hidden_layers": 16,
13
- "torch_dtype": "bfloat16",
14
  "transformers_version": "4.45.0",
15
  "use_cache": true,
16
  "vocab_size": 17727,
 
10
  "model_type": "aria",
11
  "num_attention_heads": 24,
12
  "num_hidden_layers": 16,
13
+ "torch_dtype": "float32",
14
  "transformers_version": "4.45.0",
15
  "use_cache": true,
16
  "vocab_size": 17727,
model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9057480d90c91e0b9000f365ceafcbd7e21cd1940dc4bb25f1bd328cbe26c28f
3
- size 2634170640
 
 
 
 
modeling_aria.py CHANGED
@@ -180,13 +180,11 @@ class TransformerBlock(nn.Module):
180
  xk, xv, self.layer_idx, cache_kwargs
181
  )
182
 
183
- # scaled_dot_product_attention expects: (b_sz, n_head, s_len, d_head)
184
  att = F.scaled_dot_product_attention(
185
  query=xq,
186
  key=xk,
187
  value=xv,
188
- attn_mask=attention_mask,
189
- # is_causal=True,
190
  )
191
 
192
  # Reshape for out: (b_sz, s_len, n_head, d_head)
@@ -215,6 +213,7 @@ class AriaModel(AriaPreTrainedModel):
215
  super().__init__(model_config)
216
  self.model_config = model_config
217
  self.freqs_cis = None
 
218
 
219
  self.tok_embeddings = nn.Embedding(
220
  num_embeddings=model_config.vocab_size,
@@ -341,13 +340,10 @@ class AriaModel(AriaPreTrainedModel):
341
  position_ids = cache_position.unsqueeze(0)
342
  hidden_states = inputs_embeds
343
 
344
- causal_mask = self._update_causal_mask(
345
- attention_mask,
346
- inputs_embeds,
347
- cache_position,
348
- past_key_values,
349
- output_attentions,
350
- )
351
 
352
  if self.freqs_cis is None:
353
  self.freqs_cis = precompute_freqs_cis(
@@ -360,6 +356,19 @@ class AriaModel(AriaPreTrainedModel):
360
 
361
  freqs_cis = self.freqs_cis[cache_position]
362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  kwargs = {
364
  "position_ids": position_ids,
365
  "past_key_values": past_key_values,
@@ -432,130 +441,6 @@ class AriaModel(AriaPreTrainedModel):
432
  attentions=all_attentions,
433
  )
434
 
435
- def _update_causal_mask(
436
- self,
437
- attention_mask: torch.Tensor,
438
- input_tensor: torch.Tensor,
439
- cache_position: torch.Tensor,
440
- past_key_values: Cache,
441
- output_attentions: bool,
442
- ):
443
- if self.model_config._attn_implementation == "flash_attention_2":
444
- if attention_mask is not None and (attention_mask == 0.0).any():
445
- return attention_mask
446
- return None
447
-
448
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
449
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
450
- # to infer the attention mask.
451
- past_seen_tokens = (
452
- past_key_values.get_seq_length()
453
- if past_key_values is not None
454
- else 0
455
- )
456
- using_static_cache = isinstance(past_key_values, StaticCache)
457
-
458
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
459
- if (
460
- self.model_config._attn_implementation == "sdpa"
461
- and not using_static_cache
462
- and not output_attentions
463
- ):
464
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
465
- attention_mask,
466
- inputs_embeds=input_tensor,
467
- past_key_values_length=past_seen_tokens,
468
- is_training=self.training,
469
- ):
470
- return None
471
-
472
- dtype, device = input_tensor.dtype, input_tensor.device
473
- sequence_length = input_tensor.shape[1]
474
- if using_static_cache:
475
- target_length = past_key_values.get_max_cache_shape()
476
- else:
477
- target_length = (
478
- attention_mask.shape[-1]
479
- if isinstance(attention_mask, torch.Tensor)
480
- else past_seen_tokens + sequence_length + 1
481
- )
482
-
483
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
484
- causal_mask = (
485
- self._prepare_4d_causal_attention_mask_with_cache_position(
486
- attention_mask,
487
- sequence_length=sequence_length,
488
- target_length=target_length,
489
- dtype=dtype,
490
- device=device,
491
- cache_position=cache_position,
492
- batch_size=input_tensor.shape[0],
493
- )
494
- )
495
-
496
- if (
497
- self.model_config._attn_implementation == "sdpa"
498
- and attention_mask is not None
499
- and attention_mask.device.type == "cuda"
500
- and not output_attentions
501
- ):
502
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
503
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
504
- # Details: https://github.com/pytorch/pytorch/issues/110213
505
- min_dtype = torch.finfo(dtype).min
506
- causal_mask = AttentionMaskConverter._unmask_unattended(
507
- causal_mask, min_dtype
508
- )
509
-
510
- return causal_mask
511
-
512
- @staticmethod
513
- # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
514
- def _prepare_4d_causal_attention_mask_with_cache_position(
515
- attention_mask: torch.Tensor,
516
- sequence_length: int,
517
- target_length: int,
518
- dtype: torch.dtype,
519
- device: torch.device,
520
- cache_position: torch.Tensor,
521
- batch_size: int,
522
- **kwargs,
523
- ):
524
- if attention_mask is not None and attention_mask.dim() == 4:
525
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
526
- causal_mask = attention_mask
527
- else:
528
- min_dtype = torch.finfo(dtype).min
529
- causal_mask = torch.full(
530
- (sequence_length, target_length),
531
- fill_value=min_dtype,
532
- dtype=dtype,
533
- device=device,
534
- )
535
- if sequence_length != 1:
536
- causal_mask = torch.triu(causal_mask, diagonal=1)
537
- causal_mask *= torch.arange(
538
- target_length, device=device
539
- ) > cache_position.reshape(-1, 1)
540
- causal_mask = causal_mask[None, None, :, :].expand(
541
- batch_size, 1, -1, -1
542
- )
543
- if attention_mask is not None:
544
- causal_mask = (
545
- causal_mask.clone()
546
- ) # copy to contiguous memory for in-place edit
547
- mask_length = attention_mask.shape[-1]
548
- padding_mask = (
549
- causal_mask[:, :, :, :mask_length]
550
- + attention_mask[:, None, None, :]
551
- )
552
- padding_mask = padding_mask == 0
553
- causal_mask[:, :, :, :mask_length] = causal_mask[
554
- :, :, :, :mask_length
555
- ].masked_fill(padding_mask, min_dtype)
556
-
557
- return causal_mask
558
-
559
 
560
  class AriaForCausalLM(AriaPreTrainedModel, GenerationMixin):
561
  """Transformer decoder with head for language modelling.
@@ -732,6 +617,12 @@ class AriaForSequenceEmbedding(AriaPreTrainedModel):
732
  )
733
 
734
 
 
 
 
 
 
 
735
  def precompute_freqs_cis(
736
  seq_len: int,
737
  n_elem: int,
@@ -749,7 +640,6 @@ def precompute_freqs_cis(
749
  return cache.to(dtype=dtype)
750
 
751
 
752
- @torch.jit.script
753
  def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
754
  """
755
  In-place RoPE. Credits to Katherine Crowson:
 
180
  xk, xv, self.layer_idx, cache_kwargs
181
  )
182
 
 
183
  att = F.scaled_dot_product_attention(
184
  query=xq,
185
  key=xk,
186
  value=xv,
187
+ attn_mask=attention_mask[..., : xk.shape[2]],
 
188
  )
189
 
190
  # Reshape for out: (b_sz, s_len, n_head, d_head)
 
213
  super().__init__(model_config)
214
  self.model_config = model_config
215
  self.freqs_cis = None
216
+ self.causal_mask = None
217
 
218
  self.tok_embeddings = nn.Embedding(
219
  num_embeddings=model_config.vocab_size,
 
340
  position_ids = cache_position.unsqueeze(0)
341
  hidden_states = inputs_embeds
342
 
343
+ if self.causal_mask is None:
344
+ self.causal_mask = precompute_causal_mask(
345
+ max_seq_len=self.model_config.max_seq_len,
346
+ ).to(input_ids.device)
 
 
 
347
 
348
  if self.freqs_cis is None:
349
  self.freqs_cis = precompute_freqs_cis(
 
356
 
357
  freqs_cis = self.freqs_cis[cache_position]
358
 
359
+ if use_cache is True:
360
+ causal_mask = self.causal_mask[None, None, cache_position]
361
+ else:
362
+ causal_mask = self.causal_mask[None, None, :seq_length, :seq_length]
363
+
364
+ if attention_mask is not None:
365
+ pad_len = causal_mask.shape[3] - attention_mask.shape[1]
366
+ padded_attention_mask = F.pad(attention_mask, (0, pad_len), value=1)
367
+ padded_attention_mask = padded_attention_mask[:, None, None, :]
368
+ padded_attention_mask = padded_attention_mask.bool()
369
+
370
+ causal_mask = causal_mask & padded_attention_mask
371
+
372
  kwargs = {
373
  "position_ids": position_ids,
374
  "past_key_values": past_key_values,
 
441
  attentions=all_attentions,
442
  )
443
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
  class AriaForCausalLM(AriaPreTrainedModel, GenerationMixin):
446
  """Transformer decoder with head for language modelling.
 
617
  )
618
 
619
 
620
+ def precompute_causal_mask(max_seq_len: int):
621
+ return torch.tril(
622
+ torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)
623
+ ).cuda()
624
+
625
+
626
  def precompute_freqs_cis(
627
  seq_len: int,
628
  n_elem: int,
 
640
  return cache.to(dtype=dtype)
641
 
642
 
 
643
  def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
644
  """
645
  In-place RoPE. Credits to Katherine Crowson: