fix
Browse files- modeling_deepseek.py +271 -4
modeling_deepseek.py
CHANGED
@@ -651,12 +651,280 @@ class DeepseekV3Attention(nn.Module):
|
|
651 |
return attn_output, attn_weights, past_key_value
|
652 |
|
653 |
|
|
|
654 |
class DeepseekV3FlashAttention2(DeepseekV3Attention):
|
655 |
"""
|
656 |
-
|
|
|
|
|
657 |
"""
|
658 |
-
|
659 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
660 |
|
661 |
|
662 |
ATTENTION_CLASSES = {
|
@@ -664,7 +932,6 @@ ATTENTION_CLASSES = {
|
|
664 |
"flash_attention_2": DeepseekV3FlashAttention2,
|
665 |
}
|
666 |
|
667 |
-
|
668 |
class DeepseekV3DecoderLayer(nn.Module):
|
669 |
def __init__(self, config: DeepseekV3Config, layer_idx: int):
|
670 |
super().__init__()
|
|
|
651 |
return attn_output, attn_weights, past_key_value
|
652 |
|
653 |
|
654 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3
|
655 |
class DeepseekV3FlashAttention2(DeepseekV3Attention):
|
656 |
"""
|
657 |
+
DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays
|
658 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
659 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
660 |
"""
|
661 |
+
|
662 |
+
def __init__(self, *args, **kwargs):
|
663 |
+
super().__init__(*args, **kwargs)
|
664 |
+
|
665 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
666 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
667 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
668 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
669 |
+
|
670 |
+
def forward(
|
671 |
+
self,
|
672 |
+
hidden_states: torch.Tensor,
|
673 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
674 |
+
position_ids: Optional[torch.LongTensor] = None,
|
675 |
+
past_key_value: Optional[Cache] = None,
|
676 |
+
output_attentions: bool = False,
|
677 |
+
use_cache: bool = False,
|
678 |
+
**kwargs,
|
679 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
680 |
+
# DeepseekV3FlashAttention2 attention does not support output_attentions
|
681 |
+
if "padding_mask" in kwargs:
|
682 |
+
warnings.warn(
|
683 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
684 |
+
)
|
685 |
+
|
686 |
+
# overwrite attention_mask with padding_mask
|
687 |
+
attention_mask = kwargs.pop("padding_mask")
|
688 |
+
|
689 |
+
output_attentions = False
|
690 |
+
|
691 |
+
bsz, q_len, _ = hidden_states.size()
|
692 |
+
|
693 |
+
if self.q_lora_rank is None:
|
694 |
+
q = self.q_proj(hidden_states)
|
695 |
+
else:
|
696 |
+
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
|
697 |
+
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
|
698 |
+
q_nope, q_pe = torch.split(
|
699 |
+
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
700 |
+
)
|
701 |
+
|
702 |
+
# Flash attention requires the input to have the shape
|
703 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
704 |
+
# therefore we just need to keep the original shape
|
705 |
+
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
706 |
+
compressed_kv, k_pe = torch.split(
|
707 |
+
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
708 |
+
)
|
709 |
+
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
|
710 |
+
kv = (
|
711 |
+
self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
|
712 |
+
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
713 |
+
.transpose(1, 2)
|
714 |
+
)
|
715 |
+
|
716 |
+
k_nope, value_states = torch.split(
|
717 |
+
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
718 |
+
)
|
719 |
+
kv_seq_len = value_states.shape[-2]
|
720 |
+
|
721 |
+
kv_seq_len = value_states.shape[-2]
|
722 |
+
if past_key_value is not None:
|
723 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
724 |
+
|
725 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
726 |
+
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
727 |
+
|
728 |
+
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
729 |
+
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
|
730 |
+
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
|
731 |
+
|
732 |
+
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
|
733 |
+
key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
|
734 |
+
key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
|
735 |
+
|
736 |
+
if self.q_head_dim != self.v_head_dim:
|
737 |
+
value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim])
|
738 |
+
|
739 |
+
if past_key_value is not None:
|
740 |
+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
741 |
+
key_states, value_states = past_key_value.update(
|
742 |
+
key_states, value_states, self.layer_idx, cache_kwargs
|
743 |
+
)
|
744 |
+
|
745 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
746 |
+
# to be able to avoid many of these transpose/reshape/view.
|
747 |
+
query_states = query_states.transpose(1, 2)
|
748 |
+
key_states = key_states.transpose(1, 2)
|
749 |
+
value_states = value_states.transpose(1, 2)
|
750 |
+
|
751 |
+
dropout_rate = self.attention_dropout if self.training else 0.0
|
752 |
+
|
753 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
754 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
755 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
756 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
757 |
+
# in fp32. (DeepseekV3RMSNorm handles it correctly)
|
758 |
+
|
759 |
+
input_dtype = query_states.dtype
|
760 |
+
if input_dtype == torch.float32:
|
761 |
+
# Handle the case where the model is quantized
|
762 |
+
if hasattr(self.config, "_pre_quantization_dtype"):
|
763 |
+
target_dtype = self.config._pre_quantization_dtype
|
764 |
+
elif torch.is_autocast_enabled():
|
765 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
766 |
+
else:
|
767 |
+
target_dtype = (
|
768 |
+
self.q_proj.weight.dtype
|
769 |
+
if self.q_lora_rank is None
|
770 |
+
else self.q_a_proj.weight.dtype
|
771 |
+
)
|
772 |
+
|
773 |
+
logger.warning_once(
|
774 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
775 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
776 |
+
f" {target_dtype}."
|
777 |
+
)
|
778 |
+
|
779 |
+
query_states = query_states.to(target_dtype)
|
780 |
+
key_states = key_states.to(target_dtype)
|
781 |
+
value_states = value_states.to(target_dtype)
|
782 |
+
|
783 |
+
attn_output = self._flash_attention_forward(
|
784 |
+
query_states,
|
785 |
+
key_states,
|
786 |
+
value_states,
|
787 |
+
attention_mask,
|
788 |
+
q_len,
|
789 |
+
dropout=dropout_rate,
|
790 |
+
softmax_scale=self.softmax_scale,
|
791 |
+
)
|
792 |
+
if self.q_head_dim != self.v_head_dim:
|
793 |
+
attn_output = attn_output[:, :, :, : self.v_head_dim]
|
794 |
+
|
795 |
+
attn_output = attn_output.reshape(
|
796 |
+
bsz, q_len, self.num_heads * self.v_head_dim
|
797 |
+
).contiguous()
|
798 |
+
attn_output = self.o_proj(attn_output)
|
799 |
+
|
800 |
+
if not output_attentions:
|
801 |
+
attn_weights = None
|
802 |
+
|
803 |
+
return attn_output, attn_weights, past_key_value
|
804 |
+
|
805 |
+
def _flash_attention_forward(
|
806 |
+
self,
|
807 |
+
query_states,
|
808 |
+
key_states,
|
809 |
+
value_states,
|
810 |
+
attention_mask,
|
811 |
+
query_length,
|
812 |
+
dropout=0.0,
|
813 |
+
softmax_scale=None,
|
814 |
+
):
|
815 |
+
"""
|
816 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
817 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
818 |
+
Args:
|
819 |
+
query_states (`torch.Tensor`):
|
820 |
+
Input query states to be passed to Flash Attention API
|
821 |
+
key_states (`torch.Tensor`):
|
822 |
+
Input key states to be passed to Flash Attention API
|
823 |
+
value_states (`torch.Tensor`):
|
824 |
+
Input value states to be passed to Flash Attention API
|
825 |
+
attention_mask (`torch.Tensor`):
|
826 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
827 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
828 |
+
dropout (`int`, *optional*):
|
829 |
+
Attention dropout
|
830 |
+
softmax_scale (`float`, *optional*):
|
831 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
832 |
+
"""
|
833 |
+
if not self._flash_attn_uses_top_left_mask:
|
834 |
+
causal = self.is_causal
|
835 |
+
else:
|
836 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__.
|
837 |
+
causal = self.is_causal and query_length != 1
|
838 |
+
|
839 |
+
# Contains at least one padding token in the sequence
|
840 |
+
if attention_mask is not None:
|
841 |
+
batch_size = query_states.shape[0]
|
842 |
+
(
|
843 |
+
query_states,
|
844 |
+
key_states,
|
845 |
+
value_states,
|
846 |
+
indices_q,
|
847 |
+
cu_seq_lens,
|
848 |
+
max_seq_lens,
|
849 |
+
) = self._upad_input(
|
850 |
+
query_states, key_states, value_states, attention_mask, query_length
|
851 |
+
)
|
852 |
+
|
853 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
854 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
855 |
+
|
856 |
+
attn_output_unpad = flash_attn_varlen_func(
|
857 |
+
query_states,
|
858 |
+
key_states,
|
859 |
+
value_states,
|
860 |
+
cu_seqlens_q=cu_seqlens_q,
|
861 |
+
cu_seqlens_k=cu_seqlens_k,
|
862 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
863 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
864 |
+
dropout_p=dropout,
|
865 |
+
softmax_scale=softmax_scale,
|
866 |
+
causal=causal,
|
867 |
+
)
|
868 |
+
|
869 |
+
attn_output = pad_input(
|
870 |
+
attn_output_unpad, indices_q, batch_size, query_length
|
871 |
+
)
|
872 |
+
else:
|
873 |
+
attn_output = flash_attn_func(
|
874 |
+
query_states,
|
875 |
+
key_states,
|
876 |
+
value_states,
|
877 |
+
dropout,
|
878 |
+
softmax_scale=softmax_scale,
|
879 |
+
causal=causal,
|
880 |
+
)
|
881 |
+
|
882 |
+
return attn_output
|
883 |
+
|
884 |
+
def _upad_input(
|
885 |
+
self, query_layer, key_layer, value_layer, attention_mask, query_length
|
886 |
+
):
|
887 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
888 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
889 |
+
|
890 |
+
key_layer = index_first_axis(
|
891 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
892 |
+
indices_k,
|
893 |
+
)
|
894 |
+
value_layer = index_first_axis(
|
895 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
|
896 |
+
indices_k,
|
897 |
+
)
|
898 |
+
if query_length == kv_seq_len:
|
899 |
+
query_layer = index_first_axis(
|
900 |
+
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
|
901 |
+
indices_k,
|
902 |
+
)
|
903 |
+
cu_seqlens_q = cu_seqlens_k
|
904 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
905 |
+
indices_q = indices_k
|
906 |
+
elif query_length == 1:
|
907 |
+
max_seqlen_in_batch_q = 1
|
908 |
+
cu_seqlens_q = torch.arange(
|
909 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
910 |
+
) # There is a memcpy here, that is very bad.
|
911 |
+
indices_q = cu_seqlens_q[:-1]
|
912 |
+
query_layer = query_layer.squeeze(1)
|
913 |
+
else:
|
914 |
+
# The -q_len: slice assumes left padding.
|
915 |
+
attention_mask = attention_mask[:, -query_length:]
|
916 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
|
917 |
+
query_layer, attention_mask
|
918 |
+
)
|
919 |
+
|
920 |
+
return (
|
921 |
+
query_layer,
|
922 |
+
key_layer,
|
923 |
+
value_layer,
|
924 |
+
indices_q,
|
925 |
+
(cu_seqlens_q, cu_seqlens_k),
|
926 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
927 |
+
)
|
928 |
|
929 |
|
930 |
ATTENTION_CLASSES = {
|
|
|
932 |
"flash_attention_2": DeepseekV3FlashAttention2,
|
933 |
}
|
934 |
|
|
|
935 |
class DeepseekV3DecoderLayer(nn.Module):
|
936 |
def __init__(self, config: DeepseekV3Config, layer_idx: int):
|
937 |
super().__init__()
|