Update modeling_deepseek.py
Browse files- modeling_deepseek.py +8 -4
modeling_deepseek.py
CHANGED
@@ -522,10 +522,14 @@ class DeepseekV3MoE(nn.Module):
|
|
522 |
topk_idx, topk_weight = self.gate(hidden_states)
|
523 |
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
524 |
flat_topk_idx = topk_idx.view(-1)
|
525 |
-
if not self.training:
|
526 |
-
|
527 |
-
|
528 |
-
|
|
|
|
|
|
|
|
|
529 |
return y
|
530 |
|
531 |
@torch.no_grad()
|
|
|
522 |
topk_idx, topk_weight = self.gate(hidden_states)
|
523 |
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
524 |
flat_topk_idx = topk_idx.view(-1)
|
525 |
+
# if not self.training:
|
526 |
+
y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape)
|
527 |
+
try:
|
528 |
+
if self.config.n_shared_experts is not None:
|
529 |
+
y = y + self.shared_experts(identity)
|
530 |
+
except Exception as e:
|
531 |
+
if self.config.n_shared_experts is not None:
|
532 |
+
y = self.shared_experts(identity)
|
533 |
return y
|
534 |
|
535 |
@torch.no_grad()
|