|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTorch Qwen3 model with shared expert support.""" |
|
|
|
from typing import List, Optional, Union |
|
|
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
from transformers.modeling_outputs import ( |
|
MoeCausalLMOutputWithPast, |
|
MoeModelOutputWithPast, |
|
) |
|
from transformers.activations import ACT2FN |
|
from transformers.utils import logging |
|
from transformers.models.mixtral.modeling_mixtral import ( |
|
load_balancing_loss_func, |
|
) |
|
from transformers.models.qwen3_moe.modeling_qwen3_moe import ( |
|
Qwen3MoeMLP, |
|
Qwen3MoeRMSNorm, |
|
Qwen3MoeAttention, |
|
Qwen3MoeDecoderLayer, |
|
Qwen3MoeModel, |
|
Qwen3MoeForCausalLM, |
|
) |
|
from .configuration_qwen3_shared_moe import Qwen3SharedMoeConfig |
|
|
|
import scattermoe |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class Qwen3SharedMoeSparseMoeBlock(nn.Module): |
|
def __init__(self, config: Qwen3SharedMoeConfig): |
|
super().__init__() |
|
self.config = config |
|
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) |
|
if config.shared_expert_intermediate_size is not None: |
|
self.shared_expert = Qwen3MoeMLP( |
|
config, intermediate_size=config.shared_expert_intermediate_size |
|
) |
|
else: |
|
self.shared_expert = None |
|
self.moe_mlp = scattermoe.mlp.GLUMLP( |
|
input_size=self.config.hidden_size, |
|
hidden_size=self.config.moe_intermediate_size, |
|
num_experts=self.config.num_experts, |
|
top_k=self.config.num_experts_per_tok, |
|
activation=ACT2FN[config.hidden_act], |
|
) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
|
batch_size, sequence_length, hidden_dim = hidden_states.shape |
|
hidden_states = hidden_states.view(-1, hidden_dim) |
|
|
|
router_logits = self.gate(hidden_states) |
|
|
|
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) |
|
routing_weights, selected_experts = torch.topk( |
|
routing_weights, self.config.num_experts_per_tok, dim=-1 |
|
) |
|
if self.config.norm_topk_prob: |
|
routing_weights /= routing_weights.sum(dim=-1, keepdim=True) |
|
|
|
routing_weights = routing_weights.to(hidden_states.dtype) |
|
|
|
|
|
hs_0 = self.moe_mlp(hidden_states, routing_weights, selected_experts) |
|
|
|
if self.shared_expert is not None: |
|
shared_res = self.shared_expert(hidden_states) |
|
res = hs_0 + shared_res |
|
else: |
|
res = hs_0 |
|
res = res.reshape(batch_size, sequence_length, hidden_dim) |
|
return res, router_logits |
|
|
|
|
|
class Qwen3SharedMoeDecoderLayer(Qwen3MoeDecoderLayer, nn.Module): |
|
def __init__(self, config: Qwen3SharedMoeConfig, layer_idx: int): |
|
super().__init__(config, layer_idx) |
|
self.hidden_size = config.hidden_size |
|
|
|
self.self_attn = Qwen3MoeAttention(config, layer_idx) |
|
|
|
if (layer_idx not in config.mlp_only_layers) and ( |
|
config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 |
|
): |
|
self.mlp = Qwen3SharedMoeSparseMoeBlock(config) |
|
else: |
|
self.mlp = Qwen3MoeMLP(config, intermediate_size=config.intermediate_size) |
|
|
|
self.input_layernorm = Qwen3MoeRMSNorm( |
|
config.hidden_size, eps=config.rms_norm_eps |
|
) |
|
self.post_attention_layernorm = Qwen3MoeRMSNorm( |
|
config.hidden_size, eps=config.rms_norm_eps |
|
) |
|
|
|
|
|
class Qwen3SharedMoeModel(Qwen3MoeModel): |
|
config_class = Qwen3SharedMoeConfig |
|
|
|
def __init__(self, config: Qwen3SharedMoeConfig): |
|
super().__init__(config) |
|
self.layers = nn.ModuleList( |
|
[ |
|
Qwen3SharedMoeDecoderLayer(config, layer_idx) |
|
for layer_idx in range(config.num_hidden_layers) |
|
] |
|
) |
|
|
|
|
|
class Qwen3SharedMoeForCausalLM(Qwen3MoeForCausalLM): |
|
config_class = Qwen3SharedMoeConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = Qwen3SharedMoeModel(config) |
|
self.num_experts = config.num_experts |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_router_logits: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
**kwargs, |
|
) -> MoeCausalLMOutputWithPast: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
logits_to_keep (`int` or `torch.Tensor`, *optional*): |
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all |
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that |
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size. |
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. |
|
This is useful when using packed tensor format (single dimension for batch and sequence length). |
|
|
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM |
|
|
|
>>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") |
|
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") |
|
|
|
>>> prompt = "Hey, are you conscious? Can you talk to me?" |
|
>>> inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
>>> # Generate |
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
|
```""" |
|
|
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_router_logits = ( |
|
output_router_logits |
|
if output_router_logits is not None |
|
else self.config.output_router_logits |
|
) |
|
|
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
|
|
|
|
outputs: MoeModelOutputWithPast = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
output_router_logits=output_router_logits, |
|
cache_position=cache_position, |
|
**kwargs, |
|
) |
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
|
slice_indices = ( |
|
slice(-logits_to_keep, None) |
|
if isinstance(logits_to_keep, int) |
|
else logits_to_keep |
|
) |
|
logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) |
|
|
|
aux_loss = None |
|
if output_router_logits: |
|
aux_loss = load_balancing_loss_func( |
|
outputs.router_logits, |
|
self.num_experts, |
|
self.num_experts_per_tok, |
|
attention_mask, |
|
) |
|
if labels is not None: |
|
loss += self.router_aux_loss_coef * aux_loss.to( |
|
loss.device |
|
) |
|
|
|
return MoeCausalLMOutputWithPast( |
|
loss=loss, |
|
aux_loss=aux_loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
router_logits=outputs.router_logits, |
|
) |
|
|
|
|