zhouzaida commited on
Commit
9e6c322
·
1 Parent(s): a98910a

can set attn_implementation

Browse files
Files changed (2) hide show
  1. configuration_kimi_vl.py +33 -21
  2. modeling_kimi_vl.py +9 -10
configuration_kimi_vl.py CHANGED
@@ -6,6 +6,7 @@ logger = logging.get_logger(__name__)
6
 
7
  DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
8
 
 
9
  class DeepseekV3Config(PretrainedConfig):
10
  r"""
11
  This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
@@ -122,30 +123,30 @@ class DeepseekV3Config(PretrainedConfig):
122
  vocab_size=129280,
123
  hidden_size=7168,
124
  intermediate_size=18432,
125
- moe_intermediate_size = 2048,
126
  num_hidden_layers=61,
127
  num_nextn_predict_layers=1,
128
  num_attention_heads=128,
129
  num_key_value_heads=128,
130
- n_shared_experts = 1,
131
- n_routed_experts = 256,
132
- ep_size = 1,
133
- routed_scaling_factor = 2.5,
134
- kv_lora_rank = 512,
135
- q_lora_rank = 1536,
136
- qk_rope_head_dim = 64,
137
- v_head_dim = 128,
138
- qk_nope_head_dim = 128,
139
- topk_method = 'noaux_tc',
140
- n_group = 8,
141
- topk_group = 4,
142
- num_experts_per_tok = 8,
143
- moe_layer_freq = 1,
144
- first_k_dense_replace = 3,
145
- norm_topk_prob = True,
146
- scoring_func = 'sigmoid',
147
- aux_loss_alpha = 0.001,
148
- seq_aux = True,
149
  hidden_act="silu",
150
  max_position_embeddings=4096,
151
  initializer_range=0.02,
@@ -252,7 +253,7 @@ class KimiVLConfig(PretrainedConfig):
252
  ignore_index: int = -100,
253
  media_placeholder_token_id: int = 163605,
254
  pad_token_id: int = 0,
255
- **kwargs
256
  ):
257
  if vision_config is None:
258
  vision_config = MoonViTConfig()
@@ -269,4 +270,15 @@ class KimiVLConfig(PretrainedConfig):
269
  self.ignore_index = ignore_index
270
  self.media_placeholder_token_id = media_placeholder_token_id
271
 
 
 
 
 
 
 
 
 
 
 
 
272
  super().__init__(pad_token_id=pad_token_id, **kwargs)
 
6
 
7
  DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
8
 
9
+
10
  class DeepseekV3Config(PretrainedConfig):
11
  r"""
12
  This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
 
123
  vocab_size=129280,
124
  hidden_size=7168,
125
  intermediate_size=18432,
126
+ moe_intermediate_size=2048,
127
  num_hidden_layers=61,
128
  num_nextn_predict_layers=1,
129
  num_attention_heads=128,
130
  num_key_value_heads=128,
131
+ n_shared_experts=1,
132
+ n_routed_experts=256,
133
+ ep_size=1,
134
+ routed_scaling_factor=2.5,
135
+ kv_lora_rank=512,
136
+ q_lora_rank=1536,
137
+ qk_rope_head_dim=64,
138
+ v_head_dim=128,
139
+ qk_nope_head_dim=128,
140
+ topk_method="noaux_tc",
141
+ n_group=8,
142
+ topk_group=4,
143
+ num_experts_per_tok=8,
144
+ moe_layer_freq=1,
145
+ first_k_dense_replace=3,
146
+ norm_topk_prob=True,
147
+ scoring_func="sigmoid",
148
+ aux_loss_alpha=0.001,
149
+ seq_aux=True,
150
  hidden_act="silu",
151
  max_position_embeddings=4096,
152
  initializer_range=0.02,
 
253
  ignore_index: int = -100,
254
  media_placeholder_token_id: int = 163605,
255
  pad_token_id: int = 0,
256
+ **kwargs,
257
  ):
258
  if vision_config is None:
259
  vision_config = MoonViTConfig()
 
270
  self.ignore_index = ignore_index
271
  self.media_placeholder_token_id = media_placeholder_token_id
272
 
273
+ attn_implementation = kwargs.get("attn_implementation")
274
+ if attn_implementation is not None:
275
+ if attn_implementation in ["eager", "flash_attention_2"]:
276
+ self._attn_implementation = attn_implementation
277
+ self.vision_config._attn_implementation = attn_implementation
278
+ self.text_config._attn_implementation = attn_implementation
279
+ else:
280
+ raise ValueError(
281
+ f"Invalid attention implementation: {attn_implementation}"
282
+ )
283
+
284
  super().__init__(pad_token_id=pad_token_id, **kwargs)
modeling_kimi_vl.py CHANGED
@@ -145,19 +145,13 @@ def multihead_attention(
145
  return attn_out
146
 
147
 
148
- def sdpa_attention(
149
  q: torch.Tensor,
150
  k: torch.Tensor,
151
  v: torch.Tensor,
152
  q_cu_seqlens: Optional[torch.Tensor] = None,
153
  k_cu_seqlens: Optional[torch.Tensor] = None,
154
  ) -> torch.Tensor:
155
- """SDPA attention.
156
-
157
- Args:
158
- q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
159
- or (tot_seqlens, num_heads, head_dim) if packing.
160
- """
161
  seq_length = q.shape[0]
162
  attention_mask = torch.zeros(
163
  [1, seq_length, seq_length], device=q.device, dtype=torch.bool
@@ -171,7 +165,12 @@ def sdpa_attention(
171
  q = q.transpose(0, 1)
172
  k = k.transpose(0, 1)
173
  v = v.transpose(0, 1)
174
- attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
 
 
 
 
 
175
  attn_output = attn_output.transpose(0, 1)
176
  attn_output = attn_output.reshape(seq_length, -1)
177
  return attn_output
@@ -179,7 +178,7 @@ def sdpa_attention(
179
 
180
  VL_VISION_ATTENTION_FUNCTIONS = {
181
  "flash_attention_2": multihead_attention,
182
- "sdpa": sdpa_attention,
183
  }
184
 
185
 
@@ -412,7 +411,7 @@ class MoonVitEncoderLayer(nn.Module):
412
  hidden_dim: int,
413
  mlp_dim: int,
414
  *,
415
- attn_implementation: str = "sdpa",
416
  activation=F.gelu,
417
  attn_bias: bool = False,
418
  ):
 
145
  return attn_out
146
 
147
 
148
+ def eager_attention(
149
  q: torch.Tensor,
150
  k: torch.Tensor,
151
  v: torch.Tensor,
152
  q_cu_seqlens: Optional[torch.Tensor] = None,
153
  k_cu_seqlens: Optional[torch.Tensor] = None,
154
  ) -> torch.Tensor:
 
 
 
 
 
 
155
  seq_length = q.shape[0]
156
  attention_mask = torch.zeros(
157
  [1, seq_length, seq_length], device=q.device, dtype=torch.bool
 
165
  q = q.transpose(0, 1)
166
  k = k.transpose(0, 1)
167
  v = v.transpose(0, 1)
168
+
169
+ attn_weight = q @ k.transpose(-2, -1) / math.sqrt(q.shape[-1])
170
+ attn_weight += attention_mask
171
+ attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
172
+
173
+ attn_output = attn_weight @ v
174
  attn_output = attn_output.transpose(0, 1)
175
  attn_output = attn_output.reshape(seq_length, -1)
176
  return attn_output
 
178
 
179
  VL_VISION_ATTENTION_FUNCTIONS = {
180
  "flash_attention_2": multihead_attention,
181
+ "eager": eager_attention,
182
  }
183
 
184
 
 
411
  hidden_dim: int,
412
  mlp_dim: int,
413
  *,
414
+ attn_implementation: str = "eager",
415
  activation=F.gelu,
416
  attn_bias: bool = False,
417
  ):