Upload folder using huggingface_hub
Browse files- LICENSE +201 -0
- README.md +118 -3
- added_tokens.json +1 -0
- config.json +29 -0
- configuration_ernie4_5.py +127 -0
- generation_config.json +11 -0
- model.safetensors +3 -0
- modeling_ernie4_5.py +1068 -0
- special_tokens_map.json +1 -0
- tokenization_ernie4_5.py +377 -0
- tokenizer.model +3 -0
- tokenizer_config.json +22 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,3 +1,118 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ERNIE-4.5-0.3B
|
2 |
+
|
3 |
+
## ERNIE 4.5 Highlights
|
4 |
+
|
5 |
+
The advanced capabilities of the ERNIE 4.5 models, particularly the MoE-based A47B and A3B series, are underpinned by several key technical innovations:
|
6 |
+
|
7 |
+
- **Multimodal MoE Pretraining:** Our models are jointly trained on both textual and visual modalities to better capture the nuances of multimodal information and improve performance on tasks involving text generation, image understanding, and cross-modal reasoning. To achieve this without one modality hindering the learning of another, we designed a heterogeneous MoE structure, incorporated three-dimensional rotary embeddings, and employed router orthogonal loss and multimodal token-balanced loss. These architectural choices ensure that both modalities are effectively represented, allowing for mutual reinforcement during training.
|
8 |
+
- **Scaling-Efficient Architecture and Infrastructure:** To train the large multimodal MoE models efficiently, we introduce a novel heterogeneous hybrid parallelism and multi-level load balancing strategy for efficient training of ERNIE 4.5 models. By using on-device expert parallelism, memory-efficient pipeline scheduling, and FP8 mixed precision, we achieve ideal pre-training performance. For inference, we propose a quantization method with collaborative parallelism among multiple experts to achieve lossless quantization. Built on PaddlePaddle, ERNIE 4.5 delivers high-performance inference across a wide range of hardware platforms.
|
9 |
+
- **Modality-Specific Post-training:** To meet the diverse requirements of real-world applications, we fine-tuned variants of the pretrained model for specific modalities. Our LLMs are optimized for general-purpose language understanding and generation. The VLMs focuses on visual-language understanding and supports both thinking and no-thinking mode. Each model employed a combination of Supervised Fine-tuning (SFT), Direct Preference Optimization (DPO) or a modified reinforcement learning method named Unified Preference Optimization (UPO) for post-training, using targeted datasets aligned with its intended usage scenario.
|
10 |
+
|
11 |
+
## Model Overview
|
12 |
+
|
13 |
+
ERNIE-4.5-0.3B is a text dense Post-trained model. The following are the model configuration details:
|
14 |
+
|
15 |
+
| Key | Value |
|
16 |
+
| -------------- | ------------ |
|
17 |
+
| Modality | Text |
|
18 |
+
| Training Stage | Posttraining |
|
19 |
+
| Params | 0.36B |
|
20 |
+
| Layers | 18 |
|
21 |
+
| Heads(Q/KV) | 16 / 2 |
|
22 |
+
| Context Length | 131072 |
|
23 |
+
|
24 |
+
## Quickstart
|
25 |
+
|
26 |
+
### Model Finetuning with ERNIEKit
|
27 |
+
|
28 |
+
[ERNIEKit](https://github.com/PaddlePaddle/ERNIE) is a training toolkit based on PaddlePaddle, specifically designed for the ERNIE series of open-source large models. It provides comprehensive support for scenarios such as instruction fine-tuning (SFT, LoRA) and alignment training (DPO), ensuring optimal performance.
|
29 |
+
|
30 |
+
Usage Examples:
|
31 |
+
|
32 |
+
```bash
|
33 |
+
# SFT
|
34 |
+
erniekit train --stage SFT --model_name_or_path /baidu/ERNIE-4.5-0.3B --train_dataset_path your_dataset_path
|
35 |
+
# DPO
|
36 |
+
erniekit train --stage DPO --model_name_or_path /baidu/ERNIE-4.5-0.3B --train_dataset_path your_dataset_path
|
37 |
+
```
|
38 |
+
|
39 |
+
For more detailed examples, including SFT with LoRA, multi-GPU configurations, and advanced scripts, please refer to the examples folder within the [ERNIEKit](https://github.com/PaddlePaddle/ERNIE) repository.
|
40 |
+
|
41 |
+
### FastDeploy Inference
|
42 |
+
|
43 |
+
Service deployment can be quickly completed using FastDeploy in the following command. For more detailed usage instructions, please refer to the [FastDeploy Repository](https://github.com/PaddlePaddle/FastDeploy).
|
44 |
+
|
45 |
+
```bash
|
46 |
+
python -m fastdeploy.entrypoints.openai.api_server \
|
47 |
+
--model BAIDU/ERNIE-4.5-0.3B-Paddle \
|
48 |
+
--port 8180 \
|
49 |
+
--metrics-port 8181 \
|
50 |
+
--engine-worker-queue-port 8182 \
|
51 |
+
--max-model-len 32768 \ # Maximum supported number of tokens
|
52 |
+
--max-num-seqs 32 # Maximum concurrent processing capacity
|
53 |
+
```
|
54 |
+
|
55 |
+
### Using `transformers` library
|
56 |
+
|
57 |
+
The following contains a code snippet illustrating how to use the model generate content based on given inputs.
|
58 |
+
|
59 |
+
```python
|
60 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
61 |
+
|
62 |
+
model_name = "baidu/ERNIE-4.5-0.3B-PT"
|
63 |
+
|
64 |
+
# load the tokenizer and the model
|
65 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
66 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
67 |
+
|
68 |
+
# prepare the model input
|
69 |
+
prompt = "Give me a short introduction to large language model."
|
70 |
+
messages = [
|
71 |
+
{"role": "user", "content": prompt}
|
72 |
+
]
|
73 |
+
text = tokenizer.apply_chat_template(
|
74 |
+
messages,
|
75 |
+
tokenize=False,
|
76 |
+
add_generation_prompt=True
|
77 |
+
)
|
78 |
+
model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device)
|
79 |
+
|
80 |
+
# conduct text completion
|
81 |
+
generated_ids = model.generate(
|
82 |
+
model_inputs.input_ids,
|
83 |
+
max_new_tokens=1024
|
84 |
+
)
|
85 |
+
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
|
86 |
+
|
87 |
+
# decode the generated ids
|
88 |
+
generate_text = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
|
89 |
+
print("generate_text:", generate_text)
|
90 |
+
```
|
91 |
+
|
92 |
+
### vLLM inference
|
93 |
+
|
94 |
+
vLLM is currently being adapted, priority can be given to using our fork repository [vllm](https://github.com/CSWYF3634076/vllm/tree/ernie)
|
95 |
+
|
96 |
+
```bash
|
97 |
+
vllm serve baidu/ERNIE-4.5-0.3B-PT --trust-remote-code
|
98 |
+
```
|
99 |
+
|
100 |
+
## License
|
101 |
+
|
102 |
+
The ERNIE 4.5 models are provided under the Apache License 2.0. This license permits commercial use, subject to its terms and conditions. Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
103 |
+
|
104 |
+
## Citation
|
105 |
+
|
106 |
+
If you find ERNIE 4.5 useful or wish to use it in your projects, please kindly cite our technical report:
|
107 |
+
|
108 |
+
```bibtex
|
109 |
+
@misc{ernie2025technicalreport,
|
110 |
+
title={ERNIE 4.5 Technical Report},
|
111 |
+
author={Baidu ERNIE Team},
|
112 |
+
year={2025},
|
113 |
+
eprint={},
|
114 |
+
archivePrefix={arXiv},
|
115 |
+
primaryClass={cs.CL},
|
116 |
+
url={}
|
117 |
+
}
|
118 |
+
```
|
added_tokens.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"<|IMAGE_PLACEHOLDER|>": 100295, "<|AUDIO_PLACEHOLDER|>": 100296, "<|LOC_0|>": 100297, "<|LOC_1|>": 100298, "<|LOC_2|>": 100299, "<|LOC_3|>": 100300, "<|LOC_4|>": 100301, "<|LOC_5|>": 100302, "<|LOC_6|>": 100303, "<|LOC_7|>": 100304, "<|LOC_8|>": 100305, "<|LOC_9|>": 100306, "<|LOC_10|>": 100307, "<|LOC_11|>": 100308, "<|LOC_12|>": 100309, "<|LOC_13|>": 100310, "<|LOC_14|>": 100311, "<|LOC_15|>": 100312, "<|LOC_16|>": 100313, "<|LOC_17|>": 100314, "<|LOC_18|>": 100315, "<|LOC_19|>": 100316, "<|LOC_20|>": 100317, "<|LOC_21|>": 100318, "<|LOC_22|>": 100319, "<|LOC_23|>": 100320, "<|LOC_24|>": 100321, "<|LOC_25|>": 100322, "<|LOC_26|>": 100323, "<|LOC_27|>": 100324, "<|LOC_28|>": 100325, "<|LOC_29|>": 100326, "<|LOC_30|>": 100327, "<|LOC_31|>": 100328, "<|LOC_32|>": 100329, "<|LOC_33|>": 100330, "<|LOC_34|>": 100331, "<|LOC_35|>": 100332, "<|LOC_36|>": 100333, "<|LOC_37|>": 100334, "<|LOC_38|>": 100335, "<|LOC_39|>": 100336, "<|LOC_40|>": 100337, "<|LOC_41|>": 100338, "<|LOC_42|>": 100339, "<|LOC_43|>": 100340, "<|LOC_44|>": 100341, "<|LOC_45|>": 100342, "<|LOC_46|>": 100343, "<|LOC_47|>": 100344, "<|LOC_48|>": 100345, "<|LOC_49|>": 100346, "<|LOC_50|>": 100347, "<|LOC_51|>": 100348, "<|LOC_52|>": 100349, "<|LOC_53|>": 100350, "<|LOC_54|>": 100351, "<|LOC_55|>": 100352, "<|LOC_56|>": 100353, "<|LOC_57|>": 100354, "<|LOC_58|>": 100355, "<|LOC_59|>": 100356, "<|LOC_60|>": 100357, "<|LOC_61|>": 100358, "<|LOC_62|>": 100359, "<|LOC_63|>": 100360, "<|LOC_64|>": 100361, "<|LOC_65|>": 100362, "<|LOC_66|>": 100363, "<|LOC_67|>": 100364, "<|LOC_68|>": 100365, "<|LOC_69|>": 100366, "<|LOC_70|>": 100367, "<|LOC_71|>": 100368, "<|LOC_72|>": 100369, "<|LOC_73|>": 100370, "<|LOC_74|>": 100371, "<|LOC_75|>": 100372, "<|LOC_76|>": 100373, "<|LOC_77|>": 100374, "<|LOC_78|>": 100375, "<|LOC_79|>": 100376, "<|LOC_80|>": 100377, "<|LOC_81|>": 100378, "<|LOC_82|>": 100379, "<|LOC_83|>": 100380, "<|LOC_84|>": 100381, "<|LOC_85|>": 100382, "<|LOC_86|>": 100383, "<|LOC_87|>": 100384, "<|LOC_88|>": 100385, "<|LOC_89|>": 100386, "<|LOC_90|>": 100387, "<|LOC_91|>": 100388, "<|LOC_92|>": 100389, "<|LOC_93|>": 100390, "<|LOC_94|>": 100391, "<|LOC_95|>": 100392, "<|LOC_96|>": 100393, "<|LOC_97|>": 100394, "<|LOC_98|>": 100395, "<|LOC_99|>": 100396, "<|LOC_100|>": 100397, "<|LOC_101|>": 100398, "<|LOC_102|>": 100399, "<|LOC_103|>": 100400, "<|LOC_104|>": 100401, "<|LOC_105|>": 100402, "<|LOC_106|>": 100403, "<|LOC_107|>": 100404, "<|LOC_108|>": 100405, "<|LOC_109|>": 100406, "<|LOC_110|>": 100407, "<|LOC_111|>": 100408, "<|LOC_112|>": 100409, "<|LOC_113|>": 100410, "<|LOC_114|>": 100411, "<|LOC_115|>": 100412, "<|LOC_116|>": 100413, "<|LOC_117|>": 100414, "<|LOC_118|>": 100415, "<|LOC_119|>": 100416, "<|LOC_120|>": 100417, "<|LOC_121|>": 100418, "<|LOC_122|>": 100419, "<|LOC_123|>": 100420, "<|LOC_124|>": 100421, "<|LOC_125|>": 100422, "<|LOC_126|>": 100423, "<|LOC_127|>": 100424, "<|LOC_128|>": 100425, "<|LOC_129|>": 100426, "<|LOC_130|>": 100427, "<|LOC_131|>": 100428, "<|LOC_132|>": 100429, "<|LOC_133|>": 100430, "<|LOC_134|>": 100431, "<|LOC_135|>": 100432, "<|LOC_136|>": 100433, "<|LOC_137|>": 100434, "<|LOC_138|>": 100435, "<|LOC_139|>": 100436, "<|LOC_140|>": 100437, "<|LOC_141|>": 100438, "<|LOC_142|>": 100439, "<|LOC_143|>": 100440, "<|LOC_144|>": 100441, "<|LOC_145|>": 100442, "<|LOC_146|>": 100443, "<|LOC_147|>": 100444, "<|LOC_148|>": 100445, "<|LOC_149|>": 100446, "<|LOC_150|>": 100447, "<|LOC_151|>": 100448, "<|LOC_152|>": 100449, "<|LOC_153|>": 100450, "<|LOC_154|>": 100451, "<|LOC_155|>": 100452, "<|LOC_156|>": 100453, "<|LOC_157|>": 100454, "<|LOC_158|>": 100455, "<|LOC_159|>": 100456, "<|LOC_160|>": 100457, "<|LOC_161|>": 100458, "<|LOC_162|>": 100459, "<|LOC_163|>": 100460, "<|LOC_164|>": 100461, "<|LOC_165|>": 100462, "<|LOC_166|>": 100463, "<|LOC_167|>": 100464, "<|LOC_168|>": 100465, "<|LOC_169|>": 100466, "<|LOC_170|>": 100467, "<|LOC_171|>": 100468, "<|LOC_172|>": 100469, "<|LOC_173|>": 100470, "<|LOC_174|>": 100471, "<|LOC_175|>": 100472, "<|LOC_176|>": 100473, "<|LOC_177|>": 100474, "<|LOC_178|>": 100475, "<|LOC_179|>": 100476, "<|LOC_180|>": 100477, "<|LOC_181|>": 100478, "<|LOC_182|>": 100479, "<|LOC_183|>": 100480, "<|LOC_184|>": 100481, "<|LOC_185|>": 100482, "<|LOC_186|>": 100483, "<|LOC_187|>": 100484, "<|LOC_188|>": 100485, "<|LOC_189|>": 100486, "<|LOC_190|>": 100487, "<|LOC_191|>": 100488, "<|LOC_192|>": 100489, "<|LOC_193|>": 100490, "<|LOC_194|>": 100491, "<|LOC_195|>": 100492, "<|LOC_196|>": 100493, "<|LOC_197|>": 100494, "<|LOC_198|>": 100495, "<|LOC_199|>": 100496, "<|LOC_200|>": 100497, "<|LOC_201|>": 100498, "<|LOC_202|>": 100499, "<|LOC_203|>": 100500, "<|LOC_204|>": 100501, "<|LOC_205|>": 100502, "<|LOC_206|>": 100503, "<|LOC_207|>": 100504, "<|LOC_208|>": 100505, "<|LOC_209|>": 100506, "<|LOC_210|>": 100507, "<|LOC_211|>": 100508, "<|LOC_212|>": 100509, "<|LOC_213|>": 100510, "<|LOC_214|>": 100511, "<|LOC_215|>": 100512, "<|LOC_216|>": 100513, "<|LOC_217|>": 100514, "<|LOC_218|>": 100515, "<|LOC_219|>": 100516, "<|LOC_220|>": 100517, "<|LOC_221|>": 100518, "<|LOC_222|>": 100519, "<|LOC_223|>": 100520, "<|LOC_224|>": 100521, "<|LOC_225|>": 100522, "<|LOC_226|>": 100523, "<|LOC_227|>": 100524, "<|LOC_228|>": 100525, "<|LOC_229|>": 100526, "<|LOC_230|>": 100527, "<|LOC_231|>": 100528, "<|LOC_232|>": 100529, "<|LOC_233|>": 100530, "<|LOC_234|>": 100531, "<|LOC_235|>": 100532, "<|LOC_236|>": 100533, "<|LOC_237|>": 100534, "<|LOC_238|>": 100535, "<|LOC_239|>": 100536, "<|LOC_240|>": 100537, "<|LOC_241|>": 100538, "<|LOC_242|>": 100539, "<|LOC_243|>": 100540, "<|LOC_244|>": 100541, "<|LOC_245|>": 100542, "<|LOC_246|>": 100543, "<|LOC_247|>": 100544, "<|LOC_248|>": 100545, "<|LOC_249|>": 100546, "<|LOC_250|>": 100547, "<|LOC_251|>": 100548, "<|LOC_252|>": 100549, "<|LOC_253|>": 100550, "<|LOC_254|>": 100551, "<|LOC_255|>": 100552, "<|LOC_256|>": 100553, "<|LOC_257|>": 100554, "<|LOC_258|>": 100555, "<|LOC_259|>": 100556, "<|LOC_260|>": 100557, "<|LOC_261|>": 100558, "<|LOC_262|>": 100559, "<|LOC_263|>": 100560, "<|LOC_264|>": 100561, "<|LOC_265|>": 100562, "<|LOC_266|>": 100563, "<|LOC_267|>": 100564, "<|LOC_268|>": 100565, "<|LOC_269|>": 100566, "<|LOC_270|>": 100567, "<|LOC_271|>": 100568, "<|LOC_272|>": 100569, "<|LOC_273|>": 100570, "<|LOC_274|>": 100571, "<|LOC_275|>": 100572, "<|LOC_276|>": 100573, "<|LOC_277|>": 100574, "<|LOC_278|>": 100575, "<|LOC_279|>": 100576, "<|LOC_280|>": 100577, "<|LOC_281|>": 100578, "<|LOC_282|>": 100579, "<|LOC_283|>": 100580, "<|LOC_284|>": 100581, "<|LOC_285|>": 100582, "<|LOC_286|>": 100583, "<|LOC_287|>": 100584, "<|LOC_288|>": 100585, "<|LOC_289|>": 100586, "<|LOC_290|>": 100587, "<|LOC_291|>": 100588, "<|LOC_292|>": 100589, "<|LOC_293|>": 100590, "<|LOC_294|>": 100591, "<|LOC_295|>": 100592, "<|LOC_296|>": 100593, "<|LOC_297|>": 100594, "<|LOC_298|>": 100595, "<|LOC_299|>": 100596, "<|LOC_300|>": 100597, "<|LOC_301|>": 100598, "<|LOC_302|>": 100599, "<|LOC_303|>": 100600, "<|LOC_304|>": 100601, "<|LOC_305|>": 100602, "<|LOC_306|>": 100603, "<|LOC_307|>": 100604, "<|LOC_308|>": 100605, "<|LOC_309|>": 100606, "<|LOC_310|>": 100607, "<|LOC_311|>": 100608, "<|LOC_312|>": 100609, "<|LOC_313|>": 100610, "<|LOC_314|>": 100611, "<|LOC_315|>": 100612, "<|LOC_316|>": 100613, "<|LOC_317|>": 100614, "<|LOC_318|>": 100615, "<|LOC_319|>": 100616, "<|LOC_320|>": 100617, "<|LOC_321|>": 100618, "<|LOC_322|>": 100619, "<|LOC_323|>": 100620, "<|LOC_324|>": 100621, "<|LOC_325|>": 100622, "<|LOC_326|>": 100623, "<|LOC_327|>": 100624, "<|LOC_328|>": 100625, "<|LOC_329|>": 100626, "<|LOC_330|>": 100627, "<|LOC_331|>": 100628, "<|LOC_332|>": 100629, "<|LOC_333|>": 100630, "<|LOC_334|>": 100631, "<|LOC_335|>": 100632, "<|LOC_336|>": 100633, "<|LOC_337|>": 100634, "<|LOC_338|>": 100635, "<|LOC_339|>": 100636, "<|LOC_340|>": 100637, "<|LOC_341|>": 100638, "<|LOC_342|>": 100639, "<|LOC_343|>": 100640, "<|LOC_344|>": 100641, "<|LOC_345|>": 100642, "<|LOC_346|>": 100643, "<|LOC_347|>": 100644, "<|LOC_348|>": 100645, "<|LOC_349|>": 100646, "<|LOC_350|>": 100647, "<|LOC_351|>": 100648, "<|LOC_352|>": 100649, "<|LOC_353|>": 100650, "<|LOC_354|>": 100651, "<|LOC_355|>": 100652, "<|LOC_356|>": 100653, "<|LOC_357|>": 100654, "<|LOC_358|>": 100655, "<|LOC_359|>": 100656, "<|LOC_360|>": 100657, "<|LOC_361|>": 100658, "<|LOC_362|>": 100659, "<|LOC_363|>": 100660, "<|LOC_364|>": 100661, "<|LOC_365|>": 100662, "<|LOC_366|>": 100663, "<|LOC_367|>": 100664, "<|LOC_368|>": 100665, "<|LOC_369|>": 100666, "<|LOC_370|>": 100667, "<|LOC_371|>": 100668, "<|LOC_372|>": 100669, "<|LOC_373|>": 100670, "<|LOC_374|>": 100671, "<|LOC_375|>": 100672, "<|LOC_376|>": 100673, "<|LOC_377|>": 100674, "<|LOC_378|>": 100675, "<|LOC_379|>": 100676, "<|LOC_380|>": 100677, "<|LOC_381|>": 100678, "<|LOC_382|>": 100679, "<|LOC_383|>": 100680, "<|LOC_384|>": 100681, "<|LOC_385|>": 100682, "<|LOC_386|>": 100683, "<|LOC_387|>": 100684, "<|LOC_388|>": 100685, "<|LOC_389|>": 100686, "<|LOC_390|>": 100687, "<|LOC_391|>": 100688, "<|LOC_392|>": 100689, "<|LOC_393|>": 100690, "<|LOC_394|>": 100691, "<|LOC_395|>": 100692, "<|LOC_396|>": 100693, "<|LOC_397|>": 100694, "<|LOC_398|>": 100695, "<|LOC_399|>": 100696, "<|LOC_400|>": 100697, "<|LOC_401|>": 100698, "<|LOC_402|>": 100699, "<|LOC_403|>": 100700, "<|LOC_404|>": 100701, "<|LOC_405|>": 100702, "<|LOC_406|>": 100703, "<|LOC_407|>": 100704, "<|LOC_408|>": 100705, "<|LOC_409|>": 100706, "<|LOC_410|>": 100707, "<|LOC_411|>": 100708, "<|LOC_412|>": 100709, "<|LOC_413|>": 100710, "<|LOC_414|>": 100711, "<|LOC_415|>": 100712, "<|LOC_416|>": 100713, "<|LOC_417|>": 100714, "<|LOC_418|>": 100715, "<|LOC_419|>": 100716, "<|LOC_420|>": 100717, "<|LOC_421|>": 100718, "<|LOC_422|>": 100719, "<|LOC_423|>": 100720, "<|LOC_424|>": 100721, "<|LOC_425|>": 100722, "<|LOC_426|>": 100723, "<|LOC_427|>": 100724, "<|LOC_428|>": 100725, "<|LOC_429|>": 100726, "<|LOC_430|>": 100727, "<|LOC_431|>": 100728, "<|LOC_432|>": 100729, "<|LOC_433|>": 100730, "<|LOC_434|>": 100731, "<|LOC_435|>": 100732, "<|LOC_436|>": 100733, "<|LOC_437|>": 100734, "<|LOC_438|>": 100735, "<|LOC_439|>": 100736, "<|LOC_440|>": 100737, "<|LOC_441|>": 100738, "<|LOC_442|>": 100739, "<|LOC_443|>": 100740, "<|LOC_444|>": 100741, "<|LOC_445|>": 100742, "<|LOC_446|>": 100743, "<|LOC_447|>": 100744, "<|LOC_448|>": 100745, "<|LOC_449|>": 100746, "<|LOC_450|>": 100747, "<|LOC_451|>": 100748, "<|LOC_452|>": 100749, "<|LOC_453|>": 100750, "<|LOC_454|>": 100751, "<|LOC_455|>": 100752, "<|LOC_456|>": 100753, "<|LOC_457|>": 100754, "<|LOC_458|>": 100755, "<|LOC_459|>": 100756, "<|LOC_460|>": 100757, "<|LOC_461|>": 100758, "<|LOC_462|>": 100759, "<|LOC_463|>": 100760, "<|LOC_464|>": 100761, "<|LOC_465|>": 100762, "<|LOC_466|>": 100763, "<|LOC_467|>": 100764, "<|LOC_468|>": 100765, "<|LOC_469|>": 100766, "<|LOC_470|>": 100767, "<|LOC_471|>": 100768, "<|LOC_472|>": 100769, "<|LOC_473|>": 100770, "<|LOC_474|>": 100771, "<|LOC_475|>": 100772, "<|LOC_476|>": 100773, "<|LOC_477|>": 100774, "<|LOC_478|>": 100775, "<|LOC_479|>": 100776, "<|LOC_480|>": 100777, "<|LOC_481|>": 100778, "<|LOC_482|>": 100779, "<|LOC_483|>": 100780, "<|LOC_484|>": 100781, "<|LOC_485|>": 100782, "<|LOC_486|>": 100783, "<|LOC_487|>": 100784, "<|LOC_488|>": 100785, "<|LOC_489|>": 100786, "<|LOC_490|>": 100787, "<|LOC_491|>": 100788, "<|LOC_492|>": 100789, "<|LOC_493|>": 100790, "<|LOC_494|>": 100791, "<|LOC_495|>": 100792, "<|LOC_496|>": 100793, "<|LOC_497|>": 100794, "<|LOC_498|>": 100795, "<|LOC_499|>": 100796, "<|LOC_500|>": 100797, "<|LOC_501|>": 100798, "<|LOC_502|>": 100799, "<|LOC_503|>": 100800, "<|LOC_504|>": 100801, "<|LOC_505|>": 100802, "<|LOC_506|>": 100803, "<|LOC_507|>": 100804, "<|LOC_508|>": 100805, "<|LOC_509|>": 100806, "<|LOC_510|>": 100807, "<|LOC_511|>": 100808, "<|LOC_512|>": 100809, "<|LOC_513|>": 100810, "<|LOC_514|>": 100811, "<|LOC_515|>": 100812, "<|LOC_516|>": 100813, "<|LOC_517|>": 100814, "<|LOC_518|>": 100815, "<|LOC_519|>": 100816, "<|LOC_520|>": 100817, "<|LOC_521|>": 100818, "<|LOC_522|>": 100819, "<|LOC_523|>": 100820, "<|LOC_524|>": 100821, "<|LOC_525|>": 100822, "<|LOC_526|>": 100823, "<|LOC_527|>": 100824, "<|LOC_528|>": 100825, "<|LOC_529|>": 100826, "<|LOC_530|>": 100827, "<|LOC_531|>": 100828, "<|LOC_532|>": 100829, "<|LOC_533|>": 100830, "<|LOC_534|>": 100831, "<|LOC_535|>": 100832, "<|LOC_536|>": 100833, "<|LOC_537|>": 100834, "<|LOC_538|>": 100835, "<|LOC_539|>": 100836, "<|LOC_540|>": 100837, "<|LOC_541|>": 100838, "<|LOC_542|>": 100839, "<|LOC_543|>": 100840, "<|LOC_544|>": 100841, "<|LOC_545|>": 100842, "<|LOC_546|>": 100843, "<|LOC_547|>": 100844, "<|LOC_548|>": 100845, "<|LOC_549|>": 100846, "<|LOC_550|>": 100847, "<|LOC_551|>": 100848, "<|LOC_552|>": 100849, "<|LOC_553|>": 100850, "<|LOC_554|>": 100851, "<|LOC_555|>": 100852, "<|LOC_556|>": 100853, "<|LOC_557|>": 100854, "<|LOC_558|>": 100855, "<|LOC_559|>": 100856, "<|LOC_560|>": 100857, "<|LOC_561|>": 100858, "<|LOC_562|>": 100859, "<|LOC_563|>": 100860, "<|LOC_564|>": 100861, "<|LOC_565|>": 100862, "<|LOC_566|>": 100863, "<|LOC_567|>": 100864, "<|LOC_568|>": 100865, "<|LOC_569|>": 100866, "<|LOC_570|>": 100867, "<|LOC_571|>": 100868, "<|LOC_572|>": 100869, "<|LOC_573|>": 100870, "<|LOC_574|>": 100871, "<|LOC_575|>": 100872, "<|LOC_576|>": 100873, "<|LOC_577|>": 100874, "<|LOC_578|>": 100875, "<|LOC_579|>": 100876, "<|LOC_580|>": 100877, "<|LOC_581|>": 100878, "<|LOC_582|>": 100879, "<|LOC_583|>": 100880, "<|LOC_584|>": 100881, "<|LOC_585|>": 100882, "<|LOC_586|>": 100883, "<|LOC_587|>": 100884, "<|LOC_588|>": 100885, "<|LOC_589|>": 100886, "<|LOC_590|>": 100887, "<|LOC_591|>": 100888, "<|LOC_592|>": 100889, "<|LOC_593|>": 100890, "<|LOC_594|>": 100891, "<|LOC_595|>": 100892, "<|LOC_596|>": 100893, "<|LOC_597|>": 100894, "<|LOC_598|>": 100895, "<|LOC_599|>": 100896, "<|LOC_600|>": 100897, "<|LOC_601|>": 100898, "<|LOC_602|>": 100899, "<|LOC_603|>": 100900, "<|LOC_604|>": 100901, "<|LOC_605|>": 100902, "<|LOC_606|>": 100903, "<|LOC_607|>": 100904, "<|LOC_608|>": 100905, "<|LOC_609|>": 100906, "<|LOC_610|>": 100907, "<|LOC_611|>": 100908, "<|LOC_612|>": 100909, "<|LOC_613|>": 100910, "<|LOC_614|>": 100911, "<|LOC_615|>": 100912, "<|LOC_616|>": 100913, "<|LOC_617|>": 100914, "<|LOC_618|>": 100915, "<|LOC_619|>": 100916, "<|LOC_620|>": 100917, "<|LOC_621|>": 100918, "<|LOC_622|>": 100919, "<|LOC_623|>": 100920, "<|LOC_624|>": 100921, "<|LOC_625|>": 100922, "<|LOC_626|>": 100923, "<|LOC_627|>": 100924, "<|LOC_628|>": 100925, "<|LOC_629|>": 100926, "<|LOC_630|>": 100927, "<|LOC_631|>": 100928, "<|LOC_632|>": 100929, "<|LOC_633|>": 100930, "<|LOC_634|>": 100931, "<|LOC_635|>": 100932, "<|LOC_636|>": 100933, "<|LOC_637|>": 100934, "<|LOC_638|>": 100935, "<|LOC_639|>": 100936, "<|LOC_640|>": 100937, "<|LOC_641|>": 100938, "<|LOC_642|>": 100939, "<|LOC_643|>": 100940, "<|LOC_644|>": 100941, "<|LOC_645|>": 100942, "<|LOC_646|>": 100943, "<|LOC_647|>": 100944, "<|LOC_648|>": 100945, "<|LOC_649|>": 100946, "<|LOC_650|>": 100947, "<|LOC_651|>": 100948, "<|LOC_652|>": 100949, "<|LOC_653|>": 100950, "<|LOC_654|>": 100951, "<|LOC_655|>": 100952, "<|LOC_656|>": 100953, "<|LOC_657|>": 100954, "<|LOC_658|>": 100955, "<|LOC_659|>": 100956, "<|LOC_660|>": 100957, "<|LOC_661|>": 100958, "<|LOC_662|>": 100959, "<|LOC_663|>": 100960, "<|LOC_664|>": 100961, "<|LOC_665|>": 100962, "<|LOC_666|>": 100963, "<|LOC_667|>": 100964, "<|LOC_668|>": 100965, "<|LOC_669|>": 100966, "<|LOC_670|>": 100967, "<|LOC_671|>": 100968, "<|LOC_672|>": 100969, "<|LOC_673|>": 100970, "<|LOC_674|>": 100971, "<|LOC_675|>": 100972, "<|LOC_676|>": 100973, "<|LOC_677|>": 100974, "<|LOC_678|>": 100975, "<|LOC_679|>": 100976, "<|LOC_680|>": 100977, "<|LOC_681|>": 100978, "<|LOC_682|>": 100979, "<|LOC_683|>": 100980, "<|LOC_684|>": 100981, "<|LOC_685|>": 100982, "<|LOC_686|>": 100983, "<|LOC_687|>": 100984, "<|LOC_688|>": 100985, "<|LOC_689|>": 100986, "<|LOC_690|>": 100987, "<|LOC_691|>": 100988, "<|LOC_692|>": 100989, "<|LOC_693|>": 100990, "<|LOC_694|>": 100991, "<|LOC_695|>": 100992, "<|LOC_696|>": 100993, "<|LOC_697|>": 100994, "<|LOC_698|>": 100995, "<|LOC_699|>": 100996, "<|LOC_700|>": 100997, "<|LOC_701|>": 100998, "<|LOC_702|>": 100999, "<|LOC_703|>": 101000, "<|LOC_704|>": 101001, "<|LOC_705|>": 101002, "<|LOC_706|>": 101003, "<|LOC_707|>": 101004, "<|LOC_708|>": 101005, "<|LOC_709|>": 101006, "<|LOC_710|>": 101007, "<|LOC_711|>": 101008, "<|LOC_712|>": 101009, "<|LOC_713|>": 101010, "<|LOC_714|>": 101011, "<|LOC_715|>": 101012, "<|LOC_716|>": 101013, "<|LOC_717|>": 101014, "<|LOC_718|>": 101015, "<|LOC_719|>": 101016, "<|LOC_720|>": 101017, "<|LOC_721|>": 101018, "<|LOC_722|>": 101019, "<|LOC_723|>": 101020, "<|LOC_724|>": 101021, "<|LOC_725|>": 101022, "<|LOC_726|>": 101023, "<|LOC_727|>": 101024, "<|LOC_728|>": 101025, "<|LOC_729|>": 101026, "<|LOC_730|>": 101027, "<|LOC_731|>": 101028, "<|LOC_732|>": 101029, "<|LOC_733|>": 101030, "<|LOC_734|>": 101031, "<|LOC_735|>": 101032, "<|LOC_736|>": 101033, "<|LOC_737|>": 101034, "<|LOC_738|>": 101035, "<|LOC_739|>": 101036, "<|LOC_740|>": 101037, "<|LOC_741|>": 101038, "<|LOC_742|>": 101039, "<|LOC_743|>": 101040, "<|LOC_744|>": 101041, "<|LOC_745|>": 101042, "<|LOC_746|>": 101043, "<|LOC_747|>": 101044, "<|LOC_748|>": 101045, "<|LOC_749|>": 101046, "<|LOC_750|>": 101047, "<|LOC_751|>": 101048, "<|LOC_752|>": 101049, "<|LOC_753|>": 101050, "<|LOC_754|>": 101051, "<|LOC_755|>": 101052, "<|LOC_756|>": 101053, "<|LOC_757|>": 101054, "<|LOC_758|>": 101055, "<|LOC_759|>": 101056, "<|LOC_760|>": 101057, "<|LOC_761|>": 101058, "<|LOC_762|>": 101059, "<|LOC_763|>": 101060, "<|LOC_764|>": 101061, "<|LOC_765|>": 101062, "<|LOC_766|>": 101063, "<|LOC_767|>": 101064, "<|LOC_768|>": 101065, "<|LOC_769|>": 101066, "<|LOC_770|>": 101067, "<|LOC_771|>": 101068, "<|LOC_772|>": 101069, "<|LOC_773|>": 101070, "<|LOC_774|>": 101071, "<|LOC_775|>": 101072, "<|LOC_776|>": 101073, "<|LOC_777|>": 101074, "<|LOC_778|>": 101075, "<|LOC_779|>": 101076, "<|LOC_780|>": 101077, "<|LOC_781|>": 101078, "<|LOC_782|>": 101079, "<|LOC_783|>": 101080, "<|LOC_784|>": 101081, "<|LOC_785|>": 101082, "<|LOC_786|>": 101083, "<|LOC_787|>": 101084, "<|LOC_788|>": 101085, "<|LOC_789|>": 101086, "<|LOC_790|>": 101087, "<|LOC_791|>": 101088, "<|LOC_792|>": 101089, "<|LOC_793|>": 101090, "<|LOC_794|>": 101091, "<|LOC_795|>": 101092, "<|LOC_796|>": 101093, "<|LOC_797|>": 101094, "<|LOC_798|>": 101095, "<|LOC_799|>": 101096, "<|LOC_800|>": 101097, "<|LOC_801|>": 101098, "<|LOC_802|>": 101099, "<|LOC_803|>": 101100, "<|LOC_804|>": 101101, "<|LOC_805|>": 101102, "<|LOC_806|>": 101103, "<|LOC_807|>": 101104, "<|LOC_808|>": 101105, "<|LOC_809|>": 101106, "<|LOC_810|>": 101107, "<|LOC_811|>": 101108, "<|LOC_812|>": 101109, "<|LOC_813|>": 101110, "<|LOC_814|>": 101111, "<|LOC_815|>": 101112, "<|LOC_816|>": 101113, "<|LOC_817|>": 101114, "<|LOC_818|>": 101115, "<|LOC_819|>": 101116, "<|LOC_820|>": 101117, "<|LOC_821|>": 101118, "<|LOC_822|>": 101119, "<|LOC_823|>": 101120, "<|LOC_824|>": 101121, "<|LOC_825|>": 101122, "<|LOC_826|>": 101123, "<|LOC_827|>": 101124, "<|LOC_828|>": 101125, "<|LOC_829|>": 101126, "<|LOC_830|>": 101127, "<|LOC_831|>": 101128, "<|LOC_832|>": 101129, "<|LOC_833|>": 101130, "<|LOC_834|>": 101131, "<|LOC_835|>": 101132, "<|LOC_836|>": 101133, "<|LOC_837|>": 101134, "<|LOC_838|>": 101135, "<|LOC_839|>": 101136, "<|LOC_840|>": 101137, "<|LOC_841|>": 101138, "<|LOC_842|>": 101139, "<|LOC_843|>": 101140, "<|LOC_844|>": 101141, "<|LOC_845|>": 101142, "<|LOC_846|>": 101143, "<|LOC_847|>": 101144, "<|LOC_848|>": 101145, "<|LOC_849|>": 101146, "<|LOC_850|>": 101147, "<|LOC_851|>": 101148, "<|LOC_852|>": 101149, "<|LOC_853|>": 101150, "<|LOC_854|>": 101151, "<|LOC_855|>": 101152, "<|LOC_856|>": 101153, "<|LOC_857|>": 101154, "<|LOC_858|>": 101155, "<|LOC_859|>": 101156, "<|LOC_860|>": 101157, "<|LOC_861|>": 101158, "<|LOC_862|>": 101159, "<|LOC_863|>": 101160, "<|LOC_864|>": 101161, "<|LOC_865|>": 101162, "<|LOC_866|>": 101163, "<|LOC_867|>": 101164, "<|LOC_868|>": 101165, "<|LOC_869|>": 101166, "<|LOC_870|>": 101167, "<|LOC_871|>": 101168, "<|LOC_872|>": 101169, "<|LOC_873|>": 101170, "<|LOC_874|>": 101171, "<|LOC_875|>": 101172, "<|LOC_876|>": 101173, "<|LOC_877|>": 101174, "<|LOC_878|>": 101175, "<|LOC_879|>": 101176, "<|LOC_880|>": 101177, "<|LOC_881|>": 101178, "<|LOC_882|>": 101179, "<|LOC_883|>": 101180, "<|LOC_884|>": 101181, "<|LOC_885|>": 101182, "<|LOC_886|>": 101183, "<|LOC_887|>": 101184, "<|LOC_888|>": 101185, "<|LOC_889|>": 101186, "<|LOC_890|>": 101187, "<|LOC_891|>": 101188, "<|LOC_892|>": 101189, "<|LOC_893|>": 101190, "<|LOC_894|>": 101191, "<|LOC_895|>": 101192, "<|LOC_896|>": 101193, "<|LOC_897|>": 101194, "<|LOC_898|>": 101195, "<|LOC_899|>": 101196, "<|LOC_900|>": 101197, "<|LOC_901|>": 101198, "<|LOC_902|>": 101199, "<|LOC_903|>": 101200, "<|LOC_904|>": 101201, "<|LOC_905|>": 101202, "<|LOC_906|>": 101203, "<|LOC_907|>": 101204, "<|LOC_908|>": 101205, "<|LOC_909|>": 101206, "<|LOC_910|>": 101207, "<|LOC_911|>": 101208, "<|LOC_912|>": 101209, "<|LOC_913|>": 101210, "<|LOC_914|>": 101211, "<|LOC_915|>": 101212, "<|LOC_916|>": 101213, "<|LOC_917|>": 101214, "<|LOC_918|>": 101215, "<|LOC_919|>": 101216, "<|LOC_920|>": 101217, "<|LOC_921|>": 101218, "<|LOC_922|>": 101219, "<|LOC_923|>": 101220, "<|LOC_924|>": 101221, "<|LOC_925|>": 101222, "<|LOC_926|>": 101223, "<|LOC_927|>": 101224, "<|LOC_928|>": 101225, "<|LOC_929|>": 101226, "<|LOC_930|>": 101227, "<|LOC_931|>": 101228, "<|LOC_932|>": 101229, "<|LOC_933|>": 101230, "<|LOC_934|>": 101231, "<|LOC_935|>": 101232, "<|LOC_936|>": 101233, "<|LOC_937|>": 101234, "<|LOC_938|>": 101235, "<|LOC_939|>": 101236, "<|LOC_940|>": 101237, "<|LOC_941|>": 101238, "<|LOC_942|>": 101239, "<|LOC_943|>": 101240, "<|LOC_944|>": 101241, "<|LOC_945|>": 101242, "<|LOC_946|>": 101243, "<|LOC_947|>": 101244, "<|LOC_948|>": 101245, "<|LOC_949|>": 101246, "<|LOC_950|>": 101247, "<|LOC_951|>": 101248, "<|LOC_952|>": 101249, "<|LOC_953|>": 101250, "<|LOC_954|>": 101251, "<|LOC_955|>": 101252, "<|LOC_956|>": 101253, "<|LOC_957|>": 101254, "<|LOC_958|>": 101255, "<|LOC_959|>": 101256, "<|LOC_960|>": 101257, "<|LOC_961|>": 101258, "<|LOC_962|>": 101259, "<|LOC_963|>": 101260, "<|LOC_964|>": 101261, "<|LOC_965|>": 101262, "<|LOC_966|>": 101263, "<|LOC_967|>": 101264, "<|LOC_968|>": 101265, "<|LOC_969|>": 101266, "<|LOC_970|>": 101267, "<|LOC_971|>": 101268, "<|LOC_972|>": 101269, "<|LOC_973|>": 101270, "<|LOC_974|>": 101271, "<|LOC_975|>": 101272, "<|LOC_976|>": 101273, "<|LOC_977|>": 101274, "<|LOC_978|>": 101275, "<|LOC_979|>": 101276, "<|LOC_980|>": 101277, "<|LOC_981|>": 101278, "<|LOC_982|>": 101279, "<|LOC_983|>": 101280, "<|LOC_984|>": 101281, "<|LOC_985|>": 101282, "<|LOC_986|>": 101283, "<|LOC_987|>": 101284, "<|LOC_988|>": 101285, "<|LOC_989|>": 101286, "<|LOC_990|>": 101287, "<|LOC_991|>": 101288, "<|LOC_992|>": 101289, "<|LOC_993|>": 101290, "<|LOC_994|>": 101291, "<|LOC_995|>": 101292, "<|LOC_996|>": 101293, "<|LOC_997|>": 101294, "<|LOC_998|>": 101295, "<|LOC_999|>": 101296, "<|LOC_1000|>": 101297, "<|LOC_BEGIN|>": 101298, "<|LOC_END|>": 101299, "<|LOC_SEP|>": 101300, "<|CROP_COL_SEP|>": 101301, "<|CROP_ROW_SEP|>": 101302, "<|IMAGE_SEP|>": 101303}
|
config.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"Ernie4_5_ForCausalLM"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_ernie4_5.Ernie4_5_Config",
|
7 |
+
"AutoModel": "modeling_ernie4_5.Ernie4_5_Model",
|
8 |
+
"AutoModelForCausalLM": "modeling_ernie4_5.Ernie4_5_ForCausalLM"
|
9 |
+
},
|
10 |
+
"bos_token_id": 1,
|
11 |
+
"eos_token_id": 2,
|
12 |
+
"hidden_act": "silu",
|
13 |
+
"hidden_size": 1024,
|
14 |
+
"intermediate_size": 3072,
|
15 |
+
"max_position_embeddings": 131072,
|
16 |
+
"model_type": "ernie4_5",
|
17 |
+
"num_attention_heads": 16,
|
18 |
+
"num_key_value_heads": 2,
|
19 |
+
"head_dim": 128,
|
20 |
+
"num_hidden_layers": 18,
|
21 |
+
"pad_token_id": 0,
|
22 |
+
"rms_norm_eps": 1e-05,
|
23 |
+
"use_cache": false,
|
24 |
+
"vocab_size": 103424,
|
25 |
+
"rope_theta": 500000,
|
26 |
+
"use_bias": false,
|
27 |
+
"tie_word_embeddings": true,
|
28 |
+
"torch_dtype": "bfloat16"
|
29 |
+
}
|
configuration_ernie4_5.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from transformers import PretrainedConfig
|
16 |
+
|
17 |
+
|
18 |
+
class Ernie4_5_Config(PretrainedConfig):
|
19 |
+
"""
|
20 |
+
Configuration class.
|
21 |
+
|
22 |
+
This class stores the configuration of an Ernie model, defining the model architecture.
|
23 |
+
It inherits from PretrainedConfig and can be used to control model outputs.
|
24 |
+
"""
|
25 |
+
|
26 |
+
model_type = "ernie4_5"
|
27 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
28 |
+
|
29 |
+
# Default tensor parallel plan for base model `Qwen3`
|
30 |
+
base_model_tp_plan = {
|
31 |
+
"layers.*.self_attn.q_proj": "colwise",
|
32 |
+
"layers.*.self_attn.k_proj": "colwise",
|
33 |
+
"layers.*.self_attn.v_proj": "colwise",
|
34 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
35 |
+
"layers.*.mlp.gate_proj": "colwise",
|
36 |
+
"layers.*.mlp.up_proj": "colwise",
|
37 |
+
"layers.*.mlp.down_proj": "rowwise",
|
38 |
+
}
|
39 |
+
base_model_pp_plan = {
|
40 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
41 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
42 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
43 |
+
}
|
44 |
+
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
vocab_size=32000,
|
48 |
+
hidden_size=768,
|
49 |
+
intermediate_size=11008,
|
50 |
+
max_position_embeddings=32768,
|
51 |
+
num_hidden_layers=2,
|
52 |
+
num_attention_heads=2,
|
53 |
+
rms_norm_eps=1e-6,
|
54 |
+
use_cache=False,
|
55 |
+
use_flash_attention=False,
|
56 |
+
pad_token_id=0,
|
57 |
+
bos_token_id=1,
|
58 |
+
eos_token_id=2,
|
59 |
+
use_bias=False,
|
60 |
+
rope_theta=10000,
|
61 |
+
weight_share_add_bias=True,
|
62 |
+
ignored_index=-100,
|
63 |
+
attention_probs_dropout_prob=0.0,
|
64 |
+
hidden_dropout_prob=0.0,
|
65 |
+
compression_ratio: float = 1.0,
|
66 |
+
num_key_value_heads=None,
|
67 |
+
max_sequence_length=None,
|
68 |
+
**kwargs,
|
69 |
+
):
|
70 |
+
"""
|
71 |
+
Initialize configuration with default or specified parameters.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
vocab_size (int): Size of the vocabulary (number of unique tokens)
|
75 |
+
hidden_size (int): Dimensionality of the encoder layers and the pooler layer
|
76 |
+
intermediate_size (int): Dimensionality of the "intermediate" (feed-forward) layer
|
77 |
+
max_position_embeddings (int): Maximum sequence length the model can handle
|
78 |
+
num_hidden_layers (int): Number of hidden layers in the Transformer encoder
|
79 |
+
num_attention_heads (int): Number of attention heads for each attention layer
|
80 |
+
rms_norm_eps (float): The epsilon used by the RMS normalization layers
|
81 |
+
use_cache (bool): Whether to use caching for faster generation (decoding)
|
82 |
+
use_flash_attention (bool): Whether to use FlashAttention for optimized attention computation
|
83 |
+
pad_token_id (int): Token ID used for padding sequences
|
84 |
+
bos_token_id (int): Token ID used for beginning-of-sequence
|
85 |
+
eos_token_id (int): Token ID used for end-of-sequence
|
86 |
+
use_bias (bool): Whether to use bias terms in linear layers
|
87 |
+
rope_theta (float): The base period of the RoPE embeddings
|
88 |
+
weight_share_add_bias (bool): Whether to share bias weights in certain layers
|
89 |
+
ignored_index (int): Target value that is ignored during loss computation
|
90 |
+
attention_probs_dropout_prob (float): Dropout probability for attention weights
|
91 |
+
hidden_dropout_prob (float): Dropout probability for hidden layers
|
92 |
+
compression_ratio (float): Ratio for KV cache compression (1.0 = no compression)
|
93 |
+
num_key_value_heads (int): Number of key/value heads (for Grouped Query Attention)
|
94 |
+
max_sequence_length (int): Maximum sequence length for positional embeddings
|
95 |
+
**kwargs: Additional keyword arguments passed to parent class
|
96 |
+
"""
|
97 |
+
|
98 |
+
# Set default for tied embeddings if not specified.
|
99 |
+
if "tie_word_embeddings" not in kwargs:
|
100 |
+
kwargs["tie_word_embeddings"] = False
|
101 |
+
super().__init__(
|
102 |
+
pad_token_id=pad_token_id,
|
103 |
+
bos_token_id=bos_token_id,
|
104 |
+
eos_token_id=eos_token_id,
|
105 |
+
**kwargs,
|
106 |
+
)
|
107 |
+
self.vocab_size = vocab_size
|
108 |
+
self.hidden_size = hidden_size
|
109 |
+
self.intermediate_size = intermediate_size
|
110 |
+
self.max_position_embeddings = max_position_embeddings
|
111 |
+
self.num_hidden_layers = num_hidden_layers
|
112 |
+
self.num_attention_heads = num_attention_heads
|
113 |
+
self.rms_norm_eps = rms_norm_eps
|
114 |
+
self.use_cache = use_cache
|
115 |
+
self.use_flash_attention = use_flash_attention
|
116 |
+
self.pad_token_id = pad_token_id
|
117 |
+
self.bos_token_id = bos_token_id
|
118 |
+
self.eos_token_id = eos_token_id
|
119 |
+
self.use_bias = use_bias
|
120 |
+
self.weight_share_add_bias = weight_share_add_bias
|
121 |
+
self.rope_theta = rope_theta
|
122 |
+
self.ignored_index = ignored_index
|
123 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
124 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
125 |
+
self.compression_ratio = compression_ratio
|
126 |
+
self.num_key_value_heads = num_key_value_heads
|
127 |
+
self.max_sequence_length = max_sequence_length
|
generation_config.json
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_sample": true,
|
3 |
+
"top_p": 0.8,
|
4 |
+
"temperature": 0.8,
|
5 |
+
"bos_token_id": 1,
|
6 |
+
"eos_token_id": 2,
|
7 |
+
"pad_token_id": 0,
|
8 |
+
"repetition_penalty": 1.0,
|
9 |
+
"frequency_penalty": 0.0,
|
10 |
+
"presence_penalty": 0.0
|
11 |
+
}
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2fa450d7550b5600b030f09b9a954ffad3ef31476e44fb87e2154e9ac23b51d3
|
3 |
+
size 721514672
|
modeling_ernie4_5.py
ADDED
@@ -0,0 +1,1068 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import Optional, Tuple, Union
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
21 |
+
|
22 |
+
from transformers.activations import ACT2FN
|
23 |
+
from transformers.modeling_utils import PreTrainedModel
|
24 |
+
from transformers.generation import GenerationMixin
|
25 |
+
from transformers.modeling_outputs import (
|
26 |
+
BaseModelOutputWithPast,
|
27 |
+
CausalLMOutputWithPast,
|
28 |
+
)
|
29 |
+
from transformers.utils import logging
|
30 |
+
|
31 |
+
from .configuration_ernie4_5 import Ernie4_5_Config
|
32 |
+
|
33 |
+
|
34 |
+
logger = logging.get_logger(__name__)
|
35 |
+
|
36 |
+
|
37 |
+
class Ernie4_5_RMSNorm(nn.Module):
|
38 |
+
"""
|
39 |
+
Root Mean Square Layer Normalization (Ernie4_5_RMSNorm) implementation.
|
40 |
+
|
41 |
+
Ernie4_5_RMSNorm is a simplified version of LayerNorm that focuses on the root mean square of inputs,
|
42 |
+
omitting the mean-centering operation. This provides computational efficiency while maintaining
|
43 |
+
good performance.
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, config):
|
47 |
+
"""
|
48 |
+
Initialize Ernie4_5_RMSNorm layer.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
config: Model configuration.
|
52 |
+
"""
|
53 |
+
super().__init__()
|
54 |
+
self.hidden_size = config.hidden_size
|
55 |
+
self.weight = nn.Parameter(
|
56 |
+
torch.ones(self.hidden_size, dtype=torch.get_default_dtype())
|
57 |
+
)
|
58 |
+
self.variance_epsilon = config.rms_norm_eps
|
59 |
+
|
60 |
+
def forward(self, hidden_states):
|
61 |
+
"""
|
62 |
+
Apply RMS normalization to input hidden states.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
hidden_states (Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
Tensor: Normalized output tensor of same shape as input
|
69 |
+
|
70 |
+
Note:
|
71 |
+
- computes Ernie4_5_RMSNorm manually:
|
72 |
+
1. Compute variance of features
|
73 |
+
2. Apply reciprocal square root normalization
|
74 |
+
3. Scale by learned weight parameter
|
75 |
+
- Maintains original dtype for numerical stability during computation
|
76 |
+
"""
|
77 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
78 |
+
hidden_states = torch.rsqrt(variance + self.variance_epsilon) * hidden_states
|
79 |
+
return hidden_states.to(self.weight.dtype) * self.weight
|
80 |
+
|
81 |
+
|
82 |
+
class Ernie4_5_RopeEmbedding(nn.Module):
|
83 |
+
"""
|
84 |
+
Rotary Position Embedding (RoPE) implementation for transformer models.
|
85 |
+
|
86 |
+
RoPE encodes absolute positional information with rotation matrices and
|
87 |
+
naturally incorporates relative position information in self-attention.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
head_dim (int): Dimension size of each attention head
|
91 |
+
compression_ratio (float, optional): Sequence length compression ratio. Defaults to 1.0.
|
92 |
+
base (int, optional): Base value for frequency calculation. Defaults to 10000.
|
93 |
+
|
94 |
+
Attributes:
|
95 |
+
head_dim (int): Dimension size of each attention head
|
96 |
+
compression_ratio (float): Sequence length compression factor
|
97 |
+
base (int): Base value for frequency calculation
|
98 |
+
"""
|
99 |
+
|
100 |
+
def __init__(self, head_dim, compression_ratio=1.0, base=10000):
|
101 |
+
"""
|
102 |
+
Initialize RoPE embedding layer.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
head_dim: Dimension of each attention head
|
106 |
+
compression_ratio: Scaling factor for position indices
|
107 |
+
base: Base value for frequency calculation
|
108 |
+
"""
|
109 |
+
super().__init__()
|
110 |
+
self.head_dim = head_dim
|
111 |
+
self.compression_ratio = compression_ratio
|
112 |
+
self.base = base
|
113 |
+
|
114 |
+
def forward(self, seq_length, position_ids=None):
|
115 |
+
"""
|
116 |
+
Compute rotary position embeddings for given sequence length.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
seq_length (int): Maximum sequence length
|
120 |
+
position_ids (Tensor, optional): Custom position indices. Defaults to None.
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
Tensor: Rotary position embeddings of shape [1, 1, seq_length, head_dim]
|
124 |
+
"""
|
125 |
+
indices = torch.arange(0, self.head_dim, 2, dtype=torch.float32)
|
126 |
+
indices = 1 / self.base ** (indices / self.head_dim)
|
127 |
+
if position_ids is None:
|
128 |
+
position_ids = torch.arange(
|
129 |
+
0, seq_length, 1, dtype=torch.float32
|
130 |
+
).unsqueeze(1)
|
131 |
+
position_ids = position_ids / self.compression_ratio
|
132 |
+
sinusoid_inp = position_ids * indices.unsqueeze(0)
|
133 |
+
else:
|
134 |
+
position_ids = position_ids / self.compression_ratio
|
135 |
+
seq_length = position_ids.shape[-1]
|
136 |
+
sinusoid_inp = position_ids.unsqueeze(-1).to(
|
137 |
+
torch.float32
|
138 |
+
) * indices.unsqueeze(0)
|
139 |
+
pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1)
|
140 |
+
pos_emb = pos_emb.view(-1, 1, seq_length, self.head_dim)
|
141 |
+
pos_emb = pos_emb.detach()
|
142 |
+
return pos_emb
|
143 |
+
|
144 |
+
def apply_rotary(self, rp, q, k):
|
145 |
+
"""
|
146 |
+
Apply rotary position embeddings to queries and keys.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
rp (Tensor): Rotary position embeddings
|
150 |
+
q (Tensor): Query tensor [batch, heads, seq_len, dim]
|
151 |
+
k (Tensor): Key tensor [batch, heads, seq_len, dim]
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
Tuple[Tensor, Tensor]: Rotated queries and keys
|
155 |
+
"""
|
156 |
+
sin, cos = torch.chunk(rp.to(q.device), 2, dim=-1)
|
157 |
+
# sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
|
158 |
+
sin_pos = torch.stack([sin, sin], dim=-1).reshape(rp.shape)
|
159 |
+
# cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
|
160 |
+
cos_pos = torch.stack([cos, cos], dim=-1).reshape(rp.shape)
|
161 |
+
# rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
|
162 |
+
rotate_half_q = torch.stack(
|
163 |
+
[-q[:, :, :, 1::2], q[:, :, :, 0::2]], dim=-1
|
164 |
+
).reshape(q.shape)
|
165 |
+
query = (q.to(torch.float32) * cos_pos) + (
|
166 |
+
rotate_half_q.to(torch.float32) * sin_pos
|
167 |
+
)
|
168 |
+
# rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
|
169 |
+
rotate_half_k = torch.stack(
|
170 |
+
[-k[:, :, :, 1::2], k[:, :, :, 0::2]], dim=-1
|
171 |
+
).reshape(k.shape)
|
172 |
+
key = (k.to(torch.float32) * cos_pos) + (
|
173 |
+
rotate_half_k.to(torch.float32) * sin_pos
|
174 |
+
)
|
175 |
+
return query, key
|
176 |
+
|
177 |
+
|
178 |
+
class Ernie4_5_FusedDropoutImpl(nn.Module):
|
179 |
+
"""
|
180 |
+
Fused dropout implementation with residual connection support.
|
181 |
+
|
182 |
+
This layer combines dropout and residual addition in a single operation for better performance,
|
183 |
+
particularly on GPU devices. The dropout is conditionally applied based on the probability.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
prob (float): Dropout probability (between 0 and 1)
|
187 |
+
|
188 |
+
Attributes:
|
189 |
+
prob (float): Stores the dropout probability
|
190 |
+
dropout (nn.Dropout): The actual dropout layer instance
|
191 |
+
"""
|
192 |
+
|
193 |
+
def __init__(self, prob):
|
194 |
+
"""
|
195 |
+
Initialize the fused dropout layer.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
prob (float): Dropout probability (0 means no dropout)
|
199 |
+
"""
|
200 |
+
super().__init__()
|
201 |
+
self.prob = prob
|
202 |
+
self.dropout = nn.Dropout(p=prob)
|
203 |
+
|
204 |
+
def forward(self, x, y):
|
205 |
+
"""
|
206 |
+
Forward pass of the fused dropout layer.
|
207 |
+
|
208 |
+
Args:
|
209 |
+
x (Tensor): Input tensor to potentially apply dropout
|
210 |
+
y (Tensor): Residual tensor to add to the (possibly dropped out) x
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
Tensor: Result of x (with optional dropout) + y
|
214 |
+
"""
|
215 |
+
if self.prob > 0:
|
216 |
+
x = self.dropout(x)
|
217 |
+
output = x + y
|
218 |
+
|
219 |
+
return output
|
220 |
+
|
221 |
+
|
222 |
+
class Ernie4_5_MLP(nn.Module):
|
223 |
+
"""
|
224 |
+
Ernie4_5_MLP - Gated Multi-Layer Perceptron module used in Ernie model.
|
225 |
+
"""
|
226 |
+
|
227 |
+
def __init__(self, config, layer_idx=0):
|
228 |
+
"""
|
229 |
+
Initialize the MLP module with configuration options.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
config: Model configurations.
|
233 |
+
layer_idx (int): Index of current layer (default: 0)
|
234 |
+
"""
|
235 |
+
super().__init__()
|
236 |
+
self.config = config
|
237 |
+
self.layer_idx = layer_idx
|
238 |
+
self.hidden_size = config.hidden_size
|
239 |
+
self.intermediate_size = config.intermediate_size
|
240 |
+
|
241 |
+
self.gate_proj = nn.Linear(
|
242 |
+
self.hidden_size, self.intermediate_size, bias=config.use_bias
|
243 |
+
)
|
244 |
+
self.up_proj = nn.Linear(
|
245 |
+
self.hidden_size, self.intermediate_size, bias=config.use_bias
|
246 |
+
)
|
247 |
+
self.down_proj = nn.Linear(
|
248 |
+
self.intermediate_size, self.hidden_size, bias=config.use_bias
|
249 |
+
)
|
250 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
251 |
+
|
252 |
+
def forward(self, x):
|
253 |
+
"""
|
254 |
+
Args:
|
255 |
+
x (Tensor): shape [batch_size, seq_len, hidden_size]
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
Tensor: shape [batch_size, seq_len, hidden_size]
|
259 |
+
"""
|
260 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
261 |
+
return down_proj
|
262 |
+
|
263 |
+
|
264 |
+
class Ernie4_5_Attention(nn.Module):
|
265 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
266 |
+
|
267 |
+
def __init__(self, config, layer_idx=0):
|
268 |
+
"""Initialize the attention layer.
|
269 |
+
|
270 |
+
Args:
|
271 |
+
config: Model configuration.
|
272 |
+
layer_idx (int, optional): Index in transformer stack. Defaults to 0.
|
273 |
+
"""
|
274 |
+
super().__init__()
|
275 |
+
self.layer_idx = layer_idx
|
276 |
+
self.hidden_size = config.hidden_size
|
277 |
+
self.num_heads = config.num_attention_heads
|
278 |
+
self.num_key_value_heads = config.num_key_value_heads
|
279 |
+
|
280 |
+
if config.head_dim is None:
|
281 |
+
self.head_dim = self.hidden_size // self.num_heads
|
282 |
+
else:
|
283 |
+
self.head_dim = config.head_dim
|
284 |
+
|
285 |
+
self.is_gqa = (
|
286 |
+
self.num_key_value_heads is not None
|
287 |
+
and self.num_key_value_heads != self.num_heads
|
288 |
+
)
|
289 |
+
|
290 |
+
if self.is_gqa:
|
291 |
+
logger.info(
|
292 |
+
f"use GQA - num_heads: {self.num_heads}- num_key_value_heads: {self.num_key_value_heads}"
|
293 |
+
)
|
294 |
+
assert (
|
295 |
+
self.num_heads % self.num_key_value_heads == 0
|
296 |
+
), f"num_heads: {self.num_heads}, num_key_value_heads: {self.num_key_value_heads}"
|
297 |
+
kv_hidden_size = self.head_dim * self.num_key_value_heads
|
298 |
+
q_hidden_size = self.head_dim * self.num_heads
|
299 |
+
else:
|
300 |
+
q_hidden_size = kv_hidden_size = self.head_dim * self.num_heads
|
301 |
+
|
302 |
+
self.q_proj = nn.Linear(self.hidden_size, q_hidden_size, bias=config.use_bias)
|
303 |
+
self.k_proj = nn.Linear(self.hidden_size, kv_hidden_size, bias=config.use_bias)
|
304 |
+
self.v_proj = nn.Linear(self.hidden_size, kv_hidden_size, bias=config.use_bias)
|
305 |
+
self.o_proj = nn.Linear(q_hidden_size, self.hidden_size, bias=config.use_bias)
|
306 |
+
|
307 |
+
self.rotary_emb = Ernie4_5_RopeEmbedding(
|
308 |
+
self.head_dim,
|
309 |
+
compression_ratio=config.compression_ratio,
|
310 |
+
base=config.rope_theta,
|
311 |
+
)
|
312 |
+
self.config = config
|
313 |
+
|
314 |
+
self.set_attn_func()
|
315 |
+
|
316 |
+
def set_attn_func(self):
|
317 |
+
"""Configure attention function based on settings.
|
318 |
+
|
319 |
+
Selects between flash/core attention.
|
320 |
+
"""
|
321 |
+
config = self.config
|
322 |
+
if config.use_flash_attention:
|
323 |
+
self.attn_func = self._flash_attention_wrapper
|
324 |
+
else:
|
325 |
+
self.attn_func = self.core_attn
|
326 |
+
|
327 |
+
def forward(
|
328 |
+
self,
|
329 |
+
hidden_states,
|
330 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
331 |
+
attention_mask: Optional[torch.Tensor] = None,
|
332 |
+
attn_mask_start_row_indices: Optional[torch.Tensor] = None,
|
333 |
+
position_ids: Optional[Tuple[torch.Tensor]] = None,
|
334 |
+
output_attentions: bool = False,
|
335 |
+
use_cache: bool = False,
|
336 |
+
token_type_ids: Optional[Tuple[torch.Tensor]] = None,
|
337 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
338 |
+
"""Compute attention outputs.
|
339 |
+
|
340 |
+
Args:
|
341 |
+
hidden_states (torch.Tensor): Input tensor [bsz, seq_len, hidden_size]
|
342 |
+
past_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): Cached key/value states
|
343 |
+
attention_mask (Optional[torch.Tensor]): Attention mask tensor
|
344 |
+
attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length attention indices
|
345 |
+
position_ids (Optional[torch.Tensor]): Position indices for RoPE
|
346 |
+
output_attentions (bool): Return attention weights if True
|
347 |
+
use_cache (bool): Cache key/value states if True
|
348 |
+
|
349 |
+
Returns:
|
350 |
+
Tuple containing:
|
351 |
+
- attention_output: [bsz, seq_len, hidden_size]
|
352 |
+
- attention_weights: Optional attention probabilities
|
353 |
+
- updated_key_value_cache: Optional updated cache
|
354 |
+
"""
|
355 |
+
if token_type_ids is not None:
|
356 |
+
token_type_ids = token_type_ids[:, :-1]
|
357 |
+
|
358 |
+
bsz, q_len, _ = hidden_states.shape
|
359 |
+
|
360 |
+
query_states = self.q_proj(hidden_states).reshape(
|
361 |
+
[bsz, q_len, -1, self.head_dim]
|
362 |
+
)
|
363 |
+
key_states = self.k_proj(hidden_states).reshape([bsz, q_len, -1, self.head_dim])
|
364 |
+
value_states = self.v_proj(hidden_states).reshape(
|
365 |
+
[bsz, q_len, -1, self.head_dim]
|
366 |
+
)
|
367 |
+
|
368 |
+
attn_output, attn_weights, past_key_value = self.rope_attn(
|
369 |
+
query_states=query_states,
|
370 |
+
key_states=key_states,
|
371 |
+
value_states=value_states,
|
372 |
+
attention_mask=attention_mask,
|
373 |
+
position_ids=position_ids,
|
374 |
+
output_attentions=output_attentions,
|
375 |
+
past_key_value=past_key_value,
|
376 |
+
use_cache=use_cache,
|
377 |
+
attn_mask_start_row_indices=attn_mask_start_row_indices,
|
378 |
+
)
|
379 |
+
|
380 |
+
attn_output = self.o_proj(attn_output)
|
381 |
+
|
382 |
+
if not output_attentions:
|
383 |
+
attn_weights = None
|
384 |
+
|
385 |
+
return attn_output, attn_weights, past_key_value
|
386 |
+
|
387 |
+
def repeat_kv(self, hidden_states, n_rep):
|
388 |
+
"""
|
389 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
390 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
391 |
+
"""
|
392 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
393 |
+
if n_rep == 1:
|
394 |
+
return hidden_states
|
395 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
396 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
397 |
+
)
|
398 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
399 |
+
|
400 |
+
def _flash_attention_wrapper(
|
401 |
+
self,
|
402 |
+
q,
|
403 |
+
k,
|
404 |
+
v,
|
405 |
+
attention_mask=None,
|
406 |
+
attn_mask_start_row_indices=None,
|
407 |
+
seq_length=None,
|
408 |
+
):
|
409 |
+
"""Wrapper for flash attention implementation.
|
410 |
+
|
411 |
+
Args:
|
412 |
+
q (torch.Tensor): Query tensor
|
413 |
+
k (torch.Tensor): Key tensor
|
414 |
+
v (torch.Tensor): Value tensor
|
415 |
+
attention_mask (Optional[torch.Tensor]): Attention mask
|
416 |
+
attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length indices
|
417 |
+
seq_length (Optional[int]): Sequence length
|
418 |
+
|
419 |
+
Returns:
|
420 |
+
Tuple[torch.Tensor, torch.Tensor]: Attention output and weights
|
421 |
+
"""
|
422 |
+
q = q.transpose(1, 2)
|
423 |
+
k = k.transpose(1, 2)
|
424 |
+
v = v.transpose(1, 2)
|
425 |
+
|
426 |
+
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
427 |
+
out = F.scaled_dot_product_attention(
|
428 |
+
q,
|
429 |
+
k,
|
430 |
+
v,
|
431 |
+
attn_mask=attention_mask,
|
432 |
+
dropout_p=self.config.attention_probs_dropout_prob,
|
433 |
+
is_causal=attention_mask is None and q.shape[1] != 1,
|
434 |
+
scale=1
|
435 |
+
/ (getattr(self.config, "scale_qk_coeff", 1.0) * self.head_dim**0.5),
|
436 |
+
enable_gqa=self.is_gqa,
|
437 |
+
)
|
438 |
+
out = out.transpose(1, 2)
|
439 |
+
out = out.contiguous().view(out.size(0), out.size(1), -1)
|
440 |
+
|
441 |
+
return out, None
|
442 |
+
|
443 |
+
def core_attn(
|
444 |
+
self,
|
445 |
+
q,
|
446 |
+
k,
|
447 |
+
v,
|
448 |
+
attention_mask=None,
|
449 |
+
attn_mask_start_row_indices=None,
|
450 |
+
seq_length=None,
|
451 |
+
):
|
452 |
+
"""Standard self-attention implementation.
|
453 |
+
|
454 |
+
Args:
|
455 |
+
q (torch.Tensor): Query tensor
|
456 |
+
k (torch.Tensor): Key tensor
|
457 |
+
v (torch.Tensor): Value tensor
|
458 |
+
attention_mask (Optional[torch.Tensor]): Attention mask
|
459 |
+
attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length indices
|
460 |
+
seq_length (Optional[int]): Sequence length
|
461 |
+
|
462 |
+
Returns:
|
463 |
+
Tuple[torch.Tensor, torch.Tensor]: Attention output and weights
|
464 |
+
"""
|
465 |
+
origin_dtype = q.dtype
|
466 |
+
|
467 |
+
q = q.permute(0, 2, 1, 3)
|
468 |
+
k = k.permute(0, 2, 1, 3)
|
469 |
+
v = v.permute(0, 2, 1, 3)
|
470 |
+
|
471 |
+
scale_qk_coeff = (
|
472 |
+
getattr(self.config, "scale_qk_coeff", 1.0) * self.head_dim**0.5
|
473 |
+
)
|
474 |
+
|
475 |
+
q = q / scale_qk_coeff
|
476 |
+
|
477 |
+
# Handle GQA case - repeat k and v heads to match q heads
|
478 |
+
if self.is_gqa:
|
479 |
+
# [batch, num_key_value_heads, seq_len, head_dim] -> [batch, num_heads, seq_len, head_dim]
|
480 |
+
repeat_factor = self.num_heads // self.num_key_value_heads
|
481 |
+
k = self.repeat_kv(k, repeat_factor)
|
482 |
+
v = self.repeat_kv(v, repeat_factor)
|
483 |
+
|
484 |
+
attn_scores = torch.matmul(q, k.transpose(-2, -1))
|
485 |
+
|
486 |
+
if getattr(self.config, "scale_qk_coeff", 1.0) != 1.0:
|
487 |
+
attn_scores = attn_scores * getattr(self.config, "scale_qk_coeff", 1.0)
|
488 |
+
|
489 |
+
# Causal mask
|
490 |
+
seq_len = attn_scores.size(-1)
|
491 |
+
mask = torch.triu(
|
492 |
+
torch.ones((seq_len, seq_len), dtype=torch.bool, device=attn_scores.device),
|
493 |
+
diagonal=1,
|
494 |
+
)
|
495 |
+
attn_scores = attn_scores.masked_fill(mask, float("-inf"))
|
496 |
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
497 |
+
|
498 |
+
attn_weights = attn_weights.to(origin_dtype)
|
499 |
+
|
500 |
+
# attention_probs_dropout_prob default 0.0
|
501 |
+
if getattr(self.config, "attention_probs_dropout_prob", 0.0) > 0:
|
502 |
+
attn_weights = F.dropout(
|
503 |
+
attn_weights,
|
504 |
+
p=self.config.attention_probs_dropout_prob,
|
505 |
+
training=self.training,
|
506 |
+
)
|
507 |
+
|
508 |
+
# [batch, num_heads, q_len, k_len] @ [batch, num_heads, k_len, head_dim] -> [batch, num_heads, q_len, head_dim]
|
509 |
+
out = torch.matmul(attn_weights, v)
|
510 |
+
|
511 |
+
# [batch, num_heads, seq_len, head_dim] -> [batch, seq_len, num_heads, head_dim]
|
512 |
+
out = out.permute(0, 2, 1, 3)
|
513 |
+
# [batch, seq_len, hidden_size]
|
514 |
+
out = out.contiguous().view(out.size(0), out.size(1), -1)
|
515 |
+
|
516 |
+
return out, attn_weights
|
517 |
+
|
518 |
+
def rope_attn(
|
519 |
+
self,
|
520 |
+
query_states,
|
521 |
+
key_states,
|
522 |
+
value_states,
|
523 |
+
attention_mask,
|
524 |
+
position_ids,
|
525 |
+
output_attentions=False,
|
526 |
+
past_key_value=None,
|
527 |
+
use_cache=False,
|
528 |
+
attn_mask_start_row_indices=None,
|
529 |
+
):
|
530 |
+
"""Attention computation with rotary embeddings.
|
531 |
+
|
532 |
+
Args:
|
533 |
+
query_states (torch.Tensor): Query states
|
534 |
+
key_states (torch.Tensor): Key states
|
535 |
+
value_states (torch.Tensor): Value states
|
536 |
+
attention_mask (Optional[torch.Tensor]): Attention mask
|
537 |
+
position_ids (Optional[torch.Tensor]): Position indices
|
538 |
+
output_attentions (bool): Return attention weights
|
539 |
+
past_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): Cached states
|
540 |
+
use_cache (bool): Cache new states
|
541 |
+
attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length indices
|
542 |
+
|
543 |
+
Returns:
|
544 |
+
Tuple containing:
|
545 |
+
- attention_output: Result tensor
|
546 |
+
- attention_weights: Optional weights
|
547 |
+
- updated_key_value_cache: Optional cache
|
548 |
+
"""
|
549 |
+
|
550 |
+
query_states_dtype = query_states.dtype
|
551 |
+
|
552 |
+
kv_seq_len = key_states.shape[-3]
|
553 |
+
offset = 0
|
554 |
+
if past_key_value is not None:
|
555 |
+
offset = past_key_value[0].shape[-3]
|
556 |
+
kv_seq_len += offset
|
557 |
+
|
558 |
+
cos_sin = self.rotary_emb(kv_seq_len).permute(
|
559 |
+
[0, 2, 1, 3]
|
560 |
+
) # [b,h,s,d]->[b,s,h,d]
|
561 |
+
if offset > 0:
|
562 |
+
cos_sin = cos_sin[:, offset:]
|
563 |
+
query_states, key_states = self.rotary_emb.apply_rotary(
|
564 |
+
cos_sin, query_states, key_states
|
565 |
+
)
|
566 |
+
|
567 |
+
query_states = query_states.to(query_states_dtype)
|
568 |
+
key_states = key_states.to(query_states_dtype)
|
569 |
+
if past_key_value is not None:
|
570 |
+
# reuse k, v, self_attention
|
571 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=1)
|
572 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=1)
|
573 |
+
|
574 |
+
# shape: [2, b, s, kvh, d]
|
575 |
+
past_key_value = [key_states, value_states] if use_cache else None
|
576 |
+
seq_length = query_states.shape[1]
|
577 |
+
attn_output, attn_weights = self.attn_func(
|
578 |
+
query_states,
|
579 |
+
key_states,
|
580 |
+
value_states,
|
581 |
+
attention_mask,
|
582 |
+
attn_mask_start_row_indices,
|
583 |
+
seq_length,
|
584 |
+
)
|
585 |
+
return attn_output, attn_weights, past_key_value
|
586 |
+
|
587 |
+
|
588 |
+
class Ernie4_5_DecoderLayer(nn.Module):
|
589 |
+
"""
|
590 |
+
A single transformer decoder layer in ERNIE model.
|
591 |
+
"""
|
592 |
+
|
593 |
+
def __init__(self, config, layer_idx):
|
594 |
+
"""Initialize the decoder layer.
|
595 |
+
|
596 |
+
Args:
|
597 |
+
config: Model configuration.
|
598 |
+
layer_idx (int): Index of this layer in the transformer stack
|
599 |
+
"""
|
600 |
+
super().__init__()
|
601 |
+
self.hidden_size = config.hidden_size
|
602 |
+
self.layer_idx = layer_idx
|
603 |
+
self.config = config
|
604 |
+
|
605 |
+
self.self_attn = Ernie4_5_Attention(config, layer_idx)
|
606 |
+
self.mlp = Ernie4_5_MLP(config)
|
607 |
+
|
608 |
+
self.input_layernorm = Ernie4_5_RMSNorm(config)
|
609 |
+
self.post_attention_layernorm = Ernie4_5_RMSNorm(config)
|
610 |
+
|
611 |
+
self.residual_add1 = Ernie4_5_FusedDropoutImpl(config.hidden_dropout_prob)
|
612 |
+
self.residual_add2 = Ernie4_5_FusedDropoutImpl(config.hidden_dropout_prob)
|
613 |
+
|
614 |
+
def forward(
|
615 |
+
self,
|
616 |
+
hidden_states: torch.Tensor,
|
617 |
+
attention_mask: Optional[torch.Tensor] = None,
|
618 |
+
attn_mask_start_row_indices: Optional[torch.Tensor] = None,
|
619 |
+
position_ids: Optional[torch.Tensor] = None,
|
620 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
621 |
+
output_attentions: Optional[bool] = False,
|
622 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
623 |
+
use_cache: Optional[bool] = False,
|
624 |
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
625 |
+
"""Forward pass through the decoder layer.
|
626 |
+
|
627 |
+
Args:
|
628 |
+
hidden_states (torch.Tensor): Input tensor [batch_size, seq_len, hidden_size]
|
629 |
+
attention_mask (Optional[torch.Tensor]): Attention mask tensor
|
630 |
+
attn_mask_start_row_indices (Optional[torch.Tensor]): Indices for variable length attention
|
631 |
+
position_ids (Optional[torch.Tensor]): Position indices for rotary embeddings
|
632 |
+
output_attentions (Optional[bool]): Whether to return attention weights
|
633 |
+
past_key_value (Optional[Tuple[torch.Tensor]]): Cached key/value states
|
634 |
+
use_cache (Optional[bool]): Whether to cache key/value states
|
635 |
+
|
636 |
+
Returns:
|
637 |
+
Union: Various output combinations depending on arguments:
|
638 |
+
- Base case: Hidden states tensor
|
639 |
+
- With attention: Tuple of (hidden_states, attention_weights)
|
640 |
+
- With cache: Tuple of (hidden_states, cached_key_value)
|
641 |
+
"""
|
642 |
+
residual = hidden_states
|
643 |
+
|
644 |
+
hidden_states = self.input_layernorm(hidden_states)
|
645 |
+
|
646 |
+
# Self Attention
|
647 |
+
(hidden_states, self_attn_weights, present_key_value) = self.self_attn(
|
648 |
+
hidden_states=hidden_states,
|
649 |
+
past_key_value=past_key_value,
|
650 |
+
attention_mask=attention_mask,
|
651 |
+
attn_mask_start_row_indices=attn_mask_start_row_indices,
|
652 |
+
position_ids=position_ids,
|
653 |
+
output_attentions=output_attentions,
|
654 |
+
use_cache=use_cache,
|
655 |
+
token_type_ids=token_type_ids,
|
656 |
+
)
|
657 |
+
hidden_states = self.residual_add1(hidden_states, residual)
|
658 |
+
|
659 |
+
# Fully Connected
|
660 |
+
residual = hidden_states
|
661 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
662 |
+
hidden_states = self.mlp(hidden_states)
|
663 |
+
|
664 |
+
hidden_states = self.residual_add2(hidden_states, residual)
|
665 |
+
outputs = (hidden_states,)
|
666 |
+
|
667 |
+
if output_attentions:
|
668 |
+
outputs += (self_attn_weights,)
|
669 |
+
|
670 |
+
if use_cache:
|
671 |
+
outputs += (present_key_value,)
|
672 |
+
|
673 |
+
if type(outputs) is tuple and len(outputs) == 1:
|
674 |
+
outputs = outputs[0]
|
675 |
+
|
676 |
+
return outputs
|
677 |
+
|
678 |
+
|
679 |
+
class Ernie4_5_PretrainedModel(PreTrainedModel):
|
680 |
+
"""Base class for ERNIE pretrained models."""
|
681 |
+
|
682 |
+
config_class = Ernie4_5_Config
|
683 |
+
base_model_prefix = "ernie"
|
684 |
+
|
685 |
+
|
686 |
+
class Ernie4_5_Model(Ernie4_5_PretrainedModel):
|
687 |
+
|
688 |
+
def __init__(self, config):
|
689 |
+
"""Initialize the ERNIE model architecture.
|
690 |
+
|
691 |
+
Args:
|
692 |
+
config: Model configuration.
|
693 |
+
"""
|
694 |
+
super().__init__(config)
|
695 |
+
self.padding_idx = config.pad_token_id
|
696 |
+
self.vocab_size = config.vocab_size
|
697 |
+
self.hidden_size = config.hidden_size
|
698 |
+
self.config = config
|
699 |
+
|
700 |
+
self.embed_tokens = nn.Embedding(
|
701 |
+
self.vocab_size,
|
702 |
+
self.hidden_size,
|
703 |
+
)
|
704 |
+
|
705 |
+
self.layers = nn.ModuleList(
|
706 |
+
[Ernie4_5_DecoderLayer(config, i) for i in range(config.num_hidden_layers)]
|
707 |
+
)
|
708 |
+
|
709 |
+
self.norm = Ernie4_5_RMSNorm(config)
|
710 |
+
|
711 |
+
self.gradient_checkpointing = False
|
712 |
+
|
713 |
+
def get_input_embeddings(self):
|
714 |
+
"""Get the input embedding layer.
|
715 |
+
|
716 |
+
Returns:
|
717 |
+
nn.Embedding: The embedding layer for input tokens
|
718 |
+
"""
|
719 |
+
return self.embed_tokens
|
720 |
+
|
721 |
+
def set_input_embeddings(self, value):
|
722 |
+
"""Set new input embeddings.
|
723 |
+
|
724 |
+
Args:
|
725 |
+
value (nn.Embedding): New embedding layer to use
|
726 |
+
"""
|
727 |
+
self.embed_tokens = value
|
728 |
+
|
729 |
+
def forward(
|
730 |
+
self,
|
731 |
+
input_ids=None,
|
732 |
+
position_ids=None,
|
733 |
+
token_type_ids=None,
|
734 |
+
attention_mask=None,
|
735 |
+
attn_mask_start_row_indices=None,
|
736 |
+
inputs_embeds=None,
|
737 |
+
use_cache=None,
|
738 |
+
past_key_values=None,
|
739 |
+
output_attentions=False,
|
740 |
+
output_hidden_states=None,
|
741 |
+
return_dict=False,
|
742 |
+
):
|
743 |
+
"""Forward pass through the ERNIE model.
|
744 |
+
|
745 |
+
Args:
|
746 |
+
input_ids (Optional[torch.Tensor]): Input token IDs
|
747 |
+
position_ids (Optional[torch.Tensor]): Position indices
|
748 |
+
attention_mask (Optional[torch.Tensor]): Attention mask
|
749 |
+
attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length attention indices
|
750 |
+
inputs_embeds (Optional[torch.Tensor]): Precomputed embeddings
|
751 |
+
use_cache (Optional[bool]): Whether to cache key/value states
|
752 |
+
past_key_values (Optional[Tuple[Tuple[torch.Tensor]]]): Cached key/value states
|
753 |
+
output_attentions (Optional[bool]): Whether to output attention weights
|
754 |
+
output_hidden_states (Optional[bool]): Whether to output all hidden states
|
755 |
+
return_dict (Optional[bool]): Whether to return dict or tuple
|
756 |
+
|
757 |
+
Returns:
|
758 |
+
Union[Tuple, BaseModelOutputWithPast]:
|
759 |
+
Various outputs depending on configuration, including:
|
760 |
+
- last_hidden_state: Final layer hidden states
|
761 |
+
- past_key_values: Cached key/value states if use_cache=True
|
762 |
+
- hidden_states: All hidden states if output_hidden_states=True
|
763 |
+
- attentions: Attention weights if output_attentions=True
|
764 |
+
"""
|
765 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
766 |
+
|
767 |
+
# retrieve input_ids and inputs_embeds
|
768 |
+
if input_ids is not None and inputs_embeds is not None:
|
769 |
+
raise ValueError(
|
770 |
+
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
771 |
+
)
|
772 |
+
elif input_ids is not None:
|
773 |
+
_, seq_length = input_ids.shape
|
774 |
+
elif inputs_embeds is not None:
|
775 |
+
_, seq_length, _ = inputs_embeds.shape
|
776 |
+
else:
|
777 |
+
raise ValueError(
|
778 |
+
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
779 |
+
)
|
780 |
+
|
781 |
+
if past_key_values is None:
|
782 |
+
past_key_values = tuple([None] * len(self.layers))
|
783 |
+
|
784 |
+
if inputs_embeds is None:
|
785 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
786 |
+
inputs_embeds = inputs_embeds.to(self.embed_tokens.weight.dtype)
|
787 |
+
|
788 |
+
hidden_states = inputs_embeds
|
789 |
+
|
790 |
+
# decoder layers
|
791 |
+
all_hidden_states = () if output_hidden_states else None
|
792 |
+
all_self_attns = () if output_attentions else None
|
793 |
+
next_decoder_cache = () if use_cache else None
|
794 |
+
|
795 |
+
for idx, (decoder_layer) in enumerate(self.layers):
|
796 |
+
|
797 |
+
if output_hidden_states:
|
798 |
+
all_hidden_states += (hidden_states,)
|
799 |
+
|
800 |
+
past_key_value = (
|
801 |
+
past_key_values[idx] if past_key_values is not None else None
|
802 |
+
)
|
803 |
+
|
804 |
+
layer_outputs = decoder_layer(
|
805 |
+
hidden_states,
|
806 |
+
attention_mask,
|
807 |
+
attn_mask_start_row_indices,
|
808 |
+
position_ids,
|
809 |
+
token_type_ids,
|
810 |
+
output_attentions,
|
811 |
+
past_key_value,
|
812 |
+
use_cache,
|
813 |
+
)
|
814 |
+
|
815 |
+
if isinstance(layer_outputs, (tuple, list)):
|
816 |
+
hidden_states = layer_outputs[0]
|
817 |
+
else:
|
818 |
+
hidden_states = layer_outputs
|
819 |
+
|
820 |
+
if use_cache:
|
821 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
822 |
+
|
823 |
+
if output_attentions:
|
824 |
+
all_self_attns += (layer_outputs[1],)
|
825 |
+
|
826 |
+
# apply kv cache
|
827 |
+
if past_key_value is not None:
|
828 |
+
hidden_states = hidden_states[:, -1:, :]
|
829 |
+
|
830 |
+
hidden_states = self.norm(hidden_states)
|
831 |
+
|
832 |
+
# add hidden states from the last decoder layer
|
833 |
+
if output_hidden_states:
|
834 |
+
all_hidden_states += (hidden_states,)
|
835 |
+
|
836 |
+
next_cache = next_decoder_cache if use_cache else None
|
837 |
+
|
838 |
+
if not return_dict:
|
839 |
+
return tuple(
|
840 |
+
v
|
841 |
+
for v in [
|
842 |
+
hidden_states,
|
843 |
+
next_cache,
|
844 |
+
all_hidden_states,
|
845 |
+
all_self_attns,
|
846 |
+
]
|
847 |
+
if v is not None
|
848 |
+
)
|
849 |
+
|
850 |
+
return BaseModelOutputWithPast(
|
851 |
+
last_hidden_state=hidden_states,
|
852 |
+
past_key_values=next_cache,
|
853 |
+
hidden_states=all_hidden_states,
|
854 |
+
attentions=all_self_attns,
|
855 |
+
)
|
856 |
+
|
857 |
+
|
858 |
+
class Ernie4_5_LMHead(nn.Module):
|
859 |
+
"""Language model head for ERNIE"""
|
860 |
+
|
861 |
+
def __init__(self, config):
|
862 |
+
"""Initialize the language model head.
|
863 |
+
|
864 |
+
Args:
|
865 |
+
config: Model configuration containing:
|
866 |
+
- vocab_size: Size of vocabulary
|
867 |
+
- hidden_size: Dimension of hidden states
|
868 |
+
- tie_word_embeddings: Whether to tie input/output embeddings
|
869 |
+
- weight_share_add_bias: Whether to add bias when weight sharing
|
870 |
+
- use_bias: Whether to use bias term
|
871 |
+
"""
|
872 |
+
|
873 |
+
super(Ernie4_5_LMHead, self).__init__()
|
874 |
+
self.config = config
|
875 |
+
vocab_size = config.vocab_size
|
876 |
+
|
877 |
+
if config.tie_word_embeddings:
|
878 |
+
# Weight of shape [vocab_size, hidden_size]
|
879 |
+
self.weight = nn.Parameter(
|
880 |
+
torch.empty(
|
881 |
+
vocab_size, config.hidden_size, dtype=torch.get_default_dtype()
|
882 |
+
)
|
883 |
+
)
|
884 |
+
else:
|
885 |
+
# Weight of shape [hidden_size, vocab_size]
|
886 |
+
self.weight = nn.Parameter(
|
887 |
+
torch.empty(
|
888 |
+
config.hidden_size, vocab_size, dtype=torch.get_default_dtype()
|
889 |
+
)
|
890 |
+
)
|
891 |
+
nn.init.xavier_uniform_(self.weight)
|
892 |
+
|
893 |
+
logger.info(
|
894 |
+
f"output-weight: {self.weight.shape}, tie_word_embeddings: {config.tie_word_embeddings}"
|
895 |
+
)
|
896 |
+
|
897 |
+
if config.weight_share_add_bias and config.use_bias:
|
898 |
+
self.bias = nn.Parameter(
|
899 |
+
torch.zeros(vocab_size, dtype=torch.get_default_dtype())
|
900 |
+
)
|
901 |
+
else:
|
902 |
+
self.bias = None
|
903 |
+
|
904 |
+
def forward(self, hidden_states):
|
905 |
+
"""Project hidden states to vocabulary logits.
|
906 |
+
|
907 |
+
Args:
|
908 |
+
hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
|
909 |
+
|
910 |
+
Returns:
|
911 |
+
Logits tensor of shape [batch_size, seq_len, vocab_size]
|
912 |
+
"""
|
913 |
+
return self.calc_lm_head_logits(
|
914 |
+
self.config, hidden_states, self.weight, self.bias
|
915 |
+
)
|
916 |
+
|
917 |
+
def calc_lm_head_logits(self, config, hidden_states, weight, bias):
|
918 |
+
"""
|
919 |
+
Calculate language model head logits.
|
920 |
+
|
921 |
+
This is the core function that computes the final output logits for a language model.
|
922 |
+
|
923 |
+
Args:
|
924 |
+
config: Model configuration.
|
925 |
+
hidden_states (Tensor): Hidden states from the transformer layers
|
926 |
+
weight (Tensor): Weight matrix for the language model head
|
927 |
+
bias (Tensor): Bias vector for the language model head
|
928 |
+
|
929 |
+
Returns:
|
930 |
+
Tensor: The computed logits for language modeling.
|
931 |
+
"""
|
932 |
+
|
933 |
+
if config.tie_word_embeddings:
|
934 |
+
logits = torch.matmul(hidden_states, weight.T)
|
935 |
+
else:
|
936 |
+
logits = torch.matmul(hidden_states, weight)
|
937 |
+
|
938 |
+
if bias is not None:
|
939 |
+
logits = logits + bias
|
940 |
+
|
941 |
+
return logits
|
942 |
+
|
943 |
+
|
944 |
+
class Ernie4_5_ForCausalLM(Ernie4_5_PretrainedModel, GenerationMixin):
|
945 |
+
"""ERNIE model for causal language modeling."""
|
946 |
+
|
947 |
+
_tied_weights_keys = ["lm_head.weight"]
|
948 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
949 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
950 |
+
|
951 |
+
def __init__(self, config):
|
952 |
+
"""
|
953 |
+
Initializes the ERNIE model for causal language modeling.
|
954 |
+
|
955 |
+
Args:
|
956 |
+
config: Model configuration.
|
957 |
+
"""
|
958 |
+
super().__init__(config)
|
959 |
+
|
960 |
+
self.config = config
|
961 |
+
self.model = Ernie4_5_Model(config)
|
962 |
+
self.lm_head = Ernie4_5_LMHead(config)
|
963 |
+
|
964 |
+
# Initialize weights and apply final processing
|
965 |
+
self.post_init()
|
966 |
+
|
967 |
+
@torch.no_grad()
|
968 |
+
def set_state_dict(self, state_dict, *args, **kwargs):
|
969 |
+
"""
|
970 |
+
Loads the model state dictionary.
|
971 |
+
"""
|
972 |
+
ret = super().set_state_dict(state_dict)
|
973 |
+
return ret
|
974 |
+
|
975 |
+
def get_input_embeddings(self):
|
976 |
+
"""Returns the input embeddings layer."""
|
977 |
+
return self.model.embed_tokens
|
978 |
+
|
979 |
+
def set_input_embeddings(self, value):
|
980 |
+
"""Sets the input embeddings layer."""
|
981 |
+
self.model.embed_tokens = value
|
982 |
+
|
983 |
+
def get_output_embeddings(self):
|
984 |
+
"""Returns the output embeddings (LM head)."""
|
985 |
+
return self.lm_head
|
986 |
+
|
987 |
+
def set_output_embeddings(self, new_embeddings):
|
988 |
+
"""Sets the output embeddings layer."""
|
989 |
+
self.lm_head = new_embeddings
|
990 |
+
|
991 |
+
def set_decoder(self, decoder):
|
992 |
+
"""Sets the ERNIE decoder model."""
|
993 |
+
self.model = decoder
|
994 |
+
|
995 |
+
def get_decoder(self):
|
996 |
+
"""Gets the ERNIE decoder model."""
|
997 |
+
return self.model
|
998 |
+
|
999 |
+
def forward(
|
1000 |
+
self,
|
1001 |
+
input_ids,
|
1002 |
+
position_ids=None,
|
1003 |
+
attention_mask=None,
|
1004 |
+
attn_mask_start_row_indices=None,
|
1005 |
+
token_type_ids=None,
|
1006 |
+
inputs_embeds=None,
|
1007 |
+
labels=None,
|
1008 |
+
use_cache=False,
|
1009 |
+
past_key_values=None,
|
1010 |
+
output_attentions=None,
|
1011 |
+
output_hidden_states=None,
|
1012 |
+
**kwargs,
|
1013 |
+
):
|
1014 |
+
"""
|
1015 |
+
Forward pass for causal language modeling.
|
1016 |
+
|
1017 |
+
Args:
|
1018 |
+
input_ids (torch.Tensor): Input token IDs.
|
1019 |
+
position_ids (torch.Tensor): Position IDs.
|
1020 |
+
attention_mask (torch.Tensor): Attention mask.
|
1021 |
+
attn_mask_start_row_indices (torch.Tensor): Attention mask start indices.
|
1022 |
+
inputs_embeds (torch.Tensor): Optional embedded inputs.
|
1023 |
+
labels (torch.Tensor): Target labels.
|
1024 |
+
use_cache (bool): Whether to use cached hidden states.
|
1025 |
+
past_key_values (dict): Pre-computed hidden states.
|
1026 |
+
output_attentions (bool): Whether to output attentions.
|
1027 |
+
output_hidden_states (bool): Whether to output hidden states.
|
1028 |
+
|
1029 |
+
Returns:
|
1030 |
+
CausalLMOutputWithPast: Model outputs.
|
1031 |
+
"""
|
1032 |
+
|
1033 |
+
if past_key_values is not None:
|
1034 |
+
input_ids = input_ids[:, -1:]
|
1035 |
+
|
1036 |
+
outputs = self.model(
|
1037 |
+
input_ids,
|
1038 |
+
position_ids=position_ids,
|
1039 |
+
attention_mask=attention_mask,
|
1040 |
+
token_type_ids=token_type_ids,
|
1041 |
+
attn_mask_start_row_indices=attn_mask_start_row_indices,
|
1042 |
+
inputs_embeds=inputs_embeds,
|
1043 |
+
use_cache=use_cache,
|
1044 |
+
past_key_values=past_key_values,
|
1045 |
+
output_attentions=output_attentions,
|
1046 |
+
output_hidden_states=output_hidden_states,
|
1047 |
+
return_dict=True,
|
1048 |
+
)
|
1049 |
+
|
1050 |
+
hidden_states = outputs.last_hidden_state
|
1051 |
+
logits = self.lm_head(hidden_states)
|
1052 |
+
|
1053 |
+
loss = None
|
1054 |
+
if labels is not None:
|
1055 |
+
loss = self.loss_function(
|
1056 |
+
logits=logits,
|
1057 |
+
labels=labels,
|
1058 |
+
vocab_size=self.config.vocab_size,
|
1059 |
+
**kwargs,
|
1060 |
+
)
|
1061 |
+
|
1062 |
+
return CausalLMOutputWithPast(
|
1063 |
+
loss=loss,
|
1064 |
+
logits=logits,
|
1065 |
+
past_key_values=outputs.past_key_values,
|
1066 |
+
hidden_states=outputs.hidden_states,
|
1067 |
+
attentions=outputs.attentions,
|
1068 |
+
)
|
special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bos_token": "<s>", "eos_token": "</s>", "pad_token": "<unk>", "unk_token": "<unk>", "cls_token": "<|begin_of_sentence|>", "sep_token": "<|end_of_sentence|>", "mask_token": "<mask:1>", "sys_start_token": "<mask:4>", "sys_end_token": "<mask:5>", "header_start_token": "<mask:6>", "header_end_token": "<mask:7>", "additional_special_tokens": ["<|IMAGE_PLACEHOLDER|>", "<|AUDIO_PLACEHOLDER|>", "<|LOC_0|>", "<|LOC_1|>", "<|LOC_2|>", "<|LOC_3|>", "<|LOC_4|>", "<|LOC_5|>", "<|LOC_6|>", "<|LOC_7|>", "<|LOC_8|>", "<|LOC_9|>", "<|LOC_10|>", "<|LOC_11|>", "<|LOC_12|>", "<|LOC_13|>", "<|LOC_14|>", "<|LOC_15|>", "<|LOC_16|>", "<|LOC_17|>", "<|LOC_18|>", "<|LOC_19|>", "<|LOC_20|>", "<|LOC_21|>", "<|LOC_22|>", "<|LOC_23|>", "<|LOC_24|>", "<|LOC_25|>", "<|LOC_26|>", "<|LOC_27|>", "<|LOC_28|>", "<|LOC_29|>", "<|LOC_30|>", "<|LOC_31|>", "<|LOC_32|>", "<|LOC_33|>", "<|LOC_34|>", "<|LOC_35|>", "<|LOC_36|>", "<|LOC_37|>", "<|LOC_38|>", "<|LOC_39|>", "<|LOC_40|>", "<|LOC_41|>", "<|LOC_42|>", "<|LOC_43|>", "<|LOC_44|>", "<|LOC_45|>", "<|LOC_46|>", "<|LOC_47|>", "<|LOC_48|>", "<|LOC_49|>", "<|LOC_50|>", "<|LOC_51|>", "<|LOC_52|>", "<|LOC_53|>", "<|LOC_54|>", "<|LOC_55|>", "<|LOC_56|>", "<|LOC_57|>", "<|LOC_58|>", "<|LOC_59|>", "<|LOC_60|>", "<|LOC_61|>", "<|LOC_62|>", "<|LOC_63|>", "<|LOC_64|>", "<|LOC_65|>", "<|LOC_66|>", "<|LOC_67|>", "<|LOC_68|>", "<|LOC_69|>", "<|LOC_70|>", "<|LOC_71|>", "<|LOC_72|>", "<|LOC_73|>", "<|LOC_74|>", "<|LOC_75|>", "<|LOC_76|>", "<|LOC_77|>", "<|LOC_78|>", "<|LOC_79|>", "<|LOC_80|>", "<|LOC_81|>", "<|LOC_82|>", "<|LOC_83|>", "<|LOC_84|>", "<|LOC_85|>", "<|LOC_86|>", "<|LOC_87|>", "<|LOC_88|>", "<|LOC_89|>", "<|LOC_90|>", "<|LOC_91|>", "<|LOC_92|>", "<|LOC_93|>", "<|LOC_94|>", "<|LOC_95|>", "<|LOC_96|>", "<|LOC_97|>", "<|LOC_98|>", "<|LOC_99|>", "<|LOC_100|>", "<|LOC_101|>", "<|LOC_102|>", "<|LOC_103|>", "<|LOC_104|>", "<|LOC_105|>", "<|LOC_106|>", "<|LOC_107|>", "<|LOC_108|>", "<|LOC_109|>", "<|LOC_110|>", "<|LOC_111|>", "<|LOC_112|>", "<|LOC_113|>", "<|LOC_114|>", "<|LOC_115|>", "<|LOC_116|>", "<|LOC_117|>", "<|LOC_118|>", "<|LOC_119|>", "<|LOC_120|>", "<|LOC_121|>", "<|LOC_122|>", "<|LOC_123|>", "<|LOC_124|>", "<|LOC_125|>", "<|LOC_126|>", "<|LOC_127|>", "<|LOC_128|>", "<|LOC_129|>", "<|LOC_130|>", "<|LOC_131|>", "<|LOC_132|>", "<|LOC_133|>", "<|LOC_134|>", "<|LOC_135|>", "<|LOC_136|>", "<|LOC_137|>", "<|LOC_138|>", "<|LOC_139|>", "<|LOC_140|>", "<|LOC_141|>", "<|LOC_142|>", "<|LOC_143|>", "<|LOC_144|>", "<|LOC_145|>", "<|LOC_146|>", "<|LOC_147|>", "<|LOC_148|>", "<|LOC_149|>", "<|LOC_150|>", "<|LOC_151|>", "<|LOC_152|>", "<|LOC_153|>", "<|LOC_154|>", "<|LOC_155|>", "<|LOC_156|>", "<|LOC_157|>", "<|LOC_158|>", "<|LOC_159|>", "<|LOC_160|>", "<|LOC_161|>", "<|LOC_162|>", "<|LOC_163|>", "<|LOC_164|>", "<|LOC_165|>", "<|LOC_166|>", "<|LOC_167|>", "<|LOC_168|>", "<|LOC_169|>", "<|LOC_170|>", "<|LOC_171|>", "<|LOC_172|>", "<|LOC_173|>", "<|LOC_174|>", "<|LOC_175|>", "<|LOC_176|>", "<|LOC_177|>", "<|LOC_178|>", "<|LOC_179|>", "<|LOC_180|>", "<|LOC_181|>", "<|LOC_182|>", "<|LOC_183|>", "<|LOC_184|>", "<|LOC_185|>", "<|LOC_186|>", "<|LOC_187|>", "<|LOC_188|>", "<|LOC_189|>", "<|LOC_190|>", "<|LOC_191|>", "<|LOC_192|>", "<|LOC_193|>", "<|LOC_194|>", "<|LOC_195|>", "<|LOC_196|>", "<|LOC_197|>", "<|LOC_198|>", "<|LOC_199|>", "<|LOC_200|>", "<|LOC_201|>", "<|LOC_202|>", "<|LOC_203|>", "<|LOC_204|>", "<|LOC_205|>", "<|LOC_206|>", "<|LOC_207|>", "<|LOC_208|>", "<|LOC_209|>", "<|LOC_210|>", "<|LOC_211|>", "<|LOC_212|>", "<|LOC_213|>", "<|LOC_214|>", "<|LOC_215|>", "<|LOC_216|>", "<|LOC_217|>", "<|LOC_218|>", "<|LOC_219|>", "<|LOC_220|>", "<|LOC_221|>", "<|LOC_222|>", "<|LOC_223|>", "<|LOC_224|>", "<|LOC_225|>", "<|LOC_226|>", "<|LOC_227|>", "<|LOC_228|>", "<|LOC_229|>", "<|LOC_230|>", "<|LOC_231|>", "<|LOC_232|>", "<|LOC_233|>", "<|LOC_234|>", "<|LOC_235|>", "<|LOC_236|>", "<|LOC_237|>", "<|LOC_238|>", "<|LOC_239|>", "<|LOC_240|>", "<|LOC_241|>", "<|LOC_242|>", "<|LOC_243|>", "<|LOC_244|>", "<|LOC_245|>", "<|LOC_246|>", "<|LOC_247|>", "<|LOC_248|>", "<|LOC_249|>", "<|LOC_250|>", "<|LOC_251|>", "<|LOC_252|>", "<|LOC_253|>", "<|LOC_254|>", "<|LOC_255|>", "<|LOC_256|>", "<|LOC_257|>", "<|LOC_258|>", "<|LOC_259|>", "<|LOC_260|>", "<|LOC_261|>", "<|LOC_262|>", "<|LOC_263|>", "<|LOC_264|>", "<|LOC_265|>", "<|LOC_266|>", "<|LOC_267|>", "<|LOC_268|>", "<|LOC_269|>", "<|LOC_270|>", "<|LOC_271|>", "<|LOC_272|>", "<|LOC_273|>", "<|LOC_274|>", "<|LOC_275|>", "<|LOC_276|>", "<|LOC_277|>", "<|LOC_278|>", "<|LOC_279|>", "<|LOC_280|>", "<|LOC_281|>", "<|LOC_282|>", "<|LOC_283|>", "<|LOC_284|>", "<|LOC_285|>", "<|LOC_286|>", "<|LOC_287|>", "<|LOC_288|>", "<|LOC_289|>", "<|LOC_290|>", "<|LOC_291|>", "<|LOC_292|>", "<|LOC_293|>", "<|LOC_294|>", "<|LOC_295|>", "<|LOC_296|>", "<|LOC_297|>", "<|LOC_298|>", "<|LOC_299|>", "<|LOC_300|>", "<|LOC_301|>", "<|LOC_302|>", "<|LOC_303|>", "<|LOC_304|>", "<|LOC_305|>", "<|LOC_306|>", "<|LOC_307|>", "<|LOC_308|>", "<|LOC_309|>", "<|LOC_310|>", "<|LOC_311|>", "<|LOC_312|>", "<|LOC_313|>", "<|LOC_314|>", "<|LOC_315|>", "<|LOC_316|>", "<|LOC_317|>", "<|LOC_318|>", "<|LOC_319|>", "<|LOC_320|>", "<|LOC_321|>", "<|LOC_322|>", "<|LOC_323|>", "<|LOC_324|>", "<|LOC_325|>", "<|LOC_326|>", "<|LOC_327|>", "<|LOC_328|>", "<|LOC_329|>", "<|LOC_330|>", "<|LOC_331|>", "<|LOC_332|>", "<|LOC_333|>", "<|LOC_334|>", "<|LOC_335|>", "<|LOC_336|>", "<|LOC_337|>", "<|LOC_338|>", "<|LOC_339|>", "<|LOC_340|>", "<|LOC_341|>", "<|LOC_342|>", "<|LOC_343|>", "<|LOC_344|>", "<|LOC_345|>", "<|LOC_346|>", "<|LOC_347|>", "<|LOC_348|>", "<|LOC_349|>", "<|LOC_350|>", "<|LOC_351|>", "<|LOC_352|>", "<|LOC_353|>", "<|LOC_354|>", "<|LOC_355|>", "<|LOC_356|>", "<|LOC_357|>", "<|LOC_358|>", "<|LOC_359|>", "<|LOC_360|>", "<|LOC_361|>", "<|LOC_362|>", "<|LOC_363|>", "<|LOC_364|>", "<|LOC_365|>", "<|LOC_366|>", "<|LOC_367|>", "<|LOC_368|>", "<|LOC_369|>", "<|LOC_370|>", "<|LOC_371|>", "<|LOC_372|>", "<|LOC_373|>", "<|LOC_374|>", "<|LOC_375|>", "<|LOC_376|>", "<|LOC_377|>", "<|LOC_378|>", "<|LOC_379|>", "<|LOC_380|>", "<|LOC_381|>", "<|LOC_382|>", "<|LOC_383|>", "<|LOC_384|>", "<|LOC_385|>", "<|LOC_386|>", "<|LOC_387|>", "<|LOC_388|>", "<|LOC_389|>", "<|LOC_390|>", "<|LOC_391|>", "<|LOC_392|>", "<|LOC_393|>", "<|LOC_394|>", "<|LOC_395|>", "<|LOC_396|>", "<|LOC_397|>", "<|LOC_398|>", "<|LOC_399|>", "<|LOC_400|>", "<|LOC_401|>", "<|LOC_402|>", "<|LOC_403|>", "<|LOC_404|>", "<|LOC_405|>", "<|LOC_406|>", "<|LOC_407|>", "<|LOC_408|>", "<|LOC_409|>", "<|LOC_410|>", "<|LOC_411|>", "<|LOC_412|>", "<|LOC_413|>", "<|LOC_414|>", "<|LOC_415|>", "<|LOC_416|>", "<|LOC_417|>", "<|LOC_418|>", "<|LOC_419|>", "<|LOC_420|>", "<|LOC_421|>", "<|LOC_422|>", "<|LOC_423|>", "<|LOC_424|>", "<|LOC_425|>", "<|LOC_426|>", "<|LOC_427|>", "<|LOC_428|>", "<|LOC_429|>", "<|LOC_430|>", "<|LOC_431|>", "<|LOC_432|>", "<|LOC_433|>", "<|LOC_434|>", "<|LOC_435|>", "<|LOC_436|>", "<|LOC_437|>", "<|LOC_438|>", "<|LOC_439|>", "<|LOC_440|>", "<|LOC_441|>", "<|LOC_442|>", "<|LOC_443|>", "<|LOC_444|>", "<|LOC_445|>", "<|LOC_446|>", "<|LOC_447|>", "<|LOC_448|>", "<|LOC_449|>", "<|LOC_450|>", "<|LOC_451|>", "<|LOC_452|>", "<|LOC_453|>", "<|LOC_454|>", "<|LOC_455|>", "<|LOC_456|>", "<|LOC_457|>", "<|LOC_458|>", "<|LOC_459|>", "<|LOC_460|>", "<|LOC_461|>", "<|LOC_462|>", "<|LOC_463|>", "<|LOC_464|>", "<|LOC_465|>", "<|LOC_466|>", "<|LOC_467|>", "<|LOC_468|>", "<|LOC_469|>", "<|LOC_470|>", "<|LOC_471|>", "<|LOC_472|>", "<|LOC_473|>", "<|LOC_474|>", "<|LOC_475|>", "<|LOC_476|>", "<|LOC_477|>", "<|LOC_478|>", "<|LOC_479|>", "<|LOC_480|>", "<|LOC_481|>", "<|LOC_482|>", "<|LOC_483|>", "<|LOC_484|>", "<|LOC_485|>", "<|LOC_486|>", "<|LOC_487|>", "<|LOC_488|>", "<|LOC_489|>", "<|LOC_490|>", "<|LOC_491|>", "<|LOC_492|>", "<|LOC_493|>", "<|LOC_494|>", "<|LOC_495|>", "<|LOC_496|>", "<|LOC_497|>", "<|LOC_498|>", "<|LOC_499|>", "<|LOC_500|>", "<|LOC_501|>", "<|LOC_502|>", "<|LOC_503|>", "<|LOC_504|>", "<|LOC_505|>", "<|LOC_506|>", "<|LOC_507|>", "<|LOC_508|>", "<|LOC_509|>", "<|LOC_510|>", "<|LOC_511|>", "<|LOC_512|>", "<|LOC_513|>", "<|LOC_514|>", "<|LOC_515|>", "<|LOC_516|>", "<|LOC_517|>", "<|LOC_518|>", "<|LOC_519|>", "<|LOC_520|>", "<|LOC_521|>", "<|LOC_522|>", "<|LOC_523|>", "<|LOC_524|>", "<|LOC_525|>", "<|LOC_526|>", "<|LOC_527|>", "<|LOC_528|>", "<|LOC_529|>", "<|LOC_530|>", "<|LOC_531|>", "<|LOC_532|>", "<|LOC_533|>", "<|LOC_534|>", "<|LOC_535|>", "<|LOC_536|>", "<|LOC_537|>", "<|LOC_538|>", "<|LOC_539|>", "<|LOC_540|>", "<|LOC_541|>", "<|LOC_542|>", "<|LOC_543|>", "<|LOC_544|>", "<|LOC_545|>", "<|LOC_546|>", "<|LOC_547|>", "<|LOC_548|>", "<|LOC_549|>", "<|LOC_550|>", "<|LOC_551|>", "<|LOC_552|>", "<|LOC_553|>", "<|LOC_554|>", "<|LOC_555|>", "<|LOC_556|>", "<|LOC_557|>", "<|LOC_558|>", "<|LOC_559|>", "<|LOC_560|>", "<|LOC_561|>", "<|LOC_562|>", "<|LOC_563|>", "<|LOC_564|>", "<|LOC_565|>", "<|LOC_566|>", "<|LOC_567|>", "<|LOC_568|>", "<|LOC_569|>", "<|LOC_570|>", "<|LOC_571|>", "<|LOC_572|>", "<|LOC_573|>", "<|LOC_574|>", "<|LOC_575|>", "<|LOC_576|>", "<|LOC_577|>", "<|LOC_578|>", "<|LOC_579|>", "<|LOC_580|>", "<|LOC_581|>", "<|LOC_582|>", "<|LOC_583|>", "<|LOC_584|>", "<|LOC_585|>", "<|LOC_586|>", "<|LOC_587|>", "<|LOC_588|>", "<|LOC_589|>", "<|LOC_590|>", "<|LOC_591|>", "<|LOC_592|>", "<|LOC_593|>", "<|LOC_594|>", "<|LOC_595|>", "<|LOC_596|>", "<|LOC_597|>", "<|LOC_598|>", "<|LOC_599|>", "<|LOC_600|>", "<|LOC_601|>", "<|LOC_602|>", "<|LOC_603|>", "<|LOC_604|>", "<|LOC_605|>", "<|LOC_606|>", "<|LOC_607|>", "<|LOC_608|>", "<|LOC_609|>", "<|LOC_610|>", "<|LOC_611|>", "<|LOC_612|>", "<|LOC_613|>", "<|LOC_614|>", "<|LOC_615|>", "<|LOC_616|>", "<|LOC_617|>", "<|LOC_618|>", "<|LOC_619|>", "<|LOC_620|>", "<|LOC_621|>", "<|LOC_622|>", "<|LOC_623|>", "<|LOC_624|>", "<|LOC_625|>", "<|LOC_626|>", "<|LOC_627|>", "<|LOC_628|>", "<|LOC_629|>", "<|LOC_630|>", "<|LOC_631|>", "<|LOC_632|>", "<|LOC_633|>", "<|LOC_634|>", "<|LOC_635|>", "<|LOC_636|>", "<|LOC_637|>", "<|LOC_638|>", "<|LOC_639|>", "<|LOC_640|>", "<|LOC_641|>", "<|LOC_642|>", "<|LOC_643|>", "<|LOC_644|>", "<|LOC_645|>", "<|LOC_646|>", "<|LOC_647|>", "<|LOC_648|>", "<|LOC_649|>", "<|LOC_650|>", "<|LOC_651|>", "<|LOC_652|>", "<|LOC_653|>", "<|LOC_654|>", "<|LOC_655|>", "<|LOC_656|>", "<|LOC_657|>", "<|LOC_658|>", "<|LOC_659|>", "<|LOC_660|>", "<|LOC_661|>", "<|LOC_662|>", "<|LOC_663|>", "<|LOC_664|>", "<|LOC_665|>", "<|LOC_666|>", "<|LOC_667|>", "<|LOC_668|>", "<|LOC_669|>", "<|LOC_670|>", "<|LOC_671|>", "<|LOC_672|>", "<|LOC_673|>", "<|LOC_674|>", "<|LOC_675|>", "<|LOC_676|>", "<|LOC_677|>", "<|LOC_678|>", "<|LOC_679|>", "<|LOC_680|>", "<|LOC_681|>", "<|LOC_682|>", "<|LOC_683|>", "<|LOC_684|>", "<|LOC_685|>", "<|LOC_686|>", "<|LOC_687|>", "<|LOC_688|>", "<|LOC_689|>", "<|LOC_690|>", "<|LOC_691|>", "<|LOC_692|>", "<|LOC_693|>", "<|LOC_694|>", "<|LOC_695|>", "<|LOC_696|>", "<|LOC_697|>", "<|LOC_698|>", "<|LOC_699|>", "<|LOC_700|>", "<|LOC_701|>", "<|LOC_702|>", "<|LOC_703|>", "<|LOC_704|>", "<|LOC_705|>", "<|LOC_706|>", "<|LOC_707|>", "<|LOC_708|>", "<|LOC_709|>", "<|LOC_710|>", "<|LOC_711|>", "<|LOC_712|>", "<|LOC_713|>", "<|LOC_714|>", "<|LOC_715|>", "<|LOC_716|>", "<|LOC_717|>", "<|LOC_718|>", "<|LOC_719|>", "<|LOC_720|>", "<|LOC_721|>", "<|LOC_722|>", "<|LOC_723|>", "<|LOC_724|>", "<|LOC_725|>", "<|LOC_726|>", "<|LOC_727|>", "<|LOC_728|>", "<|LOC_729|>", "<|LOC_730|>", "<|LOC_731|>", "<|LOC_732|>", "<|LOC_733|>", "<|LOC_734|>", "<|LOC_735|>", "<|LOC_736|>", "<|LOC_737|>", "<|LOC_738|>", "<|LOC_739|>", "<|LOC_740|>", "<|LOC_741|>", "<|LOC_742|>", "<|LOC_743|>", "<|LOC_744|>", "<|LOC_745|>", "<|LOC_746|>", "<|LOC_747|>", "<|LOC_748|>", "<|LOC_749|>", "<|LOC_750|>", "<|LOC_751|>", "<|LOC_752|>", "<|LOC_753|>", "<|LOC_754|>", "<|LOC_755|>", "<|LOC_756|>", "<|LOC_757|>", "<|LOC_758|>", "<|LOC_759|>", "<|LOC_760|>", "<|LOC_761|>", "<|LOC_762|>", "<|LOC_763|>", "<|LOC_764|>", "<|LOC_765|>", "<|LOC_766|>", "<|LOC_767|>", "<|LOC_768|>", "<|LOC_769|>", "<|LOC_770|>", "<|LOC_771|>", "<|LOC_772|>", "<|LOC_773|>", "<|LOC_774|>", "<|LOC_775|>", "<|LOC_776|>", "<|LOC_777|>", "<|LOC_778|>", "<|LOC_779|>", "<|LOC_780|>", "<|LOC_781|>", "<|LOC_782|>", "<|LOC_783|>", "<|LOC_784|>", "<|LOC_785|>", "<|LOC_786|>", "<|LOC_787|>", "<|LOC_788|>", "<|LOC_789|>", "<|LOC_790|>", "<|LOC_791|>", "<|LOC_792|>", "<|LOC_793|>", "<|LOC_794|>", "<|LOC_795|>", "<|LOC_796|>", "<|LOC_797|>", "<|LOC_798|>", "<|LOC_799|>", "<|LOC_800|>", "<|LOC_801|>", "<|LOC_802|>", "<|LOC_803|>", "<|LOC_804|>", "<|LOC_805|>", "<|LOC_806|>", "<|LOC_807|>", "<|LOC_808|>", "<|LOC_809|>", "<|LOC_810|>", "<|LOC_811|>", "<|LOC_812|>", "<|LOC_813|>", "<|LOC_814|>", "<|LOC_815|>", "<|LOC_816|>", "<|LOC_817|>", "<|LOC_818|>", "<|LOC_819|>", "<|LOC_820|>", "<|LOC_821|>", "<|LOC_822|>", "<|LOC_823|>", "<|LOC_824|>", "<|LOC_825|>", "<|LOC_826|>", "<|LOC_827|>", "<|LOC_828|>", "<|LOC_829|>", "<|LOC_830|>", "<|LOC_831|>", "<|LOC_832|>", "<|LOC_833|>", "<|LOC_834|>", "<|LOC_835|>", "<|LOC_836|>", "<|LOC_837|>", "<|LOC_838|>", "<|LOC_839|>", "<|LOC_840|>", "<|LOC_841|>", "<|LOC_842|>", "<|LOC_843|>", "<|LOC_844|>", "<|LOC_845|>", "<|LOC_846|>", "<|LOC_847|>", "<|LOC_848|>", "<|LOC_849|>", "<|LOC_850|>", "<|LOC_851|>", "<|LOC_852|>", "<|LOC_853|>", "<|LOC_854|>", "<|LOC_855|>", "<|LOC_856|>", "<|LOC_857|>", "<|LOC_858|>", "<|LOC_859|>", "<|LOC_860|>", "<|LOC_861|>", "<|LOC_862|>", "<|LOC_863|>", "<|LOC_864|>", "<|LOC_865|>", "<|LOC_866|>", "<|LOC_867|>", "<|LOC_868|>", "<|LOC_869|>", "<|LOC_870|>", "<|LOC_871|>", "<|LOC_872|>", "<|LOC_873|>", "<|LOC_874|>", "<|LOC_875|>", "<|LOC_876|>", "<|LOC_877|>", "<|LOC_878|>", "<|LOC_879|>", "<|LOC_880|>", "<|LOC_881|>", "<|LOC_882|>", "<|LOC_883|>", "<|LOC_884|>", "<|LOC_885|>", "<|LOC_886|>", "<|LOC_887|>", "<|LOC_888|>", "<|LOC_889|>", "<|LOC_890|>", "<|LOC_891|>", "<|LOC_892|>", "<|LOC_893|>", "<|LOC_894|>", "<|LOC_895|>", "<|LOC_896|>", "<|LOC_897|>", "<|LOC_898|>", "<|LOC_899|>", "<|LOC_900|>", "<|LOC_901|>", "<|LOC_902|>", "<|LOC_903|>", "<|LOC_904|>", "<|LOC_905|>", "<|LOC_906|>", "<|LOC_907|>", "<|LOC_908|>", "<|LOC_909|>", "<|LOC_910|>", "<|LOC_911|>", "<|LOC_912|>", "<|LOC_913|>", "<|LOC_914|>", "<|LOC_915|>", "<|LOC_916|>", "<|LOC_917|>", "<|LOC_918|>", "<|LOC_919|>", "<|LOC_920|>", "<|LOC_921|>", "<|LOC_922|>", "<|LOC_923|>", "<|LOC_924|>", "<|LOC_925|>", "<|LOC_926|>", "<|LOC_927|>", "<|LOC_928|>", "<|LOC_929|>", "<|LOC_930|>", "<|LOC_931|>", "<|LOC_932|>", "<|LOC_933|>", "<|LOC_934|>", "<|LOC_935|>", "<|LOC_936|>", "<|LOC_937|>", "<|LOC_938|>", "<|LOC_939|>", "<|LOC_940|>", "<|LOC_941|>", "<|LOC_942|>", "<|LOC_943|>", "<|LOC_944|>", "<|LOC_945|>", "<|LOC_946|>", "<|LOC_947|>", "<|LOC_948|>", "<|LOC_949|>", "<|LOC_950|>", "<|LOC_951|>", "<|LOC_952|>", "<|LOC_953|>", "<|LOC_954|>", "<|LOC_955|>", "<|LOC_956|>", "<|LOC_957|>", "<|LOC_958|>", "<|LOC_959|>", "<|LOC_960|>", "<|LOC_961|>", "<|LOC_962|>", "<|LOC_963|>", "<|LOC_964|>", "<|LOC_965|>", "<|LOC_966|>", "<|LOC_967|>", "<|LOC_968|>", "<|LOC_969|>", "<|LOC_970|>", "<|LOC_971|>", "<|LOC_972|>", "<|LOC_973|>", "<|LOC_974|>", "<|LOC_975|>", "<|LOC_976|>", "<|LOC_977|>", "<|LOC_978|>", "<|LOC_979|>", "<|LOC_980|>", "<|LOC_981|>", "<|LOC_982|>", "<|LOC_983|>", "<|LOC_984|>", "<|LOC_985|>", "<|LOC_986|>", "<|LOC_987|>", "<|LOC_988|>", "<|LOC_989|>", "<|LOC_990|>", "<|LOC_991|>", "<|LOC_992|>", "<|LOC_993|>", "<|LOC_994|>", "<|LOC_995|>", "<|LOC_996|>", "<|LOC_997|>", "<|LOC_998|>", "<|LOC_999|>", "<|LOC_1000|>", "<|LOC_BEGIN|>", "<|LOC_END|>", "<|LOC_SEP|>", "<|CROP_COL_SEP|>", "<|CROP_ROW_SEP|>", "<|IMAGE_SEP|>"]}
|
tokenization_ernie4_5.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
from shutil import copyfile
|
17 |
+
from typing import Dict, List, Optional, Tuple, Union
|
18 |
+
import torch
|
19 |
+
import numpy as np
|
20 |
+
import sentencepiece as spm
|
21 |
+
|
22 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
23 |
+
from transformers.tokenization_utils_base import (
|
24 |
+
PaddingStrategy,
|
25 |
+
)
|
26 |
+
from transformers.utils import logging
|
27 |
+
|
28 |
+
|
29 |
+
logger = logging.get_logger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
class Ernie4_5_Tokenizer(PreTrainedTokenizer):
|
33 |
+
|
34 |
+
vocab_files_names = {
|
35 |
+
"vocab_file": "tokenizer.model",
|
36 |
+
}
|
37 |
+
# Model input names expected by the tokenizer
|
38 |
+
model_input_names = ["input_ids", "position_ids", "attention_mask", "labels"]
|
39 |
+
# Padding side (where to add padding tokens)
|
40 |
+
padding_side = "right"
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
vocab_file,
|
45 |
+
bos_token="<s>",
|
46 |
+
cls_token="<cls>",
|
47 |
+
eos_token="</s>",
|
48 |
+
mask_token="<mask:0>",
|
49 |
+
pad_token="<pad>",
|
50 |
+
sep_token="<sep>",
|
51 |
+
unk_token="<unk>",
|
52 |
+
additional_special_tokens=None,
|
53 |
+
split_special_tokens=False,
|
54 |
+
alpha=None,
|
55 |
+
tokenizer_alpha=None,
|
56 |
+
**kwargs,
|
57 |
+
):
|
58 |
+
"""
|
59 |
+
Initialize the ERNIE tokenizer.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
vocab_file (str): Path to the SentencePiece model file.
|
63 |
+
bos_token (str, optional): Beginning of sentence token. Defaults to "<s>".
|
64 |
+
cls_token (str, optional): Classification token. Defaults to "<cls>".
|
65 |
+
eos_token (str, optional): End of sentence token. Defaults to "</s>".
|
66 |
+
mask_token (str, optional): Mask token. Defaults to "<mask:0>".
|
67 |
+
pad_token (str, optional): Padding token. Defaults to "<pad>".
|
68 |
+
sep_token (str, optional): Separator token. Defaults to "<sep>".
|
69 |
+
unk_token (str, optional): Unknown token. Defaults to "<unk>".
|
70 |
+
additional_special_tokens (List[str], optional): Additional special tokens.
|
71 |
+
Defaults to ["<mask:1>", "<mask:7>"].
|
72 |
+
split_special_tokens (bool, optional): Whether to split special tokens. Defaults to False.
|
73 |
+
alpha (None, optional): Currently unused parameter. Reserved for future use.
|
74 |
+
tokenizer_alpha (float, optional): Alpha parameter for SentencePiece sampling.
|
75 |
+
**kwargs: Additional keyword arguments passed to the parent class.
|
76 |
+
"""
|
77 |
+
|
78 |
+
self.vocab_file = vocab_file
|
79 |
+
self.sp_model = spm.SentencePieceProcessor()
|
80 |
+
self.sp_model.Load(vocab_file)
|
81 |
+
self.alpha = alpha
|
82 |
+
self.pad_id = self._convert_token_to_id(pad_token)
|
83 |
+
self.tokenizer_alpha = tokenizer_alpha
|
84 |
+
|
85 |
+
if additional_special_tokens is None:
|
86 |
+
additional_special_tokens = ["<mask:1>", "<mask:7>"]
|
87 |
+
super().__init__(
|
88 |
+
bos_token=bos_token,
|
89 |
+
cls_token=cls_token,
|
90 |
+
eos_token=eos_token,
|
91 |
+
mask_token=mask_token,
|
92 |
+
pad_token=pad_token,
|
93 |
+
sep_token=sep_token,
|
94 |
+
unk_token=unk_token,
|
95 |
+
additional_special_tokens=additional_special_tokens,
|
96 |
+
split_special_tokens=split_special_tokens,
|
97 |
+
**kwargs,
|
98 |
+
)
|
99 |
+
|
100 |
+
@property
|
101 |
+
def vocab_size(self):
|
102 |
+
"""Returns the size of the vocabulary.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
int: The number of tokens in the vocabulary.
|
106 |
+
"""
|
107 |
+
return self.sp_model.vocab_size()
|
108 |
+
|
109 |
+
def get_vocab(self):
|
110 |
+
"""Get the vocabulary as a dictionary mapping tokens to their IDs.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
dict: A dictionary mapping tokens to their corresponding IDs.
|
114 |
+
"""
|
115 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
116 |
+
vocab.update(self.added_tokens_encoder)
|
117 |
+
return vocab
|
118 |
+
|
119 |
+
def _tokenize(self, text):
|
120 |
+
"""Tokenize text using SentencePiece.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
text (str): The text to tokenize.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
list: A list of tokens.
|
127 |
+
"""
|
128 |
+
if self.tokenizer_alpha is not None:
|
129 |
+
return self.sp_model.encode_as_pieces(
|
130 |
+
text,
|
131 |
+
enable_sampling=True,
|
132 |
+
nbest_size=-1,
|
133 |
+
alpha=self.tokenizer_alpha,
|
134 |
+
)
|
135 |
+
else:
|
136 |
+
return self.sp_model.encode_as_pieces(text)
|
137 |
+
|
138 |
+
def _convert_token_to_id(self, token):
|
139 |
+
"""Convert a token (str) to an ID using the vocabulary.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
token (str): The token to convert.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
int: The corresponding token ID.
|
146 |
+
"""
|
147 |
+
return self.sp_model.piece_to_id(token)
|
148 |
+
|
149 |
+
def _convert_id_to_token(self, id):
|
150 |
+
"""Convert an ID to a token (str) using the vocabulary.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
id (int): The token ID to convert.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
str: The corresponding token.
|
157 |
+
"""
|
158 |
+
if id >= self.vocab_size:
|
159 |
+
return self.unk_token
|
160 |
+
else:
|
161 |
+
return self.sp_model.id_to_piece(id)
|
162 |
+
|
163 |
+
def convert_tokens_to_string(self, tokens):
|
164 |
+
"""Convert a sequence of tokens back to a single string.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
tokens (List[str]): A list of tokens to convert.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
str: The reconstructed string.
|
171 |
+
"""
|
172 |
+
current_sub_tokens = []
|
173 |
+
out_string = ""
|
174 |
+
prev_is_special = False
|
175 |
+
for token in tokens:
|
176 |
+
# make sure that special tokens are not decoded using sentencepiece model
|
177 |
+
if token in self.all_special_tokens:
|
178 |
+
if not prev_is_special:
|
179 |
+
out_string += " "
|
180 |
+
out_string += self.sp_model.decode(current_sub_tokens) + token
|
181 |
+
prev_is_special = True
|
182 |
+
current_sub_tokens = []
|
183 |
+
else:
|
184 |
+
current_sub_tokens.append(token)
|
185 |
+
prev_is_special = False
|
186 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
187 |
+
return out_string
|
188 |
+
|
189 |
+
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
190 |
+
"""Build model inputs by adding special tokens to sequences.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
token_ids_0 (List[int]): List of token IDs for the first sequence.
|
194 |
+
token_ids_1 (List[int], optional): List of token IDs for the second sequence.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
List[int]: List of token IDs with special tokens added.
|
198 |
+
"""
|
199 |
+
output = token_ids_0
|
200 |
+
last_cls_index = -1
|
201 |
+
last_sep_index = -1
|
202 |
+
if self.cls_token_id in output:
|
203 |
+
last_cls_index = len(output) - output[::-1].index(self.cls_token_id) - 1
|
204 |
+
if self.sep_token_id in output:
|
205 |
+
last_sep_index = len(output) - output[::-1].index(self.sep_token_id) - 1
|
206 |
+
|
207 |
+
if last_cls_index > last_sep_index:
|
208 |
+
next_token_id = self.sep_token_id
|
209 |
+
elif last_sep_index > last_cls_index:
|
210 |
+
next_token_id = self.cls_token_id
|
211 |
+
else:
|
212 |
+
output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
|
213 |
+
next_token_id = self.cls_token_id
|
214 |
+
|
215 |
+
output = [self.bos_token_id] + output
|
216 |
+
# Assume no markup in text if token_ids_1 is given.
|
217 |
+
if token_ids_1 is not None:
|
218 |
+
output = output + token_ids_1 + [next_token_id]
|
219 |
+
return output
|
220 |
+
|
221 |
+
def get_special_tokens_mask(
|
222 |
+
self, token_ids_0, token_ids_1=None, already_has_special_tokens=False
|
223 |
+
):
|
224 |
+
"""Get a mask showing which tokens are special tokens.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
token_ids_0 (List[int]): List of token IDs for the first sequence.
|
228 |
+
token_ids_1 (List[int], optional): List of token IDs for the second sequence.
|
229 |
+
already_has_special_tokens (bool): Whether the tokens already include special tokens.
|
230 |
+
|
231 |
+
Returns:
|
232 |
+
List[int]: A mask where 1 indicates special tokens and 0 indicates regular tokens.
|
233 |
+
"""
|
234 |
+
if already_has_special_tokens:
|
235 |
+
return super().get_special_tokens_mask(
|
236 |
+
token_ids_0, token_ids_1, already_has_special_tokens=True
|
237 |
+
)
|
238 |
+
|
239 |
+
# [bos_token, cls_token, tokens_0, sep_token]
|
240 |
+
if token_ids_1 is None:
|
241 |
+
return [1, 1] + ([0] * len(token_ids_0)) + [1]
|
242 |
+
# [bos_token, cls_token, tokens_0, sep_token, tokens_1, cls_token]
|
243 |
+
return [1, 1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
244 |
+
|
245 |
+
def save_vocabulary(
|
246 |
+
self, save_directory, filename_prefix: Optional[str] = None
|
247 |
+
) -> Tuple[str]:
|
248 |
+
"""
|
249 |
+
Save the vocabulary and special tokens file to a directory.
|
250 |
+
|
251 |
+
Args:
|
252 |
+
save_directory (str): The directory in which to save the vocabulary.
|
253 |
+
filename_prefix (Optional[str]): Optional prefix for the saved filename.
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
Tuple[str]: Paths to the files saved.
|
257 |
+
|
258 |
+
Raises:
|
259 |
+
ValueError: If the save_directory is not a valid directory.
|
260 |
+
"""
|
261 |
+
if not os.path.isdir(save_directory):
|
262 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
263 |
+
return
|
264 |
+
out_vocab_file = os.path.join(
|
265 |
+
save_directory,
|
266 |
+
(filename_prefix + "-" if filename_prefix else "")
|
267 |
+
+ self.resource_files_names["vocab_file"],
|
268 |
+
)
|
269 |
+
|
270 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(
|
271 |
+
out_vocab_file
|
272 |
+
) and os.path.isfile(self.vocab_file):
|
273 |
+
copyfile(self.vocab_file, out_vocab_file)
|
274 |
+
elif not os.path.isfile(self.vocab_file):
|
275 |
+
with open(out_vocab_file, "wb") as fi:
|
276 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
277 |
+
fi.write(content_spiece_model)
|
278 |
+
|
279 |
+
return (out_vocab_file,)
|
280 |
+
|
281 |
+
def _pad(
|
282 |
+
self,
|
283 |
+
encoded_inputs: Union[Dict],
|
284 |
+
max_length: Optional[int] = None,
|
285 |
+
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
286 |
+
pad_to_multiple_of: Optional[int] = None,
|
287 |
+
padding_side: Optional[str] = None,
|
288 |
+
return_attention_mask: Optional[bool] = None,
|
289 |
+
) -> dict:
|
290 |
+
"""
|
291 |
+
Pad encoded inputs according to specified strategy.
|
292 |
+
|
293 |
+
Args:
|
294 |
+
encoded_inputs (Union[Dict]): Dictionary of encoded inputs.
|
295 |
+
max_length (Optional[int]): Maximum length to pad to.
|
296 |
+
padding_strategy (PaddingStrategy): Strategy for padding.
|
297 |
+
pad_to_multiple_of (Optional[int]): Pad to a multiple of this value.
|
298 |
+
return_attention_mask (Optional[bool]): Whether to return attention mask.
|
299 |
+
|
300 |
+
Returns:
|
301 |
+
dict: Dictionary with padded inputs and optional attention mask.
|
302 |
+
|
303 |
+
Raises:
|
304 |
+
ValueError: If attention_mask has unexpected type or invalid padding strategy.
|
305 |
+
"""
|
306 |
+
if return_attention_mask is None:
|
307 |
+
return_attention_mask = "attention_mask" in self.model_input_names
|
308 |
+
if return_attention_mask:
|
309 |
+
required_input = encoded_inputs[self.model_input_names[0]]
|
310 |
+
if padding_strategy == PaddingStrategy.LONGEST:
|
311 |
+
max_length = len(required_input)
|
312 |
+
if (
|
313 |
+
max_length is not None
|
314 |
+
and pad_to_multiple_of is not None
|
315 |
+
and (max_length % pad_to_multiple_of != 0)
|
316 |
+
):
|
317 |
+
max_length = (
|
318 |
+
(max_length // pad_to_multiple_of) + 1
|
319 |
+
) * pad_to_multiple_of
|
320 |
+
needs_to_be_padded = (
|
321 |
+
padding_strategy != PaddingStrategy.DO_NOT_PAD
|
322 |
+
and len(required_input) != max_length
|
323 |
+
)
|
324 |
+
|
325 |
+
if (
|
326 |
+
"attention_mask" in encoded_inputs
|
327 |
+
and encoded_inputs["attention_mask"] is not None
|
328 |
+
):
|
329 |
+
attention_mask = encoded_inputs.pop("attention_mask")
|
330 |
+
if isinstance(attention_mask, torch.Tensor):
|
331 |
+
attention_mask = attention_mask.numpy()
|
332 |
+
elif isinstance(attention_mask, list):
|
333 |
+
attention_mask = np.array(attention_mask)
|
334 |
+
elif not isinstance(attention_mask, np.ndarray):
|
335 |
+
raise ValueError(
|
336 |
+
f"Unexpected type {type(attention_mask)} of attention_mask, "
|
337 |
+
)
|
338 |
+
else:
|
339 |
+
# Create default attention mask if none provided
|
340 |
+
attention_mask = np.tril(
|
341 |
+
np.ones((len(required_input), len(required_input)), dtype=np.int64)
|
342 |
+
)
|
343 |
+
attention_mask = np.expand_dims(attention_mask, axis=0)
|
344 |
+
|
345 |
+
if needs_to_be_padded:
|
346 |
+
difference = max_length - len(required_input)
|
347 |
+
if self.padding_side == "right":
|
348 |
+
if attention_mask.ndim == 1:
|
349 |
+
pad_width = [(0, difference)]
|
350 |
+
else:
|
351 |
+
pad_width = [(0, 0), (0, difference), (0, difference)]
|
352 |
+
elif self.padding_side == "left":
|
353 |
+
if attention_mask.ndim == 1:
|
354 |
+
pad_width = [(difference, 0)]
|
355 |
+
else:
|
356 |
+
pad_width = [(0, 0), (difference, 0), (difference, 0)]
|
357 |
+
else:
|
358 |
+
raise ValueError(
|
359 |
+
"Invalid padding strategy:" + str(self.padding_side)
|
360 |
+
)
|
361 |
+
attention_mask = np.pad(
|
362 |
+
attention_mask,
|
363 |
+
pad_width=pad_width,
|
364 |
+
mode="constant",
|
365 |
+
constant_values=0,
|
366 |
+
)
|
367 |
+
|
368 |
+
encoded_inputs = super()._pad(
|
369 |
+
encoded_inputs,
|
370 |
+
max_length,
|
371 |
+
padding_strategy=padding_strategy,
|
372 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
373 |
+
return_attention_mask=False,
|
374 |
+
)
|
375 |
+
if return_attention_mask:
|
376 |
+
encoded_inputs["attention_mask"] = attention_mask.tolist()
|
377 |
+
return encoded_inputs
|
tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:34ef7db83df785924fb83d7b887b6e822a031c56e15cff40aaf9b982988180df
|
3 |
+
size 1614363
|
tokenizer_config.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "<s>",
|
3 |
+
"eos_token": "</s>",
|
4 |
+
"pad_token": "<unk>",
|
5 |
+
"unk_token": "<unk>",
|
6 |
+
"cls_token": "<|begin_of_sentence|>",
|
7 |
+
"sep_token": "<|end_of_sentence|>",
|
8 |
+
"mask_token": "<mask:1>",
|
9 |
+
"sys_start_token": "<mask:4>",
|
10 |
+
"sys_end_token": "<mask:5>",
|
11 |
+
"header_start_token": "<mask:6>",
|
12 |
+
"header_end_token": "<mask:7>",
|
13 |
+
"additional_special_tokens": null,
|
14 |
+
"tokenizer_class": "Ernie4_5_Tokenizer",
|
15 |
+
"auto_map": {
|
16 |
+
"AutoTokenizer": [
|
17 |
+
"tokenization_ernie4_5.Ernie4_5_Tokenizer",
|
18 |
+
null
|
19 |
+
]
|
20 |
+
},
|
21 |
+
"chat_template": "{%- if not add_generation_prompt is defined -%}\n {%- set add_generation_prompt = true -%}\n{%- endif -%}\n{%- if not cls_token is defined -%}\n {%- set cls_token = \"<|begin_of_sentence|>\" -%}\n{%- endif -%}\n{%- if not sep_token is defined -%}\n {%- set sep_token = \"<|end_of_sentence|>\" -%}\n{%- endif -%}\n{{- cls_token -}}\n{%- for message in messages -%}\n {%- if message[\"role\"] == \"user\" -%}\n {{- \"User: \" + message[\"content\"] + \"\n\" -}}\n {%- elif message[\"role\"] == \"assistant\" -%}\n {{- \"Assistant: \" + message[\"content\"] + sep_token -}}\n {%- elif message[\"role\"] == \"system\" -%}\n {{- message[\"content\"] + \"\n\" -}}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{- \"Assistant: \" -}}\n{%- endif -%}"
|
22 |
+
}
|