zhouzaida commited on
Commit
725901f
·
verified ·
1 Parent(s): a98910a

can set attn_implemention (#8)

Browse files

- can set attn_implementation (9e6c3226b877a6be05e385279d56fcf26a0f9fab)
- add sdpa back (7718375747c38ef6a6e957a615edd4b3df495282)
- add blank (d869dc5ea79e62a0697c986a0fdeab12860c65bf)

Files changed (2) hide show
  1. configuration_kimi_vl.py +33 -21
  2. modeling_kimi_vl.py +33 -1
configuration_kimi_vl.py CHANGED
@@ -6,6 +6,7 @@ logger = logging.get_logger(__name__)
6
 
7
  DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
8
 
 
9
  class DeepseekV3Config(PretrainedConfig):
10
  r"""
11
  This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
@@ -122,30 +123,30 @@ class DeepseekV3Config(PretrainedConfig):
122
  vocab_size=129280,
123
  hidden_size=7168,
124
  intermediate_size=18432,
125
- moe_intermediate_size = 2048,
126
  num_hidden_layers=61,
127
  num_nextn_predict_layers=1,
128
  num_attention_heads=128,
129
  num_key_value_heads=128,
130
- n_shared_experts = 1,
131
- n_routed_experts = 256,
132
- ep_size = 1,
133
- routed_scaling_factor = 2.5,
134
- kv_lora_rank = 512,
135
- q_lora_rank = 1536,
136
- qk_rope_head_dim = 64,
137
- v_head_dim = 128,
138
- qk_nope_head_dim = 128,
139
- topk_method = 'noaux_tc',
140
- n_group = 8,
141
- topk_group = 4,
142
- num_experts_per_tok = 8,
143
- moe_layer_freq = 1,
144
- first_k_dense_replace = 3,
145
- norm_topk_prob = True,
146
- scoring_func = 'sigmoid',
147
- aux_loss_alpha = 0.001,
148
- seq_aux = True,
149
  hidden_act="silu",
150
  max_position_embeddings=4096,
151
  initializer_range=0.02,
@@ -252,7 +253,7 @@ class KimiVLConfig(PretrainedConfig):
252
  ignore_index: int = -100,
253
  media_placeholder_token_id: int = 163605,
254
  pad_token_id: int = 0,
255
- **kwargs
256
  ):
257
  if vision_config is None:
258
  vision_config = MoonViTConfig()
@@ -269,4 +270,15 @@ class KimiVLConfig(PretrainedConfig):
269
  self.ignore_index = ignore_index
270
  self.media_placeholder_token_id = media_placeholder_token_id
271
 
 
 
 
 
 
 
 
 
 
 
 
272
  super().__init__(pad_token_id=pad_token_id, **kwargs)
 
6
 
7
  DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
8
 
9
+
10
  class DeepseekV3Config(PretrainedConfig):
11
  r"""
12
  This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
 
123
  vocab_size=129280,
124
  hidden_size=7168,
125
  intermediate_size=18432,
126
+ moe_intermediate_size=2048,
127
  num_hidden_layers=61,
128
  num_nextn_predict_layers=1,
129
  num_attention_heads=128,
130
  num_key_value_heads=128,
131
+ n_shared_experts=1,
132
+ n_routed_experts=256,
133
+ ep_size=1,
134
+ routed_scaling_factor=2.5,
135
+ kv_lora_rank=512,
136
+ q_lora_rank=1536,
137
+ qk_rope_head_dim=64,
138
+ v_head_dim=128,
139
+ qk_nope_head_dim=128,
140
+ topk_method="noaux_tc",
141
+ n_group=8,
142
+ topk_group=4,
143
+ num_experts_per_tok=8,
144
+ moe_layer_freq=1,
145
+ first_k_dense_replace=3,
146
+ norm_topk_prob=True,
147
+ scoring_func="sigmoid",
148
+ aux_loss_alpha=0.001,
149
+ seq_aux=True,
150
  hidden_act="silu",
151
  max_position_embeddings=4096,
152
  initializer_range=0.02,
 
253
  ignore_index: int = -100,
254
  media_placeholder_token_id: int = 163605,
255
  pad_token_id: int = 0,
256
+ **kwargs,
257
  ):
258
  if vision_config is None:
259
  vision_config = MoonViTConfig()
 
270
  self.ignore_index = ignore_index
271
  self.media_placeholder_token_id = media_placeholder_token_id
272
 
273
+ attn_implementation = kwargs.get("attn_implementation")
274
+ if attn_implementation is not None:
275
+ if attn_implementation in ["eager", "flash_attention_2"]:
276
+ self._attn_implementation = attn_implementation
277
+ self.vision_config._attn_implementation = attn_implementation
278
+ self.text_config._attn_implementation = attn_implementation
279
+ else:
280
+ raise ValueError(
281
+ f"Invalid attention implementation: {attn_implementation}"
282
+ )
283
+
284
  super().__init__(pad_token_id=pad_token_id, **kwargs)
modeling_kimi_vl.py CHANGED
@@ -177,9 +177,41 @@ def sdpa_attention(
177
  return attn_output
178
 
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  VL_VISION_ATTENTION_FUNCTIONS = {
181
  "flash_attention_2": multihead_attention,
182
  "sdpa": sdpa_attention,
 
183
  }
184
 
185
 
@@ -412,7 +444,7 @@ class MoonVitEncoderLayer(nn.Module):
412
  hidden_dim: int,
413
  mlp_dim: int,
414
  *,
415
- attn_implementation: str = "sdpa",
416
  activation=F.gelu,
417
  attn_bias: bool = False,
418
  ):
 
177
  return attn_output
178
 
179
 
180
+ def eager_attention(
181
+ q: torch.Tensor,
182
+ k: torch.Tensor,
183
+ v: torch.Tensor,
184
+ q_cu_seqlens: Optional[torch.Tensor] = None,
185
+ k_cu_seqlens: Optional[torch.Tensor] = None,
186
+ ) -> torch.Tensor:
187
+ seq_length = q.shape[0]
188
+ attention_mask = torch.zeros(
189
+ [1, seq_length, seq_length], device=q.device, dtype=torch.bool
190
+ )
191
+ for i in range(1, len(q_cu_seqlens)):
192
+ attention_mask[
193
+ ...,
194
+ q_cu_seqlens[i - 1] : q_cu_seqlens[i],
195
+ q_cu_seqlens[i - 1] : q_cu_seqlens[i],
196
+ ] = True
197
+ q = q.transpose(0, 1)
198
+ k = k.transpose(0, 1)
199
+ v = v.transpose(0, 1)
200
+
201
+ attn_weight = q @ k.transpose(-2, -1) / math.sqrt(q.shape[-1])
202
+ attn_weight += attention_mask
203
+ attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
204
+
205
+ attn_output = attn_weight @ v
206
+ attn_output = attn_output.transpose(0, 1)
207
+ attn_output = attn_output.reshape(seq_length, -1)
208
+ return attn_output
209
+
210
+
211
  VL_VISION_ATTENTION_FUNCTIONS = {
212
  "flash_attention_2": multihead_attention,
213
  "sdpa": sdpa_attention,
214
+ "eager": eager_attention,
215
  }
216
 
217
 
 
444
  hidden_dim: int,
445
  mlp_dim: int,
446
  *,
447
+ attn_implementation: str = "eager",
448
  activation=F.gelu,
449
  attn_bias: bool = False,
450
  ):