File size: 7,964 Bytes
f0644c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from functools import partial
from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
class _UniformExpertAssignment(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, num_experts: int):
out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
out = torch.remainder(out, num_experts)
return out.view(x.shape)
class LearnedRouter(torch.nn.Module):
def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int, moe_jitter_eps: Optional[float], moe_normalize_expert_weights: Optional[Union[int, float]], uniform_expert_assignment: bool, device: Optional[torch.device]) -> None:
super().__init__()
self.hidden_size: int = hidden_size
self.moe_num_experts: int = moe_num_experts
self.moe_top_k: int = moe_top_k
self.moe_jitter_eps: Optional[float] = moe_jitter_eps
self.moe_normalize_expert_weights: Optional[Union[int, float]] = moe_normalize_expert_weights
self.uniform_expert_assignment: bool = uniform_expert_assignment
self.layer: torch.nn.Module = torch.nn.Linear(hidden_size, moe_num_experts, bias=False, device=device)
def jitter(self, x: torch.Tensor) -> torch.Tensor:
assert self.moe_jitter_eps is not None
low: float = 1.0 - self.moe_jitter_eps
high: float = 1.0 + self.moe_jitter_eps
noise: torch.Tensor = torch.rand(x.size(), dtype=x.dtype, device=x.device)
return low + noise * (high - low)
def _top_k(self, scores: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.moe_top_k == 1:
values, indices = scores.max(dim=-1)
return (values.unsqueeze(-1), indices.unsqueeze(-1))
return torch.topk(scores, self.moe_top_k, dim=-1)
def forward(self, x: torch.Tensor):
if self.training and self.moe_jitter_eps is not None:
x = x * self.jitter(x)
scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1)
expert_weights, top_experts = self._top_k(scores)
if self.moe_normalize_expert_weights:
expert_weights = expert_weights / torch.norm(expert_weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True)
top_experts = _UniformExpertAssignment.apply(top_experts, self.moe_num_experts) if self.uniform_expert_assignment else top_experts
scores = scores.to(x.dtype)
expert_weights = expert_weights.to(x.dtype)
return (scores, expert_weights, top_experts)
class MLP(torch.nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, activation_fn: Callable, device: Optional[torch.device]) -> None:
super().__init__()
self.moe_num_experts: int = moe_num_experts
self.ffn_hidden_size: int = ffn_hidden_size
self.hidden_size: int = hidden_size
self.activation_fn: Callable = activation_fn
self.w1 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
self.w2 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
self.activation_fn = activation_fn
def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
before_activation = x @ expert_w1.t()
layer_1_output = self.activation_fn(before_activation)
output = layer_1_output @ expert_w2
return output
class GLU(torch.nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, activation_fn: Callable, device: Optional[torch.device]):
super().__init__()
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.moe_num_experts = moe_num_experts
self.w1 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
self.v1 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
self.w2 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
self.activation_fn = activation_fn
def forward(self, x: torch.Tensor, expert_idx: torch.Tensor):
expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
x1 = x.matmul(expert_w1.t())
x2 = x.matmul(expert_v1.t())
x1 = self.activation_fn(x1)
x1 = x1 * x2
x1 = x1.matmul(expert_w2)
return x1
class DroplessMLP(torch.nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int, mlp_type: str, moe_num_experts: int, activation_fn: Callable, bias: bool, device: Optional[torch.device]):
super().__init__()
self.moe_num_experts = moe_num_experts
if mlp_type == 'mlp':
self.mlp = MLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, moe_num_experts=moe_num_experts, activation_fn=activation_fn, device=device)
elif mlp_type == 'glu':
self.mlp = GLU(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, moe_num_experts=moe_num_experts, activation_fn=activation_fn, device=device)
else:
raise ValueError(f'Received unknown mlp_type={mlp_type!r}')
def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
in_shape = x.shape
hidden_size = in_shape[-1]
x = x.view(-1, hidden_size)
out = torch.zeros_like(x)
expert_mask = torch.nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
for expert_idx in range(0, self.moe_num_experts):
topk_idx, token_idx = torch.where(expert_mask[expert_idx])
if token_idx.shape[0] == 0:
continue
token_list = token_idx.tolist()
topk_list = topk_idx.tolist()
expert_tokens = x[None, token_list].reshape(-1, hidden_size)
mlp_output = self.mlp(expert_tokens, expert_idx)
expert_weights = expert_weights.to(mlp_output.device)
expert_out = mlp_output * expert_weights[token_list, topk_list, None]
out = out.to(mlp_output.device)
token_idx = token_idx.to(mlp_output.device)
out.index_add_(0, token_idx, expert_out)
out = out.view(in_shape)
return out
class dMoE(torch.nn.Module):
def __init__(self, device: Optional[torch.device], hidden_size: int=1024, ffn_hidden_size: int=4096, moe_num_experts: int=1, moe_top_k: int=1, mlp_type: str='mlp', activation_fn: Callable=DEFAULT_ACTIVATION_FN, moe_jitter_eps: Optional[float]=None, moe_normalize_expert_weights: Optional[Union[int, float]]=None, uniform_expert_assignment: bool=False, bias: bool=True):
super().__init__()
self.router = LearnedRouter(hidden_size, moe_num_experts=moe_num_experts, moe_top_k=moe_top_k, moe_jitter_eps=moe_jitter_eps, moe_normalize_expert_weights=moe_normalize_expert_weights, uniform_expert_assignment=uniform_expert_assignment, device=device)
self.experts = DroplessMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, mlp_type=mlp_type, moe_num_experts=moe_num_experts, activation_fn=activation_fn, bias=bias, device=device)
def forward(self, x: torch.Tensor):
scores, expert_weights, top_experts = self.router(x)
return self.experts(x, scores, expert_weights, top_experts) |