Qwen
/

Text Generation
Transformers
Safetensors
qwen3_moe
conversational

Why not distinguish between sequence_length equal to 1 and greater than 1 in the MoE module's forward function?

#27
by nifeng154 - opened

When I used the following code in the forward function of Qwen3MoeSparseMoeBlock:

if sequence_length == 1:
for expert_idx in selected_experts[0]:
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])

    current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)

    current_hidden_states = expert_layer(hidden_states) * routing_weights[top_x, idx, None]

    final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

else:
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])

    # Index the correct hidden states and compute the expert hidden state for
    # the current expert. We need to make sure to multiply the output hidden
    # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
    current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
    current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

    # However `index_add_` only support torch tensors for indexing so we'll use
    # the `top_x` tensor here.
    final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))

It significantly improved the generation speed.

Sign up or log in to comment