Update qwen3_moe_model.py
Browse files- qwen3_moe_model.py +37 -44
qwen3_moe_model.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import os
|
2 |
import torch
|
3 |
from torch.nn import ModuleDict
|
4 |
from transformers import (
|
@@ -9,92 +8,86 @@ from transformers import (
|
|
9 |
)
|
10 |
from .qwen3_moe_config import Qwen3MoEConfig
|
11 |
from transformers.modeling_outputs import CausalLMOutput
|
12 |
-
from typing import Optional, Tuple, Union
|
13 |
-
|
14 |
|
15 |
|
16 |
class Qwen3MoEForCausalLM(PreTrainedModel):
|
17 |
config_class = Qwen3MoEConfig
|
18 |
|
19 |
-
def __init__(self, config: Qwen3MoEConfig):
|
|
|
|
|
|
|
20 |
super().__init__(config)
|
|
|
|
|
|
|
21 |
self.router = AutoModelForSequenceClassification.from_pretrained(
|
22 |
-
|
|
|
23 |
torch_dtype=config.torch_dtype,
|
24 |
-
trust_remote_code=True
|
25 |
-
local_files_only=True
|
26 |
)
|
27 |
-
|
28 |
self.router_tokenizer = AutoTokenizer.from_pretrained(
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
)
|
33 |
|
|
|
34 |
self.experts = ModuleDict({
|
35 |
label: AutoModelForCausalLM.from_pretrained(
|
36 |
-
|
|
|
37 |
torch_dtype=config.torch_dtype,
|
38 |
-
trust_remote_code=True
|
39 |
-
local_files_only=True
|
40 |
)
|
41 |
-
for label,
|
42 |
})
|
43 |
|
|
|
44 |
self.expert_tokenizer = AutoTokenizer.from_pretrained(
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
)
|
49 |
|
50 |
@classmethod
|
51 |
-
def from_pretrained(cls,
|
52 |
if config is None:
|
53 |
-
config
|
54 |
-
|
55 |
-
|
56 |
-
config.router_model_path = os.path.join(base, config.router_model_path)
|
57 |
-
config.expert_model_paths = {
|
58 |
-
label: os.path.join(base, path)
|
59 |
-
for label, path in config.expert_model_paths.items()
|
60 |
-
}
|
61 |
-
config.tokenizer_path = os.path.join(base, config.tokenizer_path)
|
62 |
-
|
63 |
-
return cls(config)
|
64 |
|
65 |
def get_tokenizer(self):
|
66 |
return self.expert_tokenizer
|
67 |
|
68 |
def route(self, plain_text: str) -> str:
|
|
|
|
|
|
|
69 |
with torch.no_grad():
|
70 |
inputs = self.router_tokenizer(plain_text, return_tensors="pt").to(self.router.device)
|
71 |
logits = self.router(**inputs).logits
|
72 |
-
|
73 |
if logits.dim() == 2:
|
74 |
class_id = torch.argmax(logits, dim=-1).item()
|
75 |
return self.config.labels[class_id]
|
76 |
-
|
77 |
return self.config.labels[0]
|
78 |
|
79 |
-
def generate(
|
80 |
-
|
81 |
-
text
|
82 |
-
|
83 |
-
**kwargs
|
84 |
-
) -> torch.LongTensor:
|
85 |
-
# 1. Route using router tokenizer
|
86 |
plain_text = text
|
87 |
if "<|im_start|>" in plain_text:
|
88 |
-
|
89 |
-
|
90 |
|
91 |
label = self.route(plain_text)
|
92 |
expert = self.experts[label]
|
93 |
|
94 |
-
# 2. Tokenize once with the expert tokenizer
|
95 |
inputs = self.expert_tokenizer(text, return_tensors="pt").to(expert.device)
|
96 |
|
97 |
-
# 3. Generate using selected expert
|
98 |
return expert.generate(
|
99 |
input_ids=inputs.input_ids,
|
100 |
attention_mask=inputs.attention_mask,
|
@@ -109,4 +102,4 @@ class Qwen3MoEForCausalLM(PreTrainedModel):
|
|
109 |
labels: Optional[torch.LongTensor] = None,
|
110 |
**kwargs
|
111 |
) -> Union[Tuple, CausalLMOutput]:
|
112 |
-
raise NotImplementedError("Use `generate(text=...)` instead for inference.")
|
|
|
|
|
1 |
import torch
|
2 |
from torch.nn import ModuleDict
|
3 |
from transformers import (
|
|
|
8 |
)
|
9 |
from .qwen3_moe_config import Qwen3MoEConfig
|
10 |
from transformers.modeling_outputs import CausalLMOutput
|
11 |
+
from typing import Optional, Tuple, Union
|
|
|
12 |
|
13 |
|
14 |
class Qwen3MoEForCausalLM(PreTrainedModel):
|
15 |
config_class = Qwen3MoEConfig
|
16 |
|
17 |
+
def __init__(self, config: Qwen3MoEConfig, hub_repo_id: str):
|
18 |
+
"""
|
19 |
+
hub_repo_id: str, ví dụ "beyoru/Qwen3-MaCoTo"
|
20 |
+
"""
|
21 |
super().__init__(config)
|
22 |
+
self.hub_repo_id = hub_repo_id
|
23 |
+
|
24 |
+
# Load router model + tokenizer từ subfolder "router"
|
25 |
self.router = AutoModelForSequenceClassification.from_pretrained(
|
26 |
+
hub_repo_id,
|
27 |
+
subfolder="router",
|
28 |
torch_dtype=config.torch_dtype,
|
29 |
+
trust_remote_code=True
|
|
|
30 |
)
|
|
|
31 |
self.router_tokenizer = AutoTokenizer.from_pretrained(
|
32 |
+
hub_repo_id,
|
33 |
+
subfolder="router",
|
34 |
+
trust_remote_code=True
|
35 |
)
|
36 |
|
37 |
+
# Load expert models từ các subfolder tương ứng
|
38 |
self.experts = ModuleDict({
|
39 |
label: AutoModelForCausalLM.from_pretrained(
|
40 |
+
hub_repo_id,
|
41 |
+
subfolder=folder, # folder con code/math/if/tool
|
42 |
torch_dtype=config.torch_dtype,
|
43 |
+
trust_remote_code=True
|
|
|
44 |
)
|
45 |
+
for label, folder in config.expert_model_paths.items()
|
46 |
})
|
47 |
|
48 |
+
# Load tokenizer chung của expert từ root repo
|
49 |
self.expert_tokenizer = AutoTokenizer.from_pretrained(
|
50 |
+
hub_repo_id,
|
51 |
+
subfolder=".",
|
52 |
+
trust_remote_code=True
|
53 |
)
|
54 |
|
55 |
@classmethod
|
56 |
+
def from_pretrained(cls, hub_repo_id: str, config: Optional[Qwen3MoEConfig] = None, **kwargs):
|
57 |
if config is None:
|
58 |
+
# load config.json từ root repo
|
59 |
+
config = Qwen3MoEConfig.from_pretrained(hub_repo_id, **kwargs)
|
60 |
+
return cls(config, hub_repo_id=hub_repo_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
def get_tokenizer(self):
|
63 |
return self.expert_tokenizer
|
64 |
|
65 |
def route(self, plain_text: str) -> str:
|
66 |
+
"""
|
67 |
+
Dùng router model để chọn expert phù hợp dựa trên input text.
|
68 |
+
"""
|
69 |
with torch.no_grad():
|
70 |
inputs = self.router_tokenizer(plain_text, return_tensors="pt").to(self.router.device)
|
71 |
logits = self.router(**inputs).logits
|
|
|
72 |
if logits.dim() == 2:
|
73 |
class_id = torch.argmax(logits, dim=-1).item()
|
74 |
return self.config.labels[class_id]
|
|
|
75 |
return self.config.labels[0]
|
76 |
|
77 |
+
def generate(self, text: str, max_new_tokens: int = 50, **kwargs) -> torch.LongTensor:
|
78 |
+
"""
|
79 |
+
Generate text từ expert được chọn bởi router.
|
80 |
+
"""
|
|
|
|
|
|
|
81 |
plain_text = text
|
82 |
if "<|im_start|>" in plain_text:
|
83 |
+
temp = plain_text.split("<|im_start|>")[-2]
|
84 |
+
plain_text = temp[:temp.find("<|im_end|>")][4:]
|
85 |
|
86 |
label = self.route(plain_text)
|
87 |
expert = self.experts[label]
|
88 |
|
|
|
89 |
inputs = self.expert_tokenizer(text, return_tensors="pt").to(expert.device)
|
90 |
|
|
|
91 |
return expert.generate(
|
92 |
input_ids=inputs.input_ids,
|
93 |
attention_mask=inputs.attention_mask,
|
|
|
102 |
labels: Optional[torch.LongTensor] = None,
|
103 |
**kwargs
|
104 |
) -> Union[Tuple, CausalLMOutput]:
|
105 |
+
raise NotImplementedError("Use `generate(text=...)` instead for inference.")
|