beyoru commited on
Commit
575f2c0
·
verified ·
1 Parent(s): 2f51806

Update qwen3_moe_model.py

Browse files
Files changed (1) hide show
  1. qwen3_moe_model.py +37 -44
qwen3_moe_model.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import torch
3
  from torch.nn import ModuleDict
4
  from transformers import (
@@ -9,92 +8,86 @@ from transformers import (
9
  )
10
  from .qwen3_moe_config import Qwen3MoEConfig
11
  from transformers.modeling_outputs import CausalLMOutput
12
- from typing import Optional, Tuple, Union, Dict
13
-
14
 
15
 
16
  class Qwen3MoEForCausalLM(PreTrainedModel):
17
  config_class = Qwen3MoEConfig
18
 
19
- def __init__(self, config: Qwen3MoEConfig):
 
 
 
20
  super().__init__(config)
 
 
 
21
  self.router = AutoModelForSequenceClassification.from_pretrained(
22
- config.router_model_path,
 
23
  torch_dtype=config.torch_dtype,
24
- trust_remote_code=True,
25
- local_files_only=True
26
  )
27
-
28
  self.router_tokenizer = AutoTokenizer.from_pretrained(
29
- config.router_model_path,
30
- trust_remote_code=True,
31
- local_files_only=True
32
  )
33
 
 
34
  self.experts = ModuleDict({
35
  label: AutoModelForCausalLM.from_pretrained(
36
- path,
 
37
  torch_dtype=config.torch_dtype,
38
- trust_remote_code=True,
39
- local_files_only=True
40
  )
41
- for label, path in config.expert_model_paths.items()
42
  })
43
 
 
44
  self.expert_tokenizer = AutoTokenizer.from_pretrained(
45
- config.tokenizer_path,
46
- trust_remote_code=True,
47
- local_files_only=True
48
  )
49
 
50
  @classmethod
51
- def from_pretrained(cls, pretrained_dir: str, config: Optional[Qwen3MoEConfig] = None, **kwargs):
52
  if config is None:
53
- config = Qwen3MoEConfig.from_pretrained(pretrained_dir)
54
-
55
- base = pretrained_dir
56
- config.router_model_path = os.path.join(base, config.router_model_path)
57
- config.expert_model_paths = {
58
- label: os.path.join(base, path)
59
- for label, path in config.expert_model_paths.items()
60
- }
61
- config.tokenizer_path = os.path.join(base, config.tokenizer_path)
62
-
63
- return cls(config)
64
 
65
  def get_tokenizer(self):
66
  return self.expert_tokenizer
67
 
68
  def route(self, plain_text: str) -> str:
 
 
 
69
  with torch.no_grad():
70
  inputs = self.router_tokenizer(plain_text, return_tensors="pt").to(self.router.device)
71
  logits = self.router(**inputs).logits
72
-
73
  if logits.dim() == 2:
74
  class_id = torch.argmax(logits, dim=-1).item()
75
  return self.config.labels[class_id]
76
-
77
  return self.config.labels[0]
78
 
79
- def generate(
80
- self,
81
- text: str,
82
- max_new_tokens: int = 50,
83
- **kwargs
84
- ) -> torch.LongTensor:
85
- # 1. Route using router tokenizer
86
  plain_text = text
87
  if "<|im_start|>" in plain_text:
88
- temp = plain_text.split("<|im_start|>")[-2]
89
- plain_text = temp[:temp.find("<|im_end|>")][4:]
90
 
91
  label = self.route(plain_text)
92
  expert = self.experts[label]
93
 
94
- # 2. Tokenize once with the expert tokenizer
95
  inputs = self.expert_tokenizer(text, return_tensors="pt").to(expert.device)
96
 
97
- # 3. Generate using selected expert
98
  return expert.generate(
99
  input_ids=inputs.input_ids,
100
  attention_mask=inputs.attention_mask,
@@ -109,4 +102,4 @@ class Qwen3MoEForCausalLM(PreTrainedModel):
109
  labels: Optional[torch.LongTensor] = None,
110
  **kwargs
111
  ) -> Union[Tuple, CausalLMOutput]:
112
- raise NotImplementedError("Use `generate(text=...)` instead for inference.")
 
 
1
  import torch
2
  from torch.nn import ModuleDict
3
  from transformers import (
 
8
  )
9
  from .qwen3_moe_config import Qwen3MoEConfig
10
  from transformers.modeling_outputs import CausalLMOutput
11
+ from typing import Optional, Tuple, Union
 
12
 
13
 
14
  class Qwen3MoEForCausalLM(PreTrainedModel):
15
  config_class = Qwen3MoEConfig
16
 
17
+ def __init__(self, config: Qwen3MoEConfig, hub_repo_id: str):
18
+ """
19
+ hub_repo_id: str, ví dụ "beyoru/Qwen3-MaCoTo"
20
+ """
21
  super().__init__(config)
22
+ self.hub_repo_id = hub_repo_id
23
+
24
+ # Load router model + tokenizer từ subfolder "router"
25
  self.router = AutoModelForSequenceClassification.from_pretrained(
26
+ hub_repo_id,
27
+ subfolder="router",
28
  torch_dtype=config.torch_dtype,
29
+ trust_remote_code=True
 
30
  )
 
31
  self.router_tokenizer = AutoTokenizer.from_pretrained(
32
+ hub_repo_id,
33
+ subfolder="router",
34
+ trust_remote_code=True
35
  )
36
 
37
+ # Load expert models từ các subfolder tương ứng
38
  self.experts = ModuleDict({
39
  label: AutoModelForCausalLM.from_pretrained(
40
+ hub_repo_id,
41
+ subfolder=folder, # folder con code/math/if/tool
42
  torch_dtype=config.torch_dtype,
43
+ trust_remote_code=True
 
44
  )
45
+ for label, folder in config.expert_model_paths.items()
46
  })
47
 
48
+ # Load tokenizer chung của expert từ root repo
49
  self.expert_tokenizer = AutoTokenizer.from_pretrained(
50
+ hub_repo_id,
51
+ subfolder=".",
52
+ trust_remote_code=True
53
  )
54
 
55
  @classmethod
56
+ def from_pretrained(cls, hub_repo_id: str, config: Optional[Qwen3MoEConfig] = None, **kwargs):
57
  if config is None:
58
+ # load config.json từ root repo
59
+ config = Qwen3MoEConfig.from_pretrained(hub_repo_id, **kwargs)
60
+ return cls(config, hub_repo_id=hub_repo_id)
 
 
 
 
 
 
 
 
61
 
62
  def get_tokenizer(self):
63
  return self.expert_tokenizer
64
 
65
  def route(self, plain_text: str) -> str:
66
+ """
67
+ Dùng router model để chọn expert phù hợp dựa trên input text.
68
+ """
69
  with torch.no_grad():
70
  inputs = self.router_tokenizer(plain_text, return_tensors="pt").to(self.router.device)
71
  logits = self.router(**inputs).logits
 
72
  if logits.dim() == 2:
73
  class_id = torch.argmax(logits, dim=-1).item()
74
  return self.config.labels[class_id]
 
75
  return self.config.labels[0]
76
 
77
+ def generate(self, text: str, max_new_tokens: int = 50, **kwargs) -> torch.LongTensor:
78
+ """
79
+ Generate text từ expert được chọn bởi router.
80
+ """
 
 
 
81
  plain_text = text
82
  if "<|im_start|>" in plain_text:
83
+ temp = plain_text.split("<|im_start|>")[-2]
84
+ plain_text = temp[:temp.find("<|im_end|>")][4:]
85
 
86
  label = self.route(plain_text)
87
  expert = self.experts[label]
88
 
 
89
  inputs = self.expert_tokenizer(text, return_tensors="pt").to(expert.device)
90
 
 
91
  return expert.generate(
92
  input_ids=inputs.input_ids,
93
  attention_mask=inputs.attention_mask,
 
102
  labels: Optional[torch.LongTensor] = None,
103
  **kwargs
104
  ) -> Union[Tuple, CausalLMOutput]:
105
+ raise NotImplementedError("Use `generate(text=...)` instead for inference.")