hlfby06 commited on
Commit
c6c000a
·
verified ·
1 Parent(s): 3f5bd73

Upload folder using huggingface_hub

Browse files
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,118 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ERNIE-4.5-0.3B
2
+
3
+ ## ERNIE 4.5 Highlights
4
+
5
+ The advanced capabilities of the ERNIE 4.5 models, particularly the MoE-based A47B and A3B series, are underpinned by several key technical innovations:
6
+
7
+ - **Multimodal MoE Pretraining:** Our models are jointly trained on both textual and visual modalities to better capture the nuances of multimodal information and improve performance on tasks involving text generation, image understanding, and cross-modal reasoning. To achieve this without one modality hindering the learning of another, we designed a heterogeneous MoE structure, incorporated three-dimensional rotary embeddings, and employed router orthogonal loss and multimodal token-balanced loss. These architectural choices ensure that both modalities are effectively represented, allowing for mutual reinforcement during training.
8
+ - **Scaling-Efficient Architecture and Infrastructure:** To train the large multimodal MoE models efficiently, we introduce a novel heterogeneous hybrid parallelism and multi-level load balancing strategy for efficient training of ERNIE 4.5 models. By using on-device expert parallelism, memory-efficient pipeline scheduling, and FP8 mixed precision, we achieve ideal pre-training performance. For inference, we propose a quantization method with collaborative parallelism among multiple experts to achieve lossless quantization. Built on PaddlePaddle, ERNIE 4.5 delivers high-performance inference across a wide range of hardware platforms.
9
+ - **Modality-Specific Post-training:** To meet the diverse requirements of real-world applications, we fine-tuned variants of the pretrained model for specific modalities. Our LLMs are optimized for general-purpose language understanding and generation. The VLMs focuses on visual-language understanding and supports both thinking and no-thinking mode. Each model employed a combination of Supervised Fine-tuning (SFT), Direct Preference Optimization (DPO) or a modified reinforcement learning method named Unified Preference Optimization (UPO) for post-training, using targeted datasets aligned with its intended usage scenario.
10
+
11
+ ## Model Overview
12
+
13
+ ERNIE-4.5-0.3B is a text dense Post-trained model. The following are the model configuration details:
14
+
15
+ | Key | Value |
16
+ | -------------- | ------------ |
17
+ | Modality | Text |
18
+ | Training Stage | Posttraining |
19
+ | Params | 0.36B |
20
+ | Layers | 18 |
21
+ | Heads(Q/KV) | 16 / 2 |
22
+ | Context Length | 131072 |
23
+
24
+ ## Quickstart
25
+
26
+ ### Model Finetuning with ERNIEKit
27
+
28
+ [ERNIEKit](https://github.com/PaddlePaddle/ERNIE) is a training toolkit based on PaddlePaddle, specifically designed for the ERNIE series of open-source large models. It provides comprehensive support for scenarios such as instruction fine-tuning (SFT, LoRA) and alignment training (DPO), ensuring optimal performance.
29
+
30
+ Usage Examples:
31
+
32
+ ```bash
33
+ # SFT
34
+ erniekit train --stage SFT --model_name_or_path /baidu/ERNIE-4.5-0.3B --train_dataset_path your_dataset_path
35
+ # DPO
36
+ erniekit train --stage DPO --model_name_or_path /baidu/ERNIE-4.5-0.3B --train_dataset_path your_dataset_path
37
+ ```
38
+
39
+ For more detailed examples, including SFT with LoRA, multi-GPU configurations, and advanced scripts, please refer to the examples folder within the [ERNIEKit](https://github.com/PaddlePaddle/ERNIE) repository.
40
+
41
+ ### FastDeploy Inference
42
+
43
+ Service deployment can be quickly completed using FastDeploy in the following command. For more detailed usage instructions, please refer to the [FastDeploy Repository](https://github.com/PaddlePaddle/FastDeploy).
44
+
45
+ ```bash
46
+ python -m fastdeploy.entrypoints.openai.api_server \
47
+ --model BAIDU/ERNIE-4.5-0.3B-Paddle \
48
+ --port 8180 \
49
+ --metrics-port 8181 \
50
+ --engine-worker-queue-port 8182 \
51
+ --max-model-len 32768 \ # Maximum supported number of tokens
52
+ --max-num-seqs 32 # Maximum concurrent processing capacity
53
+ ```
54
+
55
+ ### Using `transformers` library
56
+
57
+ The following contains a code snippet illustrating how to use the model generate content based on given inputs.
58
+
59
+ ```python
60
+ from transformers import AutoModelForCausalLM, AutoTokenizer
61
+
62
+ model_name = "baidu/ERNIE-4.5-0.3B-PT"
63
+
64
+ # load the tokenizer and the model
65
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
66
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
67
+
68
+ # prepare the model input
69
+ prompt = "Give me a short introduction to large language model."
70
+ messages = [
71
+ {"role": "user", "content": prompt}
72
+ ]
73
+ text = tokenizer.apply_chat_template(
74
+ messages,
75
+ tokenize=False,
76
+ add_generation_prompt=True
77
+ )
78
+ model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device)
79
+
80
+ # conduct text completion
81
+ generated_ids = model.generate(
82
+ model_inputs.input_ids,
83
+ max_new_tokens=1024
84
+ )
85
+ output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
86
+
87
+ # decode the generated ids
88
+ generate_text = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
89
+ print("generate_text:", generate_text)
90
+ ```
91
+
92
+ ### vLLM inference
93
+
94
+ vLLM is currently being adapted, priority can be given to using our fork repository [vllm](https://github.com/CSWYF3634076/vllm/tree/ernie)
95
+
96
+ ```bash
97
+ vllm serve baidu/ERNIE-4.5-0.3B-PT --trust-remote-code
98
+ ```
99
+
100
+ ## License
101
+
102
+ The ERNIE 4.5 models are provided under the Apache License 2.0. This license permits commercial use, subject to its terms and conditions. Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
103
+
104
+ ## Citation
105
+
106
+ If you find ERNIE 4.5 useful or wish to use it in your projects, please kindly cite our technical report:
107
+
108
+ ```bibtex
109
+ @misc{ernie2025technicalreport,
110
+ title={ERNIE 4.5 Technical Report},
111
+ author={Baidu ERNIE Team},
112
+ year={2025},
113
+ eprint={},
114
+ archivePrefix={arXiv},
115
+ primaryClass={cs.CL},
116
+ url={}
117
+ }
118
+ ```
added_tokens.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"<|IMAGE_PLACEHOLDER|>": 100295, "<|AUDIO_PLACEHOLDER|>": 100296, "<|LOC_0|>": 100297, "<|LOC_1|>": 100298, "<|LOC_2|>": 100299, "<|LOC_3|>": 100300, "<|LOC_4|>": 100301, "<|LOC_5|>": 100302, "<|LOC_6|>": 100303, "<|LOC_7|>": 100304, "<|LOC_8|>": 100305, "<|LOC_9|>": 100306, "<|LOC_10|>": 100307, "<|LOC_11|>": 100308, "<|LOC_12|>": 100309, "<|LOC_13|>": 100310, "<|LOC_14|>": 100311, "<|LOC_15|>": 100312, "<|LOC_16|>": 100313, "<|LOC_17|>": 100314, "<|LOC_18|>": 100315, "<|LOC_19|>": 100316, "<|LOC_20|>": 100317, "<|LOC_21|>": 100318, "<|LOC_22|>": 100319, "<|LOC_23|>": 100320, "<|LOC_24|>": 100321, "<|LOC_25|>": 100322, "<|LOC_26|>": 100323, "<|LOC_27|>": 100324, "<|LOC_28|>": 100325, "<|LOC_29|>": 100326, "<|LOC_30|>": 100327, "<|LOC_31|>": 100328, "<|LOC_32|>": 100329, "<|LOC_33|>": 100330, "<|LOC_34|>": 100331, "<|LOC_35|>": 100332, "<|LOC_36|>": 100333, "<|LOC_37|>": 100334, "<|LOC_38|>": 100335, "<|LOC_39|>": 100336, "<|LOC_40|>": 100337, "<|LOC_41|>": 100338, "<|LOC_42|>": 100339, "<|LOC_43|>": 100340, "<|LOC_44|>": 100341, "<|LOC_45|>": 100342, "<|LOC_46|>": 100343, "<|LOC_47|>": 100344, "<|LOC_48|>": 100345, "<|LOC_49|>": 100346, "<|LOC_50|>": 100347, "<|LOC_51|>": 100348, "<|LOC_52|>": 100349, "<|LOC_53|>": 100350, "<|LOC_54|>": 100351, "<|LOC_55|>": 100352, "<|LOC_56|>": 100353, "<|LOC_57|>": 100354, "<|LOC_58|>": 100355, "<|LOC_59|>": 100356, "<|LOC_60|>": 100357, "<|LOC_61|>": 100358, "<|LOC_62|>": 100359, "<|LOC_63|>": 100360, "<|LOC_64|>": 100361, "<|LOC_65|>": 100362, "<|LOC_66|>": 100363, "<|LOC_67|>": 100364, "<|LOC_68|>": 100365, "<|LOC_69|>": 100366, "<|LOC_70|>": 100367, "<|LOC_71|>": 100368, "<|LOC_72|>": 100369, "<|LOC_73|>": 100370, "<|LOC_74|>": 100371, "<|LOC_75|>": 100372, "<|LOC_76|>": 100373, "<|LOC_77|>": 100374, "<|LOC_78|>": 100375, "<|LOC_79|>": 100376, "<|LOC_80|>": 100377, "<|LOC_81|>": 100378, "<|LOC_82|>": 100379, "<|LOC_83|>": 100380, "<|LOC_84|>": 100381, "<|LOC_85|>": 100382, "<|LOC_86|>": 100383, "<|LOC_87|>": 100384, "<|LOC_88|>": 100385, "<|LOC_89|>": 100386, "<|LOC_90|>": 100387, "<|LOC_91|>": 100388, "<|LOC_92|>": 100389, "<|LOC_93|>": 100390, "<|LOC_94|>": 100391, "<|LOC_95|>": 100392, "<|LOC_96|>": 100393, "<|LOC_97|>": 100394, "<|LOC_98|>": 100395, "<|LOC_99|>": 100396, "<|LOC_100|>": 100397, "<|LOC_101|>": 100398, "<|LOC_102|>": 100399, "<|LOC_103|>": 100400, "<|LOC_104|>": 100401, "<|LOC_105|>": 100402, "<|LOC_106|>": 100403, "<|LOC_107|>": 100404, "<|LOC_108|>": 100405, "<|LOC_109|>": 100406, "<|LOC_110|>": 100407, "<|LOC_111|>": 100408, "<|LOC_112|>": 100409, "<|LOC_113|>": 100410, "<|LOC_114|>": 100411, "<|LOC_115|>": 100412, "<|LOC_116|>": 100413, "<|LOC_117|>": 100414, "<|LOC_118|>": 100415, "<|LOC_119|>": 100416, "<|LOC_120|>": 100417, "<|LOC_121|>": 100418, "<|LOC_122|>": 100419, "<|LOC_123|>": 100420, "<|LOC_124|>": 100421, "<|LOC_125|>": 100422, "<|LOC_126|>": 100423, "<|LOC_127|>": 100424, "<|LOC_128|>": 100425, "<|LOC_129|>": 100426, "<|LOC_130|>": 100427, "<|LOC_131|>": 100428, "<|LOC_132|>": 100429, "<|LOC_133|>": 100430, "<|LOC_134|>": 100431, "<|LOC_135|>": 100432, "<|LOC_136|>": 100433, "<|LOC_137|>": 100434, "<|LOC_138|>": 100435, "<|LOC_139|>": 100436, "<|LOC_140|>": 100437, "<|LOC_141|>": 100438, "<|LOC_142|>": 100439, "<|LOC_143|>": 100440, "<|LOC_144|>": 100441, "<|LOC_145|>": 100442, "<|LOC_146|>": 100443, "<|LOC_147|>": 100444, "<|LOC_148|>": 100445, "<|LOC_149|>": 100446, "<|LOC_150|>": 100447, "<|LOC_151|>": 100448, "<|LOC_152|>": 100449, "<|LOC_153|>": 100450, "<|LOC_154|>": 100451, "<|LOC_155|>": 100452, "<|LOC_156|>": 100453, "<|LOC_157|>": 100454, "<|LOC_158|>": 100455, "<|LOC_159|>": 100456, "<|LOC_160|>": 100457, "<|LOC_161|>": 100458, "<|LOC_162|>": 100459, "<|LOC_163|>": 100460, "<|LOC_164|>": 100461, "<|LOC_165|>": 100462, "<|LOC_166|>": 100463, "<|LOC_167|>": 100464, "<|LOC_168|>": 100465, "<|LOC_169|>": 100466, "<|LOC_170|>": 100467, "<|LOC_171|>": 100468, "<|LOC_172|>": 100469, "<|LOC_173|>": 100470, "<|LOC_174|>": 100471, "<|LOC_175|>": 100472, "<|LOC_176|>": 100473, "<|LOC_177|>": 100474, "<|LOC_178|>": 100475, "<|LOC_179|>": 100476, "<|LOC_180|>": 100477, "<|LOC_181|>": 100478, "<|LOC_182|>": 100479, "<|LOC_183|>": 100480, "<|LOC_184|>": 100481, "<|LOC_185|>": 100482, "<|LOC_186|>": 100483, "<|LOC_187|>": 100484, "<|LOC_188|>": 100485, "<|LOC_189|>": 100486, "<|LOC_190|>": 100487, "<|LOC_191|>": 100488, "<|LOC_192|>": 100489, "<|LOC_193|>": 100490, "<|LOC_194|>": 100491, "<|LOC_195|>": 100492, "<|LOC_196|>": 100493, "<|LOC_197|>": 100494, "<|LOC_198|>": 100495, "<|LOC_199|>": 100496, "<|LOC_200|>": 100497, "<|LOC_201|>": 100498, "<|LOC_202|>": 100499, "<|LOC_203|>": 100500, "<|LOC_204|>": 100501, "<|LOC_205|>": 100502, "<|LOC_206|>": 100503, "<|LOC_207|>": 100504, "<|LOC_208|>": 100505, "<|LOC_209|>": 100506, "<|LOC_210|>": 100507, "<|LOC_211|>": 100508, "<|LOC_212|>": 100509, "<|LOC_213|>": 100510, "<|LOC_214|>": 100511, "<|LOC_215|>": 100512, "<|LOC_216|>": 100513, "<|LOC_217|>": 100514, "<|LOC_218|>": 100515, "<|LOC_219|>": 100516, "<|LOC_220|>": 100517, "<|LOC_221|>": 100518, "<|LOC_222|>": 100519, "<|LOC_223|>": 100520, "<|LOC_224|>": 100521, "<|LOC_225|>": 100522, "<|LOC_226|>": 100523, "<|LOC_227|>": 100524, "<|LOC_228|>": 100525, "<|LOC_229|>": 100526, "<|LOC_230|>": 100527, "<|LOC_231|>": 100528, "<|LOC_232|>": 100529, "<|LOC_233|>": 100530, "<|LOC_234|>": 100531, "<|LOC_235|>": 100532, "<|LOC_236|>": 100533, "<|LOC_237|>": 100534, "<|LOC_238|>": 100535, "<|LOC_239|>": 100536, "<|LOC_240|>": 100537, "<|LOC_241|>": 100538, "<|LOC_242|>": 100539, "<|LOC_243|>": 100540, "<|LOC_244|>": 100541, "<|LOC_245|>": 100542, "<|LOC_246|>": 100543, "<|LOC_247|>": 100544, "<|LOC_248|>": 100545, "<|LOC_249|>": 100546, "<|LOC_250|>": 100547, "<|LOC_251|>": 100548, "<|LOC_252|>": 100549, "<|LOC_253|>": 100550, "<|LOC_254|>": 100551, "<|LOC_255|>": 100552, "<|LOC_256|>": 100553, "<|LOC_257|>": 100554, "<|LOC_258|>": 100555, "<|LOC_259|>": 100556, "<|LOC_260|>": 100557, "<|LOC_261|>": 100558, "<|LOC_262|>": 100559, "<|LOC_263|>": 100560, "<|LOC_264|>": 100561, "<|LOC_265|>": 100562, "<|LOC_266|>": 100563, "<|LOC_267|>": 100564, "<|LOC_268|>": 100565, "<|LOC_269|>": 100566, "<|LOC_270|>": 100567, "<|LOC_271|>": 100568, "<|LOC_272|>": 100569, "<|LOC_273|>": 100570, "<|LOC_274|>": 100571, "<|LOC_275|>": 100572, "<|LOC_276|>": 100573, "<|LOC_277|>": 100574, "<|LOC_278|>": 100575, "<|LOC_279|>": 100576, "<|LOC_280|>": 100577, "<|LOC_281|>": 100578, "<|LOC_282|>": 100579, "<|LOC_283|>": 100580, "<|LOC_284|>": 100581, "<|LOC_285|>": 100582, "<|LOC_286|>": 100583, "<|LOC_287|>": 100584, "<|LOC_288|>": 100585, "<|LOC_289|>": 100586, "<|LOC_290|>": 100587, "<|LOC_291|>": 100588, "<|LOC_292|>": 100589, "<|LOC_293|>": 100590, "<|LOC_294|>": 100591, "<|LOC_295|>": 100592, "<|LOC_296|>": 100593, "<|LOC_297|>": 100594, "<|LOC_298|>": 100595, "<|LOC_299|>": 100596, "<|LOC_300|>": 100597, "<|LOC_301|>": 100598, "<|LOC_302|>": 100599, "<|LOC_303|>": 100600, "<|LOC_304|>": 100601, "<|LOC_305|>": 100602, "<|LOC_306|>": 100603, "<|LOC_307|>": 100604, "<|LOC_308|>": 100605, "<|LOC_309|>": 100606, "<|LOC_310|>": 100607, "<|LOC_311|>": 100608, "<|LOC_312|>": 100609, "<|LOC_313|>": 100610, "<|LOC_314|>": 100611, "<|LOC_315|>": 100612, "<|LOC_316|>": 100613, "<|LOC_317|>": 100614, "<|LOC_318|>": 100615, "<|LOC_319|>": 100616, "<|LOC_320|>": 100617, "<|LOC_321|>": 100618, "<|LOC_322|>": 100619, "<|LOC_323|>": 100620, "<|LOC_324|>": 100621, "<|LOC_325|>": 100622, "<|LOC_326|>": 100623, "<|LOC_327|>": 100624, "<|LOC_328|>": 100625, "<|LOC_329|>": 100626, "<|LOC_330|>": 100627, "<|LOC_331|>": 100628, "<|LOC_332|>": 100629, "<|LOC_333|>": 100630, "<|LOC_334|>": 100631, "<|LOC_335|>": 100632, "<|LOC_336|>": 100633, "<|LOC_337|>": 100634, "<|LOC_338|>": 100635, "<|LOC_339|>": 100636, "<|LOC_340|>": 100637, "<|LOC_341|>": 100638, "<|LOC_342|>": 100639, "<|LOC_343|>": 100640, "<|LOC_344|>": 100641, "<|LOC_345|>": 100642, "<|LOC_346|>": 100643, "<|LOC_347|>": 100644, "<|LOC_348|>": 100645, "<|LOC_349|>": 100646, "<|LOC_350|>": 100647, "<|LOC_351|>": 100648, "<|LOC_352|>": 100649, "<|LOC_353|>": 100650, "<|LOC_354|>": 100651, "<|LOC_355|>": 100652, "<|LOC_356|>": 100653, "<|LOC_357|>": 100654, "<|LOC_358|>": 100655, "<|LOC_359|>": 100656, "<|LOC_360|>": 100657, "<|LOC_361|>": 100658, "<|LOC_362|>": 100659, "<|LOC_363|>": 100660, "<|LOC_364|>": 100661, "<|LOC_365|>": 100662, "<|LOC_366|>": 100663, "<|LOC_367|>": 100664, "<|LOC_368|>": 100665, "<|LOC_369|>": 100666, "<|LOC_370|>": 100667, "<|LOC_371|>": 100668, "<|LOC_372|>": 100669, "<|LOC_373|>": 100670, "<|LOC_374|>": 100671, "<|LOC_375|>": 100672, "<|LOC_376|>": 100673, "<|LOC_377|>": 100674, "<|LOC_378|>": 100675, "<|LOC_379|>": 100676, "<|LOC_380|>": 100677, "<|LOC_381|>": 100678, "<|LOC_382|>": 100679, "<|LOC_383|>": 100680, "<|LOC_384|>": 100681, "<|LOC_385|>": 100682, "<|LOC_386|>": 100683, "<|LOC_387|>": 100684, "<|LOC_388|>": 100685, "<|LOC_389|>": 100686, "<|LOC_390|>": 100687, "<|LOC_391|>": 100688, "<|LOC_392|>": 100689, "<|LOC_393|>": 100690, "<|LOC_394|>": 100691, "<|LOC_395|>": 100692, "<|LOC_396|>": 100693, "<|LOC_397|>": 100694, "<|LOC_398|>": 100695, "<|LOC_399|>": 100696, "<|LOC_400|>": 100697, "<|LOC_401|>": 100698, "<|LOC_402|>": 100699, "<|LOC_403|>": 100700, "<|LOC_404|>": 100701, "<|LOC_405|>": 100702, "<|LOC_406|>": 100703, "<|LOC_407|>": 100704, "<|LOC_408|>": 100705, "<|LOC_409|>": 100706, "<|LOC_410|>": 100707, "<|LOC_411|>": 100708, "<|LOC_412|>": 100709, "<|LOC_413|>": 100710, "<|LOC_414|>": 100711, "<|LOC_415|>": 100712, "<|LOC_416|>": 100713, "<|LOC_417|>": 100714, "<|LOC_418|>": 100715, "<|LOC_419|>": 100716, "<|LOC_420|>": 100717, "<|LOC_421|>": 100718, "<|LOC_422|>": 100719, "<|LOC_423|>": 100720, "<|LOC_424|>": 100721, "<|LOC_425|>": 100722, "<|LOC_426|>": 100723, "<|LOC_427|>": 100724, "<|LOC_428|>": 100725, "<|LOC_429|>": 100726, "<|LOC_430|>": 100727, "<|LOC_431|>": 100728, "<|LOC_432|>": 100729, "<|LOC_433|>": 100730, "<|LOC_434|>": 100731, "<|LOC_435|>": 100732, "<|LOC_436|>": 100733, "<|LOC_437|>": 100734, "<|LOC_438|>": 100735, "<|LOC_439|>": 100736, "<|LOC_440|>": 100737, "<|LOC_441|>": 100738, "<|LOC_442|>": 100739, "<|LOC_443|>": 100740, "<|LOC_444|>": 100741, "<|LOC_445|>": 100742, "<|LOC_446|>": 100743, "<|LOC_447|>": 100744, "<|LOC_448|>": 100745, "<|LOC_449|>": 100746, "<|LOC_450|>": 100747, "<|LOC_451|>": 100748, "<|LOC_452|>": 100749, "<|LOC_453|>": 100750, "<|LOC_454|>": 100751, "<|LOC_455|>": 100752, "<|LOC_456|>": 100753, "<|LOC_457|>": 100754, "<|LOC_458|>": 100755, "<|LOC_459|>": 100756, "<|LOC_460|>": 100757, "<|LOC_461|>": 100758, "<|LOC_462|>": 100759, "<|LOC_463|>": 100760, "<|LOC_464|>": 100761, "<|LOC_465|>": 100762, "<|LOC_466|>": 100763, "<|LOC_467|>": 100764, "<|LOC_468|>": 100765, "<|LOC_469|>": 100766, "<|LOC_470|>": 100767, "<|LOC_471|>": 100768, "<|LOC_472|>": 100769, "<|LOC_473|>": 100770, "<|LOC_474|>": 100771, "<|LOC_475|>": 100772, "<|LOC_476|>": 100773, "<|LOC_477|>": 100774, "<|LOC_478|>": 100775, "<|LOC_479|>": 100776, "<|LOC_480|>": 100777, "<|LOC_481|>": 100778, "<|LOC_482|>": 100779, "<|LOC_483|>": 100780, "<|LOC_484|>": 100781, "<|LOC_485|>": 100782, "<|LOC_486|>": 100783, "<|LOC_487|>": 100784, "<|LOC_488|>": 100785, "<|LOC_489|>": 100786, "<|LOC_490|>": 100787, "<|LOC_491|>": 100788, "<|LOC_492|>": 100789, "<|LOC_493|>": 100790, "<|LOC_494|>": 100791, "<|LOC_495|>": 100792, "<|LOC_496|>": 100793, "<|LOC_497|>": 100794, "<|LOC_498|>": 100795, "<|LOC_499|>": 100796, "<|LOC_500|>": 100797, "<|LOC_501|>": 100798, "<|LOC_502|>": 100799, "<|LOC_503|>": 100800, "<|LOC_504|>": 100801, "<|LOC_505|>": 100802, "<|LOC_506|>": 100803, "<|LOC_507|>": 100804, "<|LOC_508|>": 100805, "<|LOC_509|>": 100806, "<|LOC_510|>": 100807, "<|LOC_511|>": 100808, "<|LOC_512|>": 100809, "<|LOC_513|>": 100810, "<|LOC_514|>": 100811, "<|LOC_515|>": 100812, "<|LOC_516|>": 100813, "<|LOC_517|>": 100814, "<|LOC_518|>": 100815, "<|LOC_519|>": 100816, "<|LOC_520|>": 100817, "<|LOC_521|>": 100818, "<|LOC_522|>": 100819, "<|LOC_523|>": 100820, "<|LOC_524|>": 100821, "<|LOC_525|>": 100822, "<|LOC_526|>": 100823, "<|LOC_527|>": 100824, "<|LOC_528|>": 100825, "<|LOC_529|>": 100826, "<|LOC_530|>": 100827, "<|LOC_531|>": 100828, "<|LOC_532|>": 100829, "<|LOC_533|>": 100830, "<|LOC_534|>": 100831, "<|LOC_535|>": 100832, "<|LOC_536|>": 100833, "<|LOC_537|>": 100834, "<|LOC_538|>": 100835, "<|LOC_539|>": 100836, "<|LOC_540|>": 100837, "<|LOC_541|>": 100838, "<|LOC_542|>": 100839, "<|LOC_543|>": 100840, "<|LOC_544|>": 100841, "<|LOC_545|>": 100842, "<|LOC_546|>": 100843, "<|LOC_547|>": 100844, "<|LOC_548|>": 100845, "<|LOC_549|>": 100846, "<|LOC_550|>": 100847, "<|LOC_551|>": 100848, "<|LOC_552|>": 100849, "<|LOC_553|>": 100850, "<|LOC_554|>": 100851, "<|LOC_555|>": 100852, "<|LOC_556|>": 100853, "<|LOC_557|>": 100854, "<|LOC_558|>": 100855, "<|LOC_559|>": 100856, "<|LOC_560|>": 100857, "<|LOC_561|>": 100858, "<|LOC_562|>": 100859, "<|LOC_563|>": 100860, "<|LOC_564|>": 100861, "<|LOC_565|>": 100862, "<|LOC_566|>": 100863, "<|LOC_567|>": 100864, "<|LOC_568|>": 100865, "<|LOC_569|>": 100866, "<|LOC_570|>": 100867, "<|LOC_571|>": 100868, "<|LOC_572|>": 100869, "<|LOC_573|>": 100870, "<|LOC_574|>": 100871, "<|LOC_575|>": 100872, "<|LOC_576|>": 100873, "<|LOC_577|>": 100874, "<|LOC_578|>": 100875, "<|LOC_579|>": 100876, "<|LOC_580|>": 100877, "<|LOC_581|>": 100878, "<|LOC_582|>": 100879, "<|LOC_583|>": 100880, "<|LOC_584|>": 100881, "<|LOC_585|>": 100882, "<|LOC_586|>": 100883, "<|LOC_587|>": 100884, "<|LOC_588|>": 100885, "<|LOC_589|>": 100886, "<|LOC_590|>": 100887, "<|LOC_591|>": 100888, "<|LOC_592|>": 100889, "<|LOC_593|>": 100890, "<|LOC_594|>": 100891, "<|LOC_595|>": 100892, "<|LOC_596|>": 100893, "<|LOC_597|>": 100894, "<|LOC_598|>": 100895, "<|LOC_599|>": 100896, "<|LOC_600|>": 100897, "<|LOC_601|>": 100898, "<|LOC_602|>": 100899, "<|LOC_603|>": 100900, "<|LOC_604|>": 100901, "<|LOC_605|>": 100902, "<|LOC_606|>": 100903, "<|LOC_607|>": 100904, "<|LOC_608|>": 100905, "<|LOC_609|>": 100906, "<|LOC_610|>": 100907, "<|LOC_611|>": 100908, "<|LOC_612|>": 100909, "<|LOC_613|>": 100910, "<|LOC_614|>": 100911, "<|LOC_615|>": 100912, "<|LOC_616|>": 100913, "<|LOC_617|>": 100914, "<|LOC_618|>": 100915, "<|LOC_619|>": 100916, "<|LOC_620|>": 100917, "<|LOC_621|>": 100918, "<|LOC_622|>": 100919, "<|LOC_623|>": 100920, "<|LOC_624|>": 100921, "<|LOC_625|>": 100922, "<|LOC_626|>": 100923, "<|LOC_627|>": 100924, "<|LOC_628|>": 100925, "<|LOC_629|>": 100926, "<|LOC_630|>": 100927, "<|LOC_631|>": 100928, "<|LOC_632|>": 100929, "<|LOC_633|>": 100930, "<|LOC_634|>": 100931, "<|LOC_635|>": 100932, "<|LOC_636|>": 100933, "<|LOC_637|>": 100934, "<|LOC_638|>": 100935, "<|LOC_639|>": 100936, "<|LOC_640|>": 100937, "<|LOC_641|>": 100938, "<|LOC_642|>": 100939, "<|LOC_643|>": 100940, "<|LOC_644|>": 100941, "<|LOC_645|>": 100942, "<|LOC_646|>": 100943, "<|LOC_647|>": 100944, "<|LOC_648|>": 100945, "<|LOC_649|>": 100946, "<|LOC_650|>": 100947, "<|LOC_651|>": 100948, "<|LOC_652|>": 100949, "<|LOC_653|>": 100950, "<|LOC_654|>": 100951, "<|LOC_655|>": 100952, "<|LOC_656|>": 100953, "<|LOC_657|>": 100954, "<|LOC_658|>": 100955, "<|LOC_659|>": 100956, "<|LOC_660|>": 100957, "<|LOC_661|>": 100958, "<|LOC_662|>": 100959, "<|LOC_663|>": 100960, "<|LOC_664|>": 100961, "<|LOC_665|>": 100962, "<|LOC_666|>": 100963, "<|LOC_667|>": 100964, "<|LOC_668|>": 100965, "<|LOC_669|>": 100966, "<|LOC_670|>": 100967, "<|LOC_671|>": 100968, "<|LOC_672|>": 100969, "<|LOC_673|>": 100970, "<|LOC_674|>": 100971, "<|LOC_675|>": 100972, "<|LOC_676|>": 100973, "<|LOC_677|>": 100974, "<|LOC_678|>": 100975, "<|LOC_679|>": 100976, "<|LOC_680|>": 100977, "<|LOC_681|>": 100978, "<|LOC_682|>": 100979, "<|LOC_683|>": 100980, "<|LOC_684|>": 100981, "<|LOC_685|>": 100982, "<|LOC_686|>": 100983, "<|LOC_687|>": 100984, "<|LOC_688|>": 100985, "<|LOC_689|>": 100986, "<|LOC_690|>": 100987, "<|LOC_691|>": 100988, "<|LOC_692|>": 100989, "<|LOC_693|>": 100990, "<|LOC_694|>": 100991, "<|LOC_695|>": 100992, "<|LOC_696|>": 100993, "<|LOC_697|>": 100994, "<|LOC_698|>": 100995, "<|LOC_699|>": 100996, "<|LOC_700|>": 100997, "<|LOC_701|>": 100998, "<|LOC_702|>": 100999, "<|LOC_703|>": 101000, "<|LOC_704|>": 101001, "<|LOC_705|>": 101002, "<|LOC_706|>": 101003, "<|LOC_707|>": 101004, "<|LOC_708|>": 101005, "<|LOC_709|>": 101006, "<|LOC_710|>": 101007, "<|LOC_711|>": 101008, "<|LOC_712|>": 101009, "<|LOC_713|>": 101010, "<|LOC_714|>": 101011, "<|LOC_715|>": 101012, "<|LOC_716|>": 101013, "<|LOC_717|>": 101014, "<|LOC_718|>": 101015, "<|LOC_719|>": 101016, "<|LOC_720|>": 101017, "<|LOC_721|>": 101018, "<|LOC_722|>": 101019, "<|LOC_723|>": 101020, "<|LOC_724|>": 101021, "<|LOC_725|>": 101022, "<|LOC_726|>": 101023, "<|LOC_727|>": 101024, "<|LOC_728|>": 101025, "<|LOC_729|>": 101026, "<|LOC_730|>": 101027, "<|LOC_731|>": 101028, "<|LOC_732|>": 101029, "<|LOC_733|>": 101030, "<|LOC_734|>": 101031, "<|LOC_735|>": 101032, "<|LOC_736|>": 101033, "<|LOC_737|>": 101034, "<|LOC_738|>": 101035, "<|LOC_739|>": 101036, "<|LOC_740|>": 101037, "<|LOC_741|>": 101038, "<|LOC_742|>": 101039, "<|LOC_743|>": 101040, "<|LOC_744|>": 101041, "<|LOC_745|>": 101042, "<|LOC_746|>": 101043, "<|LOC_747|>": 101044, "<|LOC_748|>": 101045, "<|LOC_749|>": 101046, "<|LOC_750|>": 101047, "<|LOC_751|>": 101048, "<|LOC_752|>": 101049, "<|LOC_753|>": 101050, "<|LOC_754|>": 101051, "<|LOC_755|>": 101052, "<|LOC_756|>": 101053, "<|LOC_757|>": 101054, "<|LOC_758|>": 101055, "<|LOC_759|>": 101056, "<|LOC_760|>": 101057, "<|LOC_761|>": 101058, "<|LOC_762|>": 101059, "<|LOC_763|>": 101060, "<|LOC_764|>": 101061, "<|LOC_765|>": 101062, "<|LOC_766|>": 101063, "<|LOC_767|>": 101064, "<|LOC_768|>": 101065, "<|LOC_769|>": 101066, "<|LOC_770|>": 101067, "<|LOC_771|>": 101068, "<|LOC_772|>": 101069, "<|LOC_773|>": 101070, "<|LOC_774|>": 101071, "<|LOC_775|>": 101072, "<|LOC_776|>": 101073, "<|LOC_777|>": 101074, "<|LOC_778|>": 101075, "<|LOC_779|>": 101076, "<|LOC_780|>": 101077, "<|LOC_781|>": 101078, "<|LOC_782|>": 101079, "<|LOC_783|>": 101080, "<|LOC_784|>": 101081, "<|LOC_785|>": 101082, "<|LOC_786|>": 101083, "<|LOC_787|>": 101084, "<|LOC_788|>": 101085, "<|LOC_789|>": 101086, "<|LOC_790|>": 101087, "<|LOC_791|>": 101088, "<|LOC_792|>": 101089, "<|LOC_793|>": 101090, "<|LOC_794|>": 101091, "<|LOC_795|>": 101092, "<|LOC_796|>": 101093, "<|LOC_797|>": 101094, "<|LOC_798|>": 101095, "<|LOC_799|>": 101096, "<|LOC_800|>": 101097, "<|LOC_801|>": 101098, "<|LOC_802|>": 101099, "<|LOC_803|>": 101100, "<|LOC_804|>": 101101, "<|LOC_805|>": 101102, "<|LOC_806|>": 101103, "<|LOC_807|>": 101104, "<|LOC_808|>": 101105, "<|LOC_809|>": 101106, "<|LOC_810|>": 101107, "<|LOC_811|>": 101108, "<|LOC_812|>": 101109, "<|LOC_813|>": 101110, "<|LOC_814|>": 101111, "<|LOC_815|>": 101112, "<|LOC_816|>": 101113, "<|LOC_817|>": 101114, "<|LOC_818|>": 101115, "<|LOC_819|>": 101116, "<|LOC_820|>": 101117, "<|LOC_821|>": 101118, "<|LOC_822|>": 101119, "<|LOC_823|>": 101120, "<|LOC_824|>": 101121, "<|LOC_825|>": 101122, "<|LOC_826|>": 101123, "<|LOC_827|>": 101124, "<|LOC_828|>": 101125, "<|LOC_829|>": 101126, "<|LOC_830|>": 101127, "<|LOC_831|>": 101128, "<|LOC_832|>": 101129, "<|LOC_833|>": 101130, "<|LOC_834|>": 101131, "<|LOC_835|>": 101132, "<|LOC_836|>": 101133, "<|LOC_837|>": 101134, "<|LOC_838|>": 101135, "<|LOC_839|>": 101136, "<|LOC_840|>": 101137, "<|LOC_841|>": 101138, "<|LOC_842|>": 101139, "<|LOC_843|>": 101140, "<|LOC_844|>": 101141, "<|LOC_845|>": 101142, "<|LOC_846|>": 101143, "<|LOC_847|>": 101144, "<|LOC_848|>": 101145, "<|LOC_849|>": 101146, "<|LOC_850|>": 101147, "<|LOC_851|>": 101148, "<|LOC_852|>": 101149, "<|LOC_853|>": 101150, "<|LOC_854|>": 101151, "<|LOC_855|>": 101152, "<|LOC_856|>": 101153, "<|LOC_857|>": 101154, "<|LOC_858|>": 101155, "<|LOC_859|>": 101156, "<|LOC_860|>": 101157, "<|LOC_861|>": 101158, "<|LOC_862|>": 101159, "<|LOC_863|>": 101160, "<|LOC_864|>": 101161, "<|LOC_865|>": 101162, "<|LOC_866|>": 101163, "<|LOC_867|>": 101164, "<|LOC_868|>": 101165, "<|LOC_869|>": 101166, "<|LOC_870|>": 101167, "<|LOC_871|>": 101168, "<|LOC_872|>": 101169, "<|LOC_873|>": 101170, "<|LOC_874|>": 101171, "<|LOC_875|>": 101172, "<|LOC_876|>": 101173, "<|LOC_877|>": 101174, "<|LOC_878|>": 101175, "<|LOC_879|>": 101176, "<|LOC_880|>": 101177, "<|LOC_881|>": 101178, "<|LOC_882|>": 101179, "<|LOC_883|>": 101180, "<|LOC_884|>": 101181, "<|LOC_885|>": 101182, "<|LOC_886|>": 101183, "<|LOC_887|>": 101184, "<|LOC_888|>": 101185, "<|LOC_889|>": 101186, "<|LOC_890|>": 101187, "<|LOC_891|>": 101188, "<|LOC_892|>": 101189, "<|LOC_893|>": 101190, "<|LOC_894|>": 101191, "<|LOC_895|>": 101192, "<|LOC_896|>": 101193, "<|LOC_897|>": 101194, "<|LOC_898|>": 101195, "<|LOC_899|>": 101196, "<|LOC_900|>": 101197, "<|LOC_901|>": 101198, "<|LOC_902|>": 101199, "<|LOC_903|>": 101200, "<|LOC_904|>": 101201, "<|LOC_905|>": 101202, "<|LOC_906|>": 101203, "<|LOC_907|>": 101204, "<|LOC_908|>": 101205, "<|LOC_909|>": 101206, "<|LOC_910|>": 101207, "<|LOC_911|>": 101208, "<|LOC_912|>": 101209, "<|LOC_913|>": 101210, "<|LOC_914|>": 101211, "<|LOC_915|>": 101212, "<|LOC_916|>": 101213, "<|LOC_917|>": 101214, "<|LOC_918|>": 101215, "<|LOC_919|>": 101216, "<|LOC_920|>": 101217, "<|LOC_921|>": 101218, "<|LOC_922|>": 101219, "<|LOC_923|>": 101220, "<|LOC_924|>": 101221, "<|LOC_925|>": 101222, "<|LOC_926|>": 101223, "<|LOC_927|>": 101224, "<|LOC_928|>": 101225, "<|LOC_929|>": 101226, "<|LOC_930|>": 101227, "<|LOC_931|>": 101228, "<|LOC_932|>": 101229, "<|LOC_933|>": 101230, "<|LOC_934|>": 101231, "<|LOC_935|>": 101232, "<|LOC_936|>": 101233, "<|LOC_937|>": 101234, "<|LOC_938|>": 101235, "<|LOC_939|>": 101236, "<|LOC_940|>": 101237, "<|LOC_941|>": 101238, "<|LOC_942|>": 101239, "<|LOC_943|>": 101240, "<|LOC_944|>": 101241, "<|LOC_945|>": 101242, "<|LOC_946|>": 101243, "<|LOC_947|>": 101244, "<|LOC_948|>": 101245, "<|LOC_949|>": 101246, "<|LOC_950|>": 101247, "<|LOC_951|>": 101248, "<|LOC_952|>": 101249, "<|LOC_953|>": 101250, "<|LOC_954|>": 101251, "<|LOC_955|>": 101252, "<|LOC_956|>": 101253, "<|LOC_957|>": 101254, "<|LOC_958|>": 101255, "<|LOC_959|>": 101256, "<|LOC_960|>": 101257, "<|LOC_961|>": 101258, "<|LOC_962|>": 101259, "<|LOC_963|>": 101260, "<|LOC_964|>": 101261, "<|LOC_965|>": 101262, "<|LOC_966|>": 101263, "<|LOC_967|>": 101264, "<|LOC_968|>": 101265, "<|LOC_969|>": 101266, "<|LOC_970|>": 101267, "<|LOC_971|>": 101268, "<|LOC_972|>": 101269, "<|LOC_973|>": 101270, "<|LOC_974|>": 101271, "<|LOC_975|>": 101272, "<|LOC_976|>": 101273, "<|LOC_977|>": 101274, "<|LOC_978|>": 101275, "<|LOC_979|>": 101276, "<|LOC_980|>": 101277, "<|LOC_981|>": 101278, "<|LOC_982|>": 101279, "<|LOC_983|>": 101280, "<|LOC_984|>": 101281, "<|LOC_985|>": 101282, "<|LOC_986|>": 101283, "<|LOC_987|>": 101284, "<|LOC_988|>": 101285, "<|LOC_989|>": 101286, "<|LOC_990|>": 101287, "<|LOC_991|>": 101288, "<|LOC_992|>": 101289, "<|LOC_993|>": 101290, "<|LOC_994|>": 101291, "<|LOC_995|>": 101292, "<|LOC_996|>": 101293, "<|LOC_997|>": 101294, "<|LOC_998|>": 101295, "<|LOC_999|>": 101296, "<|LOC_1000|>": 101297, "<|LOC_BEGIN|>": 101298, "<|LOC_END|>": 101299, "<|LOC_SEP|>": 101300, "<|CROP_COL_SEP|>": 101301, "<|CROP_ROW_SEP|>": 101302, "<|IMAGE_SEP|>": 101303}
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Ernie4_5_ForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_ernie4_5.Ernie4_5_Config",
7
+ "AutoModel": "modeling_ernie4_5.Ernie4_5_Model",
8
+ "AutoModelForCausalLM": "modeling_ernie4_5.Ernie4_5_ForCausalLM"
9
+ },
10
+ "bos_token_id": 1,
11
+ "eos_token_id": 2,
12
+ "hidden_act": "silu",
13
+ "hidden_size": 1024,
14
+ "intermediate_size": 3072,
15
+ "max_position_embeddings": 131072,
16
+ "model_type": "ernie4_5",
17
+ "num_attention_heads": 16,
18
+ "num_key_value_heads": 2,
19
+ "head_dim": 128,
20
+ "num_hidden_layers": 18,
21
+ "pad_token_id": 0,
22
+ "rms_norm_eps": 1e-05,
23
+ "use_cache": false,
24
+ "vocab_size": 103424,
25
+ "rope_theta": 500000,
26
+ "use_bias": false,
27
+ "tie_word_embeddings": true,
28
+ "torch_dtype": "bfloat16"
29
+ }
configuration_ernie4_5.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from transformers import PretrainedConfig
16
+
17
+
18
+ class Ernie4_5_Config(PretrainedConfig):
19
+ """
20
+ Configuration class.
21
+
22
+ This class stores the configuration of an Ernie model, defining the model architecture.
23
+ It inherits from PretrainedConfig and can be used to control model outputs.
24
+ """
25
+
26
+ model_type = "ernie4_5"
27
+ keys_to_ignore_at_inference = ["past_key_values"]
28
+
29
+ # Default tensor parallel plan for base model `Qwen3`
30
+ base_model_tp_plan = {
31
+ "layers.*.self_attn.q_proj": "colwise",
32
+ "layers.*.self_attn.k_proj": "colwise",
33
+ "layers.*.self_attn.v_proj": "colwise",
34
+ "layers.*.self_attn.o_proj": "rowwise",
35
+ "layers.*.mlp.gate_proj": "colwise",
36
+ "layers.*.mlp.up_proj": "colwise",
37
+ "layers.*.mlp.down_proj": "rowwise",
38
+ }
39
+ base_model_pp_plan = {
40
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
41
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
42
+ "norm": (["hidden_states"], ["hidden_states"]),
43
+ }
44
+
45
+ def __init__(
46
+ self,
47
+ vocab_size=32000,
48
+ hidden_size=768,
49
+ intermediate_size=11008,
50
+ max_position_embeddings=32768,
51
+ num_hidden_layers=2,
52
+ num_attention_heads=2,
53
+ rms_norm_eps=1e-6,
54
+ use_cache=False,
55
+ use_flash_attention=False,
56
+ pad_token_id=0,
57
+ bos_token_id=1,
58
+ eos_token_id=2,
59
+ use_bias=False,
60
+ rope_theta=10000,
61
+ weight_share_add_bias=True,
62
+ ignored_index=-100,
63
+ attention_probs_dropout_prob=0.0,
64
+ hidden_dropout_prob=0.0,
65
+ compression_ratio: float = 1.0,
66
+ num_key_value_heads=None,
67
+ max_sequence_length=None,
68
+ **kwargs,
69
+ ):
70
+ """
71
+ Initialize configuration with default or specified parameters.
72
+
73
+ Args:
74
+ vocab_size (int): Size of the vocabulary (number of unique tokens)
75
+ hidden_size (int): Dimensionality of the encoder layers and the pooler layer
76
+ intermediate_size (int): Dimensionality of the "intermediate" (feed-forward) layer
77
+ max_position_embeddings (int): Maximum sequence length the model can handle
78
+ num_hidden_layers (int): Number of hidden layers in the Transformer encoder
79
+ num_attention_heads (int): Number of attention heads for each attention layer
80
+ rms_norm_eps (float): The epsilon used by the RMS normalization layers
81
+ use_cache (bool): Whether to use caching for faster generation (decoding)
82
+ use_flash_attention (bool): Whether to use FlashAttention for optimized attention computation
83
+ pad_token_id (int): Token ID used for padding sequences
84
+ bos_token_id (int): Token ID used for beginning-of-sequence
85
+ eos_token_id (int): Token ID used for end-of-sequence
86
+ use_bias (bool): Whether to use bias terms in linear layers
87
+ rope_theta (float): The base period of the RoPE embeddings
88
+ weight_share_add_bias (bool): Whether to share bias weights in certain layers
89
+ ignored_index (int): Target value that is ignored during loss computation
90
+ attention_probs_dropout_prob (float): Dropout probability for attention weights
91
+ hidden_dropout_prob (float): Dropout probability for hidden layers
92
+ compression_ratio (float): Ratio for KV cache compression (1.0 = no compression)
93
+ num_key_value_heads (int): Number of key/value heads (for Grouped Query Attention)
94
+ max_sequence_length (int): Maximum sequence length for positional embeddings
95
+ **kwargs: Additional keyword arguments passed to parent class
96
+ """
97
+
98
+ # Set default for tied embeddings if not specified.
99
+ if "tie_word_embeddings" not in kwargs:
100
+ kwargs["tie_word_embeddings"] = False
101
+ super().__init__(
102
+ pad_token_id=pad_token_id,
103
+ bos_token_id=bos_token_id,
104
+ eos_token_id=eos_token_id,
105
+ **kwargs,
106
+ )
107
+ self.vocab_size = vocab_size
108
+ self.hidden_size = hidden_size
109
+ self.intermediate_size = intermediate_size
110
+ self.max_position_embeddings = max_position_embeddings
111
+ self.num_hidden_layers = num_hidden_layers
112
+ self.num_attention_heads = num_attention_heads
113
+ self.rms_norm_eps = rms_norm_eps
114
+ self.use_cache = use_cache
115
+ self.use_flash_attention = use_flash_attention
116
+ self.pad_token_id = pad_token_id
117
+ self.bos_token_id = bos_token_id
118
+ self.eos_token_id = eos_token_id
119
+ self.use_bias = use_bias
120
+ self.weight_share_add_bias = weight_share_add_bias
121
+ self.rope_theta = rope_theta
122
+ self.ignored_index = ignored_index
123
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
124
+ self.hidden_dropout_prob = hidden_dropout_prob
125
+ self.compression_ratio = compression_ratio
126
+ self.num_key_value_heads = num_key_value_heads
127
+ self.max_sequence_length = max_sequence_length
generation_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_sample": true,
3
+ "top_p": 0.8,
4
+ "temperature": 0.8,
5
+ "bos_token_id": 1,
6
+ "eos_token_id": 2,
7
+ "pad_token_id": 0,
8
+ "repetition_penalty": 1.0,
9
+ "frequency_penalty": 0.0,
10
+ "presence_penalty": 0.0
11
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2fa450d7550b5600b030f09b9a954ffad3ef31476e44fb87e2154e9ac23b51d3
3
+ size 721514672
modeling_ernie4_5.py ADDED
@@ -0,0 +1,1068 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ from torch.nn.attention import SDPBackend, sdpa_kernel
21
+
22
+ from transformers.activations import ACT2FN
23
+ from transformers.modeling_utils import PreTrainedModel
24
+ from transformers.generation import GenerationMixin
25
+ from transformers.modeling_outputs import (
26
+ BaseModelOutputWithPast,
27
+ CausalLMOutputWithPast,
28
+ )
29
+ from transformers.utils import logging
30
+
31
+ from .configuration_ernie4_5 import Ernie4_5_Config
32
+
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+
37
+ class Ernie4_5_RMSNorm(nn.Module):
38
+ """
39
+ Root Mean Square Layer Normalization (Ernie4_5_RMSNorm) implementation.
40
+
41
+ Ernie4_5_RMSNorm is a simplified version of LayerNorm that focuses on the root mean square of inputs,
42
+ omitting the mean-centering operation. This provides computational efficiency while maintaining
43
+ good performance.
44
+ """
45
+
46
+ def __init__(self, config):
47
+ """
48
+ Initialize Ernie4_5_RMSNorm layer.
49
+
50
+ Args:
51
+ config: Model configuration.
52
+ """
53
+ super().__init__()
54
+ self.hidden_size = config.hidden_size
55
+ self.weight = nn.Parameter(
56
+ torch.ones(self.hidden_size, dtype=torch.get_default_dtype())
57
+ )
58
+ self.variance_epsilon = config.rms_norm_eps
59
+
60
+ def forward(self, hidden_states):
61
+ """
62
+ Apply RMS normalization to input hidden states.
63
+
64
+ Args:
65
+ hidden_states (Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
66
+
67
+ Returns:
68
+ Tensor: Normalized output tensor of same shape as input
69
+
70
+ Note:
71
+ - computes Ernie4_5_RMSNorm manually:
72
+ 1. Compute variance of features
73
+ 2. Apply reciprocal square root normalization
74
+ 3. Scale by learned weight parameter
75
+ - Maintains original dtype for numerical stability during computation
76
+ """
77
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
78
+ hidden_states = torch.rsqrt(variance + self.variance_epsilon) * hidden_states
79
+ return hidden_states.to(self.weight.dtype) * self.weight
80
+
81
+
82
+ class Ernie4_5_RopeEmbedding(nn.Module):
83
+ """
84
+ Rotary Position Embedding (RoPE) implementation for transformer models.
85
+
86
+ RoPE encodes absolute positional information with rotation matrices and
87
+ naturally incorporates relative position information in self-attention.
88
+
89
+ Args:
90
+ head_dim (int): Dimension size of each attention head
91
+ compression_ratio (float, optional): Sequence length compression ratio. Defaults to 1.0.
92
+ base (int, optional): Base value for frequency calculation. Defaults to 10000.
93
+
94
+ Attributes:
95
+ head_dim (int): Dimension size of each attention head
96
+ compression_ratio (float): Sequence length compression factor
97
+ base (int): Base value for frequency calculation
98
+ """
99
+
100
+ def __init__(self, head_dim, compression_ratio=1.0, base=10000):
101
+ """
102
+ Initialize RoPE embedding layer.
103
+
104
+ Args:
105
+ head_dim: Dimension of each attention head
106
+ compression_ratio: Scaling factor for position indices
107
+ base: Base value for frequency calculation
108
+ """
109
+ super().__init__()
110
+ self.head_dim = head_dim
111
+ self.compression_ratio = compression_ratio
112
+ self.base = base
113
+
114
+ def forward(self, seq_length, position_ids=None):
115
+ """
116
+ Compute rotary position embeddings for given sequence length.
117
+
118
+ Args:
119
+ seq_length (int): Maximum sequence length
120
+ position_ids (Tensor, optional): Custom position indices. Defaults to None.
121
+
122
+ Returns:
123
+ Tensor: Rotary position embeddings of shape [1, 1, seq_length, head_dim]
124
+ """
125
+ indices = torch.arange(0, self.head_dim, 2, dtype=torch.float32)
126
+ indices = 1 / self.base ** (indices / self.head_dim)
127
+ if position_ids is None:
128
+ position_ids = torch.arange(
129
+ 0, seq_length, 1, dtype=torch.float32
130
+ ).unsqueeze(1)
131
+ position_ids = position_ids / self.compression_ratio
132
+ sinusoid_inp = position_ids * indices.unsqueeze(0)
133
+ else:
134
+ position_ids = position_ids / self.compression_ratio
135
+ seq_length = position_ids.shape[-1]
136
+ sinusoid_inp = position_ids.unsqueeze(-1).to(
137
+ torch.float32
138
+ ) * indices.unsqueeze(0)
139
+ pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1)
140
+ pos_emb = pos_emb.view(-1, 1, seq_length, self.head_dim)
141
+ pos_emb = pos_emb.detach()
142
+ return pos_emb
143
+
144
+ def apply_rotary(self, rp, q, k):
145
+ """
146
+ Apply rotary position embeddings to queries and keys.
147
+
148
+ Args:
149
+ rp (Tensor): Rotary position embeddings
150
+ q (Tensor): Query tensor [batch, heads, seq_len, dim]
151
+ k (Tensor): Key tensor [batch, heads, seq_len, dim]
152
+
153
+ Returns:
154
+ Tuple[Tensor, Tensor]: Rotated queries and keys
155
+ """
156
+ sin, cos = torch.chunk(rp.to(q.device), 2, dim=-1)
157
+ # sin [θ0,θ1,θ2......θd/2-1] -> sin_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
158
+ sin_pos = torch.stack([sin, sin], dim=-1).reshape(rp.shape)
159
+ # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
160
+ cos_pos = torch.stack([cos, cos], dim=-1).reshape(rp.shape)
161
+ # rotate_half_query_layer [-q1,q0,-q3,q2......,-qd-1,qd-2]
162
+ rotate_half_q = torch.stack(
163
+ [-q[:, :, :, 1::2], q[:, :, :, 0::2]], dim=-1
164
+ ).reshape(q.shape)
165
+ query = (q.to(torch.float32) * cos_pos) + (
166
+ rotate_half_q.to(torch.float32) * sin_pos
167
+ )
168
+ # rotate_half_key_layer [-k1,k0,-k3,k2......,-kd-1,kd-2]
169
+ rotate_half_k = torch.stack(
170
+ [-k[:, :, :, 1::2], k[:, :, :, 0::2]], dim=-1
171
+ ).reshape(k.shape)
172
+ key = (k.to(torch.float32) * cos_pos) + (
173
+ rotate_half_k.to(torch.float32) * sin_pos
174
+ )
175
+ return query, key
176
+
177
+
178
+ class Ernie4_5_FusedDropoutImpl(nn.Module):
179
+ """
180
+ Fused dropout implementation with residual connection support.
181
+
182
+ This layer combines dropout and residual addition in a single operation for better performance,
183
+ particularly on GPU devices. The dropout is conditionally applied based on the probability.
184
+
185
+ Args:
186
+ prob (float): Dropout probability (between 0 and 1)
187
+
188
+ Attributes:
189
+ prob (float): Stores the dropout probability
190
+ dropout (nn.Dropout): The actual dropout layer instance
191
+ """
192
+
193
+ def __init__(self, prob):
194
+ """
195
+ Initialize the fused dropout layer.
196
+
197
+ Args:
198
+ prob (float): Dropout probability (0 means no dropout)
199
+ """
200
+ super().__init__()
201
+ self.prob = prob
202
+ self.dropout = nn.Dropout(p=prob)
203
+
204
+ def forward(self, x, y):
205
+ """
206
+ Forward pass of the fused dropout layer.
207
+
208
+ Args:
209
+ x (Tensor): Input tensor to potentially apply dropout
210
+ y (Tensor): Residual tensor to add to the (possibly dropped out) x
211
+
212
+ Returns:
213
+ Tensor: Result of x (with optional dropout) + y
214
+ """
215
+ if self.prob > 0:
216
+ x = self.dropout(x)
217
+ output = x + y
218
+
219
+ return output
220
+
221
+
222
+ class Ernie4_5_MLP(nn.Module):
223
+ """
224
+ Ernie4_5_MLP - Gated Multi-Layer Perceptron module used in Ernie model.
225
+ """
226
+
227
+ def __init__(self, config, layer_idx=0):
228
+ """
229
+ Initialize the MLP module with configuration options.
230
+
231
+ Args:
232
+ config: Model configurations.
233
+ layer_idx (int): Index of current layer (default: 0)
234
+ """
235
+ super().__init__()
236
+ self.config = config
237
+ self.layer_idx = layer_idx
238
+ self.hidden_size = config.hidden_size
239
+ self.intermediate_size = config.intermediate_size
240
+
241
+ self.gate_proj = nn.Linear(
242
+ self.hidden_size, self.intermediate_size, bias=config.use_bias
243
+ )
244
+ self.up_proj = nn.Linear(
245
+ self.hidden_size, self.intermediate_size, bias=config.use_bias
246
+ )
247
+ self.down_proj = nn.Linear(
248
+ self.intermediate_size, self.hidden_size, bias=config.use_bias
249
+ )
250
+ self.act_fn = ACT2FN[config.hidden_act]
251
+
252
+ def forward(self, x):
253
+ """
254
+ Args:
255
+ x (Tensor): shape [batch_size, seq_len, hidden_size]
256
+
257
+ Returns:
258
+ Tensor: shape [batch_size, seq_len, hidden_size]
259
+ """
260
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
261
+ return down_proj
262
+
263
+
264
+ class Ernie4_5_Attention(nn.Module):
265
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
266
+
267
+ def __init__(self, config, layer_idx=0):
268
+ """Initialize the attention layer.
269
+
270
+ Args:
271
+ config: Model configuration.
272
+ layer_idx (int, optional): Index in transformer stack. Defaults to 0.
273
+ """
274
+ super().__init__()
275
+ self.layer_idx = layer_idx
276
+ self.hidden_size = config.hidden_size
277
+ self.num_heads = config.num_attention_heads
278
+ self.num_key_value_heads = config.num_key_value_heads
279
+
280
+ if config.head_dim is None:
281
+ self.head_dim = self.hidden_size // self.num_heads
282
+ else:
283
+ self.head_dim = config.head_dim
284
+
285
+ self.is_gqa = (
286
+ self.num_key_value_heads is not None
287
+ and self.num_key_value_heads != self.num_heads
288
+ )
289
+
290
+ if self.is_gqa:
291
+ logger.info(
292
+ f"use GQA - num_heads: {self.num_heads}- num_key_value_heads: {self.num_key_value_heads}"
293
+ )
294
+ assert (
295
+ self.num_heads % self.num_key_value_heads == 0
296
+ ), f"num_heads: {self.num_heads}, num_key_value_heads: {self.num_key_value_heads}"
297
+ kv_hidden_size = self.head_dim * self.num_key_value_heads
298
+ q_hidden_size = self.head_dim * self.num_heads
299
+ else:
300
+ q_hidden_size = kv_hidden_size = self.head_dim * self.num_heads
301
+
302
+ self.q_proj = nn.Linear(self.hidden_size, q_hidden_size, bias=config.use_bias)
303
+ self.k_proj = nn.Linear(self.hidden_size, kv_hidden_size, bias=config.use_bias)
304
+ self.v_proj = nn.Linear(self.hidden_size, kv_hidden_size, bias=config.use_bias)
305
+ self.o_proj = nn.Linear(q_hidden_size, self.hidden_size, bias=config.use_bias)
306
+
307
+ self.rotary_emb = Ernie4_5_RopeEmbedding(
308
+ self.head_dim,
309
+ compression_ratio=config.compression_ratio,
310
+ base=config.rope_theta,
311
+ )
312
+ self.config = config
313
+
314
+ self.set_attn_func()
315
+
316
+ def set_attn_func(self):
317
+ """Configure attention function based on settings.
318
+
319
+ Selects between flash/core attention.
320
+ """
321
+ config = self.config
322
+ if config.use_flash_attention:
323
+ self.attn_func = self._flash_attention_wrapper
324
+ else:
325
+ self.attn_func = self.core_attn
326
+
327
+ def forward(
328
+ self,
329
+ hidden_states,
330
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
331
+ attention_mask: Optional[torch.Tensor] = None,
332
+ attn_mask_start_row_indices: Optional[torch.Tensor] = None,
333
+ position_ids: Optional[Tuple[torch.Tensor]] = None,
334
+ output_attentions: bool = False,
335
+ use_cache: bool = False,
336
+ token_type_ids: Optional[Tuple[torch.Tensor]] = None,
337
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
338
+ """Compute attention outputs.
339
+
340
+ Args:
341
+ hidden_states (torch.Tensor): Input tensor [bsz, seq_len, hidden_size]
342
+ past_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): Cached key/value states
343
+ attention_mask (Optional[torch.Tensor]): Attention mask tensor
344
+ attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length attention indices
345
+ position_ids (Optional[torch.Tensor]): Position indices for RoPE
346
+ output_attentions (bool): Return attention weights if True
347
+ use_cache (bool): Cache key/value states if True
348
+
349
+ Returns:
350
+ Tuple containing:
351
+ - attention_output: [bsz, seq_len, hidden_size]
352
+ - attention_weights: Optional attention probabilities
353
+ - updated_key_value_cache: Optional updated cache
354
+ """
355
+ if token_type_ids is not None:
356
+ token_type_ids = token_type_ids[:, :-1]
357
+
358
+ bsz, q_len, _ = hidden_states.shape
359
+
360
+ query_states = self.q_proj(hidden_states).reshape(
361
+ [bsz, q_len, -1, self.head_dim]
362
+ )
363
+ key_states = self.k_proj(hidden_states).reshape([bsz, q_len, -1, self.head_dim])
364
+ value_states = self.v_proj(hidden_states).reshape(
365
+ [bsz, q_len, -1, self.head_dim]
366
+ )
367
+
368
+ attn_output, attn_weights, past_key_value = self.rope_attn(
369
+ query_states=query_states,
370
+ key_states=key_states,
371
+ value_states=value_states,
372
+ attention_mask=attention_mask,
373
+ position_ids=position_ids,
374
+ output_attentions=output_attentions,
375
+ past_key_value=past_key_value,
376
+ use_cache=use_cache,
377
+ attn_mask_start_row_indices=attn_mask_start_row_indices,
378
+ )
379
+
380
+ attn_output = self.o_proj(attn_output)
381
+
382
+ if not output_attentions:
383
+ attn_weights = None
384
+
385
+ return attn_output, attn_weights, past_key_value
386
+
387
+ def repeat_kv(self, hidden_states, n_rep):
388
+ """
389
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
390
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
391
+ """
392
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
393
+ if n_rep == 1:
394
+ return hidden_states
395
+ hidden_states = hidden_states[:, :, None, :, :].expand(
396
+ batch, num_key_value_heads, n_rep, slen, head_dim
397
+ )
398
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
399
+
400
+ def _flash_attention_wrapper(
401
+ self,
402
+ q,
403
+ k,
404
+ v,
405
+ attention_mask=None,
406
+ attn_mask_start_row_indices=None,
407
+ seq_length=None,
408
+ ):
409
+ """Wrapper for flash attention implementation.
410
+
411
+ Args:
412
+ q (torch.Tensor): Query tensor
413
+ k (torch.Tensor): Key tensor
414
+ v (torch.Tensor): Value tensor
415
+ attention_mask (Optional[torch.Tensor]): Attention mask
416
+ attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length indices
417
+ seq_length (Optional[int]): Sequence length
418
+
419
+ Returns:
420
+ Tuple[torch.Tensor, torch.Tensor]: Attention output and weights
421
+ """
422
+ q = q.transpose(1, 2)
423
+ k = k.transpose(1, 2)
424
+ v = v.transpose(1, 2)
425
+
426
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
427
+ out = F.scaled_dot_product_attention(
428
+ q,
429
+ k,
430
+ v,
431
+ attn_mask=attention_mask,
432
+ dropout_p=self.config.attention_probs_dropout_prob,
433
+ is_causal=attention_mask is None and q.shape[1] != 1,
434
+ scale=1
435
+ / (getattr(self.config, "scale_qk_coeff", 1.0) * self.head_dim**0.5),
436
+ enable_gqa=self.is_gqa,
437
+ )
438
+ out = out.transpose(1, 2)
439
+ out = out.contiguous().view(out.size(0), out.size(1), -1)
440
+
441
+ return out, None
442
+
443
+ def core_attn(
444
+ self,
445
+ q,
446
+ k,
447
+ v,
448
+ attention_mask=None,
449
+ attn_mask_start_row_indices=None,
450
+ seq_length=None,
451
+ ):
452
+ """Standard self-attention implementation.
453
+
454
+ Args:
455
+ q (torch.Tensor): Query tensor
456
+ k (torch.Tensor): Key tensor
457
+ v (torch.Tensor): Value tensor
458
+ attention_mask (Optional[torch.Tensor]): Attention mask
459
+ attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length indices
460
+ seq_length (Optional[int]): Sequence length
461
+
462
+ Returns:
463
+ Tuple[torch.Tensor, torch.Tensor]: Attention output and weights
464
+ """
465
+ origin_dtype = q.dtype
466
+
467
+ q = q.permute(0, 2, 1, 3)
468
+ k = k.permute(0, 2, 1, 3)
469
+ v = v.permute(0, 2, 1, 3)
470
+
471
+ scale_qk_coeff = (
472
+ getattr(self.config, "scale_qk_coeff", 1.0) * self.head_dim**0.5
473
+ )
474
+
475
+ q = q / scale_qk_coeff
476
+
477
+ # Handle GQA case - repeat k and v heads to match q heads
478
+ if self.is_gqa:
479
+ # [batch, num_key_value_heads, seq_len, head_dim] -> [batch, num_heads, seq_len, head_dim]
480
+ repeat_factor = self.num_heads // self.num_key_value_heads
481
+ k = self.repeat_kv(k, repeat_factor)
482
+ v = self.repeat_kv(v, repeat_factor)
483
+
484
+ attn_scores = torch.matmul(q, k.transpose(-2, -1))
485
+
486
+ if getattr(self.config, "scale_qk_coeff", 1.0) != 1.0:
487
+ attn_scores = attn_scores * getattr(self.config, "scale_qk_coeff", 1.0)
488
+
489
+ # Causal mask
490
+ seq_len = attn_scores.size(-1)
491
+ mask = torch.triu(
492
+ torch.ones((seq_len, seq_len), dtype=torch.bool, device=attn_scores.device),
493
+ diagonal=1,
494
+ )
495
+ attn_scores = attn_scores.masked_fill(mask, float("-inf"))
496
+ attn_weights = F.softmax(attn_scores, dim=-1)
497
+
498
+ attn_weights = attn_weights.to(origin_dtype)
499
+
500
+ # attention_probs_dropout_prob default 0.0
501
+ if getattr(self.config, "attention_probs_dropout_prob", 0.0) > 0:
502
+ attn_weights = F.dropout(
503
+ attn_weights,
504
+ p=self.config.attention_probs_dropout_prob,
505
+ training=self.training,
506
+ )
507
+
508
+ # [batch, num_heads, q_len, k_len] @ [batch, num_heads, k_len, head_dim] -> [batch, num_heads, q_len, head_dim]
509
+ out = torch.matmul(attn_weights, v)
510
+
511
+ # [batch, num_heads, seq_len, head_dim] -> [batch, seq_len, num_heads, head_dim]
512
+ out = out.permute(0, 2, 1, 3)
513
+ # [batch, seq_len, hidden_size]
514
+ out = out.contiguous().view(out.size(0), out.size(1), -1)
515
+
516
+ return out, attn_weights
517
+
518
+ def rope_attn(
519
+ self,
520
+ query_states,
521
+ key_states,
522
+ value_states,
523
+ attention_mask,
524
+ position_ids,
525
+ output_attentions=False,
526
+ past_key_value=None,
527
+ use_cache=False,
528
+ attn_mask_start_row_indices=None,
529
+ ):
530
+ """Attention computation with rotary embeddings.
531
+
532
+ Args:
533
+ query_states (torch.Tensor): Query states
534
+ key_states (torch.Tensor): Key states
535
+ value_states (torch.Tensor): Value states
536
+ attention_mask (Optional[torch.Tensor]): Attention mask
537
+ position_ids (Optional[torch.Tensor]): Position indices
538
+ output_attentions (bool): Return attention weights
539
+ past_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): Cached states
540
+ use_cache (bool): Cache new states
541
+ attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length indices
542
+
543
+ Returns:
544
+ Tuple containing:
545
+ - attention_output: Result tensor
546
+ - attention_weights: Optional weights
547
+ - updated_key_value_cache: Optional cache
548
+ """
549
+
550
+ query_states_dtype = query_states.dtype
551
+
552
+ kv_seq_len = key_states.shape[-3]
553
+ offset = 0
554
+ if past_key_value is not None:
555
+ offset = past_key_value[0].shape[-3]
556
+ kv_seq_len += offset
557
+
558
+ cos_sin = self.rotary_emb(kv_seq_len).permute(
559
+ [0, 2, 1, 3]
560
+ ) # [b,h,s,d]->[b,s,h,d]
561
+ if offset > 0:
562
+ cos_sin = cos_sin[:, offset:]
563
+ query_states, key_states = self.rotary_emb.apply_rotary(
564
+ cos_sin, query_states, key_states
565
+ )
566
+
567
+ query_states = query_states.to(query_states_dtype)
568
+ key_states = key_states.to(query_states_dtype)
569
+ if past_key_value is not None:
570
+ # reuse k, v, self_attention
571
+ key_states = torch.cat([past_key_value[0], key_states], dim=1)
572
+ value_states = torch.cat([past_key_value[1], value_states], dim=1)
573
+
574
+ # shape: [2, b, s, kvh, d]
575
+ past_key_value = [key_states, value_states] if use_cache else None
576
+ seq_length = query_states.shape[1]
577
+ attn_output, attn_weights = self.attn_func(
578
+ query_states,
579
+ key_states,
580
+ value_states,
581
+ attention_mask,
582
+ attn_mask_start_row_indices,
583
+ seq_length,
584
+ )
585
+ return attn_output, attn_weights, past_key_value
586
+
587
+
588
+ class Ernie4_5_DecoderLayer(nn.Module):
589
+ """
590
+ A single transformer decoder layer in ERNIE model.
591
+ """
592
+
593
+ def __init__(self, config, layer_idx):
594
+ """Initialize the decoder layer.
595
+
596
+ Args:
597
+ config: Model configuration.
598
+ layer_idx (int): Index of this layer in the transformer stack
599
+ """
600
+ super().__init__()
601
+ self.hidden_size = config.hidden_size
602
+ self.layer_idx = layer_idx
603
+ self.config = config
604
+
605
+ self.self_attn = Ernie4_5_Attention(config, layer_idx)
606
+ self.mlp = Ernie4_5_MLP(config)
607
+
608
+ self.input_layernorm = Ernie4_5_RMSNorm(config)
609
+ self.post_attention_layernorm = Ernie4_5_RMSNorm(config)
610
+
611
+ self.residual_add1 = Ernie4_5_FusedDropoutImpl(config.hidden_dropout_prob)
612
+ self.residual_add2 = Ernie4_5_FusedDropoutImpl(config.hidden_dropout_prob)
613
+
614
+ def forward(
615
+ self,
616
+ hidden_states: torch.Tensor,
617
+ attention_mask: Optional[torch.Tensor] = None,
618
+ attn_mask_start_row_indices: Optional[torch.Tensor] = None,
619
+ position_ids: Optional[torch.Tensor] = None,
620
+ token_type_ids: Optional[torch.Tensor] = None,
621
+ output_attentions: Optional[bool] = False,
622
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
623
+ use_cache: Optional[bool] = False,
624
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
625
+ """Forward pass through the decoder layer.
626
+
627
+ Args:
628
+ hidden_states (torch.Tensor): Input tensor [batch_size, seq_len, hidden_size]
629
+ attention_mask (Optional[torch.Tensor]): Attention mask tensor
630
+ attn_mask_start_row_indices (Optional[torch.Tensor]): Indices for variable length attention
631
+ position_ids (Optional[torch.Tensor]): Position indices for rotary embeddings
632
+ output_attentions (Optional[bool]): Whether to return attention weights
633
+ past_key_value (Optional[Tuple[torch.Tensor]]): Cached key/value states
634
+ use_cache (Optional[bool]): Whether to cache key/value states
635
+
636
+ Returns:
637
+ Union: Various output combinations depending on arguments:
638
+ - Base case: Hidden states tensor
639
+ - With attention: Tuple of (hidden_states, attention_weights)
640
+ - With cache: Tuple of (hidden_states, cached_key_value)
641
+ """
642
+ residual = hidden_states
643
+
644
+ hidden_states = self.input_layernorm(hidden_states)
645
+
646
+ # Self Attention
647
+ (hidden_states, self_attn_weights, present_key_value) = self.self_attn(
648
+ hidden_states=hidden_states,
649
+ past_key_value=past_key_value,
650
+ attention_mask=attention_mask,
651
+ attn_mask_start_row_indices=attn_mask_start_row_indices,
652
+ position_ids=position_ids,
653
+ output_attentions=output_attentions,
654
+ use_cache=use_cache,
655
+ token_type_ids=token_type_ids,
656
+ )
657
+ hidden_states = self.residual_add1(hidden_states, residual)
658
+
659
+ # Fully Connected
660
+ residual = hidden_states
661
+ hidden_states = self.post_attention_layernorm(hidden_states)
662
+ hidden_states = self.mlp(hidden_states)
663
+
664
+ hidden_states = self.residual_add2(hidden_states, residual)
665
+ outputs = (hidden_states,)
666
+
667
+ if output_attentions:
668
+ outputs += (self_attn_weights,)
669
+
670
+ if use_cache:
671
+ outputs += (present_key_value,)
672
+
673
+ if type(outputs) is tuple and len(outputs) == 1:
674
+ outputs = outputs[0]
675
+
676
+ return outputs
677
+
678
+
679
+ class Ernie4_5_PretrainedModel(PreTrainedModel):
680
+ """Base class for ERNIE pretrained models."""
681
+
682
+ config_class = Ernie4_5_Config
683
+ base_model_prefix = "ernie"
684
+
685
+
686
+ class Ernie4_5_Model(Ernie4_5_PretrainedModel):
687
+
688
+ def __init__(self, config):
689
+ """Initialize the ERNIE model architecture.
690
+
691
+ Args:
692
+ config: Model configuration.
693
+ """
694
+ super().__init__(config)
695
+ self.padding_idx = config.pad_token_id
696
+ self.vocab_size = config.vocab_size
697
+ self.hidden_size = config.hidden_size
698
+ self.config = config
699
+
700
+ self.embed_tokens = nn.Embedding(
701
+ self.vocab_size,
702
+ self.hidden_size,
703
+ )
704
+
705
+ self.layers = nn.ModuleList(
706
+ [Ernie4_5_DecoderLayer(config, i) for i in range(config.num_hidden_layers)]
707
+ )
708
+
709
+ self.norm = Ernie4_5_RMSNorm(config)
710
+
711
+ self.gradient_checkpointing = False
712
+
713
+ def get_input_embeddings(self):
714
+ """Get the input embedding layer.
715
+
716
+ Returns:
717
+ nn.Embedding: The embedding layer for input tokens
718
+ """
719
+ return self.embed_tokens
720
+
721
+ def set_input_embeddings(self, value):
722
+ """Set new input embeddings.
723
+
724
+ Args:
725
+ value (nn.Embedding): New embedding layer to use
726
+ """
727
+ self.embed_tokens = value
728
+
729
+ def forward(
730
+ self,
731
+ input_ids=None,
732
+ position_ids=None,
733
+ token_type_ids=None,
734
+ attention_mask=None,
735
+ attn_mask_start_row_indices=None,
736
+ inputs_embeds=None,
737
+ use_cache=None,
738
+ past_key_values=None,
739
+ output_attentions=False,
740
+ output_hidden_states=None,
741
+ return_dict=False,
742
+ ):
743
+ """Forward pass through the ERNIE model.
744
+
745
+ Args:
746
+ input_ids (Optional[torch.Tensor]): Input token IDs
747
+ position_ids (Optional[torch.Tensor]): Position indices
748
+ attention_mask (Optional[torch.Tensor]): Attention mask
749
+ attn_mask_start_row_indices (Optional[torch.Tensor]): Variable length attention indices
750
+ inputs_embeds (Optional[torch.Tensor]): Precomputed embeddings
751
+ use_cache (Optional[bool]): Whether to cache key/value states
752
+ past_key_values (Optional[Tuple[Tuple[torch.Tensor]]]): Cached key/value states
753
+ output_attentions (Optional[bool]): Whether to output attention weights
754
+ output_hidden_states (Optional[bool]): Whether to output all hidden states
755
+ return_dict (Optional[bool]): Whether to return dict or tuple
756
+
757
+ Returns:
758
+ Union[Tuple, BaseModelOutputWithPast]:
759
+ Various outputs depending on configuration, including:
760
+ - last_hidden_state: Final layer hidden states
761
+ - past_key_values: Cached key/value states if use_cache=True
762
+ - hidden_states: All hidden states if output_hidden_states=True
763
+ - attentions: Attention weights if output_attentions=True
764
+ """
765
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
766
+
767
+ # retrieve input_ids and inputs_embeds
768
+ if input_ids is not None and inputs_embeds is not None:
769
+ raise ValueError(
770
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
771
+ )
772
+ elif input_ids is not None:
773
+ _, seq_length = input_ids.shape
774
+ elif inputs_embeds is not None:
775
+ _, seq_length, _ = inputs_embeds.shape
776
+ else:
777
+ raise ValueError(
778
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds"
779
+ )
780
+
781
+ if past_key_values is None:
782
+ past_key_values = tuple([None] * len(self.layers))
783
+
784
+ if inputs_embeds is None:
785
+ inputs_embeds = self.embed_tokens(input_ids)
786
+ inputs_embeds = inputs_embeds.to(self.embed_tokens.weight.dtype)
787
+
788
+ hidden_states = inputs_embeds
789
+
790
+ # decoder layers
791
+ all_hidden_states = () if output_hidden_states else None
792
+ all_self_attns = () if output_attentions else None
793
+ next_decoder_cache = () if use_cache else None
794
+
795
+ for idx, (decoder_layer) in enumerate(self.layers):
796
+
797
+ if output_hidden_states:
798
+ all_hidden_states += (hidden_states,)
799
+
800
+ past_key_value = (
801
+ past_key_values[idx] if past_key_values is not None else None
802
+ )
803
+
804
+ layer_outputs = decoder_layer(
805
+ hidden_states,
806
+ attention_mask,
807
+ attn_mask_start_row_indices,
808
+ position_ids,
809
+ token_type_ids,
810
+ output_attentions,
811
+ past_key_value,
812
+ use_cache,
813
+ )
814
+
815
+ if isinstance(layer_outputs, (tuple, list)):
816
+ hidden_states = layer_outputs[0]
817
+ else:
818
+ hidden_states = layer_outputs
819
+
820
+ if use_cache:
821
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
822
+
823
+ if output_attentions:
824
+ all_self_attns += (layer_outputs[1],)
825
+
826
+ # apply kv cache
827
+ if past_key_value is not None:
828
+ hidden_states = hidden_states[:, -1:, :]
829
+
830
+ hidden_states = self.norm(hidden_states)
831
+
832
+ # add hidden states from the last decoder layer
833
+ if output_hidden_states:
834
+ all_hidden_states += (hidden_states,)
835
+
836
+ next_cache = next_decoder_cache if use_cache else None
837
+
838
+ if not return_dict:
839
+ return tuple(
840
+ v
841
+ for v in [
842
+ hidden_states,
843
+ next_cache,
844
+ all_hidden_states,
845
+ all_self_attns,
846
+ ]
847
+ if v is not None
848
+ )
849
+
850
+ return BaseModelOutputWithPast(
851
+ last_hidden_state=hidden_states,
852
+ past_key_values=next_cache,
853
+ hidden_states=all_hidden_states,
854
+ attentions=all_self_attns,
855
+ )
856
+
857
+
858
+ class Ernie4_5_LMHead(nn.Module):
859
+ """Language model head for ERNIE"""
860
+
861
+ def __init__(self, config):
862
+ """Initialize the language model head.
863
+
864
+ Args:
865
+ config: Model configuration containing:
866
+ - vocab_size: Size of vocabulary
867
+ - hidden_size: Dimension of hidden states
868
+ - tie_word_embeddings: Whether to tie input/output embeddings
869
+ - weight_share_add_bias: Whether to add bias when weight sharing
870
+ - use_bias: Whether to use bias term
871
+ """
872
+
873
+ super(Ernie4_5_LMHead, self).__init__()
874
+ self.config = config
875
+ vocab_size = config.vocab_size
876
+
877
+ if config.tie_word_embeddings:
878
+ # Weight of shape [vocab_size, hidden_size]
879
+ self.weight = nn.Parameter(
880
+ torch.empty(
881
+ vocab_size, config.hidden_size, dtype=torch.get_default_dtype()
882
+ )
883
+ )
884
+ else:
885
+ # Weight of shape [hidden_size, vocab_size]
886
+ self.weight = nn.Parameter(
887
+ torch.empty(
888
+ config.hidden_size, vocab_size, dtype=torch.get_default_dtype()
889
+ )
890
+ )
891
+ nn.init.xavier_uniform_(self.weight)
892
+
893
+ logger.info(
894
+ f"output-weight: {self.weight.shape}, tie_word_embeddings: {config.tie_word_embeddings}"
895
+ )
896
+
897
+ if config.weight_share_add_bias and config.use_bias:
898
+ self.bias = nn.Parameter(
899
+ torch.zeros(vocab_size, dtype=torch.get_default_dtype())
900
+ )
901
+ else:
902
+ self.bias = None
903
+
904
+ def forward(self, hidden_states):
905
+ """Project hidden states to vocabulary logits.
906
+
907
+ Args:
908
+ hidden_states (torch.Tensor): Input tensor of shape [batch_size, seq_len, hidden_size]
909
+
910
+ Returns:
911
+ Logits tensor of shape [batch_size, seq_len, vocab_size]
912
+ """
913
+ return self.calc_lm_head_logits(
914
+ self.config, hidden_states, self.weight, self.bias
915
+ )
916
+
917
+ def calc_lm_head_logits(self, config, hidden_states, weight, bias):
918
+ """
919
+ Calculate language model head logits.
920
+
921
+ This is the core function that computes the final output logits for a language model.
922
+
923
+ Args:
924
+ config: Model configuration.
925
+ hidden_states (Tensor): Hidden states from the transformer layers
926
+ weight (Tensor): Weight matrix for the language model head
927
+ bias (Tensor): Bias vector for the language model head
928
+
929
+ Returns:
930
+ Tensor: The computed logits for language modeling.
931
+ """
932
+
933
+ if config.tie_word_embeddings:
934
+ logits = torch.matmul(hidden_states, weight.T)
935
+ else:
936
+ logits = torch.matmul(hidden_states, weight)
937
+
938
+ if bias is not None:
939
+ logits = logits + bias
940
+
941
+ return logits
942
+
943
+
944
+ class Ernie4_5_ForCausalLM(Ernie4_5_PretrainedModel, GenerationMixin):
945
+ """ERNIE model for causal language modeling."""
946
+
947
+ _tied_weights_keys = ["lm_head.weight"]
948
+ _tp_plan = {"lm_head": "colwise_rep"}
949
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
950
+
951
+ def __init__(self, config):
952
+ """
953
+ Initializes the ERNIE model for causal language modeling.
954
+
955
+ Args:
956
+ config: Model configuration.
957
+ """
958
+ super().__init__(config)
959
+
960
+ self.config = config
961
+ self.model = Ernie4_5_Model(config)
962
+ self.lm_head = Ernie4_5_LMHead(config)
963
+
964
+ # Initialize weights and apply final processing
965
+ self.post_init()
966
+
967
+ @torch.no_grad()
968
+ def set_state_dict(self, state_dict, *args, **kwargs):
969
+ """
970
+ Loads the model state dictionary.
971
+ """
972
+ ret = super().set_state_dict(state_dict)
973
+ return ret
974
+
975
+ def get_input_embeddings(self):
976
+ """Returns the input embeddings layer."""
977
+ return self.model.embed_tokens
978
+
979
+ def set_input_embeddings(self, value):
980
+ """Sets the input embeddings layer."""
981
+ self.model.embed_tokens = value
982
+
983
+ def get_output_embeddings(self):
984
+ """Returns the output embeddings (LM head)."""
985
+ return self.lm_head
986
+
987
+ def set_output_embeddings(self, new_embeddings):
988
+ """Sets the output embeddings layer."""
989
+ self.lm_head = new_embeddings
990
+
991
+ def set_decoder(self, decoder):
992
+ """Sets the ERNIE decoder model."""
993
+ self.model = decoder
994
+
995
+ def get_decoder(self):
996
+ """Gets the ERNIE decoder model."""
997
+ return self.model
998
+
999
+ def forward(
1000
+ self,
1001
+ input_ids,
1002
+ position_ids=None,
1003
+ attention_mask=None,
1004
+ attn_mask_start_row_indices=None,
1005
+ token_type_ids=None,
1006
+ inputs_embeds=None,
1007
+ labels=None,
1008
+ use_cache=False,
1009
+ past_key_values=None,
1010
+ output_attentions=None,
1011
+ output_hidden_states=None,
1012
+ **kwargs,
1013
+ ):
1014
+ """
1015
+ Forward pass for causal language modeling.
1016
+
1017
+ Args:
1018
+ input_ids (torch.Tensor): Input token IDs.
1019
+ position_ids (torch.Tensor): Position IDs.
1020
+ attention_mask (torch.Tensor): Attention mask.
1021
+ attn_mask_start_row_indices (torch.Tensor): Attention mask start indices.
1022
+ inputs_embeds (torch.Tensor): Optional embedded inputs.
1023
+ labels (torch.Tensor): Target labels.
1024
+ use_cache (bool): Whether to use cached hidden states.
1025
+ past_key_values (dict): Pre-computed hidden states.
1026
+ output_attentions (bool): Whether to output attentions.
1027
+ output_hidden_states (bool): Whether to output hidden states.
1028
+
1029
+ Returns:
1030
+ CausalLMOutputWithPast: Model outputs.
1031
+ """
1032
+
1033
+ if past_key_values is not None:
1034
+ input_ids = input_ids[:, -1:]
1035
+
1036
+ outputs = self.model(
1037
+ input_ids,
1038
+ position_ids=position_ids,
1039
+ attention_mask=attention_mask,
1040
+ token_type_ids=token_type_ids,
1041
+ attn_mask_start_row_indices=attn_mask_start_row_indices,
1042
+ inputs_embeds=inputs_embeds,
1043
+ use_cache=use_cache,
1044
+ past_key_values=past_key_values,
1045
+ output_attentions=output_attentions,
1046
+ output_hidden_states=output_hidden_states,
1047
+ return_dict=True,
1048
+ )
1049
+
1050
+ hidden_states = outputs.last_hidden_state
1051
+ logits = self.lm_head(hidden_states)
1052
+
1053
+ loss = None
1054
+ if labels is not None:
1055
+ loss = self.loss_function(
1056
+ logits=logits,
1057
+ labels=labels,
1058
+ vocab_size=self.config.vocab_size,
1059
+ **kwargs,
1060
+ )
1061
+
1062
+ return CausalLMOutputWithPast(
1063
+ loss=loss,
1064
+ logits=logits,
1065
+ past_key_values=outputs.past_key_values,
1066
+ hidden_states=outputs.hidden_states,
1067
+ attentions=outputs.attentions,
1068
+ )
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "pad_token": "<unk>", "unk_token": "<unk>", "cls_token": "<|begin_of_sentence|>", "sep_token": "<|end_of_sentence|>", "mask_token": "<mask:1>", "sys_start_token": "<mask:4>", "sys_end_token": "<mask:5>", "header_start_token": "<mask:6>", "header_end_token": "<mask:7>", "additional_special_tokens": ["<|IMAGE_PLACEHOLDER|>", "<|AUDIO_PLACEHOLDER|>", "<|LOC_0|>", "<|LOC_1|>", "<|LOC_2|>", "<|LOC_3|>", "<|LOC_4|>", "<|LOC_5|>", "<|LOC_6|>", "<|LOC_7|>", "<|LOC_8|>", "<|LOC_9|>", "<|LOC_10|>", "<|LOC_11|>", "<|LOC_12|>", "<|LOC_13|>", "<|LOC_14|>", "<|LOC_15|>", "<|LOC_16|>", "<|LOC_17|>", "<|LOC_18|>", "<|LOC_19|>", "<|LOC_20|>", "<|LOC_21|>", "<|LOC_22|>", "<|LOC_23|>", "<|LOC_24|>", "<|LOC_25|>", "<|LOC_26|>", "<|LOC_27|>", "<|LOC_28|>", "<|LOC_29|>", "<|LOC_30|>", "<|LOC_31|>", "<|LOC_32|>", "<|LOC_33|>", "<|LOC_34|>", "<|LOC_35|>", "<|LOC_36|>", "<|LOC_37|>", "<|LOC_38|>", "<|LOC_39|>", "<|LOC_40|>", "<|LOC_41|>", "<|LOC_42|>", "<|LOC_43|>", "<|LOC_44|>", "<|LOC_45|>", "<|LOC_46|>", "<|LOC_47|>", "<|LOC_48|>", "<|LOC_49|>", "<|LOC_50|>", "<|LOC_51|>", "<|LOC_52|>", "<|LOC_53|>", "<|LOC_54|>", "<|LOC_55|>", "<|LOC_56|>", "<|LOC_57|>", "<|LOC_58|>", "<|LOC_59|>", "<|LOC_60|>", "<|LOC_61|>", "<|LOC_62|>", "<|LOC_63|>", "<|LOC_64|>", "<|LOC_65|>", "<|LOC_66|>", "<|LOC_67|>", "<|LOC_68|>", "<|LOC_69|>", "<|LOC_70|>", "<|LOC_71|>", "<|LOC_72|>", "<|LOC_73|>", "<|LOC_74|>", "<|LOC_75|>", "<|LOC_76|>", "<|LOC_77|>", "<|LOC_78|>", "<|LOC_79|>", "<|LOC_80|>", "<|LOC_81|>", "<|LOC_82|>", "<|LOC_83|>", "<|LOC_84|>", "<|LOC_85|>", "<|LOC_86|>", "<|LOC_87|>", "<|LOC_88|>", "<|LOC_89|>", "<|LOC_90|>", "<|LOC_91|>", "<|LOC_92|>", "<|LOC_93|>", "<|LOC_94|>", "<|LOC_95|>", "<|LOC_96|>", "<|LOC_97|>", "<|LOC_98|>", "<|LOC_99|>", "<|LOC_100|>", "<|LOC_101|>", "<|LOC_102|>", "<|LOC_103|>", "<|LOC_104|>", "<|LOC_105|>", "<|LOC_106|>", "<|LOC_107|>", "<|LOC_108|>", "<|LOC_109|>", "<|LOC_110|>", "<|LOC_111|>", "<|LOC_112|>", "<|LOC_113|>", "<|LOC_114|>", "<|LOC_115|>", "<|LOC_116|>", "<|LOC_117|>", "<|LOC_118|>", "<|LOC_119|>", "<|LOC_120|>", "<|LOC_121|>", "<|LOC_122|>", "<|LOC_123|>", "<|LOC_124|>", "<|LOC_125|>", "<|LOC_126|>", "<|LOC_127|>", "<|LOC_128|>", "<|LOC_129|>", "<|LOC_130|>", "<|LOC_131|>", "<|LOC_132|>", "<|LOC_133|>", "<|LOC_134|>", "<|LOC_135|>", "<|LOC_136|>", "<|LOC_137|>", "<|LOC_138|>", "<|LOC_139|>", "<|LOC_140|>", "<|LOC_141|>", "<|LOC_142|>", "<|LOC_143|>", "<|LOC_144|>", "<|LOC_145|>", "<|LOC_146|>", "<|LOC_147|>", "<|LOC_148|>", "<|LOC_149|>", "<|LOC_150|>", "<|LOC_151|>", "<|LOC_152|>", "<|LOC_153|>", "<|LOC_154|>", "<|LOC_155|>", "<|LOC_156|>", "<|LOC_157|>", "<|LOC_158|>", "<|LOC_159|>", "<|LOC_160|>", "<|LOC_161|>", "<|LOC_162|>", "<|LOC_163|>", "<|LOC_164|>", "<|LOC_165|>", "<|LOC_166|>", "<|LOC_167|>", "<|LOC_168|>", "<|LOC_169|>", "<|LOC_170|>", "<|LOC_171|>", "<|LOC_172|>", "<|LOC_173|>", "<|LOC_174|>", "<|LOC_175|>", "<|LOC_176|>", "<|LOC_177|>", "<|LOC_178|>", "<|LOC_179|>", "<|LOC_180|>", "<|LOC_181|>", "<|LOC_182|>", "<|LOC_183|>", "<|LOC_184|>", "<|LOC_185|>", "<|LOC_186|>", "<|LOC_187|>", "<|LOC_188|>", "<|LOC_189|>", "<|LOC_190|>", "<|LOC_191|>", "<|LOC_192|>", "<|LOC_193|>", "<|LOC_194|>", "<|LOC_195|>", "<|LOC_196|>", "<|LOC_197|>", "<|LOC_198|>", "<|LOC_199|>", "<|LOC_200|>", "<|LOC_201|>", "<|LOC_202|>", "<|LOC_203|>", "<|LOC_204|>", "<|LOC_205|>", "<|LOC_206|>", "<|LOC_207|>", "<|LOC_208|>", "<|LOC_209|>", "<|LOC_210|>", "<|LOC_211|>", "<|LOC_212|>", "<|LOC_213|>", "<|LOC_214|>", "<|LOC_215|>", "<|LOC_216|>", "<|LOC_217|>", "<|LOC_218|>", "<|LOC_219|>", "<|LOC_220|>", "<|LOC_221|>", "<|LOC_222|>", "<|LOC_223|>", "<|LOC_224|>", "<|LOC_225|>", "<|LOC_226|>", "<|LOC_227|>", "<|LOC_228|>", "<|LOC_229|>", "<|LOC_230|>", "<|LOC_231|>", "<|LOC_232|>", "<|LOC_233|>", "<|LOC_234|>", "<|LOC_235|>", "<|LOC_236|>", "<|LOC_237|>", "<|LOC_238|>", "<|LOC_239|>", "<|LOC_240|>", "<|LOC_241|>", "<|LOC_242|>", "<|LOC_243|>", "<|LOC_244|>", "<|LOC_245|>", "<|LOC_246|>", "<|LOC_247|>", "<|LOC_248|>", "<|LOC_249|>", "<|LOC_250|>", "<|LOC_251|>", "<|LOC_252|>", "<|LOC_253|>", "<|LOC_254|>", "<|LOC_255|>", "<|LOC_256|>", "<|LOC_257|>", "<|LOC_258|>", "<|LOC_259|>", "<|LOC_260|>", "<|LOC_261|>", "<|LOC_262|>", "<|LOC_263|>", "<|LOC_264|>", "<|LOC_265|>", "<|LOC_266|>", "<|LOC_267|>", "<|LOC_268|>", "<|LOC_269|>", "<|LOC_270|>", "<|LOC_271|>", "<|LOC_272|>", "<|LOC_273|>", "<|LOC_274|>", "<|LOC_275|>", "<|LOC_276|>", "<|LOC_277|>", "<|LOC_278|>", "<|LOC_279|>", "<|LOC_280|>", "<|LOC_281|>", "<|LOC_282|>", "<|LOC_283|>", "<|LOC_284|>", "<|LOC_285|>", "<|LOC_286|>", "<|LOC_287|>", "<|LOC_288|>", "<|LOC_289|>", "<|LOC_290|>", "<|LOC_291|>", "<|LOC_292|>", "<|LOC_293|>", "<|LOC_294|>", "<|LOC_295|>", "<|LOC_296|>", "<|LOC_297|>", "<|LOC_298|>", "<|LOC_299|>", "<|LOC_300|>", "<|LOC_301|>", "<|LOC_302|>", "<|LOC_303|>", "<|LOC_304|>", "<|LOC_305|>", "<|LOC_306|>", "<|LOC_307|>", "<|LOC_308|>", "<|LOC_309|>", "<|LOC_310|>", "<|LOC_311|>", "<|LOC_312|>", "<|LOC_313|>", "<|LOC_314|>", "<|LOC_315|>", "<|LOC_316|>", "<|LOC_317|>", "<|LOC_318|>", "<|LOC_319|>", "<|LOC_320|>", "<|LOC_321|>", "<|LOC_322|>", "<|LOC_323|>", "<|LOC_324|>", "<|LOC_325|>", "<|LOC_326|>", "<|LOC_327|>", "<|LOC_328|>", "<|LOC_329|>", "<|LOC_330|>", "<|LOC_331|>", "<|LOC_332|>", "<|LOC_333|>", "<|LOC_334|>", "<|LOC_335|>", "<|LOC_336|>", "<|LOC_337|>", "<|LOC_338|>", "<|LOC_339|>", "<|LOC_340|>", "<|LOC_341|>", "<|LOC_342|>", "<|LOC_343|>", "<|LOC_344|>", "<|LOC_345|>", "<|LOC_346|>", "<|LOC_347|>", "<|LOC_348|>", "<|LOC_349|>", "<|LOC_350|>", "<|LOC_351|>", "<|LOC_352|>", "<|LOC_353|>", "<|LOC_354|>", "<|LOC_355|>", "<|LOC_356|>", "<|LOC_357|>", "<|LOC_358|>", "<|LOC_359|>", "<|LOC_360|>", "<|LOC_361|>", "<|LOC_362|>", "<|LOC_363|>", "<|LOC_364|>", "<|LOC_365|>", "<|LOC_366|>", "<|LOC_367|>", "<|LOC_368|>", "<|LOC_369|>", "<|LOC_370|>", "<|LOC_371|>", "<|LOC_372|>", "<|LOC_373|>", "<|LOC_374|>", "<|LOC_375|>", "<|LOC_376|>", "<|LOC_377|>", "<|LOC_378|>", "<|LOC_379|>", "<|LOC_380|>", "<|LOC_381|>", "<|LOC_382|>", "<|LOC_383|>", "<|LOC_384|>", "<|LOC_385|>", "<|LOC_386|>", "<|LOC_387|>", "<|LOC_388|>", "<|LOC_389|>", "<|LOC_390|>", "<|LOC_391|>", "<|LOC_392|>", "<|LOC_393|>", "<|LOC_394|>", "<|LOC_395|>", "<|LOC_396|>", "<|LOC_397|>", "<|LOC_398|>", "<|LOC_399|>", "<|LOC_400|>", "<|LOC_401|>", "<|LOC_402|>", "<|LOC_403|>", "<|LOC_404|>", "<|LOC_405|>", "<|LOC_406|>", "<|LOC_407|>", "<|LOC_408|>", "<|LOC_409|>", "<|LOC_410|>", "<|LOC_411|>", "<|LOC_412|>", "<|LOC_413|>", "<|LOC_414|>", "<|LOC_415|>", "<|LOC_416|>", "<|LOC_417|>", "<|LOC_418|>", "<|LOC_419|>", "<|LOC_420|>", "<|LOC_421|>", "<|LOC_422|>", "<|LOC_423|>", "<|LOC_424|>", "<|LOC_425|>", "<|LOC_426|>", "<|LOC_427|>", "<|LOC_428|>", "<|LOC_429|>", "<|LOC_430|>", "<|LOC_431|>", "<|LOC_432|>", "<|LOC_433|>", "<|LOC_434|>", "<|LOC_435|>", "<|LOC_436|>", "<|LOC_437|>", "<|LOC_438|>", "<|LOC_439|>", "<|LOC_440|>", "<|LOC_441|>", "<|LOC_442|>", "<|LOC_443|>", "<|LOC_444|>", "<|LOC_445|>", "<|LOC_446|>", "<|LOC_447|>", "<|LOC_448|>", "<|LOC_449|>", "<|LOC_450|>", "<|LOC_451|>", "<|LOC_452|>", "<|LOC_453|>", "<|LOC_454|>", "<|LOC_455|>", "<|LOC_456|>", "<|LOC_457|>", "<|LOC_458|>", "<|LOC_459|>", "<|LOC_460|>", "<|LOC_461|>", "<|LOC_462|>", "<|LOC_463|>", "<|LOC_464|>", "<|LOC_465|>", "<|LOC_466|>", "<|LOC_467|>", "<|LOC_468|>", "<|LOC_469|>", "<|LOC_470|>", "<|LOC_471|>", "<|LOC_472|>", "<|LOC_473|>", "<|LOC_474|>", "<|LOC_475|>", "<|LOC_476|>", "<|LOC_477|>", "<|LOC_478|>", "<|LOC_479|>", "<|LOC_480|>", "<|LOC_481|>", "<|LOC_482|>", "<|LOC_483|>", "<|LOC_484|>", "<|LOC_485|>", "<|LOC_486|>", "<|LOC_487|>", "<|LOC_488|>", "<|LOC_489|>", "<|LOC_490|>", "<|LOC_491|>", "<|LOC_492|>", "<|LOC_493|>", "<|LOC_494|>", "<|LOC_495|>", "<|LOC_496|>", "<|LOC_497|>", "<|LOC_498|>", "<|LOC_499|>", "<|LOC_500|>", "<|LOC_501|>", "<|LOC_502|>", "<|LOC_503|>", "<|LOC_504|>", "<|LOC_505|>", "<|LOC_506|>", "<|LOC_507|>", "<|LOC_508|>", "<|LOC_509|>", "<|LOC_510|>", "<|LOC_511|>", "<|LOC_512|>", "<|LOC_513|>", "<|LOC_514|>", "<|LOC_515|>", "<|LOC_516|>", "<|LOC_517|>", "<|LOC_518|>", "<|LOC_519|>", "<|LOC_520|>", "<|LOC_521|>", "<|LOC_522|>", "<|LOC_523|>", "<|LOC_524|>", "<|LOC_525|>", "<|LOC_526|>", "<|LOC_527|>", "<|LOC_528|>", "<|LOC_529|>", "<|LOC_530|>", "<|LOC_531|>", "<|LOC_532|>", "<|LOC_533|>", "<|LOC_534|>", "<|LOC_535|>", "<|LOC_536|>", "<|LOC_537|>", "<|LOC_538|>", "<|LOC_539|>", "<|LOC_540|>", "<|LOC_541|>", "<|LOC_542|>", "<|LOC_543|>", "<|LOC_544|>", "<|LOC_545|>", "<|LOC_546|>", "<|LOC_547|>", "<|LOC_548|>", "<|LOC_549|>", "<|LOC_550|>", "<|LOC_551|>", "<|LOC_552|>", "<|LOC_553|>", "<|LOC_554|>", "<|LOC_555|>", "<|LOC_556|>", "<|LOC_557|>", "<|LOC_558|>", "<|LOC_559|>", "<|LOC_560|>", "<|LOC_561|>", "<|LOC_562|>", "<|LOC_563|>", "<|LOC_564|>", "<|LOC_565|>", "<|LOC_566|>", "<|LOC_567|>", "<|LOC_568|>", "<|LOC_569|>", "<|LOC_570|>", "<|LOC_571|>", "<|LOC_572|>", "<|LOC_573|>", "<|LOC_574|>", "<|LOC_575|>", "<|LOC_576|>", "<|LOC_577|>", "<|LOC_578|>", "<|LOC_579|>", "<|LOC_580|>", "<|LOC_581|>", "<|LOC_582|>", "<|LOC_583|>", "<|LOC_584|>", "<|LOC_585|>", "<|LOC_586|>", "<|LOC_587|>", "<|LOC_588|>", "<|LOC_589|>", "<|LOC_590|>", "<|LOC_591|>", "<|LOC_592|>", "<|LOC_593|>", "<|LOC_594|>", "<|LOC_595|>", "<|LOC_596|>", "<|LOC_597|>", "<|LOC_598|>", "<|LOC_599|>", "<|LOC_600|>", "<|LOC_601|>", "<|LOC_602|>", "<|LOC_603|>", "<|LOC_604|>", "<|LOC_605|>", "<|LOC_606|>", "<|LOC_607|>", "<|LOC_608|>", "<|LOC_609|>", "<|LOC_610|>", "<|LOC_611|>", "<|LOC_612|>", "<|LOC_613|>", "<|LOC_614|>", "<|LOC_615|>", "<|LOC_616|>", "<|LOC_617|>", "<|LOC_618|>", "<|LOC_619|>", "<|LOC_620|>", "<|LOC_621|>", "<|LOC_622|>", "<|LOC_623|>", "<|LOC_624|>", "<|LOC_625|>", "<|LOC_626|>", "<|LOC_627|>", "<|LOC_628|>", "<|LOC_629|>", "<|LOC_630|>", "<|LOC_631|>", "<|LOC_632|>", "<|LOC_633|>", "<|LOC_634|>", "<|LOC_635|>", "<|LOC_636|>", "<|LOC_637|>", "<|LOC_638|>", "<|LOC_639|>", "<|LOC_640|>", "<|LOC_641|>", "<|LOC_642|>", "<|LOC_643|>", "<|LOC_644|>", "<|LOC_645|>", "<|LOC_646|>", "<|LOC_647|>", "<|LOC_648|>", "<|LOC_649|>", "<|LOC_650|>", "<|LOC_651|>", "<|LOC_652|>", "<|LOC_653|>", "<|LOC_654|>", "<|LOC_655|>", "<|LOC_656|>", "<|LOC_657|>", "<|LOC_658|>", "<|LOC_659|>", "<|LOC_660|>", "<|LOC_661|>", "<|LOC_662|>", "<|LOC_663|>", "<|LOC_664|>", "<|LOC_665|>", "<|LOC_666|>", "<|LOC_667|>", "<|LOC_668|>", "<|LOC_669|>", "<|LOC_670|>", "<|LOC_671|>", "<|LOC_672|>", "<|LOC_673|>", "<|LOC_674|>", "<|LOC_675|>", "<|LOC_676|>", "<|LOC_677|>", "<|LOC_678|>", "<|LOC_679|>", "<|LOC_680|>", "<|LOC_681|>", "<|LOC_682|>", "<|LOC_683|>", "<|LOC_684|>", "<|LOC_685|>", "<|LOC_686|>", "<|LOC_687|>", "<|LOC_688|>", "<|LOC_689|>", "<|LOC_690|>", "<|LOC_691|>", "<|LOC_692|>", "<|LOC_693|>", "<|LOC_694|>", "<|LOC_695|>", "<|LOC_696|>", "<|LOC_697|>", "<|LOC_698|>", "<|LOC_699|>", "<|LOC_700|>", "<|LOC_701|>", "<|LOC_702|>", "<|LOC_703|>", "<|LOC_704|>", "<|LOC_705|>", "<|LOC_706|>", "<|LOC_707|>", "<|LOC_708|>", "<|LOC_709|>", "<|LOC_710|>", "<|LOC_711|>", "<|LOC_712|>", "<|LOC_713|>", "<|LOC_714|>", "<|LOC_715|>", "<|LOC_716|>", "<|LOC_717|>", "<|LOC_718|>", "<|LOC_719|>", "<|LOC_720|>", "<|LOC_721|>", "<|LOC_722|>", "<|LOC_723|>", "<|LOC_724|>", "<|LOC_725|>", "<|LOC_726|>", "<|LOC_727|>", "<|LOC_728|>", "<|LOC_729|>", "<|LOC_730|>", "<|LOC_731|>", "<|LOC_732|>", "<|LOC_733|>", "<|LOC_734|>", "<|LOC_735|>", "<|LOC_736|>", "<|LOC_737|>", "<|LOC_738|>", "<|LOC_739|>", "<|LOC_740|>", "<|LOC_741|>", "<|LOC_742|>", "<|LOC_743|>", "<|LOC_744|>", "<|LOC_745|>", "<|LOC_746|>", "<|LOC_747|>", "<|LOC_748|>", "<|LOC_749|>", "<|LOC_750|>", "<|LOC_751|>", "<|LOC_752|>", "<|LOC_753|>", "<|LOC_754|>", "<|LOC_755|>", "<|LOC_756|>", "<|LOC_757|>", "<|LOC_758|>", "<|LOC_759|>", "<|LOC_760|>", "<|LOC_761|>", "<|LOC_762|>", "<|LOC_763|>", "<|LOC_764|>", "<|LOC_765|>", "<|LOC_766|>", "<|LOC_767|>", "<|LOC_768|>", "<|LOC_769|>", "<|LOC_770|>", "<|LOC_771|>", "<|LOC_772|>", "<|LOC_773|>", "<|LOC_774|>", "<|LOC_775|>", "<|LOC_776|>", "<|LOC_777|>", "<|LOC_778|>", "<|LOC_779|>", "<|LOC_780|>", "<|LOC_781|>", "<|LOC_782|>", "<|LOC_783|>", "<|LOC_784|>", "<|LOC_785|>", "<|LOC_786|>", "<|LOC_787|>", "<|LOC_788|>", "<|LOC_789|>", "<|LOC_790|>", "<|LOC_791|>", "<|LOC_792|>", "<|LOC_793|>", "<|LOC_794|>", "<|LOC_795|>", "<|LOC_796|>", "<|LOC_797|>", "<|LOC_798|>", "<|LOC_799|>", "<|LOC_800|>", "<|LOC_801|>", "<|LOC_802|>", "<|LOC_803|>", "<|LOC_804|>", "<|LOC_805|>", "<|LOC_806|>", "<|LOC_807|>", "<|LOC_808|>", "<|LOC_809|>", "<|LOC_810|>", "<|LOC_811|>", "<|LOC_812|>", "<|LOC_813|>", "<|LOC_814|>", "<|LOC_815|>", "<|LOC_816|>", "<|LOC_817|>", "<|LOC_818|>", "<|LOC_819|>", "<|LOC_820|>", "<|LOC_821|>", "<|LOC_822|>", "<|LOC_823|>", "<|LOC_824|>", "<|LOC_825|>", "<|LOC_826|>", "<|LOC_827|>", "<|LOC_828|>", "<|LOC_829|>", "<|LOC_830|>", "<|LOC_831|>", "<|LOC_832|>", "<|LOC_833|>", "<|LOC_834|>", "<|LOC_835|>", "<|LOC_836|>", "<|LOC_837|>", "<|LOC_838|>", "<|LOC_839|>", "<|LOC_840|>", "<|LOC_841|>", "<|LOC_842|>", "<|LOC_843|>", "<|LOC_844|>", "<|LOC_845|>", "<|LOC_846|>", "<|LOC_847|>", "<|LOC_848|>", "<|LOC_849|>", "<|LOC_850|>", "<|LOC_851|>", "<|LOC_852|>", "<|LOC_853|>", "<|LOC_854|>", "<|LOC_855|>", "<|LOC_856|>", "<|LOC_857|>", "<|LOC_858|>", "<|LOC_859|>", "<|LOC_860|>", "<|LOC_861|>", "<|LOC_862|>", "<|LOC_863|>", "<|LOC_864|>", "<|LOC_865|>", "<|LOC_866|>", "<|LOC_867|>", "<|LOC_868|>", "<|LOC_869|>", "<|LOC_870|>", "<|LOC_871|>", "<|LOC_872|>", "<|LOC_873|>", "<|LOC_874|>", "<|LOC_875|>", "<|LOC_876|>", "<|LOC_877|>", "<|LOC_878|>", "<|LOC_879|>", "<|LOC_880|>", "<|LOC_881|>", "<|LOC_882|>", "<|LOC_883|>", "<|LOC_884|>", "<|LOC_885|>", "<|LOC_886|>", "<|LOC_887|>", "<|LOC_888|>", "<|LOC_889|>", "<|LOC_890|>", "<|LOC_891|>", "<|LOC_892|>", "<|LOC_893|>", "<|LOC_894|>", "<|LOC_895|>", "<|LOC_896|>", "<|LOC_897|>", "<|LOC_898|>", "<|LOC_899|>", "<|LOC_900|>", "<|LOC_901|>", "<|LOC_902|>", "<|LOC_903|>", "<|LOC_904|>", "<|LOC_905|>", "<|LOC_906|>", "<|LOC_907|>", "<|LOC_908|>", "<|LOC_909|>", "<|LOC_910|>", "<|LOC_911|>", "<|LOC_912|>", "<|LOC_913|>", "<|LOC_914|>", "<|LOC_915|>", "<|LOC_916|>", "<|LOC_917|>", "<|LOC_918|>", "<|LOC_919|>", "<|LOC_920|>", "<|LOC_921|>", "<|LOC_922|>", "<|LOC_923|>", "<|LOC_924|>", "<|LOC_925|>", "<|LOC_926|>", "<|LOC_927|>", "<|LOC_928|>", "<|LOC_929|>", "<|LOC_930|>", "<|LOC_931|>", "<|LOC_932|>", "<|LOC_933|>", "<|LOC_934|>", "<|LOC_935|>", "<|LOC_936|>", "<|LOC_937|>", "<|LOC_938|>", "<|LOC_939|>", "<|LOC_940|>", "<|LOC_941|>", "<|LOC_942|>", "<|LOC_943|>", "<|LOC_944|>", "<|LOC_945|>", "<|LOC_946|>", "<|LOC_947|>", "<|LOC_948|>", "<|LOC_949|>", "<|LOC_950|>", "<|LOC_951|>", "<|LOC_952|>", "<|LOC_953|>", "<|LOC_954|>", "<|LOC_955|>", "<|LOC_956|>", "<|LOC_957|>", "<|LOC_958|>", "<|LOC_959|>", "<|LOC_960|>", "<|LOC_961|>", "<|LOC_962|>", "<|LOC_963|>", "<|LOC_964|>", "<|LOC_965|>", "<|LOC_966|>", "<|LOC_967|>", "<|LOC_968|>", "<|LOC_969|>", "<|LOC_970|>", "<|LOC_971|>", "<|LOC_972|>", "<|LOC_973|>", "<|LOC_974|>", "<|LOC_975|>", "<|LOC_976|>", "<|LOC_977|>", "<|LOC_978|>", "<|LOC_979|>", "<|LOC_980|>", "<|LOC_981|>", "<|LOC_982|>", "<|LOC_983|>", "<|LOC_984|>", "<|LOC_985|>", "<|LOC_986|>", "<|LOC_987|>", "<|LOC_988|>", "<|LOC_989|>", "<|LOC_990|>", "<|LOC_991|>", "<|LOC_992|>", "<|LOC_993|>", "<|LOC_994|>", "<|LOC_995|>", "<|LOC_996|>", "<|LOC_997|>", "<|LOC_998|>", "<|LOC_999|>", "<|LOC_1000|>", "<|LOC_BEGIN|>", "<|LOC_END|>", "<|LOC_SEP|>", "<|CROP_COL_SEP|>", "<|CROP_ROW_SEP|>", "<|IMAGE_SEP|>"]}
tokenization_ernie4_5.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Baidu, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from shutil import copyfile
17
+ from typing import Dict, List, Optional, Tuple, Union
18
+ import torch
19
+ import numpy as np
20
+ import sentencepiece as spm
21
+
22
+ from transformers.tokenization_utils import PreTrainedTokenizer
23
+ from transformers.tokenization_utils_base import (
24
+ PaddingStrategy,
25
+ )
26
+ from transformers.utils import logging
27
+
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class Ernie4_5_Tokenizer(PreTrainedTokenizer):
33
+
34
+ vocab_files_names = {
35
+ "vocab_file": "tokenizer.model",
36
+ }
37
+ # Model input names expected by the tokenizer
38
+ model_input_names = ["input_ids", "position_ids", "attention_mask", "labels"]
39
+ # Padding side (where to add padding tokens)
40
+ padding_side = "right"
41
+
42
+ def __init__(
43
+ self,
44
+ vocab_file,
45
+ bos_token="<s>",
46
+ cls_token="<cls>",
47
+ eos_token="</s>",
48
+ mask_token="<mask:0>",
49
+ pad_token="<pad>",
50
+ sep_token="<sep>",
51
+ unk_token="<unk>",
52
+ additional_special_tokens=None,
53
+ split_special_tokens=False,
54
+ alpha=None,
55
+ tokenizer_alpha=None,
56
+ **kwargs,
57
+ ):
58
+ """
59
+ Initialize the ERNIE tokenizer.
60
+
61
+ Args:
62
+ vocab_file (str): Path to the SentencePiece model file.
63
+ bos_token (str, optional): Beginning of sentence token. Defaults to "<s>".
64
+ cls_token (str, optional): Classification token. Defaults to "<cls>".
65
+ eos_token (str, optional): End of sentence token. Defaults to "</s>".
66
+ mask_token (str, optional): Mask token. Defaults to "<mask:0>".
67
+ pad_token (str, optional): Padding token. Defaults to "<pad>".
68
+ sep_token (str, optional): Separator token. Defaults to "<sep>".
69
+ unk_token (str, optional): Unknown token. Defaults to "<unk>".
70
+ additional_special_tokens (List[str], optional): Additional special tokens.
71
+ Defaults to ["<mask:1>", "<mask:7>"].
72
+ split_special_tokens (bool, optional): Whether to split special tokens. Defaults to False.
73
+ alpha (None, optional): Currently unused parameter. Reserved for future use.
74
+ tokenizer_alpha (float, optional): Alpha parameter for SentencePiece sampling.
75
+ **kwargs: Additional keyword arguments passed to the parent class.
76
+ """
77
+
78
+ self.vocab_file = vocab_file
79
+ self.sp_model = spm.SentencePieceProcessor()
80
+ self.sp_model.Load(vocab_file)
81
+ self.alpha = alpha
82
+ self.pad_id = self._convert_token_to_id(pad_token)
83
+ self.tokenizer_alpha = tokenizer_alpha
84
+
85
+ if additional_special_tokens is None:
86
+ additional_special_tokens = ["<mask:1>", "<mask:7>"]
87
+ super().__init__(
88
+ bos_token=bos_token,
89
+ cls_token=cls_token,
90
+ eos_token=eos_token,
91
+ mask_token=mask_token,
92
+ pad_token=pad_token,
93
+ sep_token=sep_token,
94
+ unk_token=unk_token,
95
+ additional_special_tokens=additional_special_tokens,
96
+ split_special_tokens=split_special_tokens,
97
+ **kwargs,
98
+ )
99
+
100
+ @property
101
+ def vocab_size(self):
102
+ """Returns the size of the vocabulary.
103
+
104
+ Returns:
105
+ int: The number of tokens in the vocabulary.
106
+ """
107
+ return self.sp_model.vocab_size()
108
+
109
+ def get_vocab(self):
110
+ """Get the vocabulary as a dictionary mapping tokens to their IDs.
111
+
112
+ Returns:
113
+ dict: A dictionary mapping tokens to their corresponding IDs.
114
+ """
115
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
116
+ vocab.update(self.added_tokens_encoder)
117
+ return vocab
118
+
119
+ def _tokenize(self, text):
120
+ """Tokenize text using SentencePiece.
121
+
122
+ Args:
123
+ text (str): The text to tokenize.
124
+
125
+ Returns:
126
+ list: A list of tokens.
127
+ """
128
+ if self.tokenizer_alpha is not None:
129
+ return self.sp_model.encode_as_pieces(
130
+ text,
131
+ enable_sampling=True,
132
+ nbest_size=-1,
133
+ alpha=self.tokenizer_alpha,
134
+ )
135
+ else:
136
+ return self.sp_model.encode_as_pieces(text)
137
+
138
+ def _convert_token_to_id(self, token):
139
+ """Convert a token (str) to an ID using the vocabulary.
140
+
141
+ Args:
142
+ token (str): The token to convert.
143
+
144
+ Returns:
145
+ int: The corresponding token ID.
146
+ """
147
+ return self.sp_model.piece_to_id(token)
148
+
149
+ def _convert_id_to_token(self, id):
150
+ """Convert an ID to a token (str) using the vocabulary.
151
+
152
+ Args:
153
+ id (int): The token ID to convert.
154
+
155
+ Returns:
156
+ str: The corresponding token.
157
+ """
158
+ if id >= self.vocab_size:
159
+ return self.unk_token
160
+ else:
161
+ return self.sp_model.id_to_piece(id)
162
+
163
+ def convert_tokens_to_string(self, tokens):
164
+ """Convert a sequence of tokens back to a single string.
165
+
166
+ Args:
167
+ tokens (List[str]): A list of tokens to convert.
168
+
169
+ Returns:
170
+ str: The reconstructed string.
171
+ """
172
+ current_sub_tokens = []
173
+ out_string = ""
174
+ prev_is_special = False
175
+ for token in tokens:
176
+ # make sure that special tokens are not decoded using sentencepiece model
177
+ if token in self.all_special_tokens:
178
+ if not prev_is_special:
179
+ out_string += " "
180
+ out_string += self.sp_model.decode(current_sub_tokens) + token
181
+ prev_is_special = True
182
+ current_sub_tokens = []
183
+ else:
184
+ current_sub_tokens.append(token)
185
+ prev_is_special = False
186
+ out_string += self.sp_model.decode(current_sub_tokens)
187
+ return out_string
188
+
189
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
190
+ """Build model inputs by adding special tokens to sequences.
191
+
192
+ Args:
193
+ token_ids_0 (List[int]): List of token IDs for the first sequence.
194
+ token_ids_1 (List[int], optional): List of token IDs for the second sequence.
195
+
196
+ Returns:
197
+ List[int]: List of token IDs with special tokens added.
198
+ """
199
+ output = token_ids_0
200
+ last_cls_index = -1
201
+ last_sep_index = -1
202
+ if self.cls_token_id in output:
203
+ last_cls_index = len(output) - output[::-1].index(self.cls_token_id) - 1
204
+ if self.sep_token_id in output:
205
+ last_sep_index = len(output) - output[::-1].index(self.sep_token_id) - 1
206
+
207
+ if last_cls_index > last_sep_index:
208
+ next_token_id = self.sep_token_id
209
+ elif last_sep_index > last_cls_index:
210
+ next_token_id = self.cls_token_id
211
+ else:
212
+ output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
213
+ next_token_id = self.cls_token_id
214
+
215
+ output = [self.bos_token_id] + output
216
+ # Assume no markup in text if token_ids_1 is given.
217
+ if token_ids_1 is not None:
218
+ output = output + token_ids_1 + [next_token_id]
219
+ return output
220
+
221
+ def get_special_tokens_mask(
222
+ self, token_ids_0, token_ids_1=None, already_has_special_tokens=False
223
+ ):
224
+ """Get a mask showing which tokens are special tokens.
225
+
226
+ Args:
227
+ token_ids_0 (List[int]): List of token IDs for the first sequence.
228
+ token_ids_1 (List[int], optional): List of token IDs for the second sequence.
229
+ already_has_special_tokens (bool): Whether the tokens already include special tokens.
230
+
231
+ Returns:
232
+ List[int]: A mask where 1 indicates special tokens and 0 indicates regular tokens.
233
+ """
234
+ if already_has_special_tokens:
235
+ return super().get_special_tokens_mask(
236
+ token_ids_0, token_ids_1, already_has_special_tokens=True
237
+ )
238
+
239
+ # [bos_token, cls_token, tokens_0, sep_token]
240
+ if token_ids_1 is None:
241
+ return [1, 1] + ([0] * len(token_ids_0)) + [1]
242
+ # [bos_token, cls_token, tokens_0, sep_token, tokens_1, cls_token]
243
+ return [1, 1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
244
+
245
+ def save_vocabulary(
246
+ self, save_directory, filename_prefix: Optional[str] = None
247
+ ) -> Tuple[str]:
248
+ """
249
+ Save the vocabulary and special tokens file to a directory.
250
+
251
+ Args:
252
+ save_directory (str): The directory in which to save the vocabulary.
253
+ filename_prefix (Optional[str]): Optional prefix for the saved filename.
254
+
255
+ Returns:
256
+ Tuple[str]: Paths to the files saved.
257
+
258
+ Raises:
259
+ ValueError: If the save_directory is not a valid directory.
260
+ """
261
+ if not os.path.isdir(save_directory):
262
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
263
+ return
264
+ out_vocab_file = os.path.join(
265
+ save_directory,
266
+ (filename_prefix + "-" if filename_prefix else "")
267
+ + self.resource_files_names["vocab_file"],
268
+ )
269
+
270
+ if os.path.abspath(self.vocab_file) != os.path.abspath(
271
+ out_vocab_file
272
+ ) and os.path.isfile(self.vocab_file):
273
+ copyfile(self.vocab_file, out_vocab_file)
274
+ elif not os.path.isfile(self.vocab_file):
275
+ with open(out_vocab_file, "wb") as fi:
276
+ content_spiece_model = self.sp_model.serialized_model_proto()
277
+ fi.write(content_spiece_model)
278
+
279
+ return (out_vocab_file,)
280
+
281
+ def _pad(
282
+ self,
283
+ encoded_inputs: Union[Dict],
284
+ max_length: Optional[int] = None,
285
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
286
+ pad_to_multiple_of: Optional[int] = None,
287
+ padding_side: Optional[str] = None,
288
+ return_attention_mask: Optional[bool] = None,
289
+ ) -> dict:
290
+ """
291
+ Pad encoded inputs according to specified strategy.
292
+
293
+ Args:
294
+ encoded_inputs (Union[Dict]): Dictionary of encoded inputs.
295
+ max_length (Optional[int]): Maximum length to pad to.
296
+ padding_strategy (PaddingStrategy): Strategy for padding.
297
+ pad_to_multiple_of (Optional[int]): Pad to a multiple of this value.
298
+ return_attention_mask (Optional[bool]): Whether to return attention mask.
299
+
300
+ Returns:
301
+ dict: Dictionary with padded inputs and optional attention mask.
302
+
303
+ Raises:
304
+ ValueError: If attention_mask has unexpected type or invalid padding strategy.
305
+ """
306
+ if return_attention_mask is None:
307
+ return_attention_mask = "attention_mask" in self.model_input_names
308
+ if return_attention_mask:
309
+ required_input = encoded_inputs[self.model_input_names[0]]
310
+ if padding_strategy == PaddingStrategy.LONGEST:
311
+ max_length = len(required_input)
312
+ if (
313
+ max_length is not None
314
+ and pad_to_multiple_of is not None
315
+ and (max_length % pad_to_multiple_of != 0)
316
+ ):
317
+ max_length = (
318
+ (max_length // pad_to_multiple_of) + 1
319
+ ) * pad_to_multiple_of
320
+ needs_to_be_padded = (
321
+ padding_strategy != PaddingStrategy.DO_NOT_PAD
322
+ and len(required_input) != max_length
323
+ )
324
+
325
+ if (
326
+ "attention_mask" in encoded_inputs
327
+ and encoded_inputs["attention_mask"] is not None
328
+ ):
329
+ attention_mask = encoded_inputs.pop("attention_mask")
330
+ if isinstance(attention_mask, torch.Tensor):
331
+ attention_mask = attention_mask.numpy()
332
+ elif isinstance(attention_mask, list):
333
+ attention_mask = np.array(attention_mask)
334
+ elif not isinstance(attention_mask, np.ndarray):
335
+ raise ValueError(
336
+ f"Unexpected type {type(attention_mask)} of attention_mask, "
337
+ )
338
+ else:
339
+ # Create default attention mask if none provided
340
+ attention_mask = np.tril(
341
+ np.ones((len(required_input), len(required_input)), dtype=np.int64)
342
+ )
343
+ attention_mask = np.expand_dims(attention_mask, axis=0)
344
+
345
+ if needs_to_be_padded:
346
+ difference = max_length - len(required_input)
347
+ if self.padding_side == "right":
348
+ if attention_mask.ndim == 1:
349
+ pad_width = [(0, difference)]
350
+ else:
351
+ pad_width = [(0, 0), (0, difference), (0, difference)]
352
+ elif self.padding_side == "left":
353
+ if attention_mask.ndim == 1:
354
+ pad_width = [(difference, 0)]
355
+ else:
356
+ pad_width = [(0, 0), (difference, 0), (difference, 0)]
357
+ else:
358
+ raise ValueError(
359
+ "Invalid padding strategy:" + str(self.padding_side)
360
+ )
361
+ attention_mask = np.pad(
362
+ attention_mask,
363
+ pad_width=pad_width,
364
+ mode="constant",
365
+ constant_values=0,
366
+ )
367
+
368
+ encoded_inputs = super()._pad(
369
+ encoded_inputs,
370
+ max_length,
371
+ padding_strategy=padding_strategy,
372
+ pad_to_multiple_of=pad_to_multiple_of,
373
+ return_attention_mask=False,
374
+ )
375
+ if return_attention_mask:
376
+ encoded_inputs["attention_mask"] = attention_mask.tolist()
377
+ return encoded_inputs
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34ef7db83df785924fb83d7b887b6e822a031c56e15cff40aaf9b982988180df
3
+ size 1614363
tokenizer_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "eos_token": "</s>",
4
+ "pad_token": "<unk>",
5
+ "unk_token": "<unk>",
6
+ "cls_token": "<|begin_of_sentence|>",
7
+ "sep_token": "<|end_of_sentence|>",
8
+ "mask_token": "<mask:1>",
9
+ "sys_start_token": "<mask:4>",
10
+ "sys_end_token": "<mask:5>",
11
+ "header_start_token": "<mask:6>",
12
+ "header_end_token": "<mask:7>",
13
+ "additional_special_tokens": null,
14
+ "tokenizer_class": "Ernie4_5_Tokenizer",
15
+ "auto_map": {
16
+ "AutoTokenizer": [
17
+ "tokenization_ernie4_5.Ernie4_5_Tokenizer",
18
+ null
19
+ ]
20
+ },
21
+ "chat_template": "{%- if not add_generation_prompt is defined -%}\n {%- set add_generation_prompt = true -%}\n{%- endif -%}\n{%- if not cls_token is defined -%}\n {%- set cls_token = \"<|begin_of_sentence|>\" -%}\n{%- endif -%}\n{%- if not sep_token is defined -%}\n {%- set sep_token = \"<|end_of_sentence|>\" -%}\n{%- endif -%}\n{{- cls_token -}}\n{%- for message in messages -%}\n {%- if message[\"role\"] == \"user\" -%}\n {{- \"User: \" + message[\"content\"] + \"\n\" -}}\n {%- elif message[\"role\"] == \"assistant\" -%}\n {{- \"Assistant: \" + message[\"content\"] + sep_token -}}\n {%- elif message[\"role\"] == \"system\" -%}\n {{- message[\"content\"] + \"\n\" -}}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{- \"Assistant: \" -}}\n{%- endif -%}"
22
+ }