JingzeShi commited on
Commit
5c54e17
·
verified ·
1 Parent(s): ec3de5d

Upload DogeForCausalLM

Browse files
Files changed (4) hide show
  1. config.json +1 -2
  2. configuration_doge.py +1 -5
  3. model.safetensors +1 -1
  4. modeling_doge.py +12 -20
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "/root/autodl-tmp/data/Doge-320M",
3
  "architectures": [
4
  "DogeForCausalLM"
5
  ],
@@ -11,7 +11,6 @@
11
  "bos_token_id": 0,
12
  "dynamic_mask_ratio": 0.0,
13
  "eos_token_id": 1,
14
- "expert_retrieval_size": 64,
15
  "hidden_act": "silu",
16
  "hidden_bias": false,
17
  "hidden_dropout": 0.0,
 
1
  {
2
+ "_name_or_path": "/root/autodl-tmp/small-doge/data/Doge-320M-decay/checkpoint-4000",
3
  "architectures": [
4
  "DogeForCausalLM"
5
  ],
 
11
  "bos_token_id": 0,
12
  "dynamic_mask_ratio": 0.0,
13
  "eos_token_id": 1,
 
14
  "hidden_act": "silu",
15
  "hidden_bias": false,
16
  "hidden_dropout": 0.0,
configuration_doge.py CHANGED
@@ -121,8 +121,6 @@ class DogeConfig(PretrainedConfig):
121
  Number of Experts for the Cross Domain Mixture of Experts.
122
  num_experts_per_tok (`int`, *optional*, defaults to 8):
123
  Number of selected experts to route per-token.
124
- expert_retrieval_size (`int`, *optional*, defaults to 64):
125
- Dimension of the Expert retrieval states for calculating the dot product of query and key to determine the expert index.
126
 
127
  ```python
128
  >>> from transformers import DogeConfig, DogeModel
@@ -149,7 +147,7 @@ class DogeConfig(PretrainedConfig):
149
  "layers.*.feed_forward.gate_proj": "colwise",
150
  "layers.*.feed_forward.up_proj": "colwise",
151
  "layers.*.feed_forward.down_proj": "rowwise",
152
- "layers.*.feed_forward.queries_proj": "colwise",
153
  "layers.*.feed_forward.down_embed": "rowwise",
154
  "layers.*.feed_forward.up_embed": "rowwise",
155
  }
@@ -181,7 +179,6 @@ class DogeConfig(PretrainedConfig):
181
  is_moe=False,
182
  num_experts=2048,
183
  num_experts_per_tok=8,
184
- expert_retrieval_size=64,
185
  **kwargs,
186
  ):
187
  self.vocab_size = vocab_size
@@ -207,7 +204,6 @@ class DogeConfig(PretrainedConfig):
207
  self.is_moe = is_moe
208
  self.num_experts = num_experts
209
  self.num_experts_per_tok = num_experts_per_tok
210
- self.expert_retrieval_size = expert_retrieval_size
211
 
212
  # Validate the correctness of rotary position embeddings parameters
213
  # BC: if there is a 'type' field, copy it it to 'rope_type'.
 
121
  Number of Experts for the Cross Domain Mixture of Experts.
122
  num_experts_per_tok (`int`, *optional*, defaults to 8):
123
  Number of selected experts to route per-token.
 
 
124
 
125
  ```python
126
  >>> from transformers import DogeConfig, DogeModel
 
147
  "layers.*.feed_forward.gate_proj": "colwise",
148
  "layers.*.feed_forward.up_proj": "colwise",
149
  "layers.*.feed_forward.down_proj": "rowwise",
150
+ "layers.*.feed_forward.router_gate": "colwise",
151
  "layers.*.feed_forward.down_embed": "rowwise",
152
  "layers.*.feed_forward.up_embed": "rowwise",
153
  }
 
179
  is_moe=False,
180
  num_experts=2048,
181
  num_experts_per_tok=8,
 
182
  **kwargs,
183
  ):
184
  self.vocab_size = vocab_size
 
204
  self.is_moe = is_moe
205
  self.num_experts = num_experts
206
  self.num_experts_per_tok = num_experts_per_tok
 
207
 
208
  # Validate the correctness of rotary position embeddings parameters
209
  # BC: if there is a 'type' field, copy it it to 'rope_type'.
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bb05e3cf42d7df4c683d7c6719d195d06614686b766ed0782bc2f3b7c71afec5
3
  size 1343277696
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce4aaf436761b12719bb9be9d3a250ba388679b324886299ed71f69c2b53a510
3
  size 1343277696
modeling_doge.py CHANGED
@@ -480,23 +480,17 @@ class DogeCDMoE(DogeMLP):
480
  self.hidden_dim = config.hidden_size
481
  self.act_fn = ACT2FN[config.hidden_act]
482
 
483
- self.expert_retrieval_dim = config.expert_retrieval_size
484
  self.num_experts = config.num_experts
485
  self.top_k = config.num_experts_per_tok
486
  self.num_keys = int(math.sqrt(self.num_experts))
487
 
488
- # queries and keys for retrieval experts
489
- self.queries_proj = nn.Linear(self.hidden_dim, self.expert_retrieval_dim, bias=False)
490
- self.keys = nn.Parameter(torch.zeros(2, self.expert_retrieval_dim // 2, self.num_keys))
491
 
492
  # experts
493
  self.down_embed = nn.Embedding(self.num_experts, self.hidden_dim)
494
  self.up_embed = nn.Embedding(self.num_experts, self.hidden_dim)
495
 
496
- # scaling factor
497
- self.mlp_scaling = nn.Parameter(torch.ones(self.hidden_dim))
498
- self.moe_scaling = nn.Parameter(torch.zeros(self.hidden_dim))
499
-
500
  def forward(
501
  self,
502
  hidden_states: torch.Tensor,
@@ -504,27 +498,25 @@ class DogeCDMoE(DogeMLP):
504
  ) -> torch.Tensor:
505
  bsz, seq_len, _ = hidden_states.shape
506
 
507
- # get routing weights with queries and keys
508
- queries = self.queries_proj(hidden_states).view(2, bsz * seq_len, -1)
509
- routing_weights = torch.matmul(queries, self.keys)
510
 
511
  # get experts with the highest routing weights
512
- (scores_x, scores_y), (indices_x, indices_y) = routing_weights.topk(self.top_k, dim=-1)
513
  all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
514
- all_scores = all_scores.view(*scores_x.shape[:-1], -1)
515
- all_indices = (indices_x.unsqueeze(-1) * self.num_keys) + indices_y.unsqueeze(-2)
516
- all_indices = all_indices.view(*indices_x.shape[:-1], -1)
517
- scores, pk_indices = all_scores.topk(self.top_k, dim=-1)
518
- indices = all_indices.gather(-1, pk_indices)
519
- down_embed = self.down_embed(indices).transpose(1, 2)
520
  up_embed = self.up_embed(indices)
521
 
522
  # mix experts states with cross domain states
523
- experts_weights = torch.matmul(hidden_states.view(bsz * seq_len, 1, -1), down_embed).view(bsz * seq_len, -1)
524
  experts_weights = self.act_fn(experts_weights) * scores.softmax(dim=-1)
525
  experts_states = torch.matmul(experts_weights.view(bsz * seq_len, 1, -1), up_embed).view(bsz, seq_len, -1)
526
  hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
527
- hidden_states = (hidden_states * self.mlp_scaling) + (experts_states * self.moe_scaling)
528
  return hidden_states
529
 
530
 
 
480
  self.hidden_dim = config.hidden_size
481
  self.act_fn = ACT2FN[config.hidden_act]
482
 
 
483
  self.num_experts = config.num_experts
484
  self.top_k = config.num_experts_per_tok
485
  self.num_keys = int(math.sqrt(self.num_experts))
486
 
487
+ # router gate for retrieval experts
488
+ self.router_gate = nn.Linear(self.hidden_dim, self.num_keys * 2)
 
489
 
490
  # experts
491
  self.down_embed = nn.Embedding(self.num_experts, self.hidden_dim)
492
  self.up_embed = nn.Embedding(self.num_experts, self.hidden_dim)
493
 
 
 
 
 
494
  def forward(
495
  self,
496
  hidden_states: torch.Tensor,
 
498
  ) -> torch.Tensor:
499
  bsz, seq_len, _ = hidden_states.shape
500
 
501
+ # get routing weights with router gate
502
+ routing_weights = self.router_gate(hidden_states).view(2, bsz * seq_len, -1)
 
503
 
504
  # get experts with the highest routing weights
505
+ (scores_x, scores_y), (indices_x, indices_y) = [w.topk(self.num_keys, dim=-1) for w in routing_weights]
506
  all_scores = scores_x.unsqueeze(-1) + scores_y.unsqueeze(-2)
507
+ all_indices = indices_x.unsqueeze(-1) * self.num_keys + indices_y.unsqueeze(-2)
508
+ all_scores = all_scores.view(*all_scores.shape[:-2], -1)
509
+ all_indices = all_indices.view(*all_indices.shape[:-2], -1)
510
+ scores, indices = all_scores.topk(self.top_k, dim=-1)
511
+ down_embed = self.down_embed(indices)
 
512
  up_embed = self.up_embed(indices)
513
 
514
  # mix experts states with cross domain states
515
+ experts_weights = torch.matmul(down_embed, hidden_states.view(bsz * seq_len, -1, 1)).view(bsz * seq_len, -1)
516
  experts_weights = self.act_fn(experts_weights) * scores.softmax(dim=-1)
517
  experts_states = torch.matmul(experts_weights.view(bsz * seq_len, 1, -1), up_embed).view(bsz, seq_len, -1)
518
  hidden_states = self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
519
+ hidden_states = hidden_states + experts_states
520
  return hidden_states
521
 
522