Commit
·
fe340b5
1
Parent(s):
1805272
use torchtitan moe impl
Browse files- modeling_deepseek.py +15 -1
modeling_deepseek.py
CHANGED
@@ -59,6 +59,8 @@ from .configuration_deepseek import DeepseekV3Config
|
|
59 |
import torch.distributed as dist
|
60 |
import numpy as np
|
61 |
|
|
|
|
|
62 |
if is_flash_attn_2_available():
|
63 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
64 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
@@ -1150,8 +1152,20 @@ class DeepseekV3DecoderLayer(nn.Module):
|
|
1150 |
config=config, layer_idx=layer_idx
|
1151 |
)
|
1152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1153 |
self.mlp = (
|
1154 |
-
|
1155 |
if (
|
1156 |
config.n_routed_experts is not None
|
1157 |
and layer_idx >= config.first_k_dense_replace
|
|
|
59 |
import torch.distributed as dist
|
60 |
import numpy as np
|
61 |
|
62 |
+
from torchtitan.models.moe import MoE, MoEArgs
|
63 |
+
|
64 |
if is_flash_attn_2_available():
|
65 |
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
66 |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
|
|
1152 |
config=config, layer_idx=layer_idx
|
1153 |
)
|
1154 |
|
1155 |
+
moe_args = MoEArgs(
|
1156 |
+
num_experts=config.n_routed_experts,
|
1157 |
+
num_shared_experts=config.n_shared_experts,
|
1158 |
+
score_func=config.scoring_func,
|
1159 |
+
route_norm=config.norm_topk_prob,
|
1160 |
+
route_scale=config.routed_scaling_factor,
|
1161 |
+
score_before_experts=False,
|
1162 |
+
top_k=config.num_experts_per_tok,
|
1163 |
+
use_grouped_mm=True,
|
1164 |
+
load_balance_coeff=1e-3,
|
1165 |
+
)
|
1166 |
+
|
1167 |
self.mlp = (
|
1168 |
+
MoE(moe_args, dim=config.hidden_size, hidden_dim=config.moe_intermediate_size)
|
1169 |
if (
|
1170 |
config.n_routed_experts is not None
|
1171 |
and layer_idx >= config.first_k_dense_replace
|