File size: 7,964 Bytes
f0644c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from functools import partial
from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')

class _UniformExpertAssignment(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x: torch.Tensor, num_experts: int):
        out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
        out = torch.remainder(out, num_experts)
        return out.view(x.shape)

class LearnedRouter(torch.nn.Module):

    def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int, moe_jitter_eps: Optional[float], moe_normalize_expert_weights: Optional[Union[int, float]], uniform_expert_assignment: bool, device: Optional[torch.device]) -> None:
        super().__init__()
        self.hidden_size: int = hidden_size
        self.moe_num_experts: int = moe_num_experts
        self.moe_top_k: int = moe_top_k
        self.moe_jitter_eps: Optional[float] = moe_jitter_eps
        self.moe_normalize_expert_weights: Optional[Union[int, float]] = moe_normalize_expert_weights
        self.uniform_expert_assignment: bool = uniform_expert_assignment
        self.layer: torch.nn.Module = torch.nn.Linear(hidden_size, moe_num_experts, bias=False, device=device)

    def jitter(self, x: torch.Tensor) -> torch.Tensor:
        assert self.moe_jitter_eps is not None
        low: float = 1.0 - self.moe_jitter_eps
        high: float = 1.0 + self.moe_jitter_eps
        noise: torch.Tensor = torch.rand(x.size(), dtype=x.dtype, device=x.device)
        return low + noise * (high - low)

    def _top_k(self, scores: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        if self.moe_top_k == 1:
            values, indices = scores.max(dim=-1)
            return (values.unsqueeze(-1), indices.unsqueeze(-1))
        return torch.topk(scores, self.moe_top_k, dim=-1)

    def forward(self, x: torch.Tensor):
        if self.training and self.moe_jitter_eps is not None:
            x = x * self.jitter(x)
        scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1)
        expert_weights, top_experts = self._top_k(scores)
        if self.moe_normalize_expert_weights:
            expert_weights = expert_weights / torch.norm(expert_weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True)
        top_experts = _UniformExpertAssignment.apply(top_experts, self.moe_num_experts) if self.uniform_expert_assignment else top_experts
        scores = scores.to(x.dtype)
        expert_weights = expert_weights.to(x.dtype)
        return (scores, expert_weights, top_experts)

class MLP(torch.nn.Module):

    def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, activation_fn: Callable, device: Optional[torch.device]) -> None:
        super().__init__()
        self.moe_num_experts: int = moe_num_experts
        self.ffn_hidden_size: int = ffn_hidden_size
        self.hidden_size: int = hidden_size
        self.activation_fn: Callable = activation_fn
        self.w1 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
        self.w2 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
        self.activation_fn = activation_fn

    def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
        expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
        expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
        before_activation = x @ expert_w1.t()
        layer_1_output = self.activation_fn(before_activation)
        output = layer_1_output @ expert_w2
        return output

class GLU(torch.nn.Module):

    def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, activation_fn: Callable, device: Optional[torch.device]):
        super().__init__()
        self.hidden_size = hidden_size
        self.ffn_hidden_size = ffn_hidden_size
        self.moe_num_experts = moe_num_experts
        self.w1 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
        self.v1 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
        self.w2 = torch.nn.Parameter(torch.rand(moe_num_experts * ffn_hidden_size, hidden_size, device=device))
        self.activation_fn = activation_fn

    def forward(self, x: torch.Tensor, expert_idx: torch.Tensor):
        expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
        expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
        expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size, self.hidden_size)[expert_idx]
        x1 = x.matmul(expert_w1.t())
        x2 = x.matmul(expert_v1.t())
        x1 = self.activation_fn(x1)
        x1 = x1 * x2
        x1 = x1.matmul(expert_w2)
        return x1

class DroplessMLP(torch.nn.Module):

    def __init__(self, hidden_size: int, ffn_hidden_size: int, mlp_type: str, moe_num_experts: int, activation_fn: Callable, bias: bool, device: Optional[torch.device]):
        super().__init__()
        self.moe_num_experts = moe_num_experts
        if mlp_type == 'mlp':
            self.mlp = MLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, moe_num_experts=moe_num_experts, activation_fn=activation_fn, device=device)
        elif mlp_type == 'glu':
            self.mlp = GLU(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, moe_num_experts=moe_num_experts, activation_fn=activation_fn, device=device)
        else:
            raise ValueError(f'Received unknown mlp_type={mlp_type!r}')

    def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
        in_shape = x.shape
        hidden_size = in_shape[-1]
        x = x.view(-1, hidden_size)
        out = torch.zeros_like(x)
        expert_mask = torch.nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
        for expert_idx in range(0, self.moe_num_experts):
            topk_idx, token_idx = torch.where(expert_mask[expert_idx])
            if token_idx.shape[0] == 0:
                continue
            token_list = token_idx.tolist()
            topk_list = topk_idx.tolist()
            expert_tokens = x[None, token_list].reshape(-1, hidden_size)
            mlp_output = self.mlp(expert_tokens, expert_idx)
            expert_weights = expert_weights.to(mlp_output.device)
            expert_out = mlp_output * expert_weights[token_list, topk_list, None]
            out = out.to(mlp_output.device)
            token_idx = token_idx.to(mlp_output.device)
            out.index_add_(0, token_idx, expert_out)
        out = out.view(in_shape)
        return out

class dMoE(torch.nn.Module):

    def __init__(self, device: Optional[torch.device], hidden_size: int=1024, ffn_hidden_size: int=4096, moe_num_experts: int=1, moe_top_k: int=1, mlp_type: str='mlp', activation_fn: Callable=DEFAULT_ACTIVATION_FN, moe_jitter_eps: Optional[float]=None, moe_normalize_expert_weights: Optional[Union[int, float]]=None, uniform_expert_assignment: bool=False, bias: bool=True):
        super().__init__()
        self.router = LearnedRouter(hidden_size, moe_num_experts=moe_num_experts, moe_top_k=moe_top_k, moe_jitter_eps=moe_jitter_eps, moe_normalize_expert_weights=moe_normalize_expert_weights, uniform_expert_assignment=uniform_expert_assignment, device=device)
        self.experts = DroplessMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, mlp_type=mlp_type, moe_num_experts=moe_num_experts, activation_fn=activation_fn, bias=bias, device=device)

    def forward(self, x: torch.Tensor):
        scores, expert_weights, top_experts = self.router(x)
        return self.experts(x, scores, expert_weights, top_experts)