File size: 11,421 Bytes
de31e1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 |
# coding=utf-8
# Copyright 2025 Charles O. Goddard, The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# The following monkeypatches were applied by Doctor Shotgun:
#
# Liger Kernel (https://github.com/linkedin/Liger-Kernel):
# 1. Liger RMSNorm
# 2. Liger RoPE
# 3. Liger SwiGLUMLP
# 4. Liger Fused Linear Cross-Entropy
"""PyTorch Qwen3 model with shared expert support."""
from typing import List, Optional, Union
import torch
from torch import nn
import torch.nn.functional as F
# Liger Patch #
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.swiglu import LigerQwen3MoeSwiGLUMLP
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
import transformers.models.qwen3_moe.modeling_qwen3_moe
transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeRMSNorm = LigerRMSNorm
transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeMLP = LigerQwen3MoeSwiGLUMLP
transformers.models.qwen3_moe.modeling_qwen3_moe.apply_rotary_pos_emb = liger_rotary_pos_emb
# Liger Patch #
from transformers.modeling_outputs import (
MoeCausalLMOutputWithPast,
MoeModelOutputWithPast,
)
from transformers.activations import ACT2FN
from transformers.utils import logging
from transformers.models.mixtral.modeling_mixtral import (
load_balancing_loss_func,
)
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
Qwen3MoeMLP,
Qwen3MoeRMSNorm,
Qwen3MoeAttention,
Qwen3MoeDecoderLayer,
Qwen3MoeModel,
Qwen3MoeForCausalLM,
)
from .configuration_qwen3_shared_moe import Qwen3SharedMoeConfig
import scattermoe
logger = logging.get_logger(__name__)
class Qwen3SharedMoeSparseMoeBlock(nn.Module):
def __init__(self, config: Qwen3SharedMoeConfig):
super().__init__()
self.config = config
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
if config.shared_expert_intermediate_size is not None:
self.shared_expert = Qwen3MoeMLP(
config, intermediate_size=config.shared_expert_intermediate_size
)
else:
self.shared_expert = None
self.moe_mlp = scattermoe.mlp.GLUMLP(
input_size=self.config.hidden_size,
hidden_size=self.config.moe_intermediate_size,
num_experts=self.config.num_experts,
top_k=self.config.num_experts_per_tok,
activation=ACT2FN[config.hidden_act],
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# handling of gate/router logits copied from Qwen3MoeSparseMoeBlock
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(
routing_weights, self.config.num_experts_per_tok, dim=-1
)
if self.config.norm_topk_prob: # only diff with mixtral sparse moe block!
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
# modified here to use scattermoe + shared_expert
hs_0 = self.moe_mlp(hidden_states, routing_weights, selected_experts)
if self.shared_expert is not None:
shared_res = self.shared_expert(hidden_states)
res = hs_0 + shared_res
else:
res = hs_0
res = res.reshape(batch_size, sequence_length, hidden_dim)
return res, router_logits
class Qwen3SharedMoeDecoderLayer(Qwen3MoeDecoderLayer, nn.Module):
def __init__(self, config: Qwen3SharedMoeConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.hidden_size = config.hidden_size
self.self_attn = Qwen3MoeAttention(config, layer_idx)
if (layer_idx not in config.mlp_only_layers) and (
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
):
self.mlp = Qwen3SharedMoeSparseMoeBlock(config)
else:
self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size)
self.input_layernorm = Qwen3MoeRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm = Qwen3MoeRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
class Qwen3SharedMoeModel(Qwen3MoeModel):
config_class = Qwen3SharedMoeConfig
def __init__(self, config: Qwen3SharedMoeConfig):
super().__init__(config)
self.layers = nn.ModuleList(
[
Qwen3SharedMoeDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
class Qwen3SharedMoeForCausalLM(Qwen3MoeForCausalLM):
config_class = Qwen3SharedMoeConfig
def __init__(self, config):
super().__init__(config)
self.model = Qwen3SharedMoeModel(config)
self.num_experts = config.num_experts
# Liger Patch #
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
**kwargs,
) -> MoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
>>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: MoeModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]
shift_labels = kwargs.pop("shift_labels", None)
logits = None
loss = None
if skip_logits is None:
skip_logits = self.training and (labels is not None or shift_labels is not None)
if skip_logits:
loss = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference model materialize logits
logits = self.lm_head(kept_hidden_states)
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
# Liger Patch #
|