Jackmin108 commited on
Commit
fe340b5
·
1 Parent(s): 1805272

use torchtitan moe impl

Browse files
Files changed (1) hide show
  1. modeling_deepseek.py +15 -1
modeling_deepseek.py CHANGED
@@ -59,6 +59,8 @@ from .configuration_deepseek import DeepseekV3Config
59
  import torch.distributed as dist
60
  import numpy as np
61
 
 
 
62
  if is_flash_attn_2_available():
63
  from flash_attn import flash_attn_func, flash_attn_varlen_func
64
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
@@ -1150,8 +1152,20 @@ class DeepseekV3DecoderLayer(nn.Module):
1150
  config=config, layer_idx=layer_idx
1151
  )
1152
 
 
 
 
 
 
 
 
 
 
 
 
 
1153
  self.mlp = (
1154
- DeepseekV3MoE(config)
1155
  if (
1156
  config.n_routed_experts is not None
1157
  and layer_idx >= config.first_k_dense_replace
 
59
  import torch.distributed as dist
60
  import numpy as np
61
 
62
+ from torchtitan.models.moe import MoE, MoEArgs
63
+
64
  if is_flash_attn_2_available():
65
  from flash_attn import flash_attn_func, flash_attn_varlen_func
66
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
1152
  config=config, layer_idx=layer_idx
1153
  )
1154
 
1155
+ moe_args = MoEArgs(
1156
+ num_experts=config.n_routed_experts,
1157
+ num_shared_experts=config.n_shared_experts,
1158
+ score_func=config.scoring_func,
1159
+ route_norm=config.norm_topk_prob,
1160
+ route_scale=config.routed_scaling_factor,
1161
+ score_before_experts=False,
1162
+ top_k=config.num_experts_per_tok,
1163
+ use_grouped_mm=True,
1164
+ load_balance_coeff=1e-3,
1165
+ )
1166
+
1167
  self.mlp = (
1168
+ MoE(moe_args, dim=config.hidden_size, hidden_dim=config.moe_intermediate_size)
1169
  if (
1170
  config.n_routed_experts is not None
1171
  and layer_idx >= config.first_k_dense_replace