|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Ernie VL model""" |
|
import re |
|
import math |
|
import itertools |
|
from dataclasses import dataclass |
|
from collections import defaultdict |
|
from copy import deepcopy |
|
from functools import partial |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from transformers.activations import ACT2FN |
|
from transformers.generation import GenerationMixin |
|
from transformers.modeling_outputs import ModelOutput |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.utils import logging |
|
from .configuration_ernie_45t_vl import ( |
|
DFNRopeVisionTransformerConfig, |
|
Ernie4_5_MoEConfig, |
|
Ernie4_5_VLMoEConfig, |
|
) |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
__all__ = [ |
|
"Ernie4_5_VLMoeForConditionalGeneration", |
|
"DFNRopeVisionTransformerPreTrainedModel", |
|
"VariableResolutionResamplerModel", |
|
] |
|
|
|
|
|
class TokenType: |
|
"""token type definition""" |
|
|
|
text = 0 |
|
image = 1 |
|
video = 2 |
|
|
|
|
|
class UniqueNameGuard: |
|
"""name guard""" |
|
|
|
def __init__(self, prefix=""): |
|
self.prefix = prefix |
|
self.counter = {} |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
pass |
|
|
|
def get_unique_name(self, name): |
|
"""get unique name""" |
|
if name not in self.counter: |
|
self.counter[name] = 0 |
|
else: |
|
self.counter[name] += 1 |
|
return f"{self.prefix}{name}_{self.counter[name]}" |
|
|
|
|
|
class RopeEmbedding(nn.Module): |
|
""" |
|
Rotary Position Embedding (RoPE) implementation for transformer models. |
|
|
|
RoPE encodes absolute positional information with rotation matrices and |
|
naturally incorporates relative position information in self-attention. |
|
|
|
Args: |
|
head_dim (int): Dimension size of each attention head |
|
compression_ratio (float, optional): Sequence length compression ratio. Defaults to 1.0. |
|
base (int, optional): Base value for frequency calculation. Defaults to 10000. |
|
|
|
Attributes: |
|
head_dim (int): Dimension size of each attention head |
|
compression_ratio (float): Sequence length compression factor |
|
base (int): Base value for frequency calculation |
|
""" |
|
|
|
def __init__(self, head_dim, compression_ratio=1.0, base=10000, freq_allocation=0): |
|
""" |
|
Initialize RoPE embedding layer. |
|
|
|
Args: |
|
head_dim: Dimension of each attention head |
|
compression_ratio: Scaling factor for position indices |
|
base: Base value for frequency calculation |
|
""" |
|
super().__init__() |
|
self.head_dim = head_dim |
|
self.compression_ratio = compression_ratio |
|
self.base = base |
|
|
|
|
|
self.freq_allocation = freq_allocation |
|
|
|
def forward(self, seq_length, position_ids=None): |
|
""" |
|
Compute rotary position embeddings for given sequence length. |
|
|
|
Args: |
|
seq_length (int): Maximum sequence length |
|
position_ids (Tensor, optional): Custom position indices. Defaults to None. |
|
|
|
Returns: |
|
Tensor: Rotary position embeddings of shape [1, 1, seq_length, head_dim] |
|
""" |
|
indices = torch.arange(0, self.head_dim, 2, dtype=torch.float32) |
|
indices = 1 / self.base ** (indices / self.head_dim) |
|
if position_ids is None: |
|
position_ids = torch.arange( |
|
0, seq_length, 1, dtype=torch.float32 |
|
).unsqueeze(1) |
|
position_ids = position_ids / self.compression_ratio |
|
sinusoid_inp = position_ids * indices.unsqueeze(0) |
|
else: |
|
position_ids = position_ids / self.compression_ratio |
|
seq_length = position_ids.shape[-1] |
|
sinusoid_inp = position_ids.unsqueeze(-1).to( |
|
torch.float32 |
|
) * indices.unsqueeze(0) |
|
pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1) |
|
pos_emb = pos_emb.view(-1, 1, seq_length, self.head_dim) |
|
pos_emb = pos_emb.detach() |
|
return pos_emb |
|
|
|
def apply_rotary(self, rp, q, k): |
|
""" |
|
Apply rotary position embeddings to queries and keys. |
|
|
|
Args: |
|
rp (Tensor): Rotary position embeddings |
|
q (Tensor): Query tensor [batch, heads, seq_len, dim] |
|
k (Tensor): Key tensor [batch, heads, seq_len, dim] |
|
|
|
Returns: |
|
Tuple[Tensor, Tensor]: Rotated queries and keys |
|
""" |
|
sin, cos = torch.chunk(rp, 2, dim=-1) |
|
|
|
sin_pos = torch.stack([sin, sin], dim=-1).reshape(rp.shape) |
|
|
|
cos_pos = torch.stack([cos, cos], dim=-1).reshape(rp.shape) |
|
|
|
rotate_half_q = torch.stack( |
|
[-q[:, :, :, 1::2], q[:, :, :, 0::2]], dim=-1 |
|
).reshape(q.shape) |
|
query = (q.to(torch.float32) * cos_pos) + ( |
|
rotate_half_q.to(torch.float32) * sin_pos |
|
) |
|
|
|
rotate_half_k = torch.stack( |
|
[-k[:, :, :, 1::2], k[:, :, :, 0::2]], dim=-1 |
|
).reshape(k.shape) |
|
key = (k.to(torch.float32) * cos_pos) + ( |
|
rotate_half_k.to(torch.float32) * sin_pos |
|
) |
|
return query, key |
|
|
|
def apply_rotary_3d(self, rp, q, k, position_ids): |
|
""" |
|
rope 3d rotary |
|
|
|
args: |
|
rp: [1, max_seqlen, 1, head_dim] |
|
q: [bsz, seqlen, head, head_dim] |
|
k: [bsz, seqlen, head, head_dim] |
|
position_ids: [bsz, seqlen, 3] |
|
""" |
|
current_device = q.device |
|
sin, cos = torch.chunk(rp, 2, axis=-1) |
|
assert position_ids.shape[:1] == q.shape[:1] |
|
batch_indices = torch.arange(end=position_ids.shape[0]) |
|
batch_indices = batch_indices[..., None] |
|
sin = sin.tile(position_ids.shape[0], 1, 1, 1).to(device=position_ids.device) |
|
cos = cos.tile(position_ids.shape[0], 1, 1, 1).to(device=position_ids.device) |
|
|
|
assert self.freq_allocation != 0 |
|
sin_t = sin[batch_indices, position_ids[..., 0], :, -self.freq_allocation :] |
|
sin_h = sin[ |
|
batch_indices, |
|
position_ids[..., 1], |
|
:, |
|
: self.head_dim // 2 - self.freq_allocation : 2, |
|
] |
|
sin_w = sin[ |
|
batch_indices, |
|
position_ids[..., 2], |
|
:, |
|
1 : self.head_dim // 2 - self.freq_allocation : 2, |
|
] |
|
sin_hw = torch.stack([sin_h, sin_w], dim=-1).reshape( |
|
sin_h.shape[:-1] + (sin_h.shape[-1] * 2,) |
|
) |
|
sin_thw = torch.cat([sin_hw, sin_t], dim=-1) |
|
|
|
cos_t = cos[batch_indices, position_ids[..., 0], :, -self.freq_allocation :] |
|
cos_h = cos[ |
|
batch_indices, |
|
position_ids[..., 1], |
|
:, |
|
: self.head_dim // 2 - self.freq_allocation : 2, |
|
] |
|
cos_w = cos[ |
|
batch_indices, |
|
position_ids[..., 2], |
|
:, |
|
1 : self.head_dim // 2 - self.freq_allocation : 2, |
|
] |
|
cos_hw = torch.stack([cos_h, cos_w], dim=-1).reshape( |
|
cos_h.shape[:-1] + (cos_h.shape[-1] * 2,) |
|
) |
|
cos_thw = torch.cat([cos_hw, cos_t], dim=-1) |
|
|
|
|
|
sin_pos = ( |
|
torch.stack([sin_thw, sin_thw], dim=-1) |
|
.reshape(sin_thw.shape[:3] + (sin_thw.shape[-1] * 2,)) |
|
.to(current_device) |
|
) |
|
|
|
cos_pos = ( |
|
torch.stack([cos_thw, cos_thw], dim=-1) |
|
.reshape(cos_thw.shape[:3] + (cos_thw.shape[-1] * 2,)) |
|
.to(current_device) |
|
) |
|
|
|
|
|
rotate_half_q = torch.stack( |
|
[-q[:, :, :, 1::2], q[:, :, :, 0::2]], dim=-1 |
|
).reshape(q.shape) |
|
query = (q.to(torch.float32) * cos_pos) + ( |
|
rotate_half_q.to(torch.float32) * sin_pos |
|
) |
|
|
|
rotate_half_k = torch.stack( |
|
[-k[:, :, :, 1::2], k[:, :, :, 0::2]], dim=-1 |
|
).reshape(k.shape) |
|
key = (k.to(torch.float32) * cos_pos) + ( |
|
rotate_half_k.to(torch.float32) * sin_pos |
|
) |
|
return query, key |
|
|
|
|
|
class Ernie4_5_MLP(nn.Module): |
|
""" |
|
Ernie4_5_MLP - Gated Multi-Layer Perceptron module used in Ernie model. |
|
""" |
|
|
|
def __init__(self, config, layer_idx=0): |
|
""" |
|
Initialize the MLP module with configuration options. |
|
|
|
Args: |
|
config (Ernie4_5_Config): Model configurations. |
|
layer_idx (int): Index of current layer (default: 0) |
|
""" |
|
super().__init__() |
|
self.config = config |
|
self.hidden_size = config.hidden_size |
|
self.intermediate_size = config.intermediate_size |
|
|
|
self.gate_proj = nn.Linear( |
|
self.hidden_size, self.intermediate_size, bias=config.use_bias |
|
) |
|
self.up_proj = nn.Linear( |
|
self.hidden_size, self.intermediate_size, bias=config.use_bias |
|
) |
|
self.down_proj = nn.Linear( |
|
self.intermediate_size, self.hidden_size, bias=config.use_bias |
|
) |
|
|
|
def forward(self, x): |
|
""" |
|
Forward pass through the MLP module. |
|
|
|
Args: |
|
x (Tensor): Input tensor of shape [batch_size, seq_len, hidden_size] |
|
|
|
Returns: |
|
Tensor: Output tensor of shape [batch_size, seq_len, hidden_size] |
|
""" |
|
current_device = self.gate_proj.weight.data.device |
|
x = x.to(current_device) |
|
down_proj = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
return down_proj |
|
|
|
|
|
class Ernie4_5_Attention(nn.Module): |
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
def __init__(self, config, layer_idx=0): |
|
"""Initialize the attention layer. |
|
|
|
Args: |
|
config (Ernie4_5_Config): Model configuration. |
|
layer_idx (int, optional): Index in transformer stack. Defaults to 0. |
|
""" |
|
super().__init__() |
|
self.layer_idx = layer_idx |
|
self.hidden_size = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.num_key_value_heads = config.num_key_value_heads |
|
self.head_dim = self.hidden_size // self.num_heads |
|
self.is_gqa = ( |
|
self.num_key_value_heads is not None |
|
and self.num_key_value_heads != self.num_heads |
|
) |
|
|
|
self.freq_allocation = getattr(config, "freq_allocation", 0) |
|
assert ( |
|
self.freq_allocation is not None |
|
), "freq_allocation must be provided if rope_3d is on." |
|
|
|
if config.tensor_parallel_degree > 1: |
|
assert ( |
|
self.num_heads % config.tensor_parallel_degree == 0 |
|
), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" |
|
self.num_heads = self.num_heads // config.tensor_parallel_degree |
|
if self.is_gqa: |
|
assert ( |
|
self.num_key_value_heads % config.tensor_parallel_degree == 0 |
|
), f"num_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" |
|
self.num_key_value_heads = ( |
|
self.num_key_value_heads // config.tensor_parallel_degree |
|
) |
|
q_hidden_size = self.head_dim * self.num_heads |
|
if self.is_gqa: |
|
logger.info( |
|
f"use GQA - num_heads: {self.num_heads}- num_key_value_heads: {self.num_key_value_heads}" |
|
) |
|
assert ( |
|
self.num_heads % self.num_key_value_heads == 0 |
|
), f"num_heads: {self.num_heads}, num_key_value_heads: {self.num_key_value_heads}" |
|
kv_hidden_size = self.head_dim * self.num_key_value_heads |
|
else: |
|
kv_hidden_size = self.head_dim * self.num_heads |
|
|
|
self.q_proj = nn.Linear(self.hidden_size, q_hidden_size, bias=config.use_bias) |
|
self.k_proj = nn.Linear(self.hidden_size, kv_hidden_size, bias=config.use_bias) |
|
self.v_proj = nn.Linear(self.hidden_size, kv_hidden_size, bias=config.use_bias) |
|
|
|
self.o_proj = nn.Linear( |
|
self.hidden_size, |
|
self.hidden_size, |
|
bias=config.use_bias, |
|
) |
|
|
|
self.rotary_emb = RopeEmbedding( |
|
self.head_dim, |
|
compression_ratio=config.compression_ratio, |
|
base=config.rope_theta, |
|
freq_allocation=self.freq_allocation, |
|
) |
|
self.config = config |
|
self.attn_func = self.core_attn |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
attn_mask_start_row_indices: Optional[torch.Tensor] = None, |
|
position_ids: Optional[Tuple[torch.Tensor]] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
token_type_ids: Optional[Tuple[torch.Tensor]] = None, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
"""Compute attention outputs. |
|
|
|
Args: |
|
hidden_states (torch.Tensor): Input tensor [bsz, seq_len, hidden_size] |
|
past_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): Cached key/value states |
|
attention_mask (Optional[torch.Tensor]): Attention mask tensor |
|
attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length attention indices |
|
position_ids (Optional[torch.Tensor]): Position indices for RoPE |
|
output_attentions (bool): Return attention weights if True |
|
use_cache (bool): Cache key/value states if True |
|
|
|
Returns: |
|
Tuple containing: |
|
- attention_output: [bsz, seq_len, hidden_size] |
|
- attention_weights: Optional attention probabilities |
|
- updated_key_value_cache: Optional updated cache |
|
""" |
|
if token_type_ids is not None: |
|
token_type_ids = token_type_ids[:, :-1] |
|
|
|
bsz, q_len, _ = hidden_states.shape |
|
query_states = self.q_proj(hidden_states).reshape( |
|
[bsz, q_len, -1, self.head_dim] |
|
) |
|
key_states = self.k_proj(hidden_states).reshape([bsz, q_len, -1, self.head_dim]) |
|
value_states = self.v_proj(hidden_states).reshape( |
|
[bsz, q_len, -1, self.head_dim] |
|
) |
|
|
|
attn_output, attn_weights, past_key_value = self.rope_attn( |
|
query_states=query_states, |
|
key_states=key_states, |
|
value_states=value_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
output_attentions=output_attentions, |
|
past_key_value=past_key_value, |
|
use_cache=use_cache, |
|
attn_mask_start_row_indices=attn_mask_start_row_indices, |
|
) |
|
attn_output = self.o_proj(attn_output) |
|
|
|
if not output_attentions: |
|
attn_weights = None |
|
|
|
return attn_output, attn_weights, past_key_value |
|
|
|
def repeat_kv(self, hidden_states, n_rep): |
|
""" |
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
""" |
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states[:, :, None, :, :].expand( |
|
batch, num_key_value_heads, n_rep, slen, head_dim |
|
) |
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
def core_attn( |
|
self, |
|
q, |
|
k, |
|
v, |
|
attention_mask=None, |
|
attn_mask_start_row_indices=None, |
|
seq_length=None, |
|
): |
|
"""Standard self-attention implementation. |
|
|
|
Args: |
|
q (torch.Tensor): Query tensor |
|
k (torch.Tensor): Key tensor |
|
v (torch.Tensor): Value tensor |
|
attention_mask (Optional[torch.Tensor]): Attention mask |
|
attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length indices |
|
seq_length (Optional[int]): Sequence length |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor]: Attention output and weights |
|
""" |
|
origin_dtype = q.dtype |
|
|
|
q = q.permute(0, 2, 1, 3) |
|
k = k.permute(0, 2, 1, 3) |
|
v = v.permute(0, 2, 1, 3) |
|
|
|
scale_qk_coeff = getattr(self.config, "scale_qk_coeff", 1.0) * ( |
|
self.head_dim**0.5 |
|
) |
|
|
|
q = q / scale_qk_coeff |
|
|
|
|
|
if self.is_gqa: |
|
|
|
repeat_factor = self.num_heads // self.num_key_value_heads |
|
k = self.repeat_kv(k, repeat_factor) |
|
v = self.repeat_kv(v, repeat_factor) |
|
|
|
product = torch.matmul(q, k.transpose(-2, -1)) |
|
|
|
product = product.to(torch.float32) |
|
if getattr(self.config, "scale_qk_coeff", 1.0) != 1.0: |
|
product = product * getattr(self.config, "scale_qk_coeff", 1.0) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) |
|
attention_mask = attention_mask.to(torch.float32) |
|
product = product + attention_mask |
|
weights = F.softmax(product, dim=-1) |
|
else: |
|
seq_len = product.size(-1) |
|
mask = torch.triu( |
|
torch.ones((seq_len, seq_len), dtype=torch.bool, device=product.device), |
|
diagonal=1, |
|
) |
|
product = product.masked_fill(mask, float("-inf")) |
|
weights = F.softmax(product, dim=-1) |
|
|
|
weights = weights.to(origin_dtype) |
|
|
|
if getattr(self.config, "attention_probs_dropout_prob", 0.0) > 0: |
|
weights = F.dropout( |
|
weights, |
|
self.config.attention_probs_dropout_prob, |
|
training=self.training, |
|
) |
|
|
|
out = torch.matmul(weights, v) |
|
|
|
|
|
out = out.permute(0, 2, 1, 3) |
|
out = out.contiguous().view(out.size(0), out.size(1), -1) |
|
|
|
return out, weights |
|
|
|
def rope_attn( |
|
self, |
|
query_states, |
|
key_states, |
|
value_states, |
|
attention_mask, |
|
position_ids, |
|
output_attentions=False, |
|
past_key_value=None, |
|
use_cache=False, |
|
attn_mask_start_row_indices=None, |
|
): |
|
"""Attention computation with rotary embeddings. |
|
|
|
Args: |
|
mix_layer (Optional[torch.Tensor]): Combined QKV projection |
|
query_states (torch.Tensor): Query states |
|
key_states (torch.Tensor): Key states |
|
value_states (torch.Tensor): Value states |
|
attention_mask (Optional[torch.Tensor]): Attention mask |
|
position_ids (Optional[torch.Tensor]): Position indices |
|
output_attentions (bool): Return attention weights |
|
past_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): Cached states |
|
use_cache (bool): Cache new states |
|
attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length indices |
|
|
|
Returns: |
|
Tuple containing: |
|
- attention_output: Result tensor |
|
- attention_weights: Optional weights |
|
- updated_key_value_cache: Optional cache |
|
""" |
|
|
|
query_states_dtype = query_states.dtype |
|
|
|
assert position_ids is not None, "rope3d requires pos-id" |
|
kv_seq_len = position_ids.max() + 1 |
|
offset = 0 |
|
if past_key_value is not None: |
|
offset = position_ids.max() |
|
kv_seq_len = position_ids.max() + 1 |
|
position_ids = position_ids[:, -1:, :] |
|
|
|
cos_sin = self.rotary_emb(kv_seq_len).permute([0, 2, 1, 3]) |
|
if offset > 0 and position_ids is None: |
|
cos_sin = cos_sin[:, offset:] |
|
query_states, key_states = self.rotary_emb.apply_rotary_3d( |
|
cos_sin, query_states, key_states, position_ids |
|
) |
|
|
|
query_states = query_states.to(query_states_dtype) |
|
key_states = key_states.to(query_states_dtype) |
|
if past_key_value is not None: |
|
|
|
key_states = torch.cat([past_key_value[0], key_states], dim=1) |
|
value_states = torch.cat([past_key_value[1], value_states], dim=1) |
|
|
|
|
|
past_key_value = [key_states, value_states] if use_cache else None |
|
seq_length = query_states.shape[1] |
|
attn_output, attn_weights = self.attn_func( |
|
query_states, |
|
key_states, |
|
value_states, |
|
attention_mask, |
|
attn_mask_start_row_indices, |
|
seq_length, |
|
) |
|
|
|
return attn_output, attn_weights, past_key_value |
|
|
|
|
|
class FusedDropoutImpl(nn.Module): |
|
""" |
|
Fused dropout implementation with residual connection support. |
|
|
|
This layer combines dropout and residual addition in a single operation for better performance, |
|
particularly on GPU devices. The dropout is conditionally applied based on the probability. |
|
|
|
Args: |
|
prob (float): Dropout probability (between 0 and 1) |
|
mode (str): Dropout mode, either 'upscale_in_train' or 'downscale_in_infer' |
|
|
|
Attributes: |
|
prob (float): Stores the dropout probability |
|
mode (str): Stores the dropout mode |
|
dropout (nn.Dropout): The actual dropout layer instance |
|
""" |
|
|
|
def __init__(self, prob, mode): |
|
""" |
|
Initialize the fused dropout layer. |
|
|
|
Args: |
|
prob (float): Dropout probability (0 means no dropout) |
|
mode (str): Dropout mode ('upscale_in_train' or 'downscale_in_infer') |
|
""" |
|
super().__init__() |
|
self.prob = prob |
|
self.dropout = nn.Dropout(p=prob) |
|
|
|
def forward(self, x, y): |
|
""" |
|
Forward pass of the fused dropout layer. |
|
|
|
Args: |
|
x (Tensor): Input tensor to potentially apply dropout on |
|
y (Tensor): Residual tensor to add to the (possibly dropped out) x |
|
|
|
Returns: |
|
Tensor: Result of x (with optional dropout) + y |
|
""" |
|
if self.prob > 0: |
|
x = self.dropout(x) |
|
output = x + y |
|
|
|
return output |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
""" |
|
Root Mean Square Layer Normalization (RMSNorm) implementation. |
|
|
|
RMSNorm is a simplified version of LayerNorm that focuses on the root mean square of inputs, |
|
omitting the mean-centering operation. This provides computational efficiency while maintaining |
|
good performance. |
|
|
|
""" |
|
|
|
def __init__(self, config): |
|
""" |
|
Initialize RMSNorm layer. |
|
|
|
Args: |
|
config (Ernie4_5_Config): Model configuration. |
|
""" |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
self.weight = nn.Parameter( |
|
torch.ones(self.hidden_size, dtype=torch.get_default_dtype()) |
|
) |
|
self.variance_epsilon = config.rms_norm_eps |
|
|
|
def forward(self, hidden_states): |
|
""" |
|
Apply RMS normalization to input hidden states. |
|
|
|
Args: |
|
hidden_states (Tensor): Input tensor of shape [batch_size, seq_len, hidden_size] |
|
|
|
Returns: |
|
Tensor: Normalized output tensor of same shape as input |
|
|
|
Note: |
|
- computes RMSNorm manually: |
|
1. Compute variance of features |
|
2. Apply reciprocal square root normalization |
|
3. Scale by learned weight parameter |
|
- Maintains original dtype for numerical stability during computation |
|
""" |
|
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) |
|
hidden_states = torch.rsqrt(variance + self.variance_epsilon) * hidden_states |
|
return hidden_states.to(self.weight.dtype) * self.weight |
|
|
|
|
|
class Ernie4_5_MoeMLP(Ernie4_5_MLP): |
|
"""Mixture of Experts (MoE) variant of ERNIE's MLP layer.""" |
|
|
|
def __init__(self, config, layer_idx=0): |
|
"""Initialize the MoE MLP layer. |
|
|
|
Args: |
|
config (Ernie4_5_MoEConfig): Configuration for MoE architecture. |
|
layer_idx (int): Index of current layer in transformer stack |
|
""" |
|
|
|
if getattr(config, "disable_ffn_model_parallel", False): |
|
config = deepcopy(config) |
|
config.tensor_parallel_degree = 1 |
|
|
|
super().__init__(config, layer_idx=layer_idx) |
|
self.moe_dropout_prob = config.moe_dropout_prob |
|
|
|
def forward(self, x): |
|
"""Forward pass through MoE MLP layer. |
|
|
|
Args: |
|
x (paddle.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size] |
|
or [seq_len, hidden_size] |
|
|
|
Returns: |
|
paddle.Tensor: Output tensor with same shape as input |
|
""" |
|
current_device = self.gate_proj.weight.data.device |
|
x = x.to(current_device) |
|
x = F.silu(self.gate_proj(x)) * self.up_proj(x) |
|
if self.moe_dropout_prob > 0: |
|
x = F.dropout(input=x, p=self.moe_dropout_prob) |
|
ret = self.down_proj(x) |
|
return ret |
|
|
|
|
|
def masked_fill(x, mask, value): |
|
""" |
|
Fills elements of the input tensor with a given value where mask is True. |
|
""" |
|
return torch.where(mask, torch.full_like(x, value), x) |
|
|
|
|
|
def _squared_l2_norm(x: torch.Tensor) -> torch.Tensor: |
|
"""Computes 0.5 * sum(x^2)""" |
|
return 0.5 * torch.sum(x * x) |
|
|
|
|
|
@torch.no_grad() |
|
def compute_optimal_transport(M, r, c, lam=1.0, epsilon=1e-8, max_iters: int = 10): |
|
""" |
|
Computes optimal transport matrix and Sinkhorn distance using Sinkhorn-Knopp algorithm. |
|
""" |
|
n, _ = M.shape |
|
P = F.softmax(-M / lam, dim=1) |
|
u = torch.zeros(n, dtype=torch.float32, device=M.device) |
|
|
|
for _ in range(max_iters): |
|
P_sum_1 = P.sum(1) |
|
if (u - P_sum_1).abs().max() < epsilon: |
|
break |
|
u = P_sum_1 |
|
P *= (r / (u + 1e-8)).unsqueeze(1) |
|
P *= (c / (P.sum(0) + 1e-8)).unsqueeze(0) |
|
|
|
P = torch.where(~P.isnan(), P, torch.zeros_like(P)) |
|
return P, _ |
|
|
|
|
|
class Top2Gate(nn.Module): |
|
""" |
|
Gate module implementing Top2Gating as described in Gshard paper. |
|
""" |
|
|
|
def __init__(self, config, layer_idx: int, group=None, gate_weight=None) -> None: |
|
""" |
|
Initialize the MoE (Mixture of Experts) layer. |
|
|
|
Args: |
|
config: Model configuration containing MoE parameters |
|
layer_idx: Index of this layer in the model |
|
group: Distributed communication group |
|
gate_weight: Optional pre-existing gate weight tensor |
|
""" |
|
super().__init__() |
|
self.config = config |
|
|
|
self.model_dim = config.hidden_size |
|
self.num_experts = config.moe_num_experts |
|
self.num_experts_tensor = ( |
|
sum(config.moe_num_experts) |
|
if config.multimodel_experts |
|
else config.moe_num_experts |
|
) |
|
|
|
self.cap = config.moe_capacity |
|
self.group = group |
|
|
|
self.layer_idx = layer_idx |
|
|
|
self.sinkhorn_2gate = config.sinkhorn_2gate |
|
self.sinkhorn_temp = config.sinkhorn_temp |
|
self.use_correction_bias = config.moe_use_aux_free |
|
self.use_token_type_bias = config.get("moe_use_token_type_bias", False) |
|
|
|
self.act = partial(F.softmax, dim=-1) |
|
|
|
self.no_jitter = True |
|
self.expert_drop = False |
|
self.eye_matrix = None |
|
self.eye_matrix_size = None |
|
self.norm_gate_logits = config.moe_norm_gate_logits |
|
self.one = torch.ones([], dtype=torch.float32) |
|
|
|
self.moe_aux_loss_lambda = torch.tensor(config.moe_aux_loss_lambda).to( |
|
dtype=torch.float32 |
|
) |
|
self.moe_z_loss_lambda = torch.tensor(config.moe_z_loss_lambda).to( |
|
dtype=torch.float32 |
|
) |
|
self.moe_orthogonal_loss_lambda = torch.tensor( |
|
config.moe_orthogonal_loss_lambda |
|
).to(dtype=torch.float32) |
|
|
|
if self.moe_aux_loss_lambda.ndim == 0: |
|
self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.unsqueeze(0) |
|
if self.moe_z_loss_lambda.ndim == 0: |
|
self.moe_z_loss_lambda = self.moe_z_loss_lambda.unsqueeze(0) |
|
if self.moe_orthogonal_loss_lambda.ndim == 0: |
|
self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.unsqueeze( |
|
0 |
|
) |
|
|
|
self.experts_type_ids = None |
|
|
|
self.eps = torch.tensor([1e-12]).to(dtype=torch.float32) |
|
if config.multimodel_experts: |
|
if config.get("moe_use_hard_gate", False): |
|
self.num_experts_list = [] |
|
self.experts_type_mask = [] |
|
|
|
experts_ids = torch.zeros( |
|
[sum(self.num_experts)], dtype=torch.int64 |
|
).reshape((1, -1)) |
|
offset = 0 |
|
for i, expert_num in enumerate(self.num_experts): |
|
experts_ids[:, offset : offset + expert_num] = i |
|
offset += expert_num |
|
self.experts_type_ids = experts_ids.reshape([-1]) |
|
logger.info( |
|
f"use moe_use_hard_gate, experts_ids: {self.experts_type_ids}" |
|
) |
|
for i, expert_num in enumerate(self.num_experts): |
|
self.experts_type_mask.append( |
|
self.experts_type_ids == i, |
|
) |
|
self.num_experts_list.append(expert_num) |
|
else: |
|
|
|
assert ( |
|
not config.moe_group_experts |
|
), "group_experts must use hard_gate when multimodel_experts is True" |
|
else: |
|
self.num_experts_list = [self.num_experts] |
|
|
|
if gate_weight is not None: |
|
self.weight = gate_weight |
|
|
|
assert ( |
|
not self.config.moe_use_token_type_bias |
|
), "gate_weights is from outside, token_type_bias can't be used" |
|
logger.info("moe use gate_weight from outside") |
|
|
|
self._cast_to_low_precision = False |
|
self._cast_to_low_precison = False |
|
else: |
|
self._create_gate_parameter() |
|
logger.info( |
|
f"{config.moe_gate}: w/ capacity: {self.cap} experts:{self.num_experts} " |
|
f"use_token_type_bias:{self.use_token_type_bias} " |
|
f"gate_act:{config.moe_gate_act} " |
|
f"norm_gate_logits={self.norm_gate_logits} use_correction_bias={self.use_correction_bias}" |
|
) |
|
|
|
def _create_gate_parameter(self): |
|
""" |
|
Create gate weight parameter. |
|
""" |
|
if self.config.multimodel_experts: |
|
|
|
self.moe_z_loss_lambda = self.moe_z_loss_lambda.expand( |
|
len(self.num_experts) |
|
) |
|
self.moe_aux_loss_lambda = self.moe_aux_loss_lambda.expand( |
|
len(self.num_experts) |
|
) |
|
self.moe_orthogonal_loss_lambda = self.moe_orthogonal_loss_lambda.expand( |
|
len(self.num_experts) |
|
) |
|
|
|
for i, num_experts in enumerate(self.num_experts): |
|
if i == 1: |
|
with UniqueNameGuard(f"mm_gate_{self.layer_idx}_"): |
|
p = nn.Parameter( |
|
torch.empty( |
|
self.model_dim, |
|
num_experts, |
|
dtype=torch.float32, |
|
device="cpu", |
|
) |
|
) |
|
nn.init.xavier_uniform_(p) |
|
else: |
|
p = nn.Parameter( |
|
torch.empty( |
|
self.model_dim, |
|
num_experts, |
|
dtype=torch.float32, |
|
device="cpu", |
|
) |
|
) |
|
nn.init.xavier_uniform_(p) |
|
self.register_parameter( |
|
"weight" if i == 0 else f"weight_{i}", |
|
p, |
|
) |
|
else: |
|
self.weight = nn.Parameter( |
|
torch.empty(self.model_dim, self.num_experts, dtype=torch.float32) |
|
) |
|
nn.init.xavier_uniform_(self.weight) |
|
|
|
self._cast_to_low_precision = False |
|
self._cast_to_low_precison = False |
|
|
|
def get_gate_weight(self, transform_weight, is_multimodel=True): |
|
""" |
|
在`multimodel_experts` 的情况下,将多个 weights merge 成一个整体 |
|
transform_weight: bool, 按照 local-expert id 将 多模态 weight 交叠 |
|
""" |
|
if not is_multimodel or not self.config.multimodel_experts: |
|
return self.weight |
|
else: |
|
return torch.cat( |
|
[ |
|
getattr(self, "weight" if i == 0 else f"weight_{i}") |
|
for i in range(len(self.num_experts)) |
|
], |
|
-1, |
|
) |
|
|
|
def forward( |
|
self, |
|
input: torch.Tensor, |
|
token_type_ids: torch.Tensor = None, |
|
transform_weight: bool = True, |
|
correction_bias: torch.Tensor = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
""" |
|
Forward pass through the gate. |
|
|
|
Args: |
|
input: Input tensor of shape [Seq, Dim] |
|
token_type_ids: Token type IDs tensor of shape [Seq] |
|
transform_weight: Whether to transform weights for multimodal experts |
|
correction_bias: Bias tensor for correction |
|
|
|
Returns: |
|
tuple: (capacity, dispatch_mask, combine_weights, scatter_index, router_loss, logits) |
|
""" |
|
orig_dtype = input.dtype |
|
current_device = input.device |
|
weight = self.get_gate_weight(transform_weight) |
|
|
|
logits = F.linear( |
|
input.to(dtype=torch.float32, device=current_device), |
|
weight.T.to(dtype=torch.float32, device=current_device), |
|
) |
|
|
|
( |
|
capacity, |
|
dispatch_mask, |
|
combine_weights, |
|
scatter_index, |
|
l_aux, |
|
l_zloss, |
|
) = self.top2_gating( |
|
logits, |
|
correction_bias=( |
|
correction_bias.to(device=current_device) |
|
if correction_bias is not None |
|
else None |
|
), |
|
) |
|
|
|
combine_weights = combine_weights.to(orig_dtype) |
|
return capacity, dispatch_mask, combine_weights, scatter_index, None, logits |
|
|
|
def get_capacity(self, num_tokens, cap_factor=None, is_multimodel=True): |
|
""" |
|
Calculate capacity based on number of tokens. |
|
|
|
Args: |
|
num_tokens: Number of input tokens |
|
cap_factor: Optional capacity factor override |
|
|
|
Returns: |
|
int: Calculated capacity |
|
""" |
|
if is_multimodel and self.config.multimodel_experts: |
|
num_experts = sum(self.num_experts_list) |
|
elif isinstance(self.num_experts, (list, tuple)): |
|
num_experts = self.num_experts[0] |
|
else: |
|
num_experts = self.num_experts |
|
if cap_factor is not None: |
|
cap = cap_factor |
|
else: |
|
if self.training: |
|
cap = self.cap[0] |
|
elif num_tokens < num_experts: |
|
cap = self.cap[2] |
|
else: |
|
cap = self.cap[1] |
|
|
|
capacity = int(cap * num_tokens // num_experts) |
|
assert ( |
|
capacity > 0 |
|
), f"requires capacity to >= 0. cap={cap}, num_tokens={num_tokens}" |
|
return capacity |
|
|
|
def top2_gating(self, logits, cap=None, correction_bias=None): |
|
""" |
|
Implement Top2 gating mechanism. |
|
|
|
Args: |
|
logits: Input logits tensor |
|
cap: Optional capacity override |
|
correction_bias: Bias tensor for correction |
|
|
|
Returns: |
|
tuple: (capacity, dispatch_masks, combine_weights, scatter_indexes, loss_aux, loss_z) |
|
|
|
Note: |
|
capacity: The maximum number that each token can be dispatched. |
|
dispatch_masks: Masks used for dispatching. The first element is the mask for the first |
|
type of tokens; the second element is the mask for the second type of tokens. |
|
combine_weights: Weights used for combining. The first element is the weight for the first |
|
type of tokens; the second element is the weight for the second type of tokens. |
|
scatter_indexes: Indexes used for scattering. The first element is the index for the first |
|
type of tokens; the second element is the index for the second type of tokens. |
|
loss_aux: Auxiliary loss. |
|
loss_z: Z loss. |
|
""" |
|
gates = self.act(logits) |
|
|
|
|
|
assert logits.ndim == 2, logits.shape |
|
num_tokens = gates.shape[0] |
|
num_experts = gates.shape[1] |
|
|
|
capacity = self.get_capacity(logits.shape[0], cap) |
|
current_device = logits.device |
|
|
|
|
|
score_for_argmax = ( |
|
gates + correction_bias.unsqueeze(0) |
|
if correction_bias is not None |
|
else gates |
|
) |
|
indices1_s = torch.argmax(score_for_argmax, dim=1) |
|
mask1 = F.one_hot(indices1_s, num_classes=num_experts).to( |
|
dtype=torch.int64, device=current_device |
|
) |
|
|
|
|
|
|
|
if self.training and not self.no_jitter: |
|
gumbels = ( |
|
-torch.empty_like( |
|
logits, |
|
device=current_device, |
|
) |
|
.exponential_() |
|
.log() |
|
) |
|
logits_w_noise = logits + gumbels |
|
else: |
|
logits_w_noise = logits |
|
|
|
logits_except1 = masked_fill( |
|
logits_w_noise, |
|
mask1.to(dtype=torch.bool, device=current_device), |
|
float("-inf"), |
|
) |
|
score_for_argmax = ( |
|
self.act(logits_except1) + correction_bias.unsqueeze(0) |
|
if correction_bias is not None |
|
else logits_except1 |
|
) |
|
indices2_s_original = torch.argmax(score_for_argmax, dim=1) |
|
|
|
if self.training and self.sinkhorn_2gate: |
|
r = ( |
|
torch.ones(num_tokens, dtype=torch.float32, device=current_device) |
|
/ num_tokens |
|
) |
|
c_mask_sum = mask1.to(dtype=torch.float32, device=current_device).sum(0) |
|
c = capacity - c_mask_sum |
|
c = torch.maximum(c, torch.zeros_like(c, device=current_device)) |
|
c_sum = c.sum() |
|
if c_sum > 0: |
|
c = c / c_sum |
|
else: |
|
c = torch.ones_like(c, device=current_device) / num_experts |
|
|
|
pi, _ = compute_optimal_transport( |
|
-logits_except1.to(dtype=torch.float32, device=current_device).detach(), |
|
r, |
|
c, |
|
lam=self.sinkhorn_temp, |
|
) |
|
pi = masked_fill( |
|
pi, mask1.to(dtype=torch.bool, device=current_device), float("-inf") |
|
) |
|
indices2_s = torch.argmax(pi, dim=1) |
|
else: |
|
indices2_s = indices2_s_original |
|
|
|
mask2 = F.one_hot(indices2_s, num_classes=self.num_experts).to( |
|
dtype=torch.int64, device=current_device |
|
) |
|
|
|
|
|
locations1 = ( |
|
torch.cumsum(mask1, dim=0) - 1 |
|
) |
|
locations2 = torch.cumsum(mask2, dim=0) - 1 |
|
|
|
locations2 += torch.sum(mask1, dim=0, keepdim=True) |
|
|
|
|
|
mask1 = mask1 * (locations1 < capacity).to( |
|
dtype=torch.int64, device=current_device |
|
) |
|
mask2 = mask2 * (locations2 < capacity).to( |
|
dtype=torch.int64, device=current_device |
|
) |
|
|
|
|
|
locations1_s = torch.sum(locations1 * mask1, dim=1) |
|
locations2_s = torch.sum(locations2 * mask2, dim=1) |
|
|
|
|
|
mask1_float = mask1.to(dtype=torch.float32, device=current_device) |
|
mask2_float = mask2.to(dtype=torch.float32, device=current_device) |
|
gates1_s = (gates * mask1_float).sum(dim=-1) |
|
gates2_s = (gates * mask2_float).sum(dim=-1) |
|
|
|
|
|
if self.norm_gate_logits: |
|
denom_s = gates1_s + gates2_s |
|
|
|
denom_s = torch.clamp(denom_s, min=1e-6) |
|
gates1_s /= denom_s |
|
gates2_s /= denom_s |
|
if self.training and self.expert_drop: |
|
|
|
gates2_s = torch.where( |
|
2 * gates2_s < torch.rand_like(gates2_s, device=current_device), |
|
torch.zeros_like(gates2_s, device=current_device), |
|
gates2_s, |
|
) |
|
|
|
|
|
gates1 = gates1_s.unsqueeze(1) * mask1_float |
|
gates2 = gates2_s.unsqueeze(1) * mask2_float |
|
|
|
combine1_weight, expert1_index = torch.max(gates1, dim=-1, keepdim=True) |
|
scatter1_index = expert1_index.squeeze(-1) * capacity + locations1_s |
|
scatter1_index = scatter1_index.to(dtype=torch.int64, device=current_device) |
|
dispatch1_mask = combine1_weight.to( |
|
dtype=torch.bool, device=current_device |
|
).detach() |
|
|
|
combine2_weight, expert2_index = torch.max(gates2, dim=-1, keepdim=True) |
|
scatter2_index = expert2_index.squeeze(-1) * capacity + locations2_s |
|
scatter2_index = scatter2_index.to(dtype=torch.int64, device=current_device) |
|
dispatch2_mask = combine2_weight.to( |
|
dtype=torch.bool, device=current_device |
|
).detach() |
|
|
|
|
|
return ( |
|
capacity, |
|
torch.cat((dispatch1_mask, dispatch2_mask), 1), |
|
torch.cat((combine1_weight, combine2_weight), 1), |
|
torch.stack((scatter1_index, scatter2_index), 1), |
|
None, |
|
None, |
|
) |
|
|
|
def _cal_orthogonal_loss_opt_each_weight(self, weight, use_group): |
|
""" |
|
Calculate optimized orthogonal loss for each weight. |
|
|
|
Args: |
|
weight: Weight tensor |
|
use_group: Whether to use expert groups |
|
|
|
Returns: |
|
Tensor: Calculated orthogonal loss |
|
""" |
|
if weight.dtype != torch.float32: |
|
weight = weight.to(torch.float32) |
|
|
|
wnorm = torch.norm(weight, p=2, dim=1) |
|
weight = weight / torch.maximum(wnorm, self.eps.to(weight.device)).unsqueeze(1) |
|
|
|
if use_group: |
|
weight = weight.reshape( |
|
[self.config.moe_k, -1, weight.shape[1]] |
|
) |
|
eye_matrix = torch.eye( |
|
weight.shape[1], dtype=weight.dtype, device=weight.device |
|
).unsqueeze(0) |
|
else: |
|
eye_matrix = torch.eye( |
|
weight.shape[0], dtype=weight.dtype, device=weight.device |
|
) |
|
|
|
weight_matmul = torch.matmul(weight, weight.T) |
|
|
|
orthogonal_loss = weight_matmul - eye_matrix |
|
orthogonal_loss = _squared_l2_norm(orthogonal_loss) / ( |
|
orthogonal_loss.size(0) * orthogonal_loss.size(1) |
|
) |
|
return orthogonal_loss |
|
|
|
|
|
class TopKGate(Top2Gate): |
|
""" |
|
Fused version of TopK gate for improved performance. |
|
""" |
|
|
|
def forward( |
|
self, |
|
input: torch.Tensor, |
|
token_type_ids=None, |
|
transform_weight=True, |
|
is_multimodel=True, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
""" |
|
Forward pass for fused gate. |
|
|
|
Args: |
|
input: Input tensor |
|
token_type_ids: Token type IDs |
|
transform_weight: Whether to transform weights |
|
|
|
Returns: |
|
tuple: (logits, capacity, router_loss) |
|
""" |
|
current_device = input.device |
|
weight = self.get_gate_weight(transform_weight, is_multimodel=is_multimodel) |
|
|
|
logits = F.linear( |
|
input.to(dtype=torch.float32, device=current_device), |
|
weight.T.to(dtype=torch.float32, device=current_device), |
|
) |
|
if self.use_token_type_bias: |
|
assert token_type_ids is not None |
|
assert ( |
|
token_type_ids.max() < self.bias.shape[0] |
|
), f"token_type_ids {token_type_ids.max()} >= bias shape {self.bias.shape[0]}" |
|
bias = self.bias[token_type_ids] |
|
logits = logits + bias |
|
|
|
return logits |
|
|
|
|
|
gate_class = dict( |
|
top2=Top2Gate, |
|
topk=TopKGate, |
|
) |
|
|
|
|
|
def get_gate( |
|
config: Ernie4_5_MoEConfig, |
|
expert: nn.Module, |
|
layer_idx: int, |
|
) -> Tuple[nn.Module, nn.ModuleList]: |
|
"""Initialize and distribute MoE (Mixture of Experts) components. |
|
|
|
Creates gate layer and distributed expert network for MoE architecture. |
|
|
|
Args: |
|
config (Ernie4_5_MoEConfig): Configuration for MoE architecture |
|
expert (nn.Module): Prototype expert network to be replicated |
|
layer_idx (int): Index of current layer in transformer stack |
|
|
|
Returns: |
|
Tuple[nn.Module, nn.ModuleList]: |
|
- gate: Initialized gate layer for routing |
|
- experts: ModuleList containing expert networks |
|
""" |
|
moe_num_experts = ( |
|
sum(config.moe_num_experts) |
|
if config.multimodel_experts |
|
else config.moe_num_experts |
|
) |
|
experts = nn.ModuleList([]) |
|
|
|
for expert_id, (experts_num, fc) in enumerate(expert): |
|
experts_to_append = [] |
|
if not hasattr(fc, "__len__"): |
|
experts_to_append.append(fc) |
|
if expert_id == 1: |
|
with UniqueNameGuard("_mm_deepcopy"): |
|
for _ in range(experts_num - 1): |
|
experts_to_append.append(deepcopy(fc)) |
|
else: |
|
for _ in range(experts_num - 1): |
|
experts_to_append.append(deepcopy(fc)) |
|
else: |
|
experts_to_append = fc |
|
for ex in experts_to_append: |
|
for p in ex.parameters(): |
|
p.expert_type = f"expert_type_{expert_id}" |
|
index = 0 |
|
for i in range(experts_num): |
|
if i // experts_num == 0: |
|
experts.append(experts_to_append[index]) |
|
index += 1 |
|
else: |
|
experts.append(None) |
|
|
|
assert ( |
|
len(experts) == moe_num_experts |
|
), f"experts.len={len(experts)} != experts_num={experts_num}" |
|
logger.info(f"MOE-GATE:-{config.moe_gate}") |
|
|
|
gate = gate_class[config.moe_gate.lower()](config, layer_idx=layer_idx) |
|
|
|
if config.multimodel_experts and config.moe_use_hard_gate and moe_num_experts > 2: |
|
lm_experts = experts[: config.moe_num_experts[0]] |
|
lm_gate = gate |
|
else: |
|
if config.multimodel_experts and config.moe_use_hard_gate: |
|
lm_gate, lm_experts = gate, experts |
|
else: |
|
lm_gate, lm_experts = None, None |
|
|
|
logger.info(f"LM-experts-{lm_experts} -- experts-{experts}") |
|
|
|
return gate, experts, lm_gate, lm_experts |
|
|
|
|
|
class MoEStatics(nn.Module): |
|
""" |
|
Stores MoE (Mixture of Experts) statistics |
|
and expert usage information. |
|
""" |
|
|
|
def __init__(self, config, layer_idx): |
|
""" |
|
Initialize MoE statistics tracking. |
|
|
|
Args: |
|
config: Model configuration containing MoE parameters |
|
layer_idx: Index of the MoE layer in the model |
|
""" |
|
super().__init__() |
|
self._cast_to_low_precision = False |
|
self._cast_to_low_precison = False |
|
num_experts = ( |
|
config.moe_num_experts[0] |
|
if config.multimodel_experts |
|
else config.moe_num_experts |
|
) |
|
if config.multimodel_experts: |
|
assert ( |
|
len(set(config.moe_num_experts)) == 1 |
|
), "assume expert group has same size, got: {config.moe_num_experts}" |
|
|
|
with UniqueNameGuard(f"mm_layer_{layer_idx}_"): |
|
num_experts_groups = ( |
|
len(config.moe_num_experts) if config.multimodel_experts else 1 |
|
) |
|
p = nn.Parameter( |
|
torch.zeros(num_experts_groups, num_experts, dtype=torch.float32), |
|
requires_grad=False, |
|
) |
|
self.e_score_correction_bias = p |
|
p = torch.zeros(num_experts_groups, num_experts, dtype=torch.int64) |
|
self.expert_usage = p |
|
|
|
|
|
def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): |
|
""" |
|
Reorders input tensor based on gate results with capacity truncation and padding. |
|
|
|
Args: |
|
x (Tensor): Input tensor of shape [Seq, Dim] |
|
dispatch_mask (Tensor): Dispatching mask of shape [Seq, 2] |
|
scatter_index (Tensor): Scatter indices of shape [Seq, 2] |
|
num_experts (int): Number of experts |
|
capacity (int): Capacity per expert |
|
|
|
Returns: |
|
Tensor: Dispatched output tensor of shape [Expert*Capacity, Dim] |
|
""" |
|
output = None |
|
orig_dtype = x.dtype |
|
scatter_index_unbound = [scatter_index[:, 0], scatter_index[:, 1]] |
|
dispatch_mask_unbound = [dispatch_mask[:, 0], dispatch_mask[:, 1]] |
|
|
|
for i_scatter_index, i_dispatch_mask in zip( |
|
scatter_index_unbound, dispatch_mask_unbound |
|
): |
|
updates = x * i_dispatch_mask.unsqueeze(-1).to(orig_dtype) |
|
init_output = torch.zeros( |
|
num_experts * capacity, x.shape[-1], dtype=orig_dtype, device=x.device |
|
) |
|
|
|
index = i_scatter_index.unsqueeze(-1).expand(-1, x.shape[-1]) |
|
if output is None: |
|
output = init_output.scatter_add(0, index, updates) |
|
else: |
|
output = output + init_output.scatter_add(0, index, updates) |
|
if output.dtype != orig_dtype: |
|
output = output.to(orig_dtype) |
|
return output |
|
|
|
|
|
def combining(x, combine_weights, scatter_index): |
|
""" |
|
Combines and aggregates input matrix using combination weights. |
|
|
|
Args: |
|
x (Tensor): Input tensor of shape [num_experts * capacity, dim] |
|
combine_weights (Tensor): Combination weights of shape [seq, 2] |
|
scatter_index (Tensor): Scatter indices of shape [seq, 2] |
|
|
|
Returns: |
|
Tensor: Combined output tensor of shape [seq, dim] |
|
""" |
|
dim = x.shape[-1] |
|
|
|
current_device = scatter_index.device |
|
x = x.to(current_device) |
|
scatter_index = scatter_index.reshape([-1]) |
|
num_k = combine_weights.shape[-1] |
|
|
|
combine_weights = combine_weights.unsqueeze(1).to(current_device) |
|
|
|
x = x[scatter_index].reshape([-1, num_k, dim]) |
|
|
|
return torch.matmul(combine_weights, x).squeeze( |
|
1 |
|
) |
|
|
|
|
|
class MOELayer(nn.Module): |
|
""" |
|
Mixture of Experts layer implementation based on GShard paper. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
gate: nn.Module, |
|
experts: List[nn.Module], |
|
layer_idx: int, |
|
shared_experts: Optional[List[nn.Module]] = None, |
|
group=None, |
|
recompute: bool = False, |
|
k: int = 2, |
|
all_to_all_dropout: float = 0, |
|
group_experts: bool = False, |
|
moe_statics=None, |
|
moe_num_experts=None, |
|
): |
|
""" |
|
Initialize MoE layer. |
|
|
|
Args: |
|
gate: Gate network for expert selection |
|
experts: List of expert networks |
|
layer_idx: Index of this layer in the model |
|
group: Distributed communication group |
|
recompute: Whether to enable recomputation |
|
k: Number of experts to select per token |
|
all_to_all_dropout: Dropout rate for all-to-all communication |
|
group_experts: Whether to group experts |
|
moe_statics: MoE statistics tracking object |
|
""" |
|
super().__init__() |
|
self.gate = gate |
|
self.layer_idx = layer_idx |
|
|
|
if isinstance(experts, nn.ModuleList): |
|
self.experts = experts |
|
else: |
|
logger.info(f"using fused experts, type={type(experts)}") |
|
self.experts = experts |
|
self.shared_experts = shared_experts |
|
|
|
self.group = group |
|
self.k = k |
|
self.all_to_all_dropout = all_to_all_dropout |
|
self.use_correction_bias = moe_statics is not None |
|
self.moe_statics = moe_statics |
|
if self.use_correction_bias: |
|
logger.info( |
|
f"using correction bias, aux-coef:{self.gate.config.moe_aux_loss_lambda}" |
|
) |
|
assert self.gate.config.moe_use_aux_free |
|
|
|
try: |
|
self.world_size = torch.distributed.get_world_size() |
|
self.rank = torch.distributed.get_rank() |
|
except: |
|
self.world_size = 1 |
|
self.rank = 0 |
|
if self.world_size < 1: |
|
self.world_size = 1 |
|
if self.rank < 0: |
|
self.rank = 0 |
|
|
|
self.multimodal_experts = ( |
|
isinstance(moe_num_experts, (tuple, list)) and len(moe_num_experts) > 1 |
|
) |
|
self.num_local_experts = len(self.experts) // self.world_size |
|
if self.multimodal_experts: |
|
self.num_local_multimodal_experts = [ |
|
num // self.world_size for num in moe_num_experts |
|
] |
|
self.multimodal_expert_index = [0] + list( |
|
itertools.accumulate(moe_num_experts) |
|
) |
|
|
|
self.input_preprocess = self.output_postprocess = None |
|
self.group_experts = group_experts |
|
self.config = self.gate.config |
|
self.zero = torch.tensor(0).to(dtype=torch.float32) |
|
|
|
def forward_experts(self, dispatched_input): |
|
""" |
|
Forward pass through experts sequentially. |
|
|
|
Args: |
|
dispatched_input: Input tensor of shape [num_experts, capacity, dim] |
|
|
|
Returns: |
|
Tensor: Expert outputs of shape [num_experts, capacity, dim] |
|
""" |
|
|
|
if not self.multimodal_experts: |
|
true_experts = self.experts[ |
|
self.rank |
|
* self.num_local_experts : (self.rank + 1) |
|
* self.num_local_experts |
|
] |
|
else: |
|
true_experts = [] |
|
for i, num in enumerate(self.num_local_multimodal_experts): |
|
current_modal_experts = self.experts[ |
|
self.multimodal_expert_index[i] : self.multimodal_expert_index[ |
|
i + 1 |
|
] |
|
] |
|
true_experts.extend( |
|
current_modal_experts[self.rank * num : (self.rank + 1) * num] |
|
) |
|
|
|
dispatched_input = dispatched_input.reshape( |
|
[self.world_size, self.num_local_experts, -1, dispatched_input.shape[-1]] |
|
) |
|
current_device = dispatched_input.device |
|
expert_outputs = [] |
|
if isinstance(self.experts, nn.ModuleList): |
|
chunks = dispatched_input.permute(1, 0, 2, 3).contiguous().unbind(0) |
|
assert len(chunks) == len( |
|
true_experts |
|
), f"{len(chunks)}, {len(true_experts)}" |
|
for chunk, expert in zip(chunks, true_experts): |
|
expert_outputs.append(expert(chunk)) |
|
else: |
|
dispatched_input = dispatched_input.permute(1, 0, 2, 3).contiguous() |
|
orig_shape = dispatched_input.shape |
|
chunks = dispatched_input.reshape(orig_shape[0], -1, orig_shape[-1]) |
|
chunks = self.experts(chunks) |
|
chunks = chunks.reshape(orig_shape[:-1] + (chunks.shape[-1],)).unbind(0) |
|
expert_outputs.extend(chunks) |
|
|
|
for i, expert_output in enumerate(expert_outputs): |
|
expert_outputs[i] = expert_output.to(current_device) |
|
expert_output = torch.stack(expert_outputs, dim=1) |
|
return expert_output |
|
|
|
def moe_gate_dispatch( |
|
self, |
|
x: torch.Tensor, |
|
gate_logits: torch.Tensor, |
|
k: int, |
|
capacity: Optional[int], |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
"""dispatch input to experts based on gate logits""" |
|
|
|
S, H = x.shape |
|
E = gate_logits.shape[1] |
|
device = x.device |
|
topk_prob, topk_idx = torch.topk(gate_logits, k, dim=-1) |
|
combine_weights = topk_prob |
|
expert_id = topk_idx |
|
y = x.new_zeros((E, capacity, H)) |
|
scatter_index = x.new_full((k, S), -1, dtype=torch.int32) |
|
|
|
slot_counter = torch.zeros(E, dtype=torch.int32, device=device) |
|
|
|
for tok in range(S): |
|
for route in range(k): |
|
e = expert_id[tok, route].item() |
|
slot = slot_counter[e].item() |
|
if slot >= capacity: |
|
combine_weights[tok, route] = 0.0 |
|
continue |
|
|
|
scatter_index[route, tok] = e * capacity + slot |
|
y[e, slot] = x[tok] |
|
slot_counter[e] += 1 |
|
|
|
expert_offset = torch.cumsum(slot_counter, 0, dtype=torch.int64) |
|
|
|
return y, combine_weights, scatter_index, expert_offset, expert_id |
|
|
|
def gate_and_dispatch(self, input, token_type_ids=None, is_multimodel=True): |
|
""" |
|
Calculate gate and dispatch inputs. |
|
|
|
Args: |
|
input: Input tensor of shape [seq, dim] |
|
|
|
Returns: |
|
tuple: (dispatched_input, combine_weights, dispatch_mask, |
|
scatter_index, router_loss, gate_logits, gate_prob) |
|
""" |
|
d_model = input.shape[1] |
|
if isinstance(self.gate, (TopKGate)): |
|
capacity = self.gate.get_capacity( |
|
input.shape[0], is_multimodel=is_multimodel |
|
) |
|
if token_type_ids is not None: |
|
token_type_ids = token_type_ids.reshape([-1]) |
|
gate_logits = self.gate( |
|
input, token_type_ids=token_type_ids, is_multimodel=is_multimodel |
|
) |
|
prob = self.gate.act(gate_logits) |
|
( |
|
dispatched_input, |
|
combine_weights_unnorm, |
|
scatter_index, |
|
dispatch_mask, |
|
_, |
|
) = self.moe_gate_dispatch(input, prob, k=self.k, capacity=capacity) |
|
dispatch_mask = torch.diff(F.pad(dispatch_mask, (1, 0))) |
|
|
|
scatter_index.detach() |
|
dispatch_mask.detach() |
|
|
|
scatter_index = scatter_index.transpose(0, 1) |
|
combine_weights = combine_weights_unnorm / torch.clamp( |
|
combine_weights_unnorm.sum(dim=-1, keepdim=True), min=1e-12 |
|
) |
|
combine_weights = combine_weights.to(dtype=dispatched_input.dtype) |
|
|
|
else: |
|
( |
|
capacity, |
|
dispatch_mask, |
|
combine_weights, |
|
scatter_index, |
|
router_loss, |
|
gate_logits, |
|
) = self.gate( |
|
input, |
|
) |
|
prob = None |
|
dispatched_input = dispatching( |
|
input, |
|
dispatch_mask, |
|
scatter_index, |
|
num_experts=self.world_size * self.num_local_experts, |
|
capacity=capacity, |
|
) |
|
|
|
dispatched_input = dispatched_input.reshape( |
|
[self.world_size * self.num_local_experts, capacity, d_model] |
|
) |
|
|
|
dispatch_mask = dispatch_mask.detach() |
|
scatter_index = scatter_index.detach() |
|
return ( |
|
dispatched_input, |
|
combine_weights, |
|
dispatch_mask, |
|
scatter_index, |
|
None, |
|
gate_logits, |
|
prob, |
|
) |
|
|
|
def combine_expert_output(self, expert_output, combine_weights, scatter_index): |
|
""" |
|
Combine expert outputs using combination weights. |
|
|
|
Args: |
|
expert_output: Expert outputs [num_experts, capacity, dim] |
|
combine_weights: Combination weights |
|
scatter_index: Scatter indices |
|
|
|
Returns: |
|
Tensor: Combined output [seqlen, dim] |
|
""" |
|
expert_output = expert_output.reshape( |
|
[-1, expert_output.shape[-1]] |
|
) |
|
|
|
combined_output = combining(expert_output, combine_weights, scatter_index) |
|
|
|
if self.output_postprocess is not None: |
|
combined_output = self.output_postprocess(combined_output) |
|
|
|
return combined_output |
|
|
|
def forward( |
|
self, |
|
input: torch.Tensor, |
|
token_type_ids=None, |
|
is_multimodel=True, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
""" |
|
Forward pass through MoE layer. |
|
|
|
Args: |
|
input: Input tensor of shape [s, d] |
|
|
|
Returns: |
|
tuple: (output, combine_weights, router_loss, gate_logits) |
|
""" |
|
if input.dim() == 3: |
|
orig_shape = input.shape |
|
input = input.reshape([-1, input.shape[-1]]) |
|
else: |
|
orig_shape = None |
|
assert ( |
|
input.dim() == 2 |
|
), f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}" |
|
if token_type_ids is not None: |
|
token_type_ids = token_type_ids.clone()[:, :-1] |
|
|
|
assert self.gate is not None |
|
|
|
gate_input = input |
|
|
|
( |
|
dispatched_input, |
|
combine_weights, |
|
dispatch_mask, |
|
scatter_index, |
|
router_loss, |
|
gate_logits, |
|
gate_prob, |
|
) = self.gate_and_dispatch( |
|
gate_input, token_type_ids, is_multimodel=is_multimodel |
|
) |
|
|
|
if self.shared_experts is not None: |
|
shared_out = self.shared_experts(input) |
|
|
|
expert_out = self.forward_experts(dispatched_input) |
|
|
|
combined_output = self.combine_expert_output( |
|
expert_out, combine_weights, scatter_index |
|
) |
|
|
|
if self.shared_experts is not None: |
|
combined_output += shared_out |
|
|
|
if orig_shape: |
|
combined_output = combined_output.clone().reshape( |
|
orig_shape[:-1] + (combined_output.shape[-1],) |
|
) |
|
return combined_output, combine_weights, None, gate_logits |
|
|
|
|
|
class MOEAllGatherLayerV2(MOELayer): |
|
""" |
|
MoE Layer with allgather implement. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
gate: nn.Module, |
|
experts: List[nn.Module], |
|
layer_idx, |
|
shared_experts: Optional[List[nn.Module]] = None, |
|
group=None, |
|
recompute=False, |
|
k=2, |
|
enable_reverse_token_drop=False, |
|
all_to_all_dropout=0, |
|
group_experts=False, |
|
use_expert_out_alltoall=True, |
|
use_expert_alltoall_overlap=False, |
|
use_padding=True, |
|
dense_token_type=3, |
|
moe_statics=None, |
|
moe_num_experts=None, |
|
): |
|
super().__init__( |
|
gate, |
|
experts, |
|
layer_idx, |
|
shared_experts, |
|
group, |
|
recompute, |
|
k, |
|
all_to_all_dropout, |
|
group_experts, |
|
moe_statics, |
|
moe_num_experts, |
|
) |
|
self.enable_reverse_token_drop = enable_reverse_token_drop |
|
self.is_allgather_moe_layer = True |
|
self.use_padding = use_padding |
|
|
|
self.send_rank = None |
|
self.local_expert_id = None |
|
self.dense_experts = None |
|
self.dense_token_type = dense_token_type |
|
self.capacity_tensor = None |
|
logger.info( |
|
f"uisng MOEAllGatherLayerV2, use_expert_out_alltoall={use_expert_out_alltoall}, " |
|
f"use_padding={use_padding}, use_expert_alltoall_overlap={use_expert_alltoall_overlap} " |
|
f"enable_reverse_token_drop={self.enable_reverse_token_drop}" |
|
) |
|
self.two = torch.tensor(2).to(dtype=torch.float32) |
|
self.zero = torch.tensor(0).to(dtype=torch.float32) |
|
|
|
def forward( |
|
self, |
|
input: torch.Tensor, |
|
token_type_ids=None, |
|
use_dense_expert=False, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
"""Implements forward pass for Mixture-of-Experts (MoE) layer with distributed communication. |
|
|
|
Core Functionality: |
|
- Processes input through gating network to determine expert assignments |
|
- Combines expert outputs and calculates routing loss |
|
|
|
Key Features: |
|
1. Supports both dense and sparse expert computation modes |
|
2. Implements fused gating and dispatch for performance optimization |
|
3. Handles sequence length padding/unpadding for irregular inputs |
|
4. Enables communication-computation overlap through asynchronous operations |
|
|
|
Args: |
|
input (Tensor): Input tensor of shape [seq_len, hidden_dim] |
|
token_type_ids: Optional segmentation markers for heterogeneous inputs |
|
use_dense_expert: Flag to enable dense expert computation bypass |
|
|
|
Returns: |
|
tuple: ( |
|
combined_output: Aggregated expert outputs [seq_len, hidden_dim], |
|
combine_weights: Expert combination coefficients, |
|
) |
|
""" |
|
use_fuse = isinstance(self.gate, (TopKGate)) |
|
assert use_fuse |
|
if input.ndim == 3: |
|
orig_shape = input.shape |
|
input = input.reshape([-1, input.shape[-1]]) |
|
else: |
|
orig_shape = None |
|
|
|
assert ( |
|
len(input.shape) == 2 |
|
), f"input Tensor must have dimensions: (s)equence, (d)im, got:{input.shape}" |
|
dispatch_token_type_ids = None |
|
global_dense_expert_mask = None |
|
if token_type_ids is not None: |
|
token_type_ids = token_type_ids[:, :-1].reshape([-1]) |
|
dispatch_token_type_ids = token_type_ids |
|
if use_dense_expert: |
|
global_dense_expert_mask = ( |
|
dispatch_token_type_ids == self.dense_token_type |
|
) |
|
|
|
assert self.gate is not None |
|
|
|
( |
|
dispatched_input, |
|
global_hidden_states, |
|
local_combine_weights, |
|
expert_num_global_no_token_drop, |
|
expert_num_global, |
|
expert_num_global_list, |
|
local_scatter_index, |
|
scatter_index_rev, |
|
router_loss, |
|
(gate_logits, gate_prob), |
|
(gate_logits_mm, gate_prob_mm), |
|
expert_num_local, |
|
) = self.fused_gate_and_dispatch( |
|
input, token_type_ids, global_dense_expert_mask |
|
) |
|
|
|
seqlen_this_mp = input.shape[0] |
|
if len(scatter_index_rev): |
|
recv_rank_local = scatter_index_rev // seqlen_this_mp |
|
else: |
|
recv_rank_local = scatter_index_rev |
|
|
|
if self.send_rank is None: |
|
capacity = self.gate.get_capacity(input.shape[0]) |
|
self.send_rank = ( |
|
torch.arange(1) |
|
.repeat_interleave(capacity * self.num_local_experts) |
|
.to(torch.int32) |
|
) |
|
self.local_expert_id = ( |
|
torch.arange(self.num_local_experts) |
|
.repeat_interleave(capacity) |
|
.repeat(1) |
|
.to(self.send_rank.dtype) |
|
) |
|
send_rank = self.send_rank |
|
local_expert_id = self.local_expert_id |
|
|
|
expert_outs = self.forward_experts(*dispatched_input) |
|
for e in expert_outs: |
|
if e is not None: |
|
current_device = e.device |
|
break |
|
expert_outs = torch.cat( |
|
[e.to(current_device) for e in expert_outs if e is not None], dim=0 |
|
) |
|
|
|
|
|
combined_output = self.combine_expert_output( |
|
expert_outs, local_combine_weights, local_scatter_index |
|
) |
|
|
|
if self.shared_experts is not None: |
|
shared_out = self.shared_experts(input).to(combined_output.device) |
|
combined_output += shared_out |
|
|
|
if orig_shape: |
|
combined_output = combined_output.reshape( |
|
*orig_shape[:-1], combined_output.shape[-1] |
|
) |
|
|
|
return combined_output, local_combine_weights, None, gate_logits |
|
|
|
def _expand_modality_expert_id( |
|
self, |
|
expert_id: torch.Tensor, |
|
seqlen: int, |
|
k: int, |
|
num_expert_per_modality: int, |
|
group_size: int, |
|
modality_offset: int, |
|
is_group_expert: bool, |
|
) -> torch.Tensor: |
|
""" |
|
expert_id: tensor of shape (seqlen, k), containing expert ids |
|
Returns: tensor of same shape, with updated expert ids |
|
""" |
|
device = expert_id.device |
|
expert_id = expert_id.clone() |
|
|
|
if is_group_expert: |
|
|
|
offsets = (torch.arange(k, device=device) * group_size).view( |
|
1, k |
|
) |
|
expert_id += offsets |
|
|
|
if num_expert_per_modality <= 0: |
|
return expert_id |
|
|
|
|
|
rank = expert_id // num_expert_per_modality |
|
expert_id_in_rank = expert_id % num_expert_per_modality |
|
|
|
|
|
expert_id_out = ( |
|
rank * (num_expert_per_modality * 2) |
|
+ expert_id_in_rank |
|
+ modality_offset * num_expert_per_modality |
|
) |
|
|
|
return expert_id_out |
|
|
|
def expand_modality_expert_id( |
|
self, |
|
expert_id, |
|
num_expert_per_modality, |
|
group_size, |
|
modality_offset, |
|
is_group_expert, |
|
): |
|
"""expand expert id for modality aware moe layer""" |
|
seq_len, k = expert_id.shape |
|
|
|
return self._expand_modality_expert_id( |
|
expert_id, |
|
seq_len, |
|
k, |
|
num_expert_per_modality, |
|
group_size, |
|
modality_offset, |
|
is_group_expert, |
|
) |
|
|
|
def fused_gate_logits_process_fused( |
|
self, gate_logits_lm, gate_logits_mm=None, token_type_ids=None |
|
): |
|
"""Process gating logits for expert selection in Mixture-of-Experts (MoE) layers. |
|
|
|
Core Functionality: |
|
- Transforms raw gating logits into expert selection weights and IDs |
|
- Supports both grouped and standard expert selection modes |
|
- Handles bias correction for improved expert load balancing |
|
|
|
Args: |
|
gate_logits_lm (Tensor): Raw gating scores of shape [batch_size, total_experts] |
|
|
|
Returns: |
|
tuple: ( |
|
lm_weight_and_expert_id: Combined tensor containing selection weights |
|
and expert IDs [batch_size, 2*top_k], |
|
prob_flat: Flattened expert probabilities [batch_size, total_experts] |
|
) |
|
""" |
|
top_k = self.k |
|
num_expert_per_rank_per_modality = gate_logits_lm.shape[-1] |
|
group_size = gate_logits_lm.shape[-1] // top_k |
|
if self.group_experts: |
|
assert not self.use_correction_bias |
|
gate_logits_lm = gate_logits_lm.reshape( |
|
[gate_logits_lm.shape[0], top_k, -1] |
|
) |
|
prob_lm = self.gate.act(gate_logits_lm) |
|
prob_lm_ = prob_lm |
|
weight_lm, expert_id_lm = prob_lm_.topk(k=1, dim=-1) |
|
weight_lm = weight_lm.reshape([gate_logits_lm.shape[0], -1]) |
|
group_size = gate_logits_lm.shape[-1] |
|
expert_id_lm = expert_id_lm.squeeze(-1) |
|
else: |
|
prob_lm = self.gate.act(gate_logits_lm) |
|
if self.use_correction_bias: |
|
prob_lm_ = prob_lm + self.moe_statics.e_score_correction_bias[ |
|
0 |
|
].detach().to(prob_lm.device) |
|
else: |
|
prob_lm_ = prob_lm |
|
weight_lm, expert_id_lm = prob_lm_.topk(k=top_k, dim=-1) |
|
|
|
if self.use_correction_bias: |
|
batch_idx = ( |
|
torch.arange(prob_lm_.shape[0]).unsqueeze(-1).expand_as(expert_id_lm) |
|
) |
|
weight_lm = prob_lm[batch_idx, expert_id_lm] |
|
|
|
expert_id_lm = self.expand_modality_expert_id( |
|
expert_id_lm, |
|
num_expert_per_modality=( |
|
num_expert_per_rank_per_modality if token_type_ids is not None else 0 |
|
), |
|
group_size=group_size, |
|
modality_offset=0, |
|
is_group_expert=self.group_experts, |
|
) |
|
expert_id_lm = expert_id_lm.reshape(weight_lm.shape) |
|
lm_weight_and_expert_id = torch.cat( |
|
[weight_lm, expert_id_lm.to(torch.float32)], -1 |
|
) |
|
|
|
if token_type_ids is None or gate_logits_mm is None: |
|
return ( |
|
lm_weight_and_expert_id, |
|
prob_lm.reshape([prob_lm.shape[0], -1]), |
|
None, |
|
) |
|
|
|
prob_mm = self.gate.act(gate_logits_mm) |
|
if self.use_correction_bias: |
|
prob_mm_ = prob_mm + self.moe_statics.e_score_correction_bias[ |
|
1 |
|
].detach().to(prob_lm.device) |
|
else: |
|
prob_mm_ = prob_mm |
|
weight_mm, expert_id_mm = prob_mm_.topk(k=top_k, dim=-1) |
|
if self.use_correction_bias: |
|
batch_idx = ( |
|
torch.arange(prob_lm_.shape[0]).unsqueeze(-1).expand_as(expert_id_lm) |
|
) |
|
weight_mm = prob_mm[batch_idx, expert_id_mm] |
|
|
|
expert_id_mm = self.expand_modality_expert_id( |
|
expert_id_mm, |
|
num_expert_per_modality=num_expert_per_rank_per_modality, |
|
group_size=group_size, |
|
modality_offset=1, |
|
is_group_expert=False, |
|
) |
|
expert_id_mm = expert_id_mm.reshape(weight_mm.shape) |
|
mm_weight_and_expert_id = torch.cat( |
|
[weight_mm, expert_id_mm.to(torch.float32)], -1 |
|
) |
|
weight_and_expert = torch.where( |
|
(token_type_ids == 0).unsqueeze(-1), |
|
lm_weight_and_expert_id.to(token_type_ids.device), |
|
mm_weight_and_expert_id.to(token_type_ids.device), |
|
) |
|
return weight_and_expert, prob_lm.reshape([prob_lm.shape[0], -1]), prob_mm |
|
|
|
def moe_gate_dispatch_partial_nosoftmaxtopk( |
|
self, |
|
x, |
|
combine_weights, |
|
expert_id, |
|
k, |
|
num_experts, |
|
): |
|
""" |
|
MoE Gate Dispatch kernel |
|
""" |
|
device = x.device |
|
dtype = x.dtype |
|
num_rows, hidden_size = x.shape |
|
k = expert_id.shape[1] |
|
expert_ids_flat = expert_id.reshape(-1) |
|
combine_weights_flat = combine_weights.reshape(-1) |
|
|
|
expanded_token_ids = torch.arange(num_rows * k, device=device) |
|
|
|
sorted_expert_ids, sorted_indices = torch.sort(expert_ids_flat, stable=True) |
|
sorted_indices = sorted_indices.to(expanded_token_ids.device) |
|
|
|
sorted_expanded_token_ids = expanded_token_ids[sorted_indices] |
|
|
|
expert_nums_local = torch.zeros(num_experts, dtype=torch.int64, device=device) |
|
|
|
for expert_idx in range(num_experts): |
|
count = (sorted_expert_ids == expert_idx).sum().item() |
|
expert_nums_local[expert_idx] = count |
|
|
|
total_dispatched_tokens = torch.cumsum(expert_nums_local, dim=0)[-1].item() |
|
|
|
y = x[sorted_indices // k] |
|
|
|
scatter_index = torch.full((k, num_rows), -1, dtype=torch.int32, device=device) |
|
|
|
for i, (expanded_idx, sorted_pos) in enumerate( |
|
zip(sorted_expanded_token_ids, range(total_dispatched_tokens)) |
|
): |
|
token_idx = expanded_idx // k |
|
k_idx = expanded_idx % k |
|
scatter_index[k_idx, token_idx] = sorted_pos |
|
|
|
scatter_index_rev = sorted_indices // k |
|
|
|
combine_weights_out = combine_weights.clone() |
|
|
|
return ( |
|
y, |
|
combine_weights_out, |
|
scatter_index, |
|
scatter_index_rev, |
|
expert_nums_local, |
|
expert_nums_local, |
|
) |
|
|
|
def fused_gate_and_dispatch( |
|
self, input, token_type_ids=None, global_dense_expert_mask=None |
|
): |
|
"""Implements fused expert gating and token dispatch logic for Mixture-of-Experts (MoE) layers. |
|
|
|
Core Functionality: |
|
- Computes expert selection probabilities and routing weights |
|
- Performs distributed token-to-expert assignment |
|
- Handles communication and synchronization in model-parallel environments |
|
|
|
Args: |
|
input (Tensor): Input tensor of shape [seq_len, hidden_dim] |
|
|
|
Returns: |
|
tuple: ( |
|
dispatched_input: Expert-assigned tokens [num_experts, capacity, hidden_dim], |
|
global_hidden_states: Full sequence representations, |
|
local_combine_weights: Local expert combination weights, |
|
expert_num_global_notrunc: Global expert token counts (without capacity truncation), |
|
expert_num_global: Actual expert token counts, |
|
expert_num_global_list: Per-expert token counts, |
|
local_scatter_index: Local token reorganization indices, |
|
scatter_index_rev: Reverse scattering indices, |
|
router_loss: Calculated routing loss, |
|
gate_outputs: Raw gating network outputs, |
|
expert_num_local: Local expert utilization counts |
|
) |
|
""" |
|
seqlen, d_model = input.shape |
|
args = () |
|
if token_type_ids is not None: |
|
token_type_ids = token_type_ids.reshape([-1]) |
|
args = (token_type_ids,) |
|
|
|
router_loss = torch.zeros([1], dtype=torch.float32) |
|
top_k = self.k |
|
|
|
def build_weights_and_expert_id(input): |
|
nonlocal token_type_ids, args |
|
logits = self.gate(input, *args, transform_weight=False) |
|
if self.config.multimodel_experts: |
|
gate_logits_lm, gate_logits_mm = logits.chunk(2, dim=-1) |
|
else: |
|
gate_logits_lm, gate_logits_mm = logits, None |
|
|
|
weigth_and_expert, gate_prob_lm, gate_prob_mm = ( |
|
self.fused_gate_logits_process_fused( |
|
gate_logits_lm, |
|
gate_logits_mm, |
|
token_type_ids if global_dense_expert_mask is None else None, |
|
) |
|
) |
|
return ( |
|
weigth_and_expert, |
|
gate_logits_lm, |
|
gate_logits_mm, |
|
gate_prob_lm, |
|
gate_prob_mm, |
|
) |
|
|
|
capacity = self.gate.get_capacity(input.shape[0]) * self.world_size |
|
global_hidden_states = input |
|
( |
|
combine_weights_and_expert_id, |
|
gate_logits_lm, |
|
gate_logits_mm, |
|
gate_prob_lm, |
|
gate_prob_mm, |
|
) = build_weights_and_expert_id(input) |
|
|
|
combine_weights_unnorm, expert_id = combine_weights_and_expert_id.chunk( |
|
2, dim=-1 |
|
) |
|
expert_id = expert_id.to(torch.int32) |
|
num_experts = ( |
|
sum(self.config.moe_num_experts) |
|
if isinstance(self.config.moe_num_experts, (tuple, list)) |
|
else self.config.moe_num_experts |
|
) |
|
if global_dense_expert_mask is not None: |
|
combine_weights_unnorm[global_dense_expert_mask] = 0.0 |
|
expert_id[global_dense_expert_mask] = num_experts |
|
num_experts += 1 |
|
|
|
( |
|
dispatched_input, |
|
combine_weights_unnorm, |
|
scatter_index, |
|
scatter_index_rev, |
|
expert_num_global, |
|
expert_num_local, |
|
) = self.moe_gate_dispatch_partial_nosoftmaxtopk( |
|
global_hidden_states, |
|
combine_weights_unnorm, |
|
expert_id, |
|
top_k, |
|
num_experts, |
|
) |
|
|
|
if self.use_correction_bias: |
|
if self.gate.config.multimodel_experts: |
|
|
|
for i in range(len(self.moe_statics.expert_usage)): |
|
self.moe_statics.expert_usage[i] += ( |
|
expert_num_local[self.gate.experts_type_mask[i]] |
|
.detach() |
|
.to(self.moe_statics.expert_usage.device) |
|
) |
|
else: |
|
|
|
self.moe_statics.expert_usage[0] += expert_num_local.detach().to( |
|
self.moe_statics.expert_usage.device |
|
) |
|
|
|
|
|
if scatter_index_rev.ndim == 0: |
|
assert not self.use_padding |
|
scatter_index_rev = torch.empty([0], dtype=scatter_index_rev.dtype) |
|
|
|
expert_num_global_notrunc = expert_num_global |
|
self.capacity_tensor = torch.tensor(capacity).to(dtype=expert_num_global.dtype) |
|
expert_num_global = torch.minimum(expert_num_global, self.capacity_tensor) |
|
|
|
if global_dense_expert_mask is not None: |
|
expert_num_global = expert_num_global[:-1] |
|
expert_num_local = expert_num_local[:-1] |
|
expert_num_global_notrunc = expert_num_global_notrunc[:-1] |
|
|
|
scatter_index = scatter_index.transpose(1, 0) |
|
scatter_index = scatter_index.to(combine_weights_unnorm.device) |
|
|
|
last_local_expert = 0 |
|
expert_offset_global = expert_num_global.cumsum(-1) |
|
|
|
expert_num_global_list = expert_num_global |
|
if self.use_padding: |
|
offset = last_local_expert * capacity |
|
else: |
|
offset = 0 |
|
local_combine_weights_unnorm = combine_weights_unnorm.contiguous() |
|
local_scatter_index = torch.where( |
|
combine_weights_unnorm > 0.0, |
|
scatter_index + offset, |
|
scatter_index, |
|
) |
|
if self.gate.norm_gate_logits: |
|
local_combine_weights = local_combine_weights_unnorm / torch.clip( |
|
local_combine_weights_unnorm.sum(-1, keepdim=True), min=1e-12 |
|
) |
|
else: |
|
local_combine_weights = local_combine_weights_unnorm |
|
local_combine_weights = local_combine_weights.to(dispatched_input.dtype) |
|
if self.use_padding: |
|
dispatched_input = dispatched_input.reshape( |
|
[self.num_local_experts, -1, d_model] |
|
) |
|
dispatched_input = dispatched_input.unbind(0) |
|
else: |
|
s = 0 |
|
e = self.num_local_experts |
|
expert_num_local = expert_num_local.tolist()[s:e] |
|
expert_num_local_valid = [i for i in expert_num_local if i > 0] |
|
valid_pos = [j for j, i in enumerate(expert_num_local) if i > 0] |
|
if expert_num_local_valid: |
|
dispatched_input_list = dispatched_input.split(expert_num_local_valid) |
|
dispatched_input = [None] * len(expert_num_local) |
|
for p, t in zip(valid_pos, dispatched_input_list): |
|
dispatched_input[p] = t |
|
else: |
|
dispatched_input = [dispatched_input] + ( |
|
[None] * (len(expert_num_local) - 1) |
|
) |
|
|
|
expert_num_global_list = expert_num_global_list.tolist() |
|
|
|
return ( |
|
dispatched_input, |
|
global_hidden_states, |
|
local_combine_weights, |
|
expert_num_global_notrunc, |
|
expert_num_global, |
|
expert_num_global_list, |
|
local_scatter_index, |
|
scatter_index_rev, |
|
router_loss, |
|
(gate_logits_lm, gate_prob_lm), |
|
(gate_logits_mm, gate_prob_mm), |
|
expert_num_local, |
|
) |
|
|
|
def forward_experts(self, *dispatched_input): |
|
"""Execute expert model computations in sequence for Mixture-of-Experts (MoE) layer. |
|
|
|
Core Functionality: |
|
- Distributes dispatched tokens to local expert models |
|
- Handles empty expert inputs with zero-initialized fallback |
|
- Maintains gradient flow for expert outputs |
|
- Aggregates outputs from all active experts |
|
|
|
Args: |
|
*dispatched_input: Variable-length expert-specific input tensors |
|
|
|
Returns: |
|
list: Expert output tensors (None for inactive experts) |
|
|
|
Implementation Details: |
|
1. Processes valid expert inputs through corresponding expert models |
|
2. Generates dummy inputs for inactive experts to preserve model structure |
|
3. Aggregates dummy outputs to first active expert to maintain gradient flow |
|
""" |
|
expert_outputs = [] |
|
assert isinstance(self.experts, nn.ModuleList), type(self.experts) |
|
|
|
no_tokens_expert_outputs = [] |
|
true_experts = self.experts[ |
|
self.rank |
|
* self.num_local_experts : (self.rank + 1) |
|
* self.num_local_experts |
|
] |
|
for iexpert, chunk in enumerate(dispatched_input): |
|
if chunk is None: |
|
expert_outputs.append(None) |
|
continue |
|
|
|
expert_out = true_experts[iexpert](chunk.contiguous()) |
|
expert_outputs.append(expert_out) |
|
|
|
if len(no_tokens_expert_outputs) > 0: |
|
first_has_tokens_idx = 0 |
|
for idx, expert_out in enumerate(expert_outputs): |
|
if expert_out is not None: |
|
first_has_tokens_idx = idx |
|
break |
|
for idx, expert_out in enumerate(no_tokens_expert_outputs): |
|
expert_outputs[first_has_tokens_idx] += expert_out |
|
|
|
return expert_outputs |
|
|
|
|
|
class Ernie4_5_DecoderLayer(nn.Module): |
|
"""A single transformer decoder layer in ERNIE-MoE model. |
|
|
|
Contains self-attention and feed-forward components with optional MoE (Mixture of Experts) |
|
support, residual connections, and layer normalization. |
|
""" |
|
|
|
_keep_in_fp32_modules = ["mlp.gate", "e_score_correction_bias"] |
|
|
|
def __init__(self, config, layer_idx): |
|
"""Initialize the decoder layer. |
|
|
|
Args: |
|
config (Ernie4_5_MoEConfig): Model configuration. |
|
layer_idx (int): Index of this layer in the transformer stack |
|
""" |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
self.layer_idx = layer_idx |
|
self.config = config |
|
self.use_moe = config.use_moe |
|
self.self_attn = Ernie4_5_Attention(config, layer_idx) |
|
|
|
moe_layer_start_index = ( |
|
min(config.moe_layer_start_index) |
|
if isinstance(config.moe_layer_start_index, (tuple, list)) |
|
else config.moe_layer_start_index |
|
) |
|
moe_layer_end_index = ( |
|
max(config.moe_layer_end_index) |
|
if isinstance(config.moe_layer_end_index, (tuple, list)) |
|
else config.moe_layer_end_index |
|
) |
|
|
|
if ( |
|
self.use_moe |
|
and ((layer_idx + 1) % config.moe_layer_interval == 0) |
|
and layer_idx >= moe_layer_start_index |
|
and layer_idx <= moe_layer_end_index |
|
): |
|
gate, experts, lm_gate, lm_experts, moe_statics = ( |
|
self._init_gate_and_experts(layer_idx) |
|
) |
|
shared_experts = ( |
|
self._init_shared_experts() |
|
if hasattr(config, "moe_num_shared_experts") |
|
else None |
|
) |
|
|
|
dense_experts = None |
|
moe_cls = MOELayer |
|
if config.moe_multimodal_dispatch_use_allgather: |
|
logger.info("Enable MOEAllGatherLayerV2!") |
|
moe_cls = partial( |
|
MOEAllGatherLayerV2, |
|
use_expert_out_alltoall="alltoall" |
|
in config.moe_multimodal_dispatch_use_allgather, |
|
use_padding=False, |
|
enable_reverse_token_drop=config.moe_reverse_token_drop, |
|
dense_token_type=config.moe_dense_experts_token_type_id, |
|
) |
|
else: |
|
assert ( |
|
dense_experts is None |
|
), "only `MOEAllGatherLayerV2` can process dense experts" |
|
|
|
self.mlp = moe_cls( |
|
gate=gate, |
|
experts=experts, |
|
layer_idx=layer_idx, |
|
shared_experts=shared_experts, |
|
group=config.moe_group, |
|
recompute=False, |
|
k=config.moe_k, |
|
all_to_all_dropout=config.moe_all_to_all_dropout, |
|
group_experts=False, |
|
moe_statics=moe_statics, |
|
moe_num_experts=config.moe_num_experts, |
|
) |
|
|
|
_mlp_text = MOELayer( |
|
gate=lm_gate, |
|
experts=lm_experts, |
|
layer_idx=layer_idx, |
|
shared_experts=shared_experts, |
|
group=config.moe_group, |
|
recompute=False, |
|
k=config.moe_k, |
|
all_to_all_dropout=config.moe_all_to_all_dropout, |
|
group_experts=False, |
|
moe_statics=moe_statics, |
|
moe_num_experts=config.moe_num_experts, |
|
) |
|
self.mlp_text = ( |
|
lambda: _mlp_text |
|
) |
|
else: |
|
self.mlp = Ernie4_5_MLP(config) |
|
|
|
Norm = RMSNorm |
|
|
|
self.input_layernorm = Norm(config) |
|
self.post_attention_layernorm = Norm(config) |
|
|
|
self.residual_add1 = FusedDropoutImpl( |
|
config.hidden_dropout_prob, mode="upscale_in_train" |
|
) |
|
self.residual_add2 = FusedDropoutImpl( |
|
config.hidden_dropout_prob, mode="upscale_in_train" |
|
) |
|
|
|
def _init_shared_experts(self): |
|
"""init shared experts |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
cfg = deepcopy(self.config) |
|
if cfg.moe_num_shared_experts > 0: |
|
if cfg.moe_intermediate_size: |
|
inter_size = ( |
|
next(iter(cfg.moe_intermediate_size)) |
|
if isinstance(cfg.moe_intermediate_size, (tuple, list)) |
|
else cfg.moe_intermediate_size |
|
) |
|
cfg.intermediate_size = inter_size * cfg.moe_num_shared_experts |
|
else: |
|
cfg.intermediate_size = ( |
|
cfg.intermediate_size * cfg.moe_num_shared_experts |
|
) |
|
cfg.disable_ffn_model_parallel = False |
|
shared_experts = Ernie4_5_MoeMLP(cfg, True) |
|
else: |
|
shared_experts = None |
|
return shared_experts |
|
|
|
def _init_gate_and_experts(self, layer_idx): |
|
"""Initialize MoE gate and expert networks. |
|
|
|
Args: |
|
layer_idx (int): Current layer index |
|
|
|
Returns: |
|
Tuple: Contains: |
|
- gate: MoE routing gate |
|
- experts: List of expert networks |
|
- moe_statics: Optional statistics tracker |
|
""" |
|
cfg = deepcopy(self.config) |
|
fc_cls = Ernie4_5_MoeMLP |
|
if cfg.moe_intermediate_size: |
|
if isinstance(cfg.moe_intermediate_size, (tuple, list)): |
|
assert isinstance(cfg.moe_num_experts, (tuple, list)) and len( |
|
cfg.moe_num_experts |
|
) == len(cfg.moe_intermediate_size) |
|
fc = [] |
|
for _i, (num_experts, intermediate_size) in enumerate( |
|
zip(cfg.moe_num_experts, cfg.moe_intermediate_size) |
|
): |
|
ex_cfg = deepcopy(cfg) |
|
ex_cfg.intermediate_size = intermediate_size |
|
cur_modality_start_layer_idx = ( |
|
cfg.moe_layer_start_index[_i] |
|
if isinstance(cfg.moe_layer_start_index, (tuple, list)) |
|
else cfg.moe_layer_start_index |
|
) |
|
cur_modality_end_layer_idx = ( |
|
cfg.moe_layer_end_index[_i] |
|
if isinstance(cfg.moe_layer_end_index, (tuple, list)) |
|
else cfg.moe_layer_end_index |
|
) |
|
if ( |
|
layer_idx >= cur_modality_start_layer_idx |
|
and layer_idx <= cur_modality_end_layer_idx |
|
): |
|
if _i == 1: |
|
with UniqueNameGuard(f"mm_expert_{layer_idx}_") as guard: |
|
fc.append((num_experts, fc_cls(ex_cfg))) |
|
else: |
|
fc.append((num_experts, fc_cls(ex_cfg))) |
|
else: |
|
logger.info( |
|
f"moe multimodal experts use Identity layer_idx: {layer_idx}" |
|
) |
|
fc.append((num_experts, nn.Identity())) |
|
else: |
|
cfg.intermediate_size = cfg.moe_intermediate_size |
|
fc = [(cfg.moe_num_experts, fc_cls(cfg, layer_idx))] |
|
else: |
|
fc = [(cfg.moe_num_experts, fc_cls(cfg, layer_idx))] |
|
if cfg.multimodel_experts: |
|
gate, experts, lm_gate, lm_experts = get_gate(self.config, fc, layer_idx) |
|
else: |
|
gate, experts = get_gate(self.config, fc, layer_idx) |
|
lm_gate, lm_experts = None, None |
|
|
|
|
|
if cfg.moe_use_aux_free: |
|
moe_statics = MoEStatics(cfg, layer_idx) |
|
else: |
|
moe_statics = None |
|
return gate, experts, lm_gate, lm_experts, moe_statics |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
attn_mask_start_row_indices: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = False, |
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
use_cache: Optional[bool] = False, |
|
output_gate_logits=True, |
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
|
"""Forward pass through the decoder layer. |
|
|
|
Args: |
|
hidden_states (torch.Tensor): Input tensor [batch_size, seq_len, hidden_size] |
|
attention_mask (Optional[torch.Tensor]): Attention mask tensor |
|
attn_mask_start_row_indices (Optional[torch.Tensor]): Indices for variable length attention |
|
position_ids (Optional[torch.Tensor]): Position indices for rotary embeddings |
|
output_attentions (Optional[bool]): Whether to return attention weights |
|
past_key_value (Optional[Tuple[torch.Tensor]]): Cached key/value states |
|
use_cache (Optional[bool]): Whether to cache key/value states |
|
output_gate_logits (bool): Whether to return MoE gate logits |
|
|
|
Returns: |
|
Union: Various output combinations depending on arguments: |
|
- Base case: Hidden states tensor |
|
- With attention: Tuple of (hidden_states, attention_weights) |
|
- With cache: Tuple of (hidden_states, cached_key_value) |
|
- With MoE: May include gate logits in output tuple |
|
""" |
|
residual = hidden_states |
|
|
|
if token_type_ids is not None: |
|
is_multimodel_token = token_type_ids.any() |
|
has_dense_experts_token = ( |
|
token_type_ids == self.config.moe_dense_experts_token_type_id |
|
).any() |
|
is_multimodel_token_cpu = is_multimodel_token.cpu() |
|
has_dense_experts_token_cpu = has_dense_experts_token.cpu() |
|
else: |
|
is_multimodel_token_cpu = None |
|
has_dense_experts_token_cpu = None |
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
|
|
(hidden_states, self_attn_weights, present_key_value, *router_loss_attn) = ( |
|
self.self_attn( |
|
hidden_states=hidden_states, |
|
past_key_value=past_key_value, |
|
attention_mask=attention_mask, |
|
attn_mask_start_row_indices=attn_mask_start_row_indices, |
|
position_ids=position_ids, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
token_type_ids=token_type_ids, |
|
) |
|
) |
|
hidden_states = self.residual_add1(hidden_states, residual) |
|
|
|
|
|
residual = hidden_states |
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
|
if isinstance(self.mlp, MOELayer): |
|
if is_multimodel_token_cpu: |
|
hidden_states, _, router_loss, gate_logits = self.mlp( |
|
hidden_states, token_type_ids |
|
) |
|
else: |
|
hidden_states, _, router_loss, gate_logits = self.mlp_text()( |
|
hidden_states, None, is_multimodel=False |
|
) |
|
else: |
|
hidden_states = self.mlp(hidden_states) |
|
gate_logits, router_loss = None, None |
|
|
|
hidden_states = self.residual_add2(hidden_states, residual) |
|
|
|
outputs = (hidden_states,) |
|
|
|
if output_attentions: |
|
outputs += (self_attn_weights,) |
|
|
|
if use_cache: |
|
outputs += (present_key_value,) |
|
|
|
if self.use_moe: |
|
|
|
if router_loss_attn: |
|
router_loss_attn = router_loss_attn[0] |
|
router_loss = router_loss + router_loss_attn |
|
|
|
if output_gate_logits: |
|
outputs += (gate_logits,) |
|
|
|
|
|
if type(outputs) is tuple and len(outputs) == 1: |
|
outputs = outputs[0] |
|
|
|
return outputs |
|
|
|
|
|
class Ernie4_5_PretrainedModel(PreTrainedModel): |
|
"""Base class for ERNIE pretrained models.""" |
|
|
|
config_class = Ernie4_5_MoEConfig |
|
base_model_prefix = "ernie" |
|
_no_split_modules = ["Ernie4_5_DecoderLayer"] |
|
|
|
|
|
|
|
class Ernie4_5_Model(Ernie4_5_PretrainedModel): |
|
"""The core ERNIE transformer model with MoE (Mixture of Experts) support.""" |
|
|
|
def __init__(self, config: Ernie4_5_MoEConfig): |
|
"""Initialize the ERNIE model architecture. |
|
|
|
Args: |
|
config (Ernie4_5_MoEConfig): Model configuration. |
|
""" |
|
super().__init__(config) |
|
self.padding_idx = config.pad_token_id |
|
self.vocab_size = config.vocab_size |
|
self.hidden_size = config.hidden_size |
|
self.config = config |
|
|
|
self.embed_tokens = nn.Embedding( |
|
self.vocab_size, |
|
self.hidden_size, |
|
) |
|
|
|
self.layers = nn.ModuleList( |
|
[Ernie4_5_DecoderLayer(config, i) for i in range(config.num_hidden_layers)] |
|
) |
|
Norm = RMSNorm |
|
self.norm = Norm(config) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
def get_input_embeddings(self): |
|
"""Get the input embedding layer. |
|
|
|
Returns: |
|
nn.Embedding: The embedding layer for input tokens |
|
""" |
|
return self.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
"""Set new input embeddings. |
|
|
|
Args: |
|
value (nn.Embedding): New embedding layer to use |
|
""" |
|
self.embed_tokens = value |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
position_ids=None, |
|
token_type_ids=None, |
|
attention_mask=None, |
|
attn_mask_start_row_indices=None, |
|
inputs_embeds=None, |
|
use_cache=None, |
|
past_key_values=None, |
|
output_attentions=False, |
|
output_hidden_states=None, |
|
return_dict=False, |
|
): |
|
"""Forward pass through the ERNIE model. |
|
|
|
Args: |
|
input_ids (Optional[torch.Tensor]): Input token IDs |
|
position_ids (Optional[torch.Tensor]): Position indices |
|
attention_mask (Optional[torch.Tensor]): Attention mask |
|
attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length attention indices |
|
inputs_embeds (Optional[torch.Tensor]): Precomputed embeddings |
|
use_cache (Optional[bool]): Whether to cache key/value states |
|
past_key_values (Optional[Tuple[Tuple[torch.Tensor]]]): Cached key/value states |
|
output_attentions (Optional[bool]): Whether to output attention weights |
|
output_hidden_states (Optional[bool]): Whether to output all hidden states |
|
return_dict (Optional[bool]): Whether to return dict or tuple |
|
|
|
Returns: |
|
Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: |
|
Various outputs depending on configuration, including: |
|
- last_hidden_state: Final layer hidden states |
|
- past_key_values: Cached key/value states if use_cache=True |
|
- hidden_states: All hidden states if output_hidden_states=True |
|
- attentions: Attention weights if output_attentions=True |
|
- router_loss: MoE router loss if use_moe=True |
|
- gate_logits: MoE gate logits if use_moe=True |
|
""" |
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
raise ValueError( |
|
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" |
|
) |
|
elif input_ids is not None: |
|
_, seq_length = input_ids.shape |
|
elif inputs_embeds is not None: |
|
_, seq_length, _ = inputs_embeds.shape |
|
else: |
|
raise ValueError( |
|
"You have to specify either decoder_input_ids or decoder_inputs_embeds" |
|
) |
|
|
|
if past_key_values is None: |
|
past_key_values = tuple([None] * len(self.layers)) |
|
|
|
seq_length_with_past = seq_length |
|
cache_length = 0 |
|
if past_key_values[0] is not None: |
|
cache_length = past_key_values[0][0].shape[1] |
|
seq_length_with_past += cache_length |
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
inputs_embeds = inputs_embeds.to(self.embed_tokens.weight.dtype) |
|
|
|
hidden_states = inputs_embeds |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
next_decoder_cache = () if use_cache else None |
|
if getattr(self.config, "use_moe", False): |
|
all_router_loss = torch.tensor(0.0).to(device=inputs_embeds.device) |
|
else: |
|
all_router_loss = None |
|
all_gate_logits = () |
|
|
|
for idx, (decoder_layer) in enumerate(self.layers): |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
past_key_value = ( |
|
past_key_values[idx] if past_key_values is not None else None |
|
) |
|
|
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask, |
|
attn_mask_start_row_indices, |
|
position_ids, |
|
token_type_ids, |
|
output_attentions, |
|
past_key_value, |
|
use_cache, |
|
) |
|
|
|
if isinstance(layer_outputs, (tuple, list)): |
|
hidden_states = layer_outputs[0] |
|
else: |
|
hidden_states = layer_outputs |
|
|
|
if use_cache: |
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
if self.config.use_moe: |
|
layer_outputs, gate_logits = layer_outputs[:-1], layer_outputs[-1] |
|
all_gate_logits = all_gate_logits + (gate_logits,) |
|
|
|
if past_key_value is not None: |
|
hidden_states = hidden_states[:, -1:, :] |
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
next_cache = next_decoder_cache if use_cache else None |
|
|
|
if not return_dict: |
|
return tuple( |
|
v |
|
for v in [ |
|
hidden_states, |
|
next_cache, |
|
all_hidden_states, |
|
all_self_attns, |
|
all_router_loss, |
|
all_gate_logits, |
|
] |
|
if v is not None |
|
) |
|
|
|
|
|
return BaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=hidden_states, |
|
past_key_values=next_cache, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
cross_attentions=None, |
|
router_loss=all_router_loss, |
|
gate_logits=all_gate_logits, |
|
) |
|
|
|
|
|
def parallel_matmul( |
|
x, |
|
y, |
|
bias=None, |
|
transpose_y=False, |
|
): |
|
""" |
|
Performs parallel matrix multiplication with tensor model parallelism support. |
|
|
|
Args: |
|
x (torch.Tensor): Input tensor with shape [batch_size, seq_len, hidden_size] |
|
y (Union[torch.Tensor, EagerParamBase]): Weight matrix which can be: |
|
- Regular tensor |
|
- Distributed parameter in tensor parallel mode |
|
bias (Optional[torch.Tensor]): Optional bias tensor |
|
transpose_y (bool): Whether to transpose the 'y' matrix before multiplication |
|
# tensor_parallel_degree (int): Degree of tensor model parallelism (default: 1) |
|
# tensor_parallel_output (bool): Whether to keep output in tensor parallel format |
|
or gather across devices (default: True) |
|
fuse_linear (bool): Whether to use fused linear operation for optimization |
|
|
|
Returns: |
|
torch.Tensor |
|
|
|
Raises: |
|
AssertionError: If tensor parallel is enabled but weight is not distributed |
|
AttributeError: If called without distributed.launch context |
|
""" |
|
if transpose_y: |
|
logits = torch.matmul(x, y.T) |
|
else: |
|
logits = torch.matmul(x, y) |
|
if bias is not None: |
|
logits += bias |
|
return logits |
|
|
|
|
|
def calc_lm_head_logits( |
|
config, hidden_states, weight, bias, tensor_parallel_output=None, training=True |
|
): |
|
""" |
|
Calculate language model head logits with support for various parallelization strategies. |
|
|
|
This is the core function that computes the final output logits for a language model, |
|
handling sequence parallelism and tensor parallelism configurations. |
|
|
|
Args: |
|
config (Ernie4_5_Config): Model configuration. |
|
hidden_states (Tensor): Hidden states from the transformer layers |
|
weight (Tensor): Weight matrix for the language model head |
|
bias (Tensor): Bias vector for the language model head |
|
tensor_parallel_output (bool, optional): Override for tensor parallel output behavior. |
|
If None, uses config.tensor_parallel_output. |
|
Defaults to None. |
|
training (bool, optional): Whether in training mode. Defaults to True. |
|
|
|
Returns: |
|
Tensor: The computed logits for language modeling. |
|
""" |
|
if tensor_parallel_output is None: |
|
tensor_parallel_output = config.tensor_parallel_output |
|
logits = parallel_matmul( |
|
hidden_states, |
|
weight, |
|
bias=bias, |
|
transpose_y=config.tie_word_embeddings, |
|
) |
|
|
|
return logits |
|
|
|
|
|
def calc_multimodal_logits( |
|
last_hidden_state: torch.Tensor, |
|
lm_head_weight: torch.Tensor, |
|
lm_head_bias: torch.Tensor, |
|
mm_head_weight: torch.Tensor, |
|
mm_head_bias: torch.Tensor, |
|
token_type_ids_shifted: torch.Tensor, |
|
config: Ernie4_5_VLMoEConfig, |
|
): |
|
""" |
|
calculate logits for pure text, multimodal text, and image |
|
Args: |
|
last_hidden_state: The hidden of the last layer, in sequence-parallel, is in the split state. |
|
... |
|
token_type_ids_shifted: # Non-sp split tensor |
|
The token-type-ids at the label position is used to select the lm-head corresponding to each token. |
|
Note: In the id sequence of alternating images and texts, the last text token will predict the image id, |
|
and vice versa, so it is necessary to select the lmhead weight corresponding to the label type. |
|
""" |
|
|
|
|
|
assert last_hidden_state.shape[:2] == token_type_ids_shifted.shape, ( |
|
last_hidden_state.shape, |
|
token_type_ids_shifted.shape, |
|
) |
|
parallel_matmul_tp = partial( |
|
parallel_matmul, |
|
) |
|
|
|
if mm_head_weight is None: |
|
if config.use_recompute_loss_fn: |
|
return last_hidden_state, None, None |
|
score_text = parallel_matmul_tp(last_hidden_state, lm_head_weight, lm_head_bias) |
|
return score_text, None, None |
|
|
|
image_mask_shifted = token_type_ids_shifted == TokenType.image |
|
text_pos_shifted = token_type_ids_shifted == TokenType.text |
|
|
|
if text_pos_shifted.any().item() > 0: |
|
score_text = parallel_matmul_tp( |
|
last_hidden_state[text_pos_shifted], lm_head_weight, lm_head_bias |
|
) |
|
else: |
|
score_text = None |
|
|
|
if mm_head_weight is not None and image_mask_shifted.any().item() > 0: |
|
score_image = parallel_matmul_tp( |
|
last_hidden_state[image_mask_shifted], mm_head_weight, mm_head_bias |
|
) |
|
else: |
|
score_image = None |
|
|
|
return score_text, score_image, None |
|
|
|
|
|
class Ernie4_5_MoeLMHead(nn.Module): |
|
"""Language model head for ERNIE with support for tensor parallelism.""" |
|
|
|
def __init__(self, config): |
|
"""Initialize the language model head. |
|
|
|
Args: |
|
config (Ernie4_5_Config): Model configuration containing: |
|
- vocab_size: Size of vocabulary |
|
- hidden_size: Dimension of hidden states |
|
# - tensor_parallel_degree: Degree of tensor parallelism |
|
- tie_word_embeddings: Whether to tie input/output embeddings |
|
- weight_share_add_bias: Whether to add bias when weight sharing |
|
- use_bias: Whether to use bias term |
|
- use_recompute_loss_fn: Whether to defer logits computation to loss function |
|
- use_sparse_head_and_loss_fn: Whether to use sparse head computation |
|
""" |
|
|
|
super(Ernie4_5_MoeLMHead, self).__init__() |
|
self.config = config |
|
if config.tensor_parallel_degree > 1: |
|
vocab_size = config.vocab_size // config.tensor_parallel_degree |
|
else: |
|
vocab_size = config.vocab_size |
|
|
|
if config.tie_word_embeddings: |
|
self.weight = nn.Parameter( |
|
torch.empty( |
|
vocab_size, config.hidden_size, dtype=torch.get_default_dtype() |
|
) |
|
) |
|
else: |
|
self.weight = nn.Parameter( |
|
torch.empty( |
|
config.hidden_size, vocab_size, dtype=torch.get_default_dtype() |
|
) |
|
) |
|
nn.init.xavier_uniform_(self.weight) |
|
|
|
logger.info( |
|
f"output-weight:{self.weight.shape} tie_word_embeddings:{config.tie_word_embeddings}" |
|
) |
|
|
|
if config.weight_share_add_bias and config.use_bias: |
|
self.bias = nn.Parameter( |
|
torch.zeros(vocab_size, dtype=torch.get_default_dtype()) |
|
) |
|
else: |
|
self.bias = None |
|
|
|
|
|
self.weight.is_distributed = ( |
|
True if (vocab_size != config.vocab_size) else False |
|
) |
|
if config.weight_share_add_bias and config.use_bias: |
|
self.bias.is_distributed = ( |
|
True if (vocab_size != config.vocab_size) else False |
|
) |
|
|
|
if self.weight.is_distributed: |
|
self.weight.split_axis = 1 |
|
if ( |
|
config.weight_share_add_bias |
|
and config.use_bias |
|
and self.bias.is_distributed |
|
): |
|
self.bias.split_axis = 0 |
|
|
|
if self.config.use_recompute_loss_fn: |
|
logger.info( |
|
"Using recompute_loss_fn, the calculation of logits will be moved into " |
|
"loss_fn for memory optimization" |
|
) |
|
|
|
def forward(self, hidden_states, tensor_parallel_output=None): |
|
"""Project hidden states to vocabulary logits. |
|
|
|
Args: |
|
hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size] |
|
tensor_parallel_output (Optional[bool]): Whether to output parallel results. Defaults to None. |
|
|
|
Returns: |
|
Union[ |
|
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
|
# When use_recompute_loss_fn or use_sparse_head_and_loss_fn |
|
- hidden_states: Original input |
|
- weight: Projection weights |
|
- bias: Optional bias term |
|
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], bool]: # With tensor_parallel_output |
|
Same as above plus tensor_parallel_output flag |
|
torch.Tensor: # Normal case |
|
Logits tensor of shape [batch_size, seq_len, vocab_size] |
|
] |
|
""" |
|
return calc_lm_head_logits( |
|
self.config, |
|
hidden_states, |
|
self.weight, |
|
self.bias, |
|
tensor_parallel_output, |
|
training=self.training, |
|
) |
|
|
|
|
|
class Ernie4_5_MoeForCausalLM(Ernie4_5_PretrainedModel, GenerationMixin): |
|
"""ERNIE Mixture of Experts (MoE) model for causal language modeling.""" |
|
|
|
_keys_to_ignore_on_load_missing = [r"lm_head.weight"] |
|
|
|
def __init__(self, config): |
|
""" |
|
Initializes the ERNIE MoE model for causal language modeling. |
|
|
|
Args: |
|
config (dict): Model configuration. |
|
""" |
|
super().__init__(config) |
|
|
|
|
|
|
|
new_initializer_range = math.sqrt(0.3333 / config.hidden_size) |
|
logger.info( |
|
f"change initializer-range from {config.initializer_range} to {new_initializer_range}" |
|
) |
|
config.initializer_range = new_initializer_range |
|
self.config = config |
|
self.model = Ernie4_5_Model(config) |
|
self.lm_head = Ernie4_5_MoeLMHead(config) |
|
|
|
self.tie_weights() |
|
|
|
def get_input_embeddings(self): |
|
"""Returns the input embeddings layer.""" |
|
return self.model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
"""Sets the input embeddings layer.""" |
|
self.model.embed_tokens = value |
|
|
|
def get_output_embeddings(self): |
|
"""Returns the output embeddings (LM head).""" |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
"""Sets the output embeddings layer.""" |
|
self.lm_head = new_embeddings |
|
|
|
def set_decoder(self, decoder): |
|
"""Sets the ERNIE decoder model.""" |
|
self.model = decoder |
|
|
|
def get_decoder(self): |
|
"""Get the transformer decoder. |
|
|
|
Returns: |
|
nn.Layer: The decoder module |
|
""" |
|
return self.model |
|
|
|
def prepare_attention_mask_for_generation( |
|
self, input_ids, pad_token_id, eos_token_id |
|
): |
|
"""Avoid using attention_mask with flash_attn on generation.""" |
|
if self.config.use_flash_attention: |
|
return None |
|
return super().prepare_attention_mask_for_generation( |
|
input_ids, pad_token_id, eos_token_id |
|
) |
|
|
|
|
|
class VisionMlp(nn.Module): |
|
"""VisionMLP""" |
|
|
|
def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None: |
|
super().__init__() |
|
self.fc1 = nn.Linear(dim, hidden_dim) |
|
self.act = ACT2FN[hidden_act] |
|
self.fc2 = nn.Linear(hidden_dim, dim) |
|
|
|
def forward(self, x) -> torch.Tensor: |
|
""" |
|
Args: |
|
x (torch.Tensor): input tensor |
|
|
|
Returns: |
|
torch.Tensor: VisionMLP output tensor |
|
""" |
|
return self.fc2(self.act(self.fc1(x))) |
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
"""PatchEmbed""" |
|
|
|
def __init__( |
|
self, |
|
patch_size: int = 14, |
|
in_channels: int = 3, |
|
embed_dim: int = 1152, |
|
) -> None: |
|
""" |
|
Args: |
|
patch_size (int, optional): patch size. Defaults to 14. |
|
in_channels (int, optional): number of channels. Defaults to 3. |
|
embed_dim (int, optional): embedding dimension. Defaults to 1152. |
|
""" |
|
super().__init__() |
|
self.patch_size = patch_size |
|
self.in_channels = in_channels |
|
self.embed_dim = embed_dim |
|
self.proj = nn.Linear( |
|
in_channels * patch_size * patch_size, embed_dim, bias=False |
|
) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
hidden_states (torch.Tensor): hidden states |
|
|
|
Returns: |
|
torch.Tensor: output tensor |
|
""" |
|
target_dtype = self.proj.weight.dtype |
|
|
|
hidden_states = self.proj(hidden_states.to(target_dtype)) |
|
|
|
return hidden_states |
|
|
|
|
|
class VisionRotaryEmbedding(nn.Module): |
|
"""VisionRotaryEmbedding""" |
|
|
|
def __init__(self, dim: int, theta: float = 10000.0) -> None: |
|
""" |
|
Args: |
|
dim (int): the dimension of each token. |
|
theta (float, optional): the frequency factor. Defaults to 10000.0. |
|
""" |
|
super().__init__() |
|
self.inv_freq = 1.0 / theta ** ( |
|
torch.arange(start=0, end=dim, step=2, dtype=torch.float32) / dim |
|
) |
|
|
|
def forward(self, seqlen: int) -> torch.Tensor: |
|
""" |
|
Args: |
|
seqlen (int): length of sequence. |
|
|
|
Returns: |
|
torch.Tensor: rotary position embedding |
|
""" |
|
seq = torch.arange(seqlen).to(self.inv_freq.dtype) |
|
freqs = torch.outer(input=seq, vec2=self.inv_freq) |
|
return freqs |
|
|
|
|
|
def rotate_half(x): |
|
"""Rotates half the hidden dims of the input.""" |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb_vision( |
|
tensor: torch.Tensor, freqs: torch.Tensor |
|
) -> torch.Tensor: |
|
"""Applies Rotary Position Embedding to the input tensors. |
|
|
|
Args: |
|
tensor (torch.Tensor): The input tensor. |
|
freqs (torch.Tensor): The frequencies used for the rotation. |
|
Returns: |
|
output (torch.Tensor): the tensor rotated using the Rotary Position Embedding. |
|
""" |
|
orig_dtype = tensor.dtype |
|
|
|
tensor = tensor.type(dtype=torch.float32) |
|
cos = freqs.cos() |
|
sin = freqs.sin() |
|
cos = cos.unsqueeze(1).tile(1, 1, 2).unsqueeze(0).type(dtype=torch.float32) |
|
sin = sin.unsqueeze(1).tile(1, 1, 2).unsqueeze(0).type(dtype=torch.float32) |
|
output = tensor * cos + rotate_half(tensor) * sin |
|
output = output.to(orig_dtype) |
|
return output |
|
|
|
|
|
class VisionAttention(nn.Module): |
|
"""VisionAttention""" |
|
|
|
def __init__(self, dim: int, num_heads: int = 16) -> None: |
|
super().__init__() |
|
self.num_heads = num_heads |
|
self.qkv = nn.Linear(dim, dim * 3, bias=True) |
|
self.proj = nn.Linear(dim, dim) |
|
self.head_dim = dim // num_heads |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
cu_seqlens: torch.Tensor, |
|
rotary_pos_emb: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
"""forward function for vision attention""" |
|
seq_length = hidden_states.shape[0] |
|
qkv = ( |
|
self.qkv(hidden_states) |
|
.reshape([seq_length, 3, self.num_heads, -1]) |
|
.permute(1, 0, 2, 3) |
|
) |
|
q, k, v = qkv.unbind(axis=0) |
|
|
|
q = apply_rotary_pos_emb_vision(q.unsqueeze(dim=0), rotary_pos_emb).squeeze( |
|
dim=0 |
|
) |
|
k = apply_rotary_pos_emb_vision(k.unsqueeze(dim=0), rotary_pos_emb).squeeze( |
|
dim=0 |
|
) |
|
|
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() |
|
|
|
attention_mask = torch.full( |
|
[1, seq_length, seq_length], |
|
torch.finfo(q.dtype).min, |
|
device=q.device, |
|
dtype=q.dtype, |
|
) |
|
for i in range(1, len(cu_seqlens)): |
|
attention_mask[ |
|
..., |
|
cu_seqlens[i - 1] : cu_seqlens[i], |
|
cu_seqlens[i - 1] : cu_seqlens[i], |
|
] = 0 |
|
|
|
q = q.transpose(0, 1) |
|
k = k.transpose(0, 1) |
|
v = v.transpose(0, 1) |
|
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) |
|
attn_weights = attn_weights + attention_mask |
|
attn_weights = nn.functional.softmax( |
|
attn_weights, dim=-1, dtype=torch.float32 |
|
).to(q.dtype) |
|
attn_output = torch.matmul(attn_weights, v) |
|
attn_output = attn_output.transpose(0, 1) |
|
attn_output = attn_output.reshape(seq_length, -1) |
|
attn_output = self.proj(attn_output) |
|
return attn_output |
|
|
|
|
|
class DFNRopeVisionBlock(nn.Module): |
|
"""DFNRopeVisionBlock""" |
|
|
|
def __init__(self, config, attn_implementation: str = "sdpa") -> None: |
|
""" |
|
Args: |
|
config (dict): model configuration. |
|
attn_implementation (str, optional): attention implementation. Defaults to "sdpa". |
|
""" |
|
super().__init__() |
|
self.norm1 = nn.LayerNorm(config.embed_dim, eps=1e-6) |
|
self.norm2 = nn.LayerNorm(config.embed_dim, eps=1e-6) |
|
mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio) |
|
|
|
self.attn = VisionAttention(config.embed_dim, num_heads=config.num_heads) |
|
self.mlp = VisionMlp( |
|
dim=config.embed_dim, |
|
hidden_dim=mlp_hidden_dim, |
|
hidden_act=config.hidden_act, |
|
) |
|
self.config = config |
|
|
|
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor: |
|
""" |
|
Args: |
|
hidden_states(torch.Tensor): hidden states |
|
cu_seqlens (torch.Tensor): cumulative sequence lengths |
|
rotary_pos_emb: rotary position embedding |
|
|
|
Returns: |
|
torch.Tensor: output tensor |
|
""" |
|
hidden_states = hidden_states + self.attn( |
|
self.norm1(hidden_states), |
|
cu_seqlens=cu_seqlens, |
|
rotary_pos_emb=rotary_pos_emb, |
|
) |
|
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) |
|
return hidden_states |
|
|
|
|
|
class DFNRopeVisionTransformerPreTrainedModel(PreTrainedModel): |
|
"""DFNRopeVisionTransformerPreTrainedModel""" |
|
|
|
config_class = DFNRopeVisionTransformerConfig |
|
_tp_plan = {} |
|
|
|
def __init__(self, config) -> None: |
|
""" |
|
Args: |
|
config (dict): model configuration |
|
""" |
|
super().__init__(config) |
|
self.spatial_merge_size = config.spatial_merge_size |
|
|
|
self.patch_embed = PatchEmbed( |
|
patch_size=config.patch_size, |
|
in_channels=config.in_channels, |
|
embed_dim=config.embed_dim, |
|
) |
|
|
|
head_dim = config.embed_dim // config.num_heads |
|
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2) |
|
|
|
self.blocks = nn.ModuleList( |
|
[DFNRopeVisionBlock(config) for _ in range(config.depth)] |
|
) |
|
|
|
assert ( |
|
config.hidden_size == config.embed_dim |
|
), "in DFNRope, vit's config.hidden must be equal to config.embed_dim" |
|
self.ln = nn.LayerNorm(config.hidden_size, eps=1e-6) |
|
|
|
def rot_pos_emb(self, grid_thw, num_pad=0): |
|
"""rot_pos_emb |
|
|
|
Args: |
|
grid_thw (torch.Tensor): grid thw of input |
|
|
|
Returns: |
|
torch.Tensor: rotary position embedding |
|
""" |
|
pos_ids = [] |
|
grid_hw_array = np.array(grid_thw.cpu(), dtype=np.int64) |
|
for t, h, w in grid_hw_array: |
|
hpos_ids = np.arange(h).reshape([-1, 1]) |
|
hpos_ids = np.tile(hpos_ids, (1, w)) |
|
hpos_ids = hpos_ids.reshape( |
|
h // self.spatial_merge_size, |
|
self.spatial_merge_size, |
|
w // self.spatial_merge_size, |
|
self.spatial_merge_size, |
|
) |
|
hpos_ids = np.transpose(hpos_ids, (0, 2, 1, 3)) |
|
hpos_ids = hpos_ids.flatten() |
|
|
|
wpos_ids = np.arange(w).reshape([1, -1]) |
|
wpos_ids = np.tile(wpos_ids, (h, 1)) |
|
wpos_ids = wpos_ids.reshape( |
|
h // self.spatial_merge_size, |
|
self.spatial_merge_size, |
|
w // self.spatial_merge_size, |
|
self.spatial_merge_size, |
|
) |
|
wpos_ids = np.transpose(wpos_ids, (0, 2, 1, 3)) |
|
wpos_ids = wpos_ids.flatten() |
|
|
|
stacked_ids = np.stack([hpos_ids, wpos_ids], axis=-1) |
|
tiled_ids = np.tile(stacked_ids, (t, 1)) |
|
pos_ids.append(tiled_ids) |
|
|
|
pos_ids = np.concatenate(pos_ids, axis=0) |
|
if num_pad > 0: |
|
pos_ids = np.concatenate( |
|
[pos_ids, np.zeros((num_pad, 2), dtype=pos_ids.dtype)] |
|
) |
|
max_grid_size = np.amax(grid_hw_array[:, 1:]) |
|
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) |
|
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(start_dim=1) |
|
return rotary_pos_emb |
|
|
|
def forward( |
|
self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0 |
|
) -> torch.Tensor: |
|
""" |
|
Args: |
|
hidden_states (torch.Tensor): input tensor |
|
grid_thw (torch.Tensor): grid thw of input |
|
num_pad (int): number of padding tokens |
|
|
|
Returns: |
|
torch.Tensor: output tensor |
|
""" |
|
hidden_states = self.patch_embed(hidden_states) |
|
|
|
rotary_pos_emb = self.rot_pos_emb(grid_thw, num_pad=num_pad) |
|
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) |
|
|
|
cu_seqlens = torch.repeat_interleave( |
|
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] |
|
).cumsum(dim=0, dtype=torch.int32) |
|
|
|
if num_pad > 0: |
|
cu_seqlens = F.pad(cu_seqlens, (1, 1), value=0) |
|
cu_seqlens[-1] = cu_seqlens[-2] + num_pad |
|
else: |
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) |
|
|
|
for idx, blk in enumerate(self.blocks): |
|
hidden_states = blk( |
|
hidden_states, |
|
cu_seqlens=cu_seqlens, |
|
rotary_pos_emb=rotary_pos_emb, |
|
) |
|
|
|
ret = self.ln(hidden_states) |
|
return ret |
|
|
|
|
|
class VariableResolutionResamplerModel(nn.Module): |
|
""" |
|
VariableResolutionResamplerModel, support variable resolution |
|
""" |
|
|
|
def __init__(self, in_dim, out_dim, spatial_conv_size, temporal_conv_size, config): |
|
super().__init__() |
|
self.in_dim = in_dim |
|
self.out_dim = out_dim |
|
self.config = config |
|
self.spatial_conv_size = spatial_conv_size |
|
self.temporal_conv_size = temporal_conv_size |
|
self.use_temporal_conv = config.use_temporal_conv |
|
|
|
|
|
self.spatial_dim = self.in_dim * self.spatial_conv_size * self.spatial_conv_size |
|
|
|
self.temporal_dim = ( |
|
self.in_dim |
|
* self.spatial_conv_size |
|
* self.spatial_conv_size |
|
* self.temporal_conv_size |
|
) |
|
|
|
|
|
with UniqueNameGuard("mm_resampler_") as guard: |
|
|
|
self.spatial_linear = nn.Sequential( |
|
nn.Linear(self.spatial_dim, self.spatial_dim), |
|
nn.GELU(), |
|
nn.Linear(self.spatial_dim, self.spatial_dim), |
|
nn.LayerNorm(self.spatial_dim, eps=1e-6), |
|
) |
|
|
|
if self.use_temporal_conv: |
|
self.temporal_linear = nn.Sequential( |
|
nn.Linear(self.temporal_dim, self.spatial_dim), |
|
nn.GELU(), |
|
nn.Linear(self.spatial_dim, self.spatial_dim), |
|
nn.LayerNorm(self.spatial_dim, eps=1e-6), |
|
) |
|
|
|
self.mlp = nn.Linear(self.spatial_dim, self.out_dim) |
|
|
|
out_config = deepcopy(config) |
|
out_config.hidden_size = out_dim |
|
self.after_norm = RMSNorm(out_config) |
|
|
|
def spatial_conv_reshape(self, x, spatial_conv_size): |
|
""" |
|
reshape before linear to imitation conv |
|
""" |
|
S, C = x.shape |
|
x = x.reshape([-1, C * (spatial_conv_size**2)]) |
|
return x |
|
|
|
def forward(self, x, image_mask, token_type_ids, image_type_ids, grid_thw): |
|
""" |
|
x: image_features |
|
image_mask: [B] |
|
token_types_ids: [B] |
|
image_type_ids: [B_image] |
|
grid_thw: [B_image, 3] |
|
""" |
|
assert image_type_ids is not None |
|
|
|
def fwd_spatial(x): |
|
""" |
|
x in the shape of [S, H] |
|
S is ordered in the following way: [ [patch_h*patch_w (row-major traversal)] * patch_time] |
|
H is simply hidden |
|
""" |
|
x = self.spatial_conv_reshape(x, self.spatial_conv_size) |
|
|
|
x = self.spatial_linear(x) |
|
|
|
return x |
|
|
|
def fwd_placeholder(x, grid_thw, to_tensor=False): |
|
""" |
|
x: [S, H] |
|
grid_thw: [S, 3] |
|
the second dimension: [t, h, w] |
|
""" |
|
|
|
grid_thw_cpu = grid_thw.cpu().numpy() |
|
grid_t, grid_hw = grid_thw_cpu[:, 0], grid_thw_cpu[:, 1:] |
|
grid_hw_after_conv = grid_hw.prod(-1) // (self.spatial_conv_size**2) |
|
|
|
tokens_per_img_or_vid = grid_thw_cpu.prod(-1) // (self.spatial_conv_size**2) |
|
batch_offset = np.empty( |
|
tokens_per_img_or_vid.size, dtype=tokens_per_img_or_vid.dtype |
|
) |
|
batch_offset[0] = 0 |
|
batch_offset[1:] = tokens_per_img_or_vid.cumsum()[:-1] |
|
|
|
assert ( |
|
self.temporal_conv_size == 2 |
|
), f"Hard Code: temporal_conv_size==2, got:{self.temporal_conv_size}" |
|
|
|
|
|
slice_offsets = [] |
|
for temporoal_size, spatial_size, b_offset in zip( |
|
grid_t, grid_hw_after_conv, batch_offset |
|
): |
|
for temp_offset in range(0, temporoal_size, 2): |
|
slice_offsets.append( |
|
np.arange( |
|
b_offset + (temp_offset) * spatial_size, |
|
b_offset + (temp_offset + 1) * spatial_size, |
|
) |
|
) |
|
slice_offsets = torch.tensor(np.concatenate(slice_offsets, axis=-1)).to( |
|
x.device |
|
) |
|
|
|
slice_offsets2 = [] |
|
for temporoal_size, spatial_size, b_offset in zip( |
|
grid_t, grid_hw_after_conv, batch_offset |
|
): |
|
for temp_offset in range( |
|
1 if temporoal_size > 1 else 0, temporoal_size, 2 |
|
): |
|
slice_offsets2.append( |
|
np.arange( |
|
b_offset + (temp_offset) * spatial_size, |
|
b_offset + (temp_offset + 1) * spatial_size, |
|
) |
|
) |
|
slice_offsets2 = torch.tensor(np.concatenate(slice_offsets2, axis=-1)).to( |
|
x.device |
|
) |
|
|
|
x_timestep_1 = torch.index_select(x, dim=0, index=slice_offsets) |
|
x_timestep_2 = torch.index_select(x, dim=0, index=slice_offsets2) |
|
x = torch.concat([x_timestep_1, x_timestep_2], dim=-1) |
|
return x |
|
|
|
def fwd_temporal(x): |
|
x = self.temporal_linear(x) |
|
return x |
|
|
|
def fwd_mlp(x): |
|
x = self.mlp(x) |
|
x = self.after_norm(x) |
|
return x |
|
|
|
x = fwd_spatial(x) |
|
if self.use_temporal_conv: |
|
x = fwd_placeholder(x, grid_thw) |
|
x = fwd_temporal(x) |
|
x = fwd_mlp(x) |
|
return x |
|
|
|
|
|
class Ernie4_5_MoeVLHead(Ernie4_5_MoeLMHead): |
|
"""Ernie4_5_MoeVLHead""" |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
if config.mm_vocab_size > 0: |
|
mm_vocab_config = deepcopy(config) |
|
mm_vocab_config.vocab_size = config.mm_vocab_size |
|
assert mm_vocab_config.vocab_size > 0, mm_vocab_config |
|
assert ( |
|
mm_vocab_config.im_patch_id >= mm_vocab_config.max_text_id |
|
), mm_vocab_config |
|
self.mm_head = Ernie4_5_MoeLMHead(mm_vocab_config) |
|
else: |
|
self.mm_head = None |
|
|
|
def forward(self, hidden_state, token_type_ids_labels, use_cache=False): |
|
""" |
|
Args: |
|
hidden_state(torch.Tensor): hidden state |
|
token_type_ids_labels(torch.Tensor): token ids |
|
use_cache(bool): whether to use cache, default is False |
|
|
|
Returns: |
|
logits_text(torch.Tensor): text logits |
|
logits_image(torch.Tensor): image logits |
|
""" |
|
if not use_cache: |
|
mm_head_weight = self.mm_head.weight if self.mm_head is not None else None |
|
mm_head_bias = self.mm_head.bias if self.mm_head is not None else None |
|
logits_text, logits_image, _ = calc_multimodal_logits( |
|
hidden_state, |
|
self.weight, |
|
self.bias, |
|
mm_head_weight, |
|
mm_head_bias, |
|
token_type_ids_labels, |
|
self.config, |
|
) |
|
return logits_text, logits_image, None |
|
else: |
|
|
|
return ( |
|
parallel_matmul( |
|
hidden_state[:, -1:, :], |
|
self.weight, |
|
self.bias, |
|
transpose_y=self.config.tie_word_embeddings, |
|
), |
|
None, |
|
None, |
|
) |
|
|
|
|
|
class Ernie4_5_VLMoeForConditionalGeneration(Ernie4_5_MoeForCausalLM): |
|
"""Ernie4_5_VLMoeForConditionalGeneration""" |
|
|
|
config_class = Ernie4_5_VLMoEConfig |
|
main_input_name = "pixel_values" |
|
_keep_in_fp16_modules = ["vision_model"] |
|
_tp_plan = {} |
|
|
|
def __init__( |
|
self, config: Ernie4_5_VLMoEConfig, vision_model=None, resampler_model=None |
|
): |
|
""" |
|
initialize Ernie4_5_VLMoeForConditionalGeneration |
|
|
|
Args: |
|
config(Ernie4_5_VLMoEConfig): Model configuration. |
|
vision_model(nn.Module): vision model |
|
resampler_model(nn.Module): resampler model |
|
""" |
|
super().__init__(config) |
|
|
|
self.vision_model = DFNRopeVisionTransformerPreTrainedModel( |
|
config.vision_config |
|
) |
|
|
|
self.model.resampler_model = VariableResolutionResamplerModel( |
|
config.pixel_hidden_size, |
|
config.hidden_size, |
|
config.spatial_conv_size, |
|
config.temporal_conv_size, |
|
config=config, |
|
) |
|
|
|
self.image_preprocess = None |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
self.post_init() |
|
|
|
def add_image_preprocess(self, processor): |
|
"""add image preprocess""" |
|
logger.info("image preprocess is set") |
|
|
|
image_preprocess = processor.image_processor |
|
image_preprocess.image_mean_tensor = torch.tensor( |
|
image_preprocess.image_mean, dtype=torch.float32 |
|
).reshape([1, 3, 1, 1]) |
|
image_preprocess.image_std_tensor = torch.tensor( |
|
image_preprocess.image_std, dtype=torch.float32 |
|
).reshape([1, 3, 1, 1]) |
|
image_preprocess.rescale_factor = torch.tensor( |
|
image_preprocess.rescale_factor, dtype=torch.float32 |
|
) |
|
image_preprocess.image_mean_tensor = image_preprocess.image_mean_tensor.squeeze( |
|
[-2, -1] |
|
).repeat_interleave(self.config.vision_config.patch_size**2 * 1, -1) |
|
image_preprocess.image_std_tensor = image_preprocess.image_std_tensor.squeeze( |
|
[-2, -1] |
|
).repeat_interleave(self.config.vision_config.patch_size**2 * 1, -1) |
|
|
|
self.image_preprocess = image_preprocess |
|
|
|
def vision_forward( |
|
self, |
|
images, |
|
image_position_ids, |
|
image_attention_mask, |
|
grid_thw, |
|
): |
|
"""vision_forward""" |
|
if self.image_preprocess is not None: |
|
assert images.dtype == torch.uint8, images.dtype |
|
current_device = images.device |
|
self.image_preprocess.image_mean_tensor = ( |
|
self.image_preprocess.image_mean_tensor.to(current_device) |
|
) |
|
self.image_preprocess.image_std_tensor = ( |
|
self.image_preprocess.image_std_tensor.to(current_device) |
|
) |
|
images = self.image_preprocess.rescale_factor * images.to(torch.float32) |
|
images = ( |
|
images - self.image_preprocess.image_mean_tensor |
|
) / self.image_preprocess.image_std_tensor |
|
images = images.to(torch.bfloat16) |
|
else: |
|
assert images.dtype == torch.bfloat16, images.dtype |
|
|
|
if grid_thw is not None: |
|
grid_thw = grid_thw[grid_thw > 0].reshape([-1, 3]) |
|
grid_thw = F.pad( |
|
torch.repeat_interleave(grid_thw[:, 1:], grid_thw[:, 0], 0), |
|
[1, 0, 0, 0], |
|
value=1, |
|
) |
|
image_features = self.vision_model(images, grid_thw) |
|
return image_features |
|
|
|
def vision_mapping_forward( |
|
self, |
|
token_type_ids, |
|
token_type_ids_w_video, |
|
input_ids, |
|
mm_input_ids, |
|
image_features, |
|
inputs_embeds, |
|
image_type_ids, |
|
grid_thw, |
|
): |
|
"""vision_mapping_forward""" |
|
image_mask = input_ids == self.config.im_patch_id |
|
image_features = self.model.resampler_model( |
|
image_features, |
|
image_mask, |
|
token_type_ids_w_video, |
|
image_type_ids, |
|
grid_thw, |
|
) |
|
|
|
if image_features.dim == 2: |
|
B, N, C = image_features.shape |
|
image_features = image_features.reshape([B * N, C]).to(inputs_embeds.dtype) |
|
|
|
inputs_embeds[image_mask.to(inputs_embeds.device)] = image_features.to( |
|
inputs_embeds.device |
|
) |
|
return inputs_embeds |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids, |
|
images=None, |
|
use_cache=False, |
|
past_key_values=None, |
|
inputs_embeds=None, |
|
image_position_ids=None, |
|
image_attention_mask=None, |
|
token_type_ids=None, |
|
image_type_ids=None, |
|
grid_thw=None, |
|
**kwargs, |
|
): |
|
""" |
|
Prepare inputs for the decoder that can be used for generation. |
|
|
|
Args: |
|
input_ids (torch.Tensor): Input ids. |
|
images (torch.Tensor): Images. Default to None. |
|
use_cache (bool): Whether to use cache. Default to False. |
|
past_key_values (list): Past key values. Default to None. |
|
inputs_embeds (torch.Tensor): Input embeddings. Default to None. |
|
image_position_ids (torch.Tensor): Image position ids. Default to None. |
|
image_attention_mask (torch.Tensor): Image attention mask. Default to None. |
|
token_type_ids (torch.Tensor): Token type ids. Default to None. |
|
image_type_ids (torch.Tensor): Image type ids. Default to None. |
|
grid_thw (torch.Tensor): Grid thw. Default to None. |
|
""" |
|
if past_key_values: |
|
input_ids = input_ids[:, -1:] |
|
token_type_ids = token_type_ids[:, -1:] |
|
image_type_ids = ( |
|
image_type_ids[:, -1:] if image_type_ids is not None else None |
|
) |
|
|
|
attention_mask = kwargs.get("attention_mask", None) |
|
|
|
|
|
if inputs_embeds is not None and past_key_values is None: |
|
model_inputs = {"inputs_embeds": inputs_embeds} |
|
else: |
|
model_inputs = {"input_ids": input_ids} |
|
|
|
model_inputs.update( |
|
{ |
|
"past_key_values": past_key_values, |
|
"use_cache": True, |
|
"attention_mask": attention_mask, |
|
"images": images, |
|
"image_position_ids": image_position_ids, |
|
"image_attention_mask": image_attention_mask, |
|
"image_type_ids": image_type_ids, |
|
"token_type_ids": torch.cat( |
|
[ |
|
token_type_ids, |
|
torch.zeros( |
|
[len(token_type_ids), 1], dtype=token_type_ids.dtype |
|
).to(token_type_ids.device), |
|
], |
|
dim=-1, |
|
), |
|
"grid_thw": grid_thw, |
|
} |
|
) |
|
if self.config.rope_3d: |
|
model_inputs.update({"position_ids": kwargs["position_ids"]}) |
|
|
|
return model_inputs |
|
|
|
def _post_init(self, original_init, *args, **kwargs): |
|
""" |
|
Label all multimodal parameters in the model, only head and Embedding |
|
Experts parameters are already labeled |
|
""" |
|
super()._post_init(self, original_init, *args, **kwargs) |
|
if self.lm_head.mm_head is not None: |
|
self.lm_head.mm_head.weight.expert_type = "expert_type_1" |
|
if getattr(self.lm_head.mm_head, "bias", None) is not None: |
|
self.lm_head.mm_head.bias.expert_type = "expert_type_1" |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
position_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[List[torch.Tensor]] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
images: Optional[torch.Tensor] = None, |
|
ignored_index: Optional[int] = 0, |
|
return_dict: Optional[bool] = None, |
|
image_position_ids: Optional[torch.Tensor] = None, |
|
image_attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
image_type_ids: Optional[torch.Tensor] = None, |
|
grid_thw: Optional[torch.Tensor] = None, |
|
**kwargs, |
|
): |
|
""" |
|
Forward for Ernie4_5_VLMoeForConditionalGeneration |
|
|
|
Args: |
|
input_ids (torch.Tensor): Input ids. |
|
position_ids (Optional[torch.Tensor], optional): Position ids. Defaults to None. |
|
attention_mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None. |
|
past_key_values (Optional[List[torch.Tensor]], optional): Past key values. Defaults to None. |
|
use_cache (Optional[bool], optional): Use cache. Defaults to None. |
|
output_attentions (Optional[bool], optional): Output attentions. Defaults to None. |
|
output_hidden_states (Optional[bool], optional): Output hidden states. Defaults to None. |
|
labels (Optional[torch.Tensor], optional): Labels. Defaults to None. |
|
images (Optional[torch.Tensor]): Images. Defaults to None. |
|
ignored_index (Optional[int], optional): Ignored index. Defaults to 0. |
|
return_dict (Optional[bool], optional): Return dict. Defaults to None. |
|
image_position_ids (Optional[torch.Tensor], optional): Image position ids. Defaults to None. |
|
image_attention_mask (Optional[torch.Tensor], optional): Image attention mask. Defaults to None. |
|
token_type_ids (Optional[torch.Tensor], optional): Token type ids. Defaults to None. |
|
image_type_ids (Optional[torch.Tensor], optional): Image type ids. Defaults to None. |
|
grid_thw (Optional[torch.Tensor], optional): Grid thw. Defaults to None. |
|
""" |
|
if grid_thw is not None: |
|
grid_thw = grid_thw[grid_thw > 0].reshape([-1, 3]) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
image_mask = input_ids == self.config.im_patch_id |
|
|
|
image_rate = image_mask.to(torch.float32).mean() |
|
|
|
if past_key_values is None: |
|
if images is not None: |
|
assert (image_mask).any().item(), ( |
|
image_mask.detach().cpu().numpy().tolist(), |
|
input_ids.detach().cpu().numpy().tolist(), |
|
self.config.im_patch_id, |
|
images.shape, |
|
) |
|
image_features = self.vision_forward( |
|
images, |
|
image_position_ids, |
|
image_attention_mask, |
|
grid_thw, |
|
) |
|
else: |
|
image_features = None |
|
else: |
|
image_features = None |
|
if token_type_ids is None: |
|
token_type_ids = image_mask.to(torch.int64) |
|
token_type_ids_labels = torch.cat( |
|
[token_type_ids[:, 1:], token_type_ids[:, -1:]], 1 |
|
) |
|
else: |
|
assert ( |
|
token_type_ids.shape[1] == input_ids.shape[1] + 1 |
|
), f"token_type:{token_type_ids.shape}, ids:{input_ids.shape}" |
|
token_type_ids_labels = token_type_ids[..., 1:] |
|
|
|
lm_input_ids = input_ids.clone() |
|
mm_input_ids = input_ids.clone() |
|
|
|
inputs_embeds = self.model.embed_tokens(lm_input_ids) |
|
token_type_ids_w_video = token_type_ids[..., :-1].clone() |
|
token_type_ids[token_type_ids == TokenType.video] = TokenType.image |
|
|
|
if images is not None and image_features is not None: |
|
inputs_embeds = self.vision_mapping_forward( |
|
token_type_ids, |
|
token_type_ids_w_video, |
|
input_ids, |
|
mm_input_ids, |
|
image_features, |
|
inputs_embeds, |
|
image_type_ids, |
|
grid_thw, |
|
) |
|
else: |
|
pass |
|
|
|
outputs = self.model( |
|
position_ids=position_ids, |
|
attention_mask=None, |
|
token_type_ids=token_type_ids, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
past_key_values=past_key_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=True, |
|
) |
|
|
|
if not use_cache: |
|
assert outputs.last_hidden_state.shape[:2] == token_type_ids_labels.shape, ( |
|
outputs.last_hidden_state.shape, |
|
token_type_ids_labels.shape, |
|
) |
|
if self.config.use_recompute_loss_fn: |
|
logits = outputs.last_hidden_state |
|
else: |
|
logits = self.lm_head(outputs.last_hidden_state) |
|
else: |
|
logits = self.lm_head(outputs.last_hidden_state[:, -1:, :]) |
|
|
|
router_loss = outputs.router_loss |
|
|
|
|
|
loss = None |
|
return CausalLMOutputWithCrossAttentions( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
router_loss=outputs.router_loss, |
|
) |
|
|
|
@staticmethod |
|
def _resolve_prefix_keys(state_keys_base, state_keys_real, ignore_error=False): |
|
"""_resolve_prefix_keys""" |
|
|
|
state_keys_map = {} |
|
|
|
state_keys_base = set(state_keys_base) |
|
state_keys_real = set(state_keys_real) |
|
|
|
for key in state_keys_base: |
|
for x in state_keys_real: |
|
if "mm_embed_tokens" in x: |
|
if "mm_embed_tokens" in key: |
|
state_keys_map[key] = x |
|
break |
|
elif x.endswith(key): |
|
state_keys_map[key] = x |
|
break |
|
if key not in state_keys_map: |
|
if not ignore_error: |
|
logger.error(f"could not find name {key} in loaded state dict!") |
|
else: |
|
state_keys_real.remove(state_keys_map[key]) |
|
|
|
return state_keys_map |
|
|
|
|
|
@dataclass |
|
class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): |
|
""" |
|
Base class for model outputs with past key values and cross attention layers, |
|
with additional support for router components in mixture-of-experts models. |
|
|
|
This extends the base model output to include: |
|
1. Router-related outputs for expert selection |
|
2. Maintains all existing functionality from the parent class |
|
""" |
|
|
|
last_hidden_state: Optional[Tuple[torch.Tensor]] = None |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None |
|
hidden_states: Optional[Tuple[torch.Tensor]] = None |
|
attentions: Optional[Tuple[torch.Tensor]] = None |
|
cross_attentions: Optional[Tuple[torch.Tensor]] = None |
|
router_loss: Optional[torch.Tensor] = None |
|
gate_logits: Optional[Tuple[torch.Tensor]] = None |
|
|
|
|
|
@dataclass |
|
class CausalLMOutputWithCrossAttentions(ModelOutput): |
|
""" |
|
Base class for causal language model (or autoregressive) outputs. |
|
|
|
Args: |
|
loss (`torch.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
|
Language modeling loss (for next-token prediction). |
|
logits (`torch.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
|
hidden_states (`tuple(torch.Tensor)`, *optional*, returned when `output_hidden_states=True` |
|
is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
|
attentions (`tuple(torch.Tensor)`, *optional*, returned when `output_attentions=True` is passed or |
|
when `config.output_attentions=True`): |
|
Tuple of `torch.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
router_loss (Optional[torch.Tensor]): |
|
The routing loss computed by the gating network in mixture-of-experts models. |
|
This is typically the load balancing loss that encourages equal expert utilization. |
|
None when not using mixture-of-experts routing. |
|
""" |
|
|
|
loss: Optional[torch.Tensor] = None |
|
logits: torch.Tensor = None |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None |
|
hidden_states: Optional[Tuple[torch.Tensor]] = None |
|
attentions: Optional[Tuple[torch.Tensor]] = None |
|
router_loss: Optional[Tuple[torch.Tensor]] = None |
|
|