spuliz commited on
Commit
6c44c12
·
1 Parent(s): 61d3fe4
Files changed (1) hide show
  1. modeling_deepseek.py +271 -4
modeling_deepseek.py CHANGED
@@ -651,12 +651,280 @@ class DeepseekV3Attention(nn.Module):
651
  return attn_output, attn_weights, past_key_value
652
 
653
 
 
654
  class DeepseekV3FlashAttention2(DeepseekV3Attention):
655
  """
656
- Omitted for brevity - see original code above if you want flash attention integration
 
 
657
  """
658
- # Implementation remains the same as above...
659
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
660
 
661
 
662
  ATTENTION_CLASSES = {
@@ -664,7 +932,6 @@ ATTENTION_CLASSES = {
664
  "flash_attention_2": DeepseekV3FlashAttention2,
665
  }
666
 
667
-
668
  class DeepseekV3DecoderLayer(nn.Module):
669
  def __init__(self, config: DeepseekV3Config, layer_idx: int):
670
  super().__init__()
 
651
  return attn_output, attn_weights, past_key_value
652
 
653
 
654
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3
655
  class DeepseekV3FlashAttention2(DeepseekV3Attention):
656
  """
657
+ DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays
658
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
659
+ flash attention and deal with padding tokens in case the input contains any of them.
660
  """
661
+
662
+ def __init__(self, *args, **kwargs):
663
+ super().__init__(*args, **kwargs)
664
+
665
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
666
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
667
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
668
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
669
+
670
+ def forward(
671
+ self,
672
+ hidden_states: torch.Tensor,
673
+ attention_mask: Optional[torch.LongTensor] = None,
674
+ position_ids: Optional[torch.LongTensor] = None,
675
+ past_key_value: Optional[Cache] = None,
676
+ output_attentions: bool = False,
677
+ use_cache: bool = False,
678
+ **kwargs,
679
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
680
+ # DeepseekV3FlashAttention2 attention does not support output_attentions
681
+ if "padding_mask" in kwargs:
682
+ warnings.warn(
683
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
684
+ )
685
+
686
+ # overwrite attention_mask with padding_mask
687
+ attention_mask = kwargs.pop("padding_mask")
688
+
689
+ output_attentions = False
690
+
691
+ bsz, q_len, _ = hidden_states.size()
692
+
693
+ if self.q_lora_rank is None:
694
+ q = self.q_proj(hidden_states)
695
+ else:
696
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
697
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
698
+ q_nope, q_pe = torch.split(
699
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
700
+ )
701
+
702
+ # Flash attention requires the input to have the shape
703
+ # batch_size x seq_length x head_dim x hidden_dim
704
+ # therefore we just need to keep the original shape
705
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
706
+ compressed_kv, k_pe = torch.split(
707
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
708
+ )
709
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
710
+ kv = (
711
+ self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
712
+ .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
713
+ .transpose(1, 2)
714
+ )
715
+
716
+ k_nope, value_states = torch.split(
717
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
718
+ )
719
+ kv_seq_len = value_states.shape[-2]
720
+
721
+ kv_seq_len = value_states.shape[-2]
722
+ if past_key_value is not None:
723
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
724
+
725
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
726
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
727
+
728
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
729
+ query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
730
+ query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
731
+
732
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
733
+ key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
734
+ key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
735
+
736
+ if self.q_head_dim != self.v_head_dim:
737
+ value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
738
+
739
+ if past_key_value is not None:
740
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
741
+ key_states, value_states = past_key_value.update(
742
+ key_states, value_states, self.layer_idx, cache_kwargs
743
+ )
744
+
745
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
746
+ # to be able to avoid many of these transpose/reshape/view.
747
+ query_states = query_states.transpose(1, 2)
748
+ key_states = key_states.transpose(1, 2)
749
+ value_states = value_states.transpose(1, 2)
750
+
751
+ dropout_rate = self.attention_dropout if self.training else 0.0
752
+
753
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
754
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
755
+ # cast them back in the correct dtype just to be sure everything works as expected.
756
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
757
+ # in fp32. (DeepseekV3RMSNorm handles it correctly)
758
+
759
+ input_dtype = query_states.dtype
760
+ if input_dtype == torch.float32:
761
+ # Handle the case where the model is quantized
762
+ if hasattr(self.config, "_pre_quantization_dtype"):
763
+ target_dtype = self.config._pre_quantization_dtype
764
+ elif torch.is_autocast_enabled():
765
+ target_dtype = torch.get_autocast_gpu_dtype()
766
+ else:
767
+ target_dtype = (
768
+ self.q_proj.weight.dtype
769
+ if self.q_lora_rank is None
770
+ else self.q_a_proj.weight.dtype
771
+ )
772
+
773
+ logger.warning_once(
774
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
775
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
776
+ f" {target_dtype}."
777
+ )
778
+
779
+ query_states = query_states.to(target_dtype)
780
+ key_states = key_states.to(target_dtype)
781
+ value_states = value_states.to(target_dtype)
782
+
783
+ attn_output = self._flash_attention_forward(
784
+ query_states,
785
+ key_states,
786
+ value_states,
787
+ attention_mask,
788
+ q_len,
789
+ dropout=dropout_rate,
790
+ softmax_scale=self.softmax_scale,
791
+ )
792
+ if self.q_head_dim != self.v_head_dim:
793
+ attn_output = attn_output[:, :, :, : self.v_head_dim]
794
+
795
+ attn_output = attn_output.reshape(
796
+ bsz, q_len, self.num_heads * self.v_head_dim
797
+ ).contiguous()
798
+ attn_output = self.o_proj(attn_output)
799
+
800
+ if not output_attentions:
801
+ attn_weights = None
802
+
803
+ return attn_output, attn_weights, past_key_value
804
+
805
+ def _flash_attention_forward(
806
+ self,
807
+ query_states,
808
+ key_states,
809
+ value_states,
810
+ attention_mask,
811
+ query_length,
812
+ dropout=0.0,
813
+ softmax_scale=None,
814
+ ):
815
+ """
816
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
817
+ first unpad the input, then computes the attention scores and pad the final attention scores.
818
+ Args:
819
+ query_states (`torch.Tensor`):
820
+ Input query states to be passed to Flash Attention API
821
+ key_states (`torch.Tensor`):
822
+ Input key states to be passed to Flash Attention API
823
+ value_states (`torch.Tensor`):
824
+ Input value states to be passed to Flash Attention API
825
+ attention_mask (`torch.Tensor`):
826
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
827
+ position of padding tokens and 1 for the position of non-padding tokens.
828
+ dropout (`int`, *optional*):
829
+ Attention dropout
830
+ softmax_scale (`float`, *optional*):
831
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
832
+ """
833
+ if not self._flash_attn_uses_top_left_mask:
834
+ causal = self.is_causal
835
+ else:
836
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__.
837
+ causal = self.is_causal and query_length != 1
838
+
839
+ # Contains at least one padding token in the sequence
840
+ if attention_mask is not None:
841
+ batch_size = query_states.shape[0]
842
+ (
843
+ query_states,
844
+ key_states,
845
+ value_states,
846
+ indices_q,
847
+ cu_seq_lens,
848
+ max_seq_lens,
849
+ ) = self._upad_input(
850
+ query_states, key_states, value_states, attention_mask, query_length
851
+ )
852
+
853
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
854
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
855
+
856
+ attn_output_unpad = flash_attn_varlen_func(
857
+ query_states,
858
+ key_states,
859
+ value_states,
860
+ cu_seqlens_q=cu_seqlens_q,
861
+ cu_seqlens_k=cu_seqlens_k,
862
+ max_seqlen_q=max_seqlen_in_batch_q,
863
+ max_seqlen_k=max_seqlen_in_batch_k,
864
+ dropout_p=dropout,
865
+ softmax_scale=softmax_scale,
866
+ causal=causal,
867
+ )
868
+
869
+ attn_output = pad_input(
870
+ attn_output_unpad, indices_q, batch_size, query_length
871
+ )
872
+ else:
873
+ attn_output = flash_attn_func(
874
+ query_states,
875
+ key_states,
876
+ value_states,
877
+ dropout,
878
+ softmax_scale=softmax_scale,
879
+ causal=causal,
880
+ )
881
+
882
+ return attn_output
883
+
884
+ def _upad_input(
885
+ self, query_layer, key_layer, value_layer, attention_mask, query_length
886
+ ):
887
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
888
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
889
+
890
+ key_layer = index_first_axis(
891
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
892
+ indices_k,
893
+ )
894
+ value_layer = index_first_axis(
895
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
896
+ indices_k,
897
+ )
898
+ if query_length == kv_seq_len:
899
+ query_layer = index_first_axis(
900
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
901
+ indices_k,
902
+ )
903
+ cu_seqlens_q = cu_seqlens_k
904
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
905
+ indices_q = indices_k
906
+ elif query_length == 1:
907
+ max_seqlen_in_batch_q = 1
908
+ cu_seqlens_q = torch.arange(
909
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
910
+ ) # There is a memcpy here, that is very bad.
911
+ indices_q = cu_seqlens_q[:-1]
912
+ query_layer = query_layer.squeeze(1)
913
+ else:
914
+ # The -q_len: slice assumes left padding.
915
+ attention_mask = attention_mask[:, -query_length:]
916
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
917
+ query_layer, attention_mask
918
+ )
919
+
920
+ return (
921
+ query_layer,
922
+ key_layer,
923
+ value_layer,
924
+ indices_q,
925
+ (cu_seqlens_q, cu_seqlens_k),
926
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
927
+ )
928
 
929
 
930
  ATTENTION_CLASSES = {
 
932
  "flash_attention_2": DeepseekV3FlashAttention2,
933
  }
934
 
 
935
  class DeepseekV3DecoderLayer(nn.Module):
936
  def __init__(self, config: DeepseekV3Config, layer_idx: int):
937
  super().__init__()