can set attn_implemention (#8)
Browse files- can set attn_implementation (9e6c3226b877a6be05e385279d56fcf26a0f9fab)
- add sdpa back (7718375747c38ef6a6e957a615edd4b3df495282)
- add blank (d869dc5ea79e62a0697c986a0fdeab12860c65bf)
- configuration_kimi_vl.py +33 -21
- modeling_kimi_vl.py +33 -1
configuration_kimi_vl.py
CHANGED
@@ -6,6 +6,7 @@ logger = logging.get_logger(__name__)
|
|
6 |
|
7 |
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
8 |
|
|
|
9 |
class DeepseekV3Config(PretrainedConfig):
|
10 |
r"""
|
11 |
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
|
@@ -122,30 +123,30 @@ class DeepseekV3Config(PretrainedConfig):
|
|
122 |
vocab_size=129280,
|
123 |
hidden_size=7168,
|
124 |
intermediate_size=18432,
|
125 |
-
moe_intermediate_size
|
126 |
num_hidden_layers=61,
|
127 |
num_nextn_predict_layers=1,
|
128 |
num_attention_heads=128,
|
129 |
num_key_value_heads=128,
|
130 |
-
n_shared_experts
|
131 |
-
n_routed_experts
|
132 |
-
ep_size
|
133 |
-
routed_scaling_factor
|
134 |
-
kv_lora_rank
|
135 |
-
q_lora_rank
|
136 |
-
qk_rope_head_dim
|
137 |
-
v_head_dim
|
138 |
-
qk_nope_head_dim
|
139 |
-
topk_method
|
140 |
-
n_group
|
141 |
-
topk_group
|
142 |
-
num_experts_per_tok
|
143 |
-
moe_layer_freq
|
144 |
-
first_k_dense_replace
|
145 |
-
norm_topk_prob
|
146 |
-
scoring_func
|
147 |
-
aux_loss_alpha
|
148 |
-
seq_aux
|
149 |
hidden_act="silu",
|
150 |
max_position_embeddings=4096,
|
151 |
initializer_range=0.02,
|
@@ -252,7 +253,7 @@ class KimiVLConfig(PretrainedConfig):
|
|
252 |
ignore_index: int = -100,
|
253 |
media_placeholder_token_id: int = 163605,
|
254 |
pad_token_id: int = 0,
|
255 |
-
**kwargs
|
256 |
):
|
257 |
if vision_config is None:
|
258 |
vision_config = MoonViTConfig()
|
@@ -269,4 +270,15 @@ class KimiVLConfig(PretrainedConfig):
|
|
269 |
self.ignore_index = ignore_index
|
270 |
self.media_placeholder_token_id = media_placeholder_token_id
|
271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
|
|
6 |
|
7 |
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
8 |
|
9 |
+
|
10 |
class DeepseekV3Config(PretrainedConfig):
|
11 |
r"""
|
12 |
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
|
|
|
123 |
vocab_size=129280,
|
124 |
hidden_size=7168,
|
125 |
intermediate_size=18432,
|
126 |
+
moe_intermediate_size=2048,
|
127 |
num_hidden_layers=61,
|
128 |
num_nextn_predict_layers=1,
|
129 |
num_attention_heads=128,
|
130 |
num_key_value_heads=128,
|
131 |
+
n_shared_experts=1,
|
132 |
+
n_routed_experts=256,
|
133 |
+
ep_size=1,
|
134 |
+
routed_scaling_factor=2.5,
|
135 |
+
kv_lora_rank=512,
|
136 |
+
q_lora_rank=1536,
|
137 |
+
qk_rope_head_dim=64,
|
138 |
+
v_head_dim=128,
|
139 |
+
qk_nope_head_dim=128,
|
140 |
+
topk_method="noaux_tc",
|
141 |
+
n_group=8,
|
142 |
+
topk_group=4,
|
143 |
+
num_experts_per_tok=8,
|
144 |
+
moe_layer_freq=1,
|
145 |
+
first_k_dense_replace=3,
|
146 |
+
norm_topk_prob=True,
|
147 |
+
scoring_func="sigmoid",
|
148 |
+
aux_loss_alpha=0.001,
|
149 |
+
seq_aux=True,
|
150 |
hidden_act="silu",
|
151 |
max_position_embeddings=4096,
|
152 |
initializer_range=0.02,
|
|
|
253 |
ignore_index: int = -100,
|
254 |
media_placeholder_token_id: int = 163605,
|
255 |
pad_token_id: int = 0,
|
256 |
+
**kwargs,
|
257 |
):
|
258 |
if vision_config is None:
|
259 |
vision_config = MoonViTConfig()
|
|
|
270 |
self.ignore_index = ignore_index
|
271 |
self.media_placeholder_token_id = media_placeholder_token_id
|
272 |
|
273 |
+
attn_implementation = kwargs.get("attn_implementation")
|
274 |
+
if attn_implementation is not None:
|
275 |
+
if attn_implementation in ["eager", "flash_attention_2"]:
|
276 |
+
self._attn_implementation = attn_implementation
|
277 |
+
self.vision_config._attn_implementation = attn_implementation
|
278 |
+
self.text_config._attn_implementation = attn_implementation
|
279 |
+
else:
|
280 |
+
raise ValueError(
|
281 |
+
f"Invalid attention implementation: {attn_implementation}"
|
282 |
+
)
|
283 |
+
|
284 |
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
modeling_kimi_vl.py
CHANGED
@@ -177,9 +177,41 @@ def sdpa_attention(
|
|
177 |
return attn_output
|
178 |
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
VL_VISION_ATTENTION_FUNCTIONS = {
|
181 |
"flash_attention_2": multihead_attention,
|
182 |
"sdpa": sdpa_attention,
|
|
|
183 |
}
|
184 |
|
185 |
|
@@ -412,7 +444,7 @@ class MoonVitEncoderLayer(nn.Module):
|
|
412 |
hidden_dim: int,
|
413 |
mlp_dim: int,
|
414 |
*,
|
415 |
-
attn_implementation: str = "
|
416 |
activation=F.gelu,
|
417 |
attn_bias: bool = False,
|
418 |
):
|
|
|
177 |
return attn_output
|
178 |
|
179 |
|
180 |
+
def eager_attention(
|
181 |
+
q: torch.Tensor,
|
182 |
+
k: torch.Tensor,
|
183 |
+
v: torch.Tensor,
|
184 |
+
q_cu_seqlens: Optional[torch.Tensor] = None,
|
185 |
+
k_cu_seqlens: Optional[torch.Tensor] = None,
|
186 |
+
) -> torch.Tensor:
|
187 |
+
seq_length = q.shape[0]
|
188 |
+
attention_mask = torch.zeros(
|
189 |
+
[1, seq_length, seq_length], device=q.device, dtype=torch.bool
|
190 |
+
)
|
191 |
+
for i in range(1, len(q_cu_seqlens)):
|
192 |
+
attention_mask[
|
193 |
+
...,
|
194 |
+
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
|
195 |
+
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
|
196 |
+
] = True
|
197 |
+
q = q.transpose(0, 1)
|
198 |
+
k = k.transpose(0, 1)
|
199 |
+
v = v.transpose(0, 1)
|
200 |
+
|
201 |
+
attn_weight = q @ k.transpose(-2, -1) / math.sqrt(q.shape[-1])
|
202 |
+
attn_weight += attention_mask
|
203 |
+
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32).to(q.dtype)
|
204 |
+
|
205 |
+
attn_output = attn_weight @ v
|
206 |
+
attn_output = attn_output.transpose(0, 1)
|
207 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
208 |
+
return attn_output
|
209 |
+
|
210 |
+
|
211 |
VL_VISION_ATTENTION_FUNCTIONS = {
|
212 |
"flash_attention_2": multihead_attention,
|
213 |
"sdpa": sdpa_attention,
|
214 |
+
"eager": eager_attention,
|
215 |
}
|
216 |
|
217 |
|
|
|
444 |
hidden_dim: int,
|
445 |
mlp_dim: int,
|
446 |
*,
|
447 |
+
attn_implementation: str = "eager",
|
448 |
activation=F.gelu,
|
449 |
attn_bias: bool = False,
|
450 |
):
|