Upload folder using huggingface_hub
Browse files- .gitattributes +2 -0
- LICENSE +49 -0
- README.md +164 -2
- config.json +65 -0
- configuration_ragplm.py +118 -0
- generation_config.json +6 -0
- modeling_ragplm.py +1260 -0
- performance.png +3 -0
- proteinmoe_architecture.png +3 -0
- pytorch_model-00001-of-00007.bin +3 -0
- pytorch_model-00002-of-00007.bin +3 -0
- pytorch_model-00003-of-00007.bin +3 -0
- pytorch_model-00004-of-00007.bin +3 -0
- pytorch_model-00005-of-00007.bin +3 -0
- pytorch_model-00006-of-00007.bin +3 -0
- pytorch_model-00007-of-00007.bin +3 -0
- pytorch_model.bin.index.json +0 -0
- special_tokens_map.json +1 -0
- tokenization.py +268 -0
- tokenizer_config.json +19 -0
- vocab.json +549 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
performance.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
proteinmoe_architecture.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
CHANGED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GENBIO AI COMMUNITY LICENSE AGREEMENT
|
2 |
+
|
3 |
+
This GenBio AI Community License Agreement (the “License”) constitutes an agreement between you or the legal entity you represent (“you” or “your”) and GENBIO.AI, INC. (“GenBio”), governing your use of the GenBio Materials. If you are using the GenBio Materials on behalf of a legal entity, you represent and warrant to GenBio that you have full legal authority to act on behalf of that legal entity as applicable under the License. If you do not have the authority to accept this License or if you disagree with any or all of the License, you shall not use the GenBio Materials in any manner. By using or distributing any portion or element of the GenBio Materials, you imply your agreement to be bound by the License.
|
4 |
+
|
5 |
+
“GenBio Materials” means any datasets, code, model weights or any other materials provided by GenBio at the following GitHub Page https://github.com/genbio-ai or Hugging Face Page https://huggingface.co/genbio-ai, including any updates or modifications made from time to time, whether in Source or Object form, and is made available to you under this License.
|
6 |
+
|
7 |
+
|
8 |
+
1. License Grant.
|
9 |
+
1.1 License Scope. Subject to the terms of this License, GenBio grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under GenBio’s intellectual property or other rights owned by GenBio embodied in the GenBio Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the GenBio Materials for any Non-Commercial Purposes.
|
10 |
+
1.2 Use Restrictions. Restricted activities in relation to the License or use of GenBio Materials include:
|
11 |
+
1.2.1 You shall use the GenBio Materials, Contributions, Derivative Works, Outputs and Output Derivatives (as defined below) solely for Non-Commercial Purposes;
|
12 |
+
1.2.2 You shall not, directly or indirectly: (a) use or provide access to any Outputs or Output Derivatives to train, optimize, improve, or otherwise enhance the functionality or performance of any machine learning models or related technologies that are similar to the GenBio Materials; (b) engage in any form of model distillation or other methods that would achieve the purposes described in subsection (a) above. Notwithstanding the foregoing, you may use Outputs and Output Derivatives to train, optimize, improve, or enhance the functionality or performance of: (i) The GenBio Materials itself; and (ii) downstream Derivative Works of the GenBio Materials;
|
13 |
+
1.2.3 Your use of the GenBio Materials shall be subject to any additional terms and conditions that: (a) GenBio provides to you separately; or (b) GenBio otherwise makes available to you.
|
14 |
+
|
15 |
+
2. Sharing and Distribution.
|
16 |
+
2.1 Subject to Section 1, if you distribute or make available the GenBio Materials or a Derivative Work to a third party for your Non-Commercial Purposes, in Source or Object form, you shall:
|
17 |
+
2.1.1 provide a copy of this License to that third party;
|
18 |
+
2.1.2 retain the following attribution notice within a “Notice” text file distributed as a part of such copies: “This is licensed under the GenBio AI Community License Agreement, Copyright © GENBIO.AI, INC. All Rights Reserved”; and
|
19 |
+
2.1.3 prominently display “Powered by GenBio AI” on a related website, user interface, blogpost, about page, or product documentation.
|
20 |
+
2.2 If You create a Derivative Work, you may add your own attribution notice(s) to the “Notice” text file included with that Derivative Work, provided that you clearly indicate which attributions apply to the GenBio Materials and state in the “Notice” text file that you changed the GenBio Materials and how it was modified.
|
21 |
+
|
22 |
+
3. Submission of Contribution.
|
23 |
+
Unless you explicitly state otherwise, any Contribution intentionally submitted for inclusion in the GenBio Materials by you to GenBio shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with GenBio regarding such Contributions.
|
24 |
+
|
25 |
+
4. Export Control.
|
26 |
+
You shall comply with the applicable U.S. Foreign Corrupt Practices Act and all applicable export laws, restrictions and regulations of the U.S. Department of Commerce, and any other applicable U.S. and foreign authority.
|
27 |
+
|
28 |
+
5. Disclaimer of Warranty.
|
29 |
+
GENBIO MATERIALS PROVIDED BY GENBIO OR ANY OUTPUT YOU RECEIVED ARE PROVIDED “AS IS.” EXCEPT TO THE EXTENT PROHIBITED BY LAW. GENBIO MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND, WHETHER EXPRESS, IMPLIED OR OTHERWISE, REGARDING THE ACCURACY, COMPLETENESS OR PERFORMANCE OF THE SERVICES AND YOUR OUTPUT, OR WITH RESPECT TO SATISFACTORY QUALITY, FITNESS FOR A PARTICULAR PURPOSE OR NON-INFRINGEMENT.
|
30 |
+
|
31 |
+
6. Limitation of Liability.
|
32 |
+
In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the GenBio Materials (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
33 |
+
|
34 |
+
7. General Terms.
|
35 |
+
7.1 Relationship of Parties. You and GenBio are independent contractors, and nothing herein shall be deemed to constitute either party as the agent or representative of the other or both parties as joint venturers or partners for any purpose.
|
36 |
+
7.2 Assignment. This License and the rights and obligations herein may not be assigned or transferred, in whole or in part, by You without the prior written consent of GenBio. Any assignment in violation of this provision is void. GenBio may freely assign or transfer this License, in whole or in part. This License shall be binding upon, and inure to the benefit of, the successors and permitted assigns of the parties.
|
37 |
+
7.3 Governing Law. This License shall be governed, construed and interpreted in accordance with the laws of the State of California, without giving effect to principles of conflicts of law. Each of the parties to this License consents to the exclusive jurisdiction and venue of the courts of the state and federal courts of California.
|
38 |
+
7.4 Severability. If any provision of this License is held to be invalid, illegal or unenforceable in any respect, that provision shall be limited or eliminated to the minimum extent necessary so that this License otherwise remains in full force and effect and enforceable.
|
39 |
+
|
40 |
+
8. Definitions.
|
41 |
+
8.1 “Commercial Entity” means any entity engaged in any activity intended for or directed toward commercial advantage or monetary compensation, including, without limitation, the development of any product or service intended to be sold or made available for a fee. For the purpose of this License, references to a Commercial Entity expressly exclude any universities, non-profit organizations, not-for-profit entities, research institutes and educational and government bodies.
|
42 |
+
8.2 “Contribution” means any work of authorship, including the original version of the GenBio Materials and any modifications or additions to that GenBio Materials or Derivative Works thereof, that is intentionally submitted to GenBio for inclusion in the GenBio Materials by the copyright owner or by an individual or legal entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to GenBio or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, GenBio for the purpose of discussing and improving the GenBio Materials, but excluding Outputs and all communications that are conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution”.
|
43 |
+
8.3 “Contributor” means GenBio and any individual or legal entity on behalf of whom a Contribution has been received by GenBio and subsequently incorporated within the GenBio Materials.
|
44 |
+
8.4 “Derivative Work” means any work, whether in Source or Object form, that is based on (or derived from) the GenBio Materials and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the GenBio Materials and Derivative Works thereof.
|
45 |
+
8.5 “Non-Commercial Purposes” means uses not intended for or directed toward commercial advantage or monetary compensation, or the facilitation of development of any product or service to be sold or made available for a fee. For the avoidance of doubt, the provision of Outputs as a service is not a Non-Commercial Purpose.
|
46 |
+
8.6 “Object” means any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
47 |
+
8.7 “Output” means any output, including any protein sequence, structure prediction, functional annotation, molecule, descriptions of a molecule, model, sequence, text, and/or image that is elicited directly or indirectly by, or otherwise made available to, you in connection with your use of the GenBio Materials, including, but not limited to, the use of AI-Powered Technology. For the avoidance of doubt, it includes any intermediate results, such as activations across model layers, intermediate outputs from model layers (e.g., attention maps), as well as gradients and embeddings produced by the GenBio Materials.
|
48 |
+
8.8 “Output Derivatives” means any enhancements, modifications and derivative works of Outputs (including, but not limited to, any derivative sequences or molecules).
|
49 |
+
8.9 “Source” means the preferred form for making modifications, including but not limited to GenBio Materials source code, documentation source, and configuration files.
|
README.md
CHANGED
@@ -1,5 +1,167 @@
|
|
1 |
---
|
2 |
license: other
|
3 |
-
license_name: genbio.ai-community-license
|
4 |
-
license_link: LICENSE
|
5 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: other
|
|
|
|
|
3 |
---
|
4 |
+
|
5 |
+
# AIDO.Protein-RAG-16B-proteingym-dms-zeroshot
|
6 |
+
|
7 |
+
AIDO.Protein-RAG-16B-proteingym-dms-zeroshot is a multimodal protein language model that integrates Multiple Sequence Alignment (MSA) and structural data, building upon the [AIDO.Protein-16B](https://huggingface.co/genbio-ai/AIDO.Protein-16B) foundation. The training process comprises three main stages:
|
8 |
+
|
9 |
+
1. 2D RoPE encoding fine-tuning
|
10 |
+
2. Initial training on 100 billion tokens from UniRef50/UniClust30 MSA data
|
11 |
+
3. Subsequent training on 23 billion tokens from AlphaFold Database MSA and structural data
|
12 |
+
|
13 |
+
## Model Architecture
|
14 |
+
|
15 |
+
AIDO.Protein-RAG-16B-proteingym-dms-zeroshot employs a transformer encoder-only architecture featuring sparse Mixture-of-Experts (MoE) layers that replace dense MLP layers in each transformer block. Utilizing single amino acid tokenization and optimized through masked language modeling (MLM), the model activates 2 experts per token via top-2 routing mechanisms.
|
16 |
+
|
17 |
+
<center><img src="proteinmoe_architecture.png" alt="An Overview of AIDO.Protein" style="width:70%; height:auto;" /></center>
|
18 |
+
|
19 |
+
More architecture details are shown below:
|
20 |
+
|
21 |
+
| Model Arch Component | Value |
|
22 |
+
| ----------------------- | :---: |
|
23 |
+
| Num Attention Head | 36 |
|
24 |
+
| Num Hidden Layer | 36 |
|
25 |
+
| Hidden Size | 2304 |
|
26 |
+
| FFN Hidden Size | 7680 |
|
27 |
+
| Num MoE Layer per Block | 8 |
|
28 |
+
| Num MoE Layer per Token | 2 |
|
29 |
+
| Vocab Size | 44 |
|
30 |
+
| Context Length | 2048 |
|
31 |
+
|
32 |
+
## Pre-training of AIDO.Protein-RAG-16B-proteingym-dms-zeroshot
|
33 |
+
|
34 |
+
Here we briefly introduce the details of pre-training of AIDO.Protein-RAG-16B-proteingym-dms-zeroshot. Mainly divided into three stages: (1) 1D -> 2D RoPE encoding fine-tuning; (2) UniRef50/Uniclust30 MSA fine-tuning; (3) AlphaFold Database MSA & Structure tokens fine-tuning
|
35 |
+
|
36 |
+
### Data
|
37 |
+
|
38 |
+
**UniRef50/Uniclust30 MSA dataset**: We utilized sequences from UniRef50 as queries to search for homologous sequences in UniClust30, subsequently constructing multiple sequence alignments (MSAs). UniRef50 comprises a total of 53.6 million sequences. Using HHblits, we searched all sequences, identifying over 25 homologous sequences for 23.7 million of them. This dataset was directly used as the training set, referred to as `HHblits_MSA`. The remaining 29.9 million sequences were input into MSA Retriever, resulting in 7.7 million sequences with more than 25 homologous sequences. This dataset was designated as `Retriever_MSA`. During training, RAGPLM randomly sampled from the two datasets with probabilities of 0.75 and 0.25. Refer to AIDO.Protein-RAG-3B paper ([link](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1)) for more information.
|
39 |
+
|
40 |
+
**AlphaFold Database MSA & Structure dataset**: We downloaded all structural data from the AlphaFold Database and kept only those where more than 40% of amino acids had a pLDDT score > 70. The remaining sequences were clustered using `mmseqs` (`seq id=0.5`), and one representative per cluster was retained, resulting in 46.9 million sequence/structure pairs. For each structure, we used [genbio-ai/AIDO.StructureTokenizer](https://huggingface.co/genbio-ai/AIDO.StructureTokenizer) to obtain structure tokens and embeddings. [MSA Retriever](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1) was used to obtain the corresponding MSA.
|
41 |
+
|
42 |
+
### Training Details
|
43 |
+
|
44 |
+
Model training is divided into three stages:
|
45 |
+
|
46 |
+
#### (1) 1D -> 2D RoPE Encoding Fine-tuning
|
47 |
+
|
48 |
+
Same training data as [AIDO.Protein-16B](https://huggingface.co/genbio-ai/AIDO.Protein-16B), but with [2D rotary position embedding](https://arxiv.org/abs/2406.05347) for token encoding.
|
49 |
+
|
50 |
+
#### (2) UniRef50/UniClust30 MSA Fine-tuning
|
51 |
+
|
52 |
+
The model from Stage 1 is further fine-tuned on the UniRef50/Uniclust30 MSA dataset. See the [AIDO.Protein-RAG-3B paper](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1) for more.
|
53 |
+
|
54 |
+
#### (3) AlphaFold Database MSA & Structure Fine-tuning
|
55 |
+
|
56 |
+
We fine-tuned the model with concatenated query and homologous sequences. Structure embeddings (dim = 384) are linearly mapped to 2304 and added to the query token embeddings.
|
57 |
+
|
58 |
+
##### Sequence Masking
|
59 |
+
|
60 |
+
* Randomly sample `0.05 × L` span positions from a query of length `L`. Span lengths follow a geometric distribution (`p=0.2`), capped at length 10. On average, ~15% of query tokens are masked.
|
61 |
+
|
62 |
+
* When a residue is selected, its aligned residues across all sequences (MSA column) are also masked.
|
63 |
+
|
64 |
+
* For masked MSA columns: 80% are replaced with `<MASK>`, 10% with random amino acids, and 10% left unchanged.
|
65 |
+
|
66 |
+
##### Structure Masking
|
67 |
+
|
68 |
+
* In 20% of cases, structure embeddings are replaced with 0.
|
69 |
+
|
70 |
+
* In 80% of cases, a number of amino acids is sampled using the BetaLinear30 distribution and corresponding embeddings are zeroed. (BetaLinear30 = 20% Uniform(0,1) + 80% Beta(3,9)).
|
71 |
+
|
72 |
+
##### Positional Embedding
|
73 |
+
|
74 |
+
We use [2D rotary position embedding](https://arxiv.org/abs/2406.05347) to help the model distinguish token chain identities and residue indices. See AIDO.Protein-RAG-3B paper ([link](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1)) for more information.
|
75 |
+
|
76 |
+
##### Loss Function
|
77 |
+
|
78 |
+
Total loss is a weighted sum of sequence loss (weight 1.0) and structure loss (weight 0.025).
|
79 |
+
|
80 |
+
* **Sequence loss**: CrossEntropy loss for masked token prediction.
|
81 |
+
|
82 |
+
* **Structure loss**: CrossEntropy loss for masked structure token prediction.
|
83 |
+
|
84 |
+
| Hyper-params | (1) 1D -> 2D fine-tuning | (2) UniRef50/Uniclust30 MSA fine-tuning | (3) AFDB MSA & Structure tokens fine-tuning |
|
85 |
+
| --------------------------- | :---------------------: | :------------------------------------: | :----------------------------------------: |
|
86 |
+
| Initialized parameters | AIDO.Protein-16B | Stage (1) | Stage (2) |
|
87 |
+
| Data | ColabFoldDB, UniRef | HHblits_MSA, Retriever_MSA | AFDB MSA & Structure tokens |
|
88 |
+
| Global Batch Size | 512 | 256 | 256 |
|
89 |
+
| Sequence length | 2048 | 12800 | 12800 |
|
90 |
+
| Per Device Micro Batch Size | 1 | 1 | 1 |
|
91 |
+
| Precision | Mixed FP32-FP16 | Mixed FP32-FP16 | Mixed FP32-FP16 |
|
92 |
+
| LR | [5e-6,5e-5] | [1e-6, 1e-5] | 1e-5 |
|
93 |
+
| Num Tokens | 10 billion | 100 billion | 23 billion |
|
94 |
+
| Structural loss | N/A | N/A | 0.025 |
|
95 |
+
|
96 |
+
### Tokenization
|
97 |
+
|
98 |
+
We encode protein sequence with single amino acid resolution with 44 vocabularies, where 24 tokens represent amino acid types and 20 are special tokens. Sequences were also suffixed with a `[SEP]` token as hooks for downstream tasks.
|
99 |
+
|
100 |
+
## Results
|
101 |
+
|
102 |
+
### Zero-shot DMS score
|
103 |
+
|
104 |
+
<center><img src="performance.png" alt="performance" style="width:100%; height:auto;" /></center>
|
105 |
+
|
106 |
+
## How to Run
|
107 |
+
|
108 |
+
### Load the model and tokenizer
|
109 |
+
|
110 |
+
```python
|
111 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModelForMaskedLM
|
112 |
+
|
113 |
+
tokenizer = AutoTokenizer.from_pretrained("genbio-ai/AIDO.Protein-RAG-16B-proteingym-dms-zeroshot", trust_remote_code=True)
|
114 |
+
model = AutoModelForCausalLM.from_pretrained("genbio-ai/AIDO.Protein-RAG-16B-proteingym-dms-zeroshot", trust_remote_code=True, torch_dtype=torch.bfloat16)
|
115 |
+
model = model.bfloat16().eval().to('cuda:0')
|
116 |
+
```
|
117 |
+
|
118 |
+
### Clone the github respository and install environment **TODO**
|
119 |
+
|
120 |
+
Please read introduction of [github respository](https://gitlab.genbio.ai/pan.li/ragplm_zeroshot/-/tree/master) to get the detail of installing environment and running method.
|
121 |
+
|
122 |
+
```bash
|
123 |
+
conda create -n ragplm python=3.11 -y
|
124 |
+
conda activate ragplm
|
125 |
+
|
126 |
+
pip install tabulate seaborn deepspeed
|
127 |
+
pip install git+https://github.com/genbio-ai/ModelGenerator.git
|
128 |
+
|
129 |
+
git clone [email protected]:pan.li/ragplm_zeroshot.git
|
130 |
+
cd ragplm_zeroshot
|
131 |
+
|
132 |
+
tar xf dms_data.tar.gz
|
133 |
+
tar xf struc_data.tar.gz
|
134 |
+
mkdir output
|
135 |
+
```
|
136 |
+
|
137 |
+
### Run zero-shot
|
138 |
+
|
139 |
+
```bash
|
140 |
+
python compute_fitness.py --dms_ids PTEN_HUMAN_Mighell_2018
|
141 |
+
```
|
142 |
+
|
143 |
+
# Citation
|
144 |
+
|
145 |
+
Please cite AIDO.Protein-RAG-16B-proteingym-dms-zeroshot using the following BibTex code:
|
146 |
+
|
147 |
+
```
|
148 |
+
@inproceedings{sun_mixture_2024,
|
149 |
+
title = {Mixture of Experts Enable Efficient and Effective Protein Understanding and Design},
|
150 |
+
url = {https://www.biorxiv.org/content/10.1101/2024.11.29.625425v1},
|
151 |
+
doi = {10.1101/2024.11.29.625425},
|
152 |
+
publisher = {bioRxiv},
|
153 |
+
author = {Sun, Ning and Zou, Shuxian and Tao, Tianhua and Mahbub, Sazan and Li, Dian and Zhuang, Yonghao and Wang, Hongyi and Cheng, Xingyi and Song, Le and Xing, Eric P.},
|
154 |
+
year = {2024},
|
155 |
+
booktitle={NeurIPS 2024 Workshop on AI for New Drug Modalities},
|
156 |
+
}
|
157 |
+
|
158 |
+
@article {Li2024.12.02.626519,
|
159 |
+
author = {Li, Pan and Cheng, Xingyi and Song, Le and Xing, Eric},
|
160 |
+
title = {Retrieval Augmented Protein Language Models for Protein Structure Prediction},
|
161 |
+
url = {https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1},
|
162 |
+
year = {2024},
|
163 |
+
doi = {10.1101/2024.12.02.626519},
|
164 |
+
publisher = {bioRxiv},
|
165 |
+
booktitle={NeurIPS 2024 Workshop on Machine Learning in Structural Biology},
|
166 |
+
}
|
167 |
+
```
|
config.json
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "Protein/RAGPLM",
|
3 |
+
"add_bias_linear": true,
|
4 |
+
"add_qkv_bias": true,
|
5 |
+
"add_seq_emb_ln": false,
|
6 |
+
"add_str_emb_ln": false,
|
7 |
+
"apply_query_key_layer_scaling": true,
|
8 |
+
"apply_residual_connection_post_layernorm": false,
|
9 |
+
"architectures": [
|
10 |
+
"RAGPLMForConditionalGeneration"
|
11 |
+
],
|
12 |
+
"attention_dropout": 0.0,
|
13 |
+
"attention_softmax_in_fp32": true,
|
14 |
+
"auto_map": {
|
15 |
+
"AutoConfig": "configuration_ragplm.RAGPLMConfig",
|
16 |
+
"AutoModel": "modeling_ragplm.RAGPLMModel",
|
17 |
+
"AutoModelForCausalLM": "modeling_ragplm.RAGPLMForConditionalGeneration",
|
18 |
+
"AutoModelForSeq2SeqLM": "modeling_ragplm.RAGPLMForConditionalGeneration"
|
19 |
+
},
|
20 |
+
"bias_dropout_fusion": true,
|
21 |
+
"classifier_dropout": null,
|
22 |
+
"deepnorm": false,
|
23 |
+
"eos_token_id": 34,
|
24 |
+
"experts_per_token": 2,
|
25 |
+
"ffn_hidden_size": 7680,
|
26 |
+
"fp32_residual_connection": false,
|
27 |
+
"glu_activation": "swiglu",
|
28 |
+
"hidden_dropout": 0.0,
|
29 |
+
"hidden_size": 2304,
|
30 |
+
"is_causal": false,
|
31 |
+
"kv_channels": 64,
|
32 |
+
"layernorm_epsilon": 1e-05,
|
33 |
+
"lora": false,
|
34 |
+
"lora_alpha": 16,
|
35 |
+
"lora_before_position": false,
|
36 |
+
"lora_dropout": 0,
|
37 |
+
"lora_r": 8,
|
38 |
+
"mlp_lora": false,
|
39 |
+
"model_type": "ragplm",
|
40 |
+
"moe": true,
|
41 |
+
"multi_query_attention": false,
|
42 |
+
"multi_query_group_num": 2,
|
43 |
+
"num_attention_heads": 36,
|
44 |
+
"num_experts": 8,
|
45 |
+
"num_layers": 36,
|
46 |
+
"original_rope": true,
|
47 |
+
"pad_token_id": 0,
|
48 |
+
"padded_vocab_size": 640,
|
49 |
+
"post_layer_norm": true,
|
50 |
+
"qseq_output_dim": null,
|
51 |
+
"quantization_bit": 0,
|
52 |
+
"rmsnorm": true,
|
53 |
+
"rotary_embedding_2d": true,
|
54 |
+
"rotary_freq_base": 10000,
|
55 |
+
"seq_length": 2048,
|
56 |
+
"str_input_dim": 384,
|
57 |
+
"str_output_dim": 512,
|
58 |
+
"str_vocab_size": null,
|
59 |
+
"tie_word_embeddings": false,
|
60 |
+
"torch_dtype": "torch.bfloat16",
|
61 |
+
"transformers_version": "4.48.3",
|
62 |
+
"use_cache": true,
|
63 |
+
"use_pytorch_sdpa": true,
|
64 |
+
"vocab_size": 128
|
65 |
+
}
|
configuration_ragplm.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
import torch
|
3 |
+
|
4 |
+
class RAGPLMConfig(PretrainedConfig):
|
5 |
+
model_type = "ragplm"
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
num_layers=28,
|
9 |
+
padded_vocab_size=65024,
|
10 |
+
hidden_size=4096,
|
11 |
+
ffn_hidden_size=13696,
|
12 |
+
kv_channels=128,
|
13 |
+
num_attention_heads=32,
|
14 |
+
|
15 |
+
add_str_emb_ln=False, # Add layer norm to the structure embedding layer
|
16 |
+
add_seq_emb_ln=False, # Add layer norm to the sequence embedding layer
|
17 |
+
str_vocab_size=None,
|
18 |
+
str_input_dim=None,
|
19 |
+
str_output_dim=None,
|
20 |
+
qseq_output_dim=None,
|
21 |
+
|
22 |
+
seq_length=2048,
|
23 |
+
hidden_dropout=0.0,
|
24 |
+
classifier_dropout=None,
|
25 |
+
attention_dropout=0.0,
|
26 |
+
layernorm_epsilon=1e-5,
|
27 |
+
glu_activation='geglu',
|
28 |
+
torch_dtype=torch.bfloat16,
|
29 |
+
rmsnorm=True,
|
30 |
+
deepnorm=True,
|
31 |
+
apply_residual_connection_post_layernorm=False,
|
32 |
+
post_layer_norm=True,
|
33 |
+
add_bias_linear=False,
|
34 |
+
add_qkv_bias=False,
|
35 |
+
bias_dropout_fusion=True,
|
36 |
+
multi_query_attention=False,
|
37 |
+
multi_query_group_num=1,
|
38 |
+
apply_query_key_layer_scaling=True,
|
39 |
+
attention_softmax_in_fp32=True,
|
40 |
+
fp32_residual_connection=False,
|
41 |
+
quantization_bit=0,
|
42 |
+
# pre_seq_len=None,
|
43 |
+
# prefix_projection=False,
|
44 |
+
rotary_embedding_2d=True,
|
45 |
+
rotary_freq_base=10000,
|
46 |
+
lora=False,
|
47 |
+
mlp_lora=False,
|
48 |
+
lora_before_position=False, ### Default the QKV LoRA is after the position encoding
|
49 |
+
lora_r=8,
|
50 |
+
lora_alpha=16,
|
51 |
+
lora_dropout=0,
|
52 |
+
use_pytorch_sdpa=True,
|
53 |
+
is_causal=True,
|
54 |
+
moe=False,
|
55 |
+
num_experts=16,
|
56 |
+
experts_per_token=2,
|
57 |
+
**kwargs
|
58 |
+
):
|
59 |
+
|
60 |
+
if not deepnorm and apply_residual_connection_post_layernorm:
|
61 |
+
print(f"Warning: deepnorm is False and apply_residual_connection_post_layernorm is True")
|
62 |
+
|
63 |
+
self.num_layers = num_layers
|
64 |
+
self.vocab_size = padded_vocab_size
|
65 |
+
self.padded_vocab_size = padded_vocab_size
|
66 |
+
self.hidden_size = hidden_size
|
67 |
+
self.ffn_hidden_size = ffn_hidden_size
|
68 |
+
self.kv_channels = kv_channels
|
69 |
+
self.num_attention_heads = num_attention_heads
|
70 |
+
self.add_str_emb_ln = add_str_emb_ln
|
71 |
+
self.add_seq_emb_ln = add_seq_emb_ln
|
72 |
+
self.str_vocab_size = str_vocab_size
|
73 |
+
self.str_input_dim = str_input_dim
|
74 |
+
self.str_output_dim = str_output_dim
|
75 |
+
self.qseq_output_dim = qseq_output_dim
|
76 |
+
self.seq_length = seq_length
|
77 |
+
self.hidden_dropout = hidden_dropout
|
78 |
+
self.classifier_dropout = classifier_dropout
|
79 |
+
self.attention_dropout = attention_dropout
|
80 |
+
self.layernorm_epsilon = layernorm_epsilon
|
81 |
+
self.torch_dtype = torch_dtype
|
82 |
+
self.glu_activation = glu_activation
|
83 |
+
self.rmsnorm = rmsnorm
|
84 |
+
self.deepnorm = deepnorm
|
85 |
+
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
86 |
+
self.post_layer_norm = post_layer_norm
|
87 |
+
self.add_bias_linear = add_bias_linear
|
88 |
+
self.add_qkv_bias = add_qkv_bias
|
89 |
+
self.bias_dropout_fusion = bias_dropout_fusion
|
90 |
+
self.multi_query_attention = multi_query_attention
|
91 |
+
self.multi_query_group_num = multi_query_group_num
|
92 |
+
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
93 |
+
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
94 |
+
self.fp32_residual_connection = fp32_residual_connection
|
95 |
+
self.quantization_bit = quantization_bit
|
96 |
+
#self.pre_seq_len = pre_seq_len
|
97 |
+
#self.prefix_projection = prefix_projection
|
98 |
+
self.rotary_embedding_2d = rotary_embedding_2d
|
99 |
+
self.rotary_freq_base = rotary_freq_base
|
100 |
+
self.is_causal = is_causal
|
101 |
+
self.lora = lora
|
102 |
+
self.mlp_lora = mlp_lora
|
103 |
+
self.lora_before_position = lora_before_position
|
104 |
+
self.lora_r = lora_r
|
105 |
+
self.lora_alpha = lora_alpha
|
106 |
+
self.lora_dropout = lora_dropout
|
107 |
+
self.use_pytorch_sdpa = use_pytorch_sdpa
|
108 |
+
self.moe = moe
|
109 |
+
self.num_experts = num_experts
|
110 |
+
self.experts_per_token = experts_per_token
|
111 |
+
|
112 |
+
super().__init__(**kwargs)
|
113 |
+
|
114 |
+
if isinstance(torch_dtype, str):
|
115 |
+
if torch_dtype.startswith('torch.'):
|
116 |
+
self.torch_dtype = eval(torch_dtype)
|
117 |
+
else:
|
118 |
+
self.torch_dtype = eval(f"torch.{torch_dtype}")
|
generation_config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"eos_token_id": 34,
|
4 |
+
"pad_token_id": 0,
|
5 |
+
"transformers_version": "4.48.3"
|
6 |
+
}
|
modeling_ragplm.py
ADDED
@@ -0,0 +1,1260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" PyTorch AIDO.Protein-DMS-16B model. """
|
2 |
+
|
3 |
+
import math
|
4 |
+
import copy
|
5 |
+
import warnings
|
6 |
+
import re
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
import pathlib
|
10 |
+
import time
|
11 |
+
import argparse
|
12 |
+
import random
|
13 |
+
import numpy as np
|
14 |
+
from tqdm.auto import tqdm, trange
|
15 |
+
from functools import partial
|
16 |
+
|
17 |
+
import torch, deepspeed
|
18 |
+
import torch.utils.checkpoint
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from torch import nn
|
21 |
+
from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
|
22 |
+
from torch.nn.utils import skip_init
|
23 |
+
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
|
24 |
+
from copy import deepcopy
|
25 |
+
from collections import namedtuple
|
26 |
+
|
27 |
+
from transformers.modeling_outputs import (
|
28 |
+
BaseModelOutputWithPast,
|
29 |
+
CausalLMOutputWithPast,
|
30 |
+
SequenceClassifierOutputWithPast,
|
31 |
+
)
|
32 |
+
from transformers.modeling_utils import PreTrainedModel
|
33 |
+
from transformers.utils import logging
|
34 |
+
from transformers.generation.logits_process import LogitsProcessor
|
35 |
+
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
|
36 |
+
|
37 |
+
from .configuration_ragplm import RAGPLMConfig
|
38 |
+
|
39 |
+
def get_checkpoint_fn():
|
40 |
+
if deepspeed.checkpointing.is_configured():
|
41 |
+
# checkpoint = deepspeed.checkpointing.non_reentrant_checkpoint
|
42 |
+
checkpoint = deepspeed.checkpointing.checkpoint
|
43 |
+
else:
|
44 |
+
checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
45 |
+
# checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=True)
|
46 |
+
return checkpoint
|
47 |
+
|
48 |
+
# flags required to enable jit fusion kernels
|
49 |
+
|
50 |
+
if sys.platform != 'darwin':
|
51 |
+
torch._C._jit_set_profiling_mode(False)
|
52 |
+
torch._C._jit_set_profiling_executor(False)
|
53 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
54 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
55 |
+
|
56 |
+
logger = logging.get_logger(__name__)
|
57 |
+
|
58 |
+
_CHECKPOINT_FOR_DOC = "Protein/Protein_RAGPLM"
|
59 |
+
_CONFIG_FOR_DOC = "RAGPLMConfig"
|
60 |
+
|
61 |
+
|
62 |
+
def default_init(cls, *args, **kwargs):
|
63 |
+
return cls(*args, **kwargs)
|
64 |
+
|
65 |
+
DeepNormCoefficients = namedtuple("DeepNormCoefficients", ["alpha", "beta"])
|
66 |
+
|
67 |
+
def get_deepnorm_coefficients(config: RAGPLMConfig):
|
68 |
+
"""
|
69 |
+
DeepNorm coefficients from : https://kexue.fm/archives/8978
|
70 |
+
"""
|
71 |
+
num_layers = config.num_layers
|
72 |
+
return DeepNormCoefficients(alpha=(2 * num_layers) ** 0.5, beta=(2 * num_layers) ** -0.5)
|
73 |
+
|
74 |
+
|
75 |
+
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
76 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
77 |
+
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
78 |
+
scores.zero_()
|
79 |
+
scores[..., 5] = 5e4
|
80 |
+
return scores
|
81 |
+
|
82 |
+
def split_tensor_along_last_dim(
|
83 |
+
tensor: torch.Tensor,
|
84 |
+
num_partitions: int,
|
85 |
+
contiguous_split_chunks: bool = False,
|
86 |
+
) -> List[torch.Tensor]:
|
87 |
+
"""Split a tensor along its last dimension.
|
88 |
+
|
89 |
+
Arguments:
|
90 |
+
tensor: input tensor.
|
91 |
+
num_partitions: number of partitions to split the tensor
|
92 |
+
contiguous_split_chunks: If True, make each chunk contiguous
|
93 |
+
in memory.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
A list of Tensors
|
97 |
+
"""
|
98 |
+
# Get the size and dimension.
|
99 |
+
last_dim = tensor.dim() - 1
|
100 |
+
last_dim_size = tensor.size()[last_dim] // num_partitions
|
101 |
+
# Split.
|
102 |
+
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
103 |
+
# Note: torch.split does not create contiguous tensors by default.
|
104 |
+
if contiguous_split_chunks:
|
105 |
+
return tuple(chunk.contiguous() for chunk in tensor_list)
|
106 |
+
|
107 |
+
return tensor_list
|
108 |
+
|
109 |
+
class RotaryEmbedding(torch.nn.Module):
|
110 |
+
|
111 |
+
def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
|
112 |
+
super().__init__()
|
113 |
+
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)).to(precision)
|
114 |
+
self.dim = dim
|
115 |
+
self.base = base
|
116 |
+
self.learnable = learnable
|
117 |
+
if learnable:
|
118 |
+
self.inv_freq = torch.nn.Parameter(inv_freq)
|
119 |
+
self.max_seq_len_cached = None
|
120 |
+
else:
|
121 |
+
self.register_buffer('inv_freq', inv_freq)
|
122 |
+
self.max_seq_len_cached = None
|
123 |
+
self.cos_cached = None
|
124 |
+
self.sin_cached = None
|
125 |
+
self.precision = precision
|
126 |
+
|
127 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
128 |
+
# import pdb; pdb.set_trace();
|
129 |
+
if f'{prefix}inv_freq' in state_dict:
|
130 |
+
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
131 |
+
else:
|
132 |
+
self.inv_freq.copy_(1. / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)).to(self.precision))
|
133 |
+
|
134 |
+
def forward(self, x, seq_dim=1, seq_len=None):
|
135 |
+
|
136 |
+
# self.inv_freq = 1. / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)).to(x.device)
|
137 |
+
if seq_len is None:
|
138 |
+
seq_len = x.shape[seq_dim]
|
139 |
+
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
|
140 |
+
self.max_seq_len_cached = None if self.learnable else seq_len
|
141 |
+
t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
|
142 |
+
# import pdb; pdb.set_trace();
|
143 |
+
freqs = torch.einsum('i,j->ij', t, self.inv_freq.to(x.device))
|
144 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
145 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
146 |
+
if self.precision == torch.bfloat16 or self.precision == torch.half:
|
147 |
+
emb = emb.float()
|
148 |
+
# [sx, 1 (b * np), hn]
|
149 |
+
cos_cached = emb.cos()[:, None, :]
|
150 |
+
sin_cached = emb.sin()[:, None, :]
|
151 |
+
if self.precision == torch.bfloat16:
|
152 |
+
cos_cached = cos_cached.bfloat16()
|
153 |
+
sin_cached = sin_cached.bfloat16()
|
154 |
+
elif self.precision == torch.half:
|
155 |
+
cos_cached = cos_cached.half()
|
156 |
+
sin_cached = sin_cached.half()
|
157 |
+
if self.learnable:
|
158 |
+
return cos_cached, sin_cached
|
159 |
+
self.cos_cached, self.sin_cached = cos_cached, sin_cached
|
160 |
+
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
|
161 |
+
|
162 |
+
def rotate_half(x):
|
163 |
+
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
164 |
+
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
|
165 |
+
|
166 |
+
def assert_dim_check(tensor, ndim=None, shape=None):
|
167 |
+
if ndim is not None:
|
168 |
+
assert tensor.ndim == ndim, f"Exepct tensor.ndim={ndim}. gut got tensor.shape={tensor.shape}"
|
169 |
+
if shape is not None:
|
170 |
+
assert list(tensor.shape) == list(shape), f"Exepct tensor.shape={shape}. gut got tensor.shape={tensor.shape}"
|
171 |
+
|
172 |
+
def apply_rotary_pos_emb_index_torch(q, k, cos, sin, position_id): # jitting fails with bf16
|
173 |
+
# position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
|
174 |
+
cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
|
175 |
+
F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
|
176 |
+
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
177 |
+
return q, k
|
178 |
+
|
179 |
+
try:
|
180 |
+
# raise 'Errror'
|
181 |
+
from apex.normalization import MixedFusedRMSNorm
|
182 |
+
from apex.normalization import FusedLayerNorm
|
183 |
+
print(f"{__file__}: Use apex.normalization.MixedFusedRMSNorm as RMSNorm")
|
184 |
+
|
185 |
+
class RMSNorm(MixedFusedRMSNorm):
|
186 |
+
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, memory_efficient=False):
|
187 |
+
super(RMSNorm, self).__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient)
|
188 |
+
|
189 |
+
def forward(self, input):
|
190 |
+
dtype = input.dtype
|
191 |
+
with torch.autocast('cuda', enabled=True, dtype=torch.float32, cache_enabled=None):
|
192 |
+
output = super().forward(input)
|
193 |
+
return output.to(dtype)
|
194 |
+
|
195 |
+
class LayerNorm(FusedLayerNorm):
|
196 |
+
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, memory_efficient=False):
|
197 |
+
super(LayerNorm, self).__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient)
|
198 |
+
|
199 |
+
def forward(self, input):
|
200 |
+
dtype = input.dtype
|
201 |
+
with torch.autocast('cuda', enabled=True, dtype=torch.float32, cache_enabled=None):
|
202 |
+
output = super().forward(input)
|
203 |
+
return output.to(dtype)
|
204 |
+
|
205 |
+
except:
|
206 |
+
class RMSNorm(torch.nn.Module):
|
207 |
+
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
|
208 |
+
super().__init__()
|
209 |
+
self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
|
210 |
+
self.eps = eps
|
211 |
+
@torch.jit.export
|
212 |
+
def forward(self, hidden_states: torch.Tensor):
|
213 |
+
input_dtype = hidden_states.dtype
|
214 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
215 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
216 |
+
return (self.weight * hidden_states).to(input_dtype)
|
217 |
+
print(f"{__file__}: Use custom RMSNorm")
|
218 |
+
|
219 |
+
class CoreAttention(torch.nn.Module):
|
220 |
+
def __init__(self, config: RAGPLMConfig, layer_number):
|
221 |
+
super(CoreAttention, self).__init__()
|
222 |
+
|
223 |
+
self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
|
224 |
+
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
|
225 |
+
if self.apply_query_key_layer_scaling:
|
226 |
+
self.attention_softmax_in_fp32 = True
|
227 |
+
self.layer_number = max(1, layer_number)
|
228 |
+
|
229 |
+
projection_size = config.kv_channels * config.num_attention_heads
|
230 |
+
|
231 |
+
# Per attention head and per partition values.
|
232 |
+
self.hidden_size_per_partition = projection_size
|
233 |
+
self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
|
234 |
+
self.num_attention_heads_per_partition = config.num_attention_heads
|
235 |
+
|
236 |
+
coeff = None
|
237 |
+
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
238 |
+
if self.apply_query_key_layer_scaling:
|
239 |
+
coeff = self.layer_number
|
240 |
+
self.norm_factor *= coeff
|
241 |
+
self.coeff = coeff
|
242 |
+
|
243 |
+
self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
|
244 |
+
|
245 |
+
self.is_causal = config.is_causal
|
246 |
+
self.use_pytorch_sdpa = config.use_pytorch_sdpa
|
247 |
+
|
248 |
+
def forward(self, query_layer, key_layer, value_layer, attention_mask):
|
249 |
+
# query_layer, key_layer, value_layer: [seq_len, batch_size, num_heads, head_dim]
|
250 |
+
# import pdb; pdb.set_trace();
|
251 |
+
pytorch_major_version = int(torch.__version__.split('.')[0])
|
252 |
+
# assert pytorch_major_version >= 2, f"Expect PyTorch version > 2.0"
|
253 |
+
if pytorch_major_version >= 2 and self.use_pytorch_sdpa:
|
254 |
+
dropout_p = self.attention_dropout.p if self.training else 0
|
255 |
+
# [seq_len, batch_size, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim]
|
256 |
+
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
|
257 |
+
# import pdb; pdb.set_trace();
|
258 |
+
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
|
259 |
+
# context_layer: [batch_size, num_heads, seq_len, head_dim]
|
260 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, is_causal=self.is_causal, dropout_p=dropout_p)
|
261 |
+
#print(f"torch.nn.functional.scaled_dot_product_attention")
|
262 |
+
else:
|
263 |
+
if (attention_mask is not None) and (attention_mask.dtype == torch.bool):
|
264 |
+
attention_mask = attention_mask.logical_not() ## DO NOT inplace operation!!!!
|
265 |
+
#print(f"attention_mask.shape={attention_mask.shape}, attention_mask={attention_mask}")
|
266 |
+
else:
|
267 |
+
pass
|
268 |
+
# print(f"query_layer.shape={query_layer.shape}, key_layer.shape={key_layer.shape}, attention_mask={attention_mask}")
|
269 |
+
context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, attention_mask, dropout_p=dropout_p)
|
270 |
+
# [batch_size, num_heads, seq_len, head_dim] -> [seq_len, batch_size, num_heads, head_dim]
|
271 |
+
context_layer = context_layer.permute(2, 0, 1, 3)
|
272 |
+
# [seq_len, batch_size, 2560]
|
273 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
274 |
+
context_layer = context_layer.reshape(*new_context_layer_shape)
|
275 |
+
else:
|
276 |
+
# Raw attention scores
|
277 |
+
|
278 |
+
# [b, np, sq, sk]
|
279 |
+
output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
|
280 |
+
|
281 |
+
# [sq, b, np, hn] -> [sq, b * np, hn]
|
282 |
+
query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
|
283 |
+
# [sk, b, np, hn] -> [sk, b * np, hn]
|
284 |
+
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
|
285 |
+
|
286 |
+
# preallocting input tensor: [b * np, sq, sk]
|
287 |
+
matmul_input_buffer = torch.empty(
|
288 |
+
output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
|
289 |
+
device=query_layer.device
|
290 |
+
)
|
291 |
+
|
292 |
+
# Raw attention scores. [b * np, sq, sk]
|
293 |
+
matmul_result = torch.baddbmm(
|
294 |
+
matmul_input_buffer,
|
295 |
+
query_layer.transpose(0, 1), # [b * np, sq, hn]
|
296 |
+
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
|
297 |
+
beta=0.0,
|
298 |
+
alpha=(1.0 / self.norm_factor),
|
299 |
+
)
|
300 |
+
|
301 |
+
# change view to [b, np, sq, sk]
|
302 |
+
attention_scores = matmul_result.view(*output_size)
|
303 |
+
|
304 |
+
# ===========================
|
305 |
+
# Attention probs and dropout
|
306 |
+
# ===========================
|
307 |
+
|
308 |
+
# attention scores and attention mask [b, np, sq, sk]
|
309 |
+
if self.attention_softmax_in_fp32:
|
310 |
+
attention_scores = attention_scores.float()
|
311 |
+
if self.coeff is not None:
|
312 |
+
attention_scores = attention_scores * self.coeff
|
313 |
+
if self.is_causal and attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
|
314 |
+
attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
|
315 |
+
device=attention_scores.device, dtype=torch.bool)
|
316 |
+
attention_mask.tril_()
|
317 |
+
attention_mask = ~attention_mask
|
318 |
+
if attention_mask is not None:
|
319 |
+
attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
|
320 |
+
attention_probs = F.softmax(attention_scores, dim=-1)
|
321 |
+
attention_probs = attention_probs.type_as(value_layer)
|
322 |
+
|
323 |
+
# This is actually dropping out entire tokens to attend to, which might
|
324 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
325 |
+
attention_probs = self.attention_dropout(attention_probs)
|
326 |
+
# =========================
|
327 |
+
# Context layer. [sq, b, hp]
|
328 |
+
# =========================
|
329 |
+
|
330 |
+
# value_layer -> context layer.
|
331 |
+
# [sk, b, np, hn] --> [b, np, sq, hn]
|
332 |
+
|
333 |
+
# context layer shape: [b, np, sq, hn]
|
334 |
+
output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
|
335 |
+
# change view [sk, b * np, hn]
|
336 |
+
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
|
337 |
+
# change view [b * np, sq, sk]
|
338 |
+
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
|
339 |
+
# matmul: [b * np, sq, hn]
|
340 |
+
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
|
341 |
+
# change view [b, np, sq, hn]
|
342 |
+
context_layer = context_layer.view(*output_size)
|
343 |
+
# [b, np, sq, hn] --> [sq, b, np, hn]
|
344 |
+
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
345 |
+
# [sq, b, np, hn] --> [sq, b, hp]
|
346 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
347 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
348 |
+
|
349 |
+
return context_layer
|
350 |
+
|
351 |
+
|
352 |
+
class SelfAttention(torch.nn.Module):
|
353 |
+
"""Parallel self-attention layer abstract class.
|
354 |
+
|
355 |
+
Self-attention layer takes input with size [s, b, h]
|
356 |
+
and returns output of the same size.
|
357 |
+
"""
|
358 |
+
|
359 |
+
def __init__(self, config: RAGPLMConfig, layer_number, device=None):
|
360 |
+
super(SelfAttention, self).__init__()
|
361 |
+
self.layer_number = max(1, layer_number)
|
362 |
+
|
363 |
+
self.projection_size = config.kv_channels * config.num_attention_heads
|
364 |
+
|
365 |
+
# Per attention head and per partition values.
|
366 |
+
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
|
367 |
+
self.num_attention_heads_per_partition = config.num_attention_heads
|
368 |
+
|
369 |
+
self.multi_query_attention = config.multi_query_attention
|
370 |
+
self.qkv_hidden_size = 3 * self.projection_size
|
371 |
+
if self.multi_query_attention:
|
372 |
+
self.num_multi_query_groups_per_partition = config.multi_query_group_num
|
373 |
+
self.qkv_hidden_size = (
|
374 |
+
self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
|
375 |
+
)
|
376 |
+
self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
|
377 |
+
bias=config.add_bias_linear or config.add_qkv_bias,
|
378 |
+
device=device, **_config_to_kwargs(config)
|
379 |
+
)
|
380 |
+
|
381 |
+
self.core_attention = CoreAttention(config, self.layer_number)
|
382 |
+
|
383 |
+
# Output.
|
384 |
+
self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, device=device, **_config_to_kwargs(config))
|
385 |
+
|
386 |
+
self.rotary_embedding_2d = config.rotary_embedding_2d
|
387 |
+
# dim, base=10000, precision=torch.half, learnable=False
|
388 |
+
self.rotary_emb = RotaryEmbedding(self.hidden_size_per_attention_head // 2 if self.rotary_embedding_2d else self.hidden_size_per_attention_head,
|
389 |
+
base=config.rotary_freq_base, precision=config.torch_dtype, learnable=False)
|
390 |
+
|
391 |
+
##### LoRA
|
392 |
+
self.lora = config.lora
|
393 |
+
if config.lora:
|
394 |
+
self.lora_linear = torch.nn.ModuleDict()
|
395 |
+
self.lora_dropout = torch.nn.Dropout(config.lora_dropout)
|
396 |
+
self.lora_alpha = config.lora_alpha
|
397 |
+
self.lora_r = config.lora_r
|
398 |
+
self.lora_before_position = config.lora_before_position
|
399 |
+
for name in ('Q', 'K', 'V', 'O'):
|
400 |
+
self.lora_linear[f'{name}_A'] = torch.nn.Linear(config.hidden_size, config.lora_r, bias=False)
|
401 |
+
self.lora_linear[f'{name}_B'] = torch.nn.Linear(config.lora_r, config.hidden_size, bias=False)
|
402 |
+
torch.nn.init.kaiming_uniform_(self.lora_linear[f"{name}_A"].weight, a=math.sqrt(5))
|
403 |
+
torch.nn.init.zeros_(self.lora_linear[f'{name}_B'].weight)
|
404 |
+
|
405 |
+
def forward(
|
406 |
+
self, hidden_states, attention_mask, position_ids, kv_cache=None, use_cache=True
|
407 |
+
):
|
408 |
+
|
409 |
+
# =================================================
|
410 |
+
# Pre-allocate memory for key-values for inference.
|
411 |
+
# =================================================
|
412 |
+
# =====================
|
413 |
+
# Query, Key, and Value
|
414 |
+
# =====================
|
415 |
+
|
416 |
+
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
417 |
+
mixed_x_layer = self.query_key_value(hidden_states) # [12800, 1, 6912]
|
418 |
+
|
419 |
+
if self.multi_query_attention:
|
420 |
+
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
|
421 |
+
[
|
422 |
+
self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
|
423 |
+
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
424 |
+
self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
|
425 |
+
],
|
426 |
+
dim=-1,
|
427 |
+
)
|
428 |
+
query_layer = query_layer.view(
|
429 |
+
query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
430 |
+
)
|
431 |
+
key_layer = key_layer.view(
|
432 |
+
key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
433 |
+
)
|
434 |
+
value_layer = value_layer.view(
|
435 |
+
value_layer.size()[:-1]
|
436 |
+
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
437 |
+
)
|
438 |
+
else:
|
439 |
+
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head) # [12800, 1, 36, 192]
|
440 |
+
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [12800, 1, 36, 192]
|
441 |
+
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
442 |
+
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
443 |
+
|
444 |
+
if self.lora and self.lora_before_position:
|
445 |
+
scaling = self.lora_alpha / self.lora_r
|
446 |
+
query_layer = query_layer + ( self.lora_linear['Q_B'](self.lora_linear['Q_A'](self.lora_dropout(hidden_states))) * scaling ).reshape(query_layer.shape)
|
447 |
+
key_layer = key_layer + ( self.lora_linear['K_B'](self.lora_linear['K_A'](self.lora_dropout(hidden_states))) * scaling ).reshape(key_layer.shape)
|
448 |
+
value_layer = value_layer + ( self.lora_linear['V_B'](self.lora_linear['V_A'](self.lora_dropout(hidden_states))) * scaling ).reshape(value_layer.shape)
|
449 |
+
|
450 |
+
# apply relative positional encoding (rotary embedding)
|
451 |
+
if position_ids is not None: # [seq_len, 2, batch_size, 32, 2]
|
452 |
+
|
453 |
+
|
454 |
+
if self.rotary_embedding_2d:
|
455 |
+
q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) # 32
|
456 |
+
k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
|
457 |
+
# import pdb; pdb.set_trace();
|
458 |
+
cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) # 32
|
459 |
+
position_ids, block_position_ids = \
|
460 |
+
position_ids[:, 0, :].transpose(0, 1).contiguous(), \
|
461 |
+
position_ids[:, 1, :].transpose(0, 1).contiguous()
|
462 |
+
q1, k1 = apply_rotary_pos_emb_index_torch(q1, k1, cos, sin, position_ids)
|
463 |
+
q2, k2 = apply_rotary_pos_emb_index_torch(q2, k2, cos, sin, block_position_ids)
|
464 |
+
query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
|
465 |
+
key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))
|
466 |
+
else:
|
467 |
+
# [b, sq] -> [sq, b]
|
468 |
+
position_ids = position_ids.transpose(0, 1)
|
469 |
+
cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1)
|
470 |
+
query_layer, key_layer = apply_rotary_pos_emb_index_torch(query_layer, key_layer, cos, sin, position_ids)
|
471 |
+
|
472 |
+
|
473 |
+
if self.lora and not self.lora_before_position:
|
474 |
+
# query_layer = query_layer + lora_layer["Q_B"](lora_layer["Q_A"](self.lora_dropout(hidden_states)))* self.scaling
|
475 |
+
scaling = self.lora_alpha / self.lora_r
|
476 |
+
query_layer = query_layer + ( self.lora_linear['Q_B'](self.lora_linear['Q_A'](self.lora_dropout(hidden_states))) * scaling ).reshape(query_layer.shape)
|
477 |
+
key_layer = key_layer + ( self.lora_linear['K_B'](self.lora_linear['K_A'](self.lora_dropout(hidden_states))) * scaling ).reshape(key_layer.shape)
|
478 |
+
value_layer = value_layer + ( self.lora_linear['V_B'](self.lora_linear['V_A'](self.lora_dropout(hidden_states))) * scaling ).reshape(value_layer.shape)
|
479 |
+
|
480 |
+
# adjust key and value for inference
|
481 |
+
if kv_cache is not None:
|
482 |
+
cache_k, cache_v = kv_cache
|
483 |
+
key_layer = torch.cat((cache_k, key_layer), dim=0)
|
484 |
+
value_layer = torch.cat((cache_v, value_layer), dim=0)
|
485 |
+
if use_cache:
|
486 |
+
kv_cache = (key_layer, value_layer)
|
487 |
+
else:
|
488 |
+
kv_cache = None
|
489 |
+
|
490 |
+
if self.multi_query_attention:
|
491 |
+
key_layer = key_layer.unsqueeze(-2)
|
492 |
+
key_layer = key_layer.expand(-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1)
|
493 |
+
key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))
|
494 |
+
value_layer = value_layer.unsqueeze(-2)
|
495 |
+
value_layer = value_layer.expand(-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1)
|
496 |
+
value_layer = value_layer.contiguous().view(value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))
|
497 |
+
|
498 |
+
# ==================================
|
499 |
+
# core attention computation
|
500 |
+
# ==================================
|
501 |
+
|
502 |
+
context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) # context_layer: [seq_len, batch_size, num_heads*head_dim]
|
503 |
+
output = self.dense(context_layer)
|
504 |
+
if self.lora:
|
505 |
+
scaling = self.lora_alpha / self.lora_r
|
506 |
+
output = output + self.lora_linear['O_B'](self.lora_linear['O_A'](self.lora_dropout(context_layer))) * scaling
|
507 |
+
|
508 |
+
# =================
|
509 |
+
# Output. [sq, b, h]
|
510 |
+
# =================
|
511 |
+
|
512 |
+
# output = context_layer @ self.dense.weight.T + self.dense.bias
|
513 |
+
return output, kv_cache
|
514 |
+
|
515 |
+
|
516 |
+
def _config_to_kwargs(args):
|
517 |
+
common_kwargs = {
|
518 |
+
"dtype": args.torch_dtype,
|
519 |
+
}
|
520 |
+
return common_kwargs
|
521 |
+
|
522 |
+
|
523 |
+
class MLP(torch.nn.Module):
|
524 |
+
"""MLP.
|
525 |
+
|
526 |
+
MLP will take the input with h hidden state, project it to 4*h
|
527 |
+
hidden dimension, perform nonlinear transformation, and project the
|
528 |
+
state back into h hidden dimension.
|
529 |
+
"""
|
530 |
+
|
531 |
+
def __init__(self, config: RAGPLMConfig, device=None):
|
532 |
+
super(MLP, self).__init__()
|
533 |
+
|
534 |
+
self.add_bias = config.add_bias_linear
|
535 |
+
self.moe = config.moe
|
536 |
+
self.mlp_lora = config.mlp_lora
|
537 |
+
self.num_experts = config.num_experts
|
538 |
+
self.experts_per_token = config.experts_per_token # 2
|
539 |
+
|
540 |
+
if self.moe is True and self.mlp_lora is True:
|
541 |
+
raise NotImplementedError(f"moe and mlp_lora are both enabled")
|
542 |
+
|
543 |
+
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
544 |
+
self.dense_h_to_4h = nn.Linear(
|
545 |
+
config.hidden_size,
|
546 |
+
config.ffn_hidden_size * 2,
|
547 |
+
bias=self.add_bias,
|
548 |
+
device=device,
|
549 |
+
**_config_to_kwargs(config)
|
550 |
+
)
|
551 |
+
|
552 |
+
def swiglu(x):
|
553 |
+
x = torch.chunk(x, 2, dim=-1)
|
554 |
+
return x[0] * F.silu(x[1])
|
555 |
+
|
556 |
+
def geglu(x):
|
557 |
+
x = torch.chunk(x, 2, dim=-1)
|
558 |
+
return x[0] * F.gelu(x[1])
|
559 |
+
|
560 |
+
if config.glu_activation == 'geglu':
|
561 |
+
self.activation_func = geglu
|
562 |
+
elif config.glu_activation == 'swiglu':
|
563 |
+
self.activation_func = swiglu
|
564 |
+
else:
|
565 |
+
assert RuntimeError(f"Unsupported glu_activation: {config.glu_activation}")
|
566 |
+
|
567 |
+
# Project back to h.
|
568 |
+
self.dense_4h_to_h = nn.Linear(
|
569 |
+
config.ffn_hidden_size,
|
570 |
+
config.hidden_size,
|
571 |
+
bias=self.add_bias,
|
572 |
+
device=device,
|
573 |
+
**_config_to_kwargs(config)
|
574 |
+
)
|
575 |
+
|
576 |
+
if self.moe:
|
577 |
+
assert self.num_experts > 1
|
578 |
+
del self.dense_h_to_4h
|
579 |
+
del self.dense_4h_to_h
|
580 |
+
self.router = nn.Linear(
|
581 |
+
config.hidden_size,
|
582 |
+
config.num_experts,
|
583 |
+
bias=False,
|
584 |
+
device=device,
|
585 |
+
dtype=torch.float32
|
586 |
+
)
|
587 |
+
for i in range(0, self.num_experts):
|
588 |
+
self.register_module(f"dense_h_to_4h_{i}", nn.Linear(
|
589 |
+
config.hidden_size,
|
590 |
+
config.ffn_hidden_size * 2,
|
591 |
+
bias=self.add_bias,
|
592 |
+
device=device,
|
593 |
+
**_config_to_kwargs(config)
|
594 |
+
))
|
595 |
+
self.register_module(f"dense_4h_to_h_{i}", nn.Linear(
|
596 |
+
config.ffn_hidden_size,
|
597 |
+
config.hidden_size,
|
598 |
+
bias=self.add_bias,
|
599 |
+
device=device,
|
600 |
+
**_config_to_kwargs(config)
|
601 |
+
))
|
602 |
+
|
603 |
+
if self.mlp_lora:
|
604 |
+
self.lora_linear = torch.nn.ModuleDict()
|
605 |
+
self.lora_dropout = torch.nn.Dropout(config.lora_dropout)
|
606 |
+
self.lora_alpha = config.lora_alpha
|
607 |
+
self.lora_r = config.lora_r
|
608 |
+
for name in ('dense_h_to_4h', 'dense_4h_to_h'):
|
609 |
+
if name == 'dense_h_to_4h':
|
610 |
+
self.lora_linear[f'{name}_A'] = torch.nn.Linear(config.hidden_size, config.lora_r, bias=False)
|
611 |
+
self.lora_linear[f'{name}_B'] = torch.nn.Linear(config.lora_r, config.ffn_hidden_size * 2, bias=False)
|
612 |
+
elif name == 'dense_4h_to_h':
|
613 |
+
self.lora_linear[f'{name}_A'] = torch.nn.Linear(config.ffn_hidden_size, config.lora_r, bias=False)
|
614 |
+
self.lora_linear[f'{name}_B'] = torch.nn.Linear(config.lora_r, config.hidden_size, bias=False)
|
615 |
+
torch.nn.init.kaiming_uniform_(self.lora_linear[f"{name}_A"].weight, a=math.sqrt(5))
|
616 |
+
torch.nn.init.zeros_(self.lora_linear[f'{name}_B'].weight)
|
617 |
+
|
618 |
+
def moe_forward(self, hidden_states, expert_idx):
|
619 |
+
# hidden_states: torch.Size([503, 1920])
|
620 |
+
# import pdb; pdb.set_trace();
|
621 |
+
intermediate_parallel = getattr(self, f"dense_h_to_4h_{expert_idx}")(hidden_states) # torch.Size([503, 20480])
|
622 |
+
intermediate_parallel = self.activation_func(intermediate_parallel) # torch.Size([503, 10240])
|
623 |
+
output = getattr(self, f"dense_4h_to_h_{expert_idx}")(intermediate_parallel) # torch.Size([503, 1920])
|
624 |
+
return output
|
625 |
+
|
626 |
+
def forward(self, hidden_states):
|
627 |
+
if self.moe:
|
628 |
+
# import pdb; pdb.set_trace();
|
629 |
+
s, b, n = hidden_states.shape
|
630 |
+
dtype = hidden_states.dtype
|
631 |
+
hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h]
|
632 |
+
|
633 |
+
route = self.router(hidden_states).to(dtype)
|
634 |
+
|
635 |
+
weights, selected_experts = torch.topk(route, self.experts_per_token)
|
636 |
+
weights = F.softmax(weights, dim=1, dtype=torch.float).to(hidden_states.dtype)
|
637 |
+
output = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
|
638 |
+
for expert_idx in range(self.num_experts):
|
639 |
+
batch_idx, nth_expert = torch.where(selected_experts == expert_idx)
|
640 |
+
if nth_expert.shape[0] == 0:
|
641 |
+
continue
|
642 |
+
cur_out = self.moe_forward(hidden_states[batch_idx], expert_idx)
|
643 |
+
output[batch_idx] += weights[batch_idx, nth_expert, None] * cur_out
|
644 |
+
output = output.reshape(s, b, n)
|
645 |
+
else:
|
646 |
+
# [s, b, 4hp]
|
647 |
+
#intermediate_parallel = hidden_states @ self.dense_h_to_4h.weight.T + self.dense_h_to_4h.bias
|
648 |
+
intermediate_parallel = self.dense_h_to_4h(hidden_states)
|
649 |
+
if self.mlp_lora:
|
650 |
+
scaling = self.lora_alpha / self.lora_r
|
651 |
+
intermediate_parallel = intermediate_parallel + ( self.lora_linear['dense_h_to_4h_B'](self.lora_linear['dense_h_to_4h_A'](self.lora_dropout(hidden_states))) * scaling )
|
652 |
+
|
653 |
+
intermediate_parallel = self.activation_func(intermediate_parallel)
|
654 |
+
# [s, b, h]
|
655 |
+
output = self.dense_4h_to_h(intermediate_parallel)
|
656 |
+
if self.mlp_lora:
|
657 |
+
output = output + ( self.lora_linear['dense_4h_to_h_B'](self.lora_linear['dense_4h_to_h_A'](self.lora_dropout(intermediate_parallel))) * scaling )# .reshape(output.shape)
|
658 |
+
|
659 |
+
#output = intermediate_parallel @ self.dense_4h_to_h.weight.T + self.dense_4h_to_h.bias # self.dense_4h_to_h(intermediate_parallel)
|
660 |
+
return output
|
661 |
+
|
662 |
+
class RAGPLMBlock(torch.nn.Module):
|
663 |
+
"""A single transformer layer.
|
664 |
+
|
665 |
+
Transformer layer takes input with size [s, b, h] and returns an
|
666 |
+
output of the same size.
|
667 |
+
"""
|
668 |
+
|
669 |
+
def __init__(self, config: RAGPLMConfig, layer_number, device=None):
|
670 |
+
super(RAGPLMBlock, self).__init__()
|
671 |
+
self.layer_number = layer_number
|
672 |
+
|
673 |
+
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
674 |
+
|
675 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
676 |
+
|
677 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
678 |
+
# Layernorm on the input data.
|
679 |
+
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon)
|
680 |
+
|
681 |
+
# Self attention.
|
682 |
+
self.self_attention = SelfAttention(config, layer_number, device=device)
|
683 |
+
self.hidden_dropout = config.hidden_dropout
|
684 |
+
|
685 |
+
# Layernorm on the attention output
|
686 |
+
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon)
|
687 |
+
|
688 |
+
# MLP
|
689 |
+
self.mlp = MLP(config, device=device)
|
690 |
+
|
691 |
+
self.deepnorm_coeff = get_deepnorm_coefficients(config) if config.deepnorm else None
|
692 |
+
|
693 |
+
def forward(
|
694 |
+
self, hidden_states, attention_mask, position_ids, kv_cache=None, use_cache=True,
|
695 |
+
):
|
696 |
+
# hidden_states: [s, b, h]
|
697 |
+
|
698 |
+
layernorm_output = self.input_layernorm(hidden_states)
|
699 |
+
# Self attention.
|
700 |
+
|
701 |
+
attention_output, kv_cache = self.self_attention(
|
702 |
+
layernorm_output,
|
703 |
+
attention_mask,
|
704 |
+
position_ids, # [batch_size, 2, seq_len, 32, 2]
|
705 |
+
kv_cache=kv_cache,
|
706 |
+
use_cache=use_cache
|
707 |
+
)
|
708 |
+
|
709 |
+
# Residual connection.
|
710 |
+
if self.apply_residual_connection_post_layernorm:
|
711 |
+
residual = layernorm_output
|
712 |
+
else:
|
713 |
+
residual = hidden_states
|
714 |
+
|
715 |
+
layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
|
716 |
+
if self.deepnorm_coeff is not None:
|
717 |
+
layernorm_input = residual*self.deepnorm_coeff.alpha + layernorm_input
|
718 |
+
else:
|
719 |
+
layernorm_input = residual + layernorm_input
|
720 |
+
|
721 |
+
# Layer norm post the self attention.
|
722 |
+
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
723 |
+
|
724 |
+
|
725 |
+
# MLP.
|
726 |
+
mlp_output = self.mlp(layernorm_output)
|
727 |
+
|
728 |
+
# Second residual connection.
|
729 |
+
if self.apply_residual_connection_post_layernorm:
|
730 |
+
residual = layernorm_output
|
731 |
+
else:
|
732 |
+
residual = layernorm_input
|
733 |
+
|
734 |
+
|
735 |
+
output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
|
736 |
+
if self.deepnorm_coeff is not None:
|
737 |
+
output = residual*self.deepnorm_coeff.alpha + output
|
738 |
+
else:
|
739 |
+
output = residual + output
|
740 |
+
|
741 |
+
return output, kv_cache
|
742 |
+
|
743 |
+
|
744 |
+
class RAGPLMTransformer(torch.nn.Module):
|
745 |
+
"""Transformer class."""
|
746 |
+
|
747 |
+
def __init__(self, config: RAGPLMConfig, device=None):
|
748 |
+
super(RAGPLMTransformer, self).__init__()
|
749 |
+
|
750 |
+
self.config = config
|
751 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
752 |
+
self.post_layer_norm = config.post_layer_norm
|
753 |
+
|
754 |
+
# Number of layers.
|
755 |
+
self.num_layers = config.num_layers
|
756 |
+
|
757 |
+
# Transformer layers.
|
758 |
+
def build_layer(layer_number):
|
759 |
+
return RAGPLMBlock(config, layer_number, device=device)
|
760 |
+
|
761 |
+
self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
|
762 |
+
|
763 |
+
if self.post_layer_norm:
|
764 |
+
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
|
765 |
+
# Final layer norm before output.
|
766 |
+
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon)
|
767 |
+
|
768 |
+
self.gradient_checkpointing = False
|
769 |
+
# Introduce a gradient checkpointing for per num_checkpoint layers
|
770 |
+
# For example: num_checkpoint=1 will checkpoint all layers, num_checkpoint=2 will checkpoint half of layers
|
771 |
+
self.num_checkpoint = 1
|
772 |
+
|
773 |
+
def _get_layer(self, layer_number):
|
774 |
+
return self.layers[layer_number]
|
775 |
+
|
776 |
+
def forward(
|
777 |
+
self, hidden_states, attention_mask, position_ids, kv_caches=None,
|
778 |
+
use_cache: Optional[bool] = True,
|
779 |
+
output_hidden_states: Optional[bool] = False,
|
780 |
+
):
|
781 |
+
if not kv_caches:
|
782 |
+
kv_caches = [None for _ in range(self.num_layers)]
|
783 |
+
presents = () if use_cache else None
|
784 |
+
if self.gradient_checkpointing and self.training:
|
785 |
+
if use_cache:
|
786 |
+
logger.warning_once(
|
787 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
788 |
+
)
|
789 |
+
use_cache = False
|
790 |
+
|
791 |
+
all_self_attentions = None
|
792 |
+
all_hidden_states = () if output_hidden_states else None
|
793 |
+
|
794 |
+
for index in range(self.num_layers):
|
795 |
+
if output_hidden_states:
|
796 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
797 |
+
layer = self._get_layer(index)
|
798 |
+
if self.gradient_checkpointing and self.training and torch.is_grad_enabled() and index % self.num_checkpoint == 0:
|
799 |
+
#### A trick to enable gradient to avoid gradient checkpointing error
|
800 |
+
if hidden_states.requires_grad is False and deepspeed.checkpointing.is_configured() and (self.config.lora or self.config.mlp_lora):
|
801 |
+
# print(f"index={index}, set hidden_states.requires_grad = True")
|
802 |
+
hidden_states = hidden_states.clone()
|
803 |
+
hidden_states.requires_grad = True
|
804 |
+
layer_ret = get_checkpoint_fn()(
|
805 |
+
layer,
|
806 |
+
hidden_states,
|
807 |
+
attention_mask,
|
808 |
+
position_ids,
|
809 |
+
kv_caches[index],
|
810 |
+
use_cache
|
811 |
+
)
|
812 |
+
else:
|
813 |
+
layer_ret = layer(
|
814 |
+
hidden_states,
|
815 |
+
attention_mask,
|
816 |
+
position_ids,
|
817 |
+
kv_cache=kv_caches[index],
|
818 |
+
use_cache=use_cache
|
819 |
+
)
|
820 |
+
|
821 |
+
hidden_states, kv_cache = layer_ret
|
822 |
+
if use_cache:
|
823 |
+
presents = presents + (kv_cache,)
|
824 |
+
|
825 |
+
if output_hidden_states:
|
826 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
827 |
+
|
828 |
+
# Final layer norm.
|
829 |
+
if self.post_layer_norm:
|
830 |
+
hidden_states = self.final_layernorm(hidden_states)
|
831 |
+
|
832 |
+
return hidden_states, presents, all_hidden_states, all_self_attentions
|
833 |
+
|
834 |
+
|
835 |
+
class RAGPLMPreTrainedModel(PreTrainedModel):
|
836 |
+
"""
|
837 |
+
An abstract class to handle weights initialization and
|
838 |
+
a simple interface for downloading and loading pretrained models.
|
839 |
+
"""
|
840 |
+
|
841 |
+
is_parallelizable = False
|
842 |
+
supports_gradient_checkpointing = True
|
843 |
+
config_class = RAGPLMConfig
|
844 |
+
base_model_prefix = "transformer"
|
845 |
+
_no_split_modules = ["RAGPLMBlock"]
|
846 |
+
|
847 |
+
def _init_weights(self, module: nn.Module):
|
848 |
+
"""Initialize the weights."""
|
849 |
+
return
|
850 |
+
|
851 |
+
def get_masks(self, input_ids, past_key_values, padding_mask=None):
|
852 |
+
batch_size, seq_length = input_ids.shape
|
853 |
+
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
|
854 |
+
full_attention_mask.tril_()
|
855 |
+
past_length = 0
|
856 |
+
if past_key_values:
|
857 |
+
past_length = past_key_values[0][0].shape[0]
|
858 |
+
if past_length:
|
859 |
+
full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
|
860 |
+
device=input_ids.device), full_attention_mask), dim=-1)
|
861 |
+
if padding_mask is not None:
|
862 |
+
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
863 |
+
if not past_length and padding_mask is not None:
|
864 |
+
full_attention_mask -= padding_mask.unsqueeze(-1) - 1
|
865 |
+
full_attention_mask = (full_attention_mask < 0.5).bool()
|
866 |
+
full_attention_mask.unsqueeze_(1)
|
867 |
+
return full_attention_mask
|
868 |
+
|
869 |
+
def get_position_ids(self, input_ids, device):
|
870 |
+
batch_size, seq_length = input_ids.shape
|
871 |
+
position_ids_1 = torch.zeros( seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
|
872 |
+
position_ids_2 = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
|
873 |
+
position_ids = torch.stack([position_ids_1, position_ids_2], axis=1) # [batch_size, 2, seq_len]
|
874 |
+
return position_ids
|
875 |
+
|
876 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
877 |
+
if isinstance(module, RAGPLMTransformer):
|
878 |
+
module.gradient_checkpointing = value
|
879 |
+
|
880 |
+
class Embedding(torch.nn.Module):
|
881 |
+
"""Language model embeddings."""
|
882 |
+
|
883 |
+
def __init__(self, config: RAGPLMConfig, device=None):
|
884 |
+
super(Embedding, self).__init__()
|
885 |
+
|
886 |
+
self.hidden_size = config.hidden_size
|
887 |
+
# Word embeddings (parallel).
|
888 |
+
self.word_embeddings = nn.Embedding(
|
889 |
+
config.padded_vocab_size,
|
890 |
+
self.hidden_size,
|
891 |
+
dtype=config.torch_dtype,
|
892 |
+
device=device
|
893 |
+
)
|
894 |
+
self.fp32_residual_connection = config.fp32_residual_connection
|
895 |
+
|
896 |
+
def forward(self, input_ids):
|
897 |
+
# Embeddings.
|
898 |
+
words_embeddings = self.word_embeddings(input_ids)
|
899 |
+
embeddings = words_embeddings
|
900 |
+
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
|
901 |
+
embeddings = embeddings.transpose(0, 1).contiguous()
|
902 |
+
# If the input flag for fp32 residual connection is set, convert for float.
|
903 |
+
if self.fp32_residual_connection:
|
904 |
+
embeddings = embeddings.float()
|
905 |
+
return embeddings
|
906 |
+
|
907 |
+
class RAGPLMModel(RAGPLMPreTrainedModel):
|
908 |
+
def __init__(self, config: RAGPLMConfig, device=None, empty_init=True):
|
909 |
+
super().__init__(config)
|
910 |
+
if empty_init:
|
911 |
+
init_method = skip_init
|
912 |
+
else:
|
913 |
+
init_method = default_init
|
914 |
+
init_kwargs = {}
|
915 |
+
if device is not None:
|
916 |
+
init_kwargs["device"] = device
|
917 |
+
self.embedding = init_method(Embedding, config, **init_kwargs)
|
918 |
+
self.num_layers = config.num_layers
|
919 |
+
self.multi_query_group_num = config.multi_query_group_num
|
920 |
+
self.kv_channels = config.kv_channels
|
921 |
+
|
922 |
+
self.str_emb_transform = None
|
923 |
+
self.seq_ln = None
|
924 |
+
self.str_ln = None
|
925 |
+
self.str_embedding = None
|
926 |
+
|
927 |
+
self.add_str_emb_ln = config.add_str_emb_ln
|
928 |
+
self.add_seq_emb_ln = config.add_seq_emb_ln
|
929 |
+
if config.str_input_dim is not None and config.str_input_dim > 0:
|
930 |
+
# Structure input as codebook: str_input_dim given, str_output_dim given
|
931 |
+
self.str_emb_transform = torch.nn.Linear(config.str_input_dim, config.hidden_size, bias=False)
|
932 |
+
if config.add_seq_emb_ln:
|
933 |
+
self.seq_ln = torch.nn.LayerNorm(config.hidden_size)
|
934 |
+
if config.add_str_emb_ln:
|
935 |
+
self.str_ln = torch.nn.LayerNorm(config.hidden_size)
|
936 |
+
|
937 |
+
# if config.str_input_dim is None and config.str_output_dim is not None and config.str_output_dim > 0:
|
938 |
+
if config.str_input_dim is None and config.str_vocab_size is not None and config.str_vocab_size > 0:
|
939 |
+
# Structure input as index: str_input_dim not given, str_output_dim is the vocab size, the structure embedding will be nn.Embedding(str_output_dim+1, hidden_size)
|
940 |
+
self.str_embedding = torch.nn.Embedding(config.str_vocab_size+1, config.hidden_size)
|
941 |
+
|
942 |
+
# Rotary positional embeddings
|
943 |
+
self.seq_length = config.seq_length
|
944 |
+
rotary_dim = (
|
945 |
+
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
946 |
+
)
|
947 |
+
|
948 |
+
# self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, base=10000, precision=config.torch_dtype, learnable=False)
|
949 |
+
self.encoder = init_method(RAGPLMTransformer, config, **init_kwargs)
|
950 |
+
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
951 |
+
dtype=config.torch_dtype, **init_kwargs)
|
952 |
+
|
953 |
+
if config.str_output_dim is not None and config.str_output_dim > 0:
|
954 |
+
self.output_layer_str = init_method(nn.Linear, config.hidden_size, config.str_output_dim, bias=False,
|
955 |
+
dtype=config.torch_dtype, **init_kwargs)
|
956 |
+
else:
|
957 |
+
self.output_layer_str = None
|
958 |
+
|
959 |
+
if config.qseq_output_dim is not None and config.qseq_output_dim > 0:
|
960 |
+
self.output_layer_qseq = init_method(nn.Linear, config.hidden_size, config.qseq_output_dim, bias=False,
|
961 |
+
dtype=config.torch_dtype, **init_kwargs)
|
962 |
+
else:
|
963 |
+
self.output_layer_qseq = None
|
964 |
+
|
965 |
+
def init_lora_modules(self):
|
966 |
+
for name, param in self.named_parameters():
|
967 |
+
if 'lora_linear' in name:
|
968 |
+
if '_A' in name:
|
969 |
+
torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5))
|
970 |
+
elif '_B' in name:
|
971 |
+
torch.nn.init.zeros_(param)
|
972 |
+
|
973 |
+
def get_input_embeddings(self):
|
974 |
+
return self.embedding.word_embeddings
|
975 |
+
|
976 |
+
def forward(
|
977 |
+
self,
|
978 |
+
input_ids,
|
979 |
+
position_ids: Optional[torch.Tensor] = None, # position_ids: [batch_size, 2, seq_len]
|
980 |
+
attention_mask: Optional[torch.BoolTensor] = None,
|
981 |
+
full_attention_mask: Optional[torch.BoolTensor] = None,
|
982 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
983 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
984 |
+
inputs_str_ids: Optional[torch.Tensor] = None,
|
985 |
+
inputs_str_embeds: Optional[torch.Tensor] = None,
|
986 |
+
use_cache: Optional[bool] = None,
|
987 |
+
output_hidden_states: Optional[bool] = None,
|
988 |
+
return_dict: Optional[bool] = None,
|
989 |
+
):
|
990 |
+
output_hidden_states = (
|
991 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
992 |
+
)
|
993 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
994 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
995 |
+
|
996 |
+
batch_size, seq_length = input_ids.shape
|
997 |
+
|
998 |
+
if inputs_embeds is None:
|
999 |
+
inputs_embeds = self.embedding(input_ids) # [L, B, E]
|
1000 |
+
|
1001 |
+
if self.str_emb_transform is not None and inputs_str_embeds is not None:
|
1002 |
+
# inputs_embeds: torch.Size([12800, 1, 2304]), inputs_str_embeds: torch.Size([1, 337, 384])
|
1003 |
+
assert inputs_str_embeds.ndim == 3, f"inputs_embeds: {inputs_embeds.shape}, inputs_str_embeds: {inputs_str_embeds.shape}"
|
1004 |
+
assert inputs_str_embeds.shape[0] == inputs_embeds.shape[1], f"inputs_embeds: {inputs_embeds.shape}, inputs_str_embeds: {inputs_str_embeds.shape}"
|
1005 |
+
inputs_str_embeds = inputs_str_embeds.transpose(0, 1) # [L, B, E]
|
1006 |
+
|
1007 |
+
num_res, num_batch, num_dim = inputs_str_embeds.shape
|
1008 |
+
# inputs_embeds: [L, B, E]
|
1009 |
+
padding = inputs_embeds.shape[0] - num_res
|
1010 |
+
inputs_str_embeds = F.pad(inputs_str_embeds, [0, 0, 0, 0, 0, padding], value=0)
|
1011 |
+
str_embs = self.str_emb_transform(inputs_str_embeds)
|
1012 |
+
|
1013 |
+
if self.add_str_emb_ln:
|
1014 |
+
str_embs = self.str_ln(str_embs)
|
1015 |
+
|
1016 |
+
if self.add_seq_emb_ln:
|
1017 |
+
# seq_ln only apply to the query sequence part
|
1018 |
+
inputs_embeds = torch.cat([ self.seq_ln(inputs_embeds[:num_res]), inputs_embeds[num_res:] ], dim=0)
|
1019 |
+
|
1020 |
+
inputs_embeds = inputs_embeds + str_embs
|
1021 |
+
|
1022 |
+
#if self.add_emb_ln:
|
1023 |
+
# inputs_embeds = self.seq_ln(inputs_embeds) + self.str_ln(self.str_emb_transform(inputs_str_embeds))
|
1024 |
+
#else:
|
1025 |
+
# inputs_embeds = inputs_embeds + self.str_emb_transform(inputs_str_embeds)
|
1026 |
+
|
1027 |
+
if self.str_embedding is not None and inputs_str_ids is not None:
|
1028 |
+
|
1029 |
+
|
1030 |
+
str_embedding_weight = self.str_embedding.weight # [513, 2304]
|
1031 |
+
# Add a dimension represent the padding token
|
1032 |
+
str_embedding_weight = F.pad(str_embedding_weight, (0, 0, 0, 1)) # [514, 2304]
|
1033 |
+
|
1034 |
+
assert inputs_str_ids.max() < str_embedding_weight.shape[0], f"inputs_str_ids.max()={inputs_str_ids.max()}, str_embedding_weight.shape[0]={str_embedding_weight.shape[0]}"
|
1035 |
+
|
1036 |
+
str_embs = str_embedding_weight[inputs_str_ids] # [B, L, E]
|
1037 |
+
str_embs = str_embs.permute([1, 0, 2]) # [L, B, E]
|
1038 |
+
num_res, num_batch, num_dim = str_embs.shape
|
1039 |
+
padding = inputs_embeds.shape[0] - num_res
|
1040 |
+
str_embs = F.pad(str_embs, [0, 0, 0, 0, 0, padding], value=0)
|
1041 |
+
inputs_embeds = inputs_embeds + str_embs
|
1042 |
+
|
1043 |
+
if full_attention_mask is None:
|
1044 |
+
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
1045 |
+
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
1046 |
+
|
1047 |
+
# Run encoder.
|
1048 |
+
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
1049 |
+
inputs_embeds, full_attention_mask, position_ids=position_ids,
|
1050 |
+
kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
|
1051 |
+
)
|
1052 |
+
|
1053 |
+
if not return_dict:
|
1054 |
+
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
1055 |
+
|
1056 |
+
return BaseModelOutputWithPast(
|
1057 |
+
last_hidden_state=hidden_states,
|
1058 |
+
past_key_values=presents,
|
1059 |
+
hidden_states=all_hidden_states,
|
1060 |
+
attentions=all_self_attentions,
|
1061 |
+
)
|
1062 |
+
|
1063 |
+
class RAGPLMForConditionalGeneration(RAGPLMPreTrainedModel):
|
1064 |
+
def __init__(self, config: RAGPLMConfig, empty_init=True, device=None):
|
1065 |
+
super().__init__(config)
|
1066 |
+
|
1067 |
+
self.max_sequence_length = config.max_length
|
1068 |
+
self.transformer = RAGPLMModel(config, empty_init=empty_init, device=device)
|
1069 |
+
self.config = config
|
1070 |
+
|
1071 |
+
def _update_model_kwargs_for_generation(
|
1072 |
+
self,
|
1073 |
+
outputs: ModelOutput,
|
1074 |
+
model_kwargs: Dict[str, Any],
|
1075 |
+
is_encoder_decoder: bool = False,
|
1076 |
+
standardize_cache_format: bool = False,
|
1077 |
+
) -> Dict[str, Any]:
|
1078 |
+
|
1079 |
+
# update past_key_values
|
1080 |
+
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
|
1081 |
+
outputs, standardize_cache_format=standardize_cache_format
|
1082 |
+
)
|
1083 |
+
|
1084 |
+
# update attention mask
|
1085 |
+
if "attention_mask" in model_kwargs:
|
1086 |
+
attention_mask = model_kwargs["attention_mask"]
|
1087 |
+
model_kwargs["attention_mask"] = torch.cat(
|
1088 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
1089 |
+
)
|
1090 |
+
|
1091 |
+
if 'full_attention_mask' in model_kwargs:
|
1092 |
+
raise NotImplementedError(f"full_attention_mask...")
|
1093 |
+
model_kwargs['full_attention_mask'] = F.pad(model_kwargs['full_attention_mask'], [0, 1, 0, 1])
|
1094 |
+
if self.config.is_causal:
|
1095 |
+
model_kwargs['full_attention_mask'][..., -1] = True
|
1096 |
+
|
1097 |
+
# update position ids
|
1098 |
+
if "position_ids" in model_kwargs:
|
1099 |
+
position_ids = model_kwargs["position_ids"]
|
1100 |
+
new_position_id = position_ids[..., -1:].clone() # [batch_size, 2, 1]
|
1101 |
+
if self.config.rotary_embedding_2d:
|
1102 |
+
new_position_id[:, 1] += 1 # Only update the 2nd dimension
|
1103 |
+
else:
|
1104 |
+
new_position_id[:] += 1
|
1105 |
+
model_kwargs["position_ids"] = torch.cat(
|
1106 |
+
[position_ids, new_position_id], dim=-1
|
1107 |
+
) # [batch_size, 2, seq_len+1]
|
1108 |
+
|
1109 |
+
model_kwargs["is_first_forward"] = False
|
1110 |
+
return model_kwargs
|
1111 |
+
|
1112 |
+
def prepare_inputs_for_generation(
|
1113 |
+
self,
|
1114 |
+
input_ids: torch.LongTensor,
|
1115 |
+
past_key_values: Optional[torch.Tensor] = None,
|
1116 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1117 |
+
full_attention_mask: Optional[torch.Tensor] = None,
|
1118 |
+
position_ids: Optional[torch.Tensor] = None,
|
1119 |
+
use_cache: Optional[bool] = None,
|
1120 |
+
is_first_forward: bool = True,
|
1121 |
+
**kwargs
|
1122 |
+
) -> dict:
|
1123 |
+
# only last token for input_ids if past is not None
|
1124 |
+
if position_ids is None:
|
1125 |
+
position_ids = self.get_position_ids(input_ids, device=input_ids.device) # position_ids: [batch_size, 2, seq_len]
|
1126 |
+
if not is_first_forward:
|
1127 |
+
if past_key_values is not None:
|
1128 |
+
position_ids = position_ids[..., -1:]
|
1129 |
+
input_ids = input_ids[:, -1:]
|
1130 |
+
return {
|
1131 |
+
"input_ids": input_ids,
|
1132 |
+
"past_key_values": past_key_values,
|
1133 |
+
"position_ids": position_ids,
|
1134 |
+
"attention_mask": attention_mask,
|
1135 |
+
"full_attention_mask": full_attention_mask,
|
1136 |
+
"return_last_logit": True,
|
1137 |
+
"use_cache": use_cache
|
1138 |
+
}
|
1139 |
+
|
1140 |
+
def forward(
|
1141 |
+
self,
|
1142 |
+
input_ids: Optional[torch.Tensor] = None,
|
1143 |
+
position_ids: Optional[torch.Tensor] = None,
|
1144 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1145 |
+
full_attention_mask: Optional[torch.Tensor] = None,
|
1146 |
+
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
|
1147 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
1148 |
+
labels: Optional[torch.Tensor] = None,
|
1149 |
+
use_cache: Optional[bool] = None,
|
1150 |
+
output_attentions: Optional[bool] = None,
|
1151 |
+
output_hidden_states: Optional[bool] = None,
|
1152 |
+
return_dict: Optional[bool] = None,
|
1153 |
+
return_last_logit: Optional[bool] = False,
|
1154 |
+
):
|
1155 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1156 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1157 |
+
|
1158 |
+
transformer_outputs = self.transformer(
|
1159 |
+
input_ids=input_ids,
|
1160 |
+
position_ids=position_ids, # position_ids: [batch_size, 2, seq_len]
|
1161 |
+
attention_mask=attention_mask,
|
1162 |
+
full_attention_mask=full_attention_mask,
|
1163 |
+
past_key_values=past_key_values,
|
1164 |
+
inputs_embeds=inputs_embeds,
|
1165 |
+
use_cache=use_cache,
|
1166 |
+
output_hidden_states=output_hidden_states,
|
1167 |
+
return_dict=return_dict,
|
1168 |
+
)
|
1169 |
+
|
1170 |
+
hidden_states = transformer_outputs[0]
|
1171 |
+
if return_last_logit:
|
1172 |
+
hidden_states = hidden_states[-1:]
|
1173 |
+
lm_logits = self.transformer.output_layer(hidden_states)
|
1174 |
+
# output_layer_str
|
1175 |
+
lm_logits = lm_logits.transpose(0, 1).contiguous()
|
1176 |
+
|
1177 |
+
loss = None
|
1178 |
+
if labels is not None:
|
1179 |
+
lm_logits = lm_logits.to(torch.float32)
|
1180 |
+
|
1181 |
+
# Shift so that tokens < n predict n
|
1182 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
1183 |
+
shift_labels = labels[..., 1:].contiguous()
|
1184 |
+
# Flatten the tokens
|
1185 |
+
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
1186 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
1187 |
+
|
1188 |
+
lm_logits = lm_logits.to(hidden_states.dtype)
|
1189 |
+
loss = loss.to(hidden_states.dtype)
|
1190 |
+
|
1191 |
+
if not return_dict:
|
1192 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
1193 |
+
return ((loss,) + output) if loss is not None else output
|
1194 |
+
|
1195 |
+
return CausalLMOutputWithPast(
|
1196 |
+
loss=loss,
|
1197 |
+
logits=lm_logits,
|
1198 |
+
past_key_values=transformer_outputs.past_key_values,
|
1199 |
+
hidden_states=transformer_outputs.hidden_states,
|
1200 |
+
attentions=transformer_outputs.attentions,
|
1201 |
+
)
|
1202 |
+
|
1203 |
+
@staticmethod
|
1204 |
+
def _reorder_cache(
|
1205 |
+
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
1206 |
+
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
1207 |
+
"""
|
1208 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
1209 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
1210 |
+
beam_idx at every generation step.
|
1211 |
+
|
1212 |
+
Output shares the same memory storage as `past`.
|
1213 |
+
"""
|
1214 |
+
return tuple(
|
1215 |
+
(
|
1216 |
+
layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
|
1217 |
+
layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
|
1218 |
+
)
|
1219 |
+
for layer_past in past
|
1220 |
+
)
|
1221 |
+
|
1222 |
+
def process_response(self, output, history):
|
1223 |
+
content = ""
|
1224 |
+
history = deepcopy(history)
|
1225 |
+
for response in output.split("<|assistant|>"):
|
1226 |
+
if "\n" in response:
|
1227 |
+
metadata, content = response.split("\n", maxsplit=1)
|
1228 |
+
else:
|
1229 |
+
metadata, content = "", response
|
1230 |
+
if not metadata.strip():
|
1231 |
+
content = content.strip()
|
1232 |
+
history.append({"role": "assistant", "metadata": metadata, "content": content})
|
1233 |
+
content = content.replace("[[训练时间]]", "2023年")
|
1234 |
+
else:
|
1235 |
+
history.append({"role": "assistant", "metadata": metadata, "content": content})
|
1236 |
+
if history[0]["role"] == "system" and "tools" in history[0]:
|
1237 |
+
content = "\n".join(content.split("\n")[1:-1])
|
1238 |
+
def tool_call(**kwargs):
|
1239 |
+
return kwargs
|
1240 |
+
parameters = eval(content)
|
1241 |
+
content = {"name": metadata.strip(), "parameters": parameters}
|
1242 |
+
else:
|
1243 |
+
content = {"name": metadata.strip(), "content": content}
|
1244 |
+
return content, history
|
1245 |
+
|
1246 |
+
@torch.inference_mode()
|
1247 |
+
def chat(self, tokenizer, query: str, max_length: int = 2048, num_beams=1,
|
1248 |
+
do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
|
1249 |
+
if logits_processor is None:
|
1250 |
+
logits_processor = LogitsProcessorList()
|
1251 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
1252 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
1253 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
1254 |
+
inputs = tokenizer.build_chat_input(query)
|
1255 |
+
inputs = inputs.to(self.device)
|
1256 |
+
eos_token_id = [tokenizer.eos_token_id]
|
1257 |
+
outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
|
1258 |
+
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
|
1259 |
+
response = tokenizer.decode(outputs)
|
1260 |
+
return response
|
performance.png
ADDED
![]() |
Git LFS Details
|
proteinmoe_architecture.png
ADDED
![]() |
Git LFS Details
|
pytorch_model-00001-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2442d42059efa680a6dfcb0098629bb1661582a0b61b1eed4e6ed472dbf1d7b9
|
3 |
+
size 4932953760
|
pytorch_model-00002-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bb9e891095ef8df462d4c192accdf3650764acecf8f3512568c8d25c66ea60ea
|
3 |
+
size 4999044744
|
pytorch_model-00003-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ae85274ab0bc8fcd790fdb55b0993d6bafffe153511d82047b7c12bca28f225e
|
3 |
+
size 4991905042
|
pytorch_model-00004-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0d798373023367553ae425c7e57170ee7dce7ac516aec6edabd8c894f6842731
|
3 |
+
size 4963629472
|
pytorch_model-00005-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:900893052b8f4b2a5e939c777a3c9a07b1cc1e52efdb7f688231a8cd650ef89a
|
3 |
+
size 4991904870
|
pytorch_model-00006-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f2c940f8773468bc21e771d98f35be0250a4ba2ae77e5429f56cd1cbe8b55c84
|
3 |
+
size 4999045148
|
pytorch_model-00007-of-00007.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9a45ba9f924d8ab3f664f8365f07c40b1247f824acbf662629b324a2022af928
|
3 |
+
size 2249880965
|
pytorch_model.bin.index.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{}
|
tokenization.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence, Tuple, List, Union, Optional
|
2 |
+
from abc import ABC
|
3 |
+
from abc import abstractmethod
|
4 |
+
# from .tokenizer import AbstractTokenizer
|
5 |
+
import logging
|
6 |
+
import itertools
|
7 |
+
from transformers import PreTrainedTokenizer
|
8 |
+
import torch
|
9 |
+
import json
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
class ResidueLevelTokenizer(object):
|
15 |
+
"""
|
16 |
+
Tokenizer for Protein Residue Level Tokenization.
|
17 |
+
"""
|
18 |
+
def __init__(self, **kwargs):
|
19 |
+
super(ResidueLevelTokenizer, self).__init__()
|
20 |
+
|
21 |
+
### Set normal tokens
|
22 |
+
self.all_toks = ['[pad]', 'L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-']
|
23 |
+
|
24 |
+
### Set special tokens
|
25 |
+
_special_tokens = ['tMASK', 'gMASK', 'sMASK', 'eod', 'sop', 'eop', '</s>' ] # + ['MSA', 'ID'] + [ str(d) for d in range(0, 64) ]
|
26 |
+
self.special_tokens = { tok: len(self.all_toks)+i for i,tok in enumerate(_special_tokens) }
|
27 |
+
self.special_tokens_decoder = { v:k for k, v in self.special_tokens.items() }
|
28 |
+
self.special_tokens['eos'] = self.special_tokens['</s>']
|
29 |
+
self.all_toks.extend(_special_tokens)
|
30 |
+
|
31 |
+
self.vocab = {tok:idx for idx,tok in enumerate(self.all_toks)}
|
32 |
+
self.command_token = {'[MASK]':'MASK', '[gMASK]': 'gMASK', '[sMASK]':'sMASK'} # , '[MSA]':'MSA', '[ID]':'ID'}
|
33 |
+
|
34 |
+
self.gMASK_token_id = self.convert_token_to_id('gMASK')
|
35 |
+
self.sop_token_id = self.convert_token_to_id('sop')
|
36 |
+
self.eos_token_id = self.convert_token_to_id('</s>')
|
37 |
+
# self.id_token_id = self.convert_token_to_id('ID')
|
38 |
+
self.pad_token_id = self.convert_token_to_id('[pad]')
|
39 |
+
|
40 |
+
def __len__(self):
|
41 |
+
return len(self.vocab)
|
42 |
+
|
43 |
+
def get_special_token(self, token):
|
44 |
+
return self.special_tokens[token]
|
45 |
+
|
46 |
+
def get_vocab(self):
|
47 |
+
return self.vocab
|
48 |
+
|
49 |
+
def convert_id_to_token(self, idx):
|
50 |
+
idx = int(idx)
|
51 |
+
if idx == 0:
|
52 |
+
return '[pad]'
|
53 |
+
elif idx in self.special_tokens_decoder:
|
54 |
+
return f"[{self.special_tokens_decoder[idx]}]"
|
55 |
+
else:
|
56 |
+
return self.all_toks[idx]
|
57 |
+
|
58 |
+
def convert_token_to_id(self, token):
|
59 |
+
if token == '[pad]':
|
60 |
+
return 0
|
61 |
+
elif token in self.special_tokens:
|
62 |
+
return self.special_tokens[token]
|
63 |
+
else:
|
64 |
+
return self.vocab[token]
|
65 |
+
|
66 |
+
def encode(self, sequence, add_eos=True):
|
67 |
+
"""
|
68 |
+
Encode string or list of tokens into array
|
69 |
+
|
70 |
+
Examples
|
71 |
+
----------------
|
72 |
+
encode('[pad]ABDWOKOAKOQA[pad][MSA][3][2][19][3]')
|
73 |
+
encode(['A', 'B', 'D', 'MSA', '34', '2'])
|
74 |
+
"""
|
75 |
+
|
76 |
+
all_toks = set(self.all_toks)
|
77 |
+
a2b = {f"[{t}]":t for t in self.special_tokens if t[0] not in ('[', )}
|
78 |
+
all_toks.update( set(a2b.keys()) )
|
79 |
+
|
80 |
+
if isinstance(sequence, (tuple, list)):
|
81 |
+
if sequence[-1] != '</s>' and add_eos:
|
82 |
+
sequence = sequence + ['</s>']
|
83 |
+
sequence = [ a2b.get(tok, tok) for tok in sequence ]
|
84 |
+
return np.array([ self.convert_token_to_id(t) for t in sequence ])
|
85 |
+
elif isinstance(sequence, str):
|
86 |
+
if not sequence.endswith('</s>') and add_eos:
|
87 |
+
sequence = sequence + '</s>'
|
88 |
+
s = 0
|
89 |
+
e = 1
|
90 |
+
tok_list = []
|
91 |
+
while s < len(sequence):
|
92 |
+
while sequence[s:e] not in all_toks and e < len(sequence):
|
93 |
+
e += 1
|
94 |
+
assert sequence[s:e] in all_toks, f"Error: sub sequence {sequence[s:]} cannot be parsed"
|
95 |
+
tok = sequence[s:e]
|
96 |
+
tok = a2b.get(tok, tok) # [gMASK], [sMASK] ...
|
97 |
+
tok_id = self.convert_token_to_id(tok)
|
98 |
+
tok_list.append(tok_id)
|
99 |
+
s = e
|
100 |
+
return np.array(tok_list)
|
101 |
+
else:
|
102 |
+
raise RuntimeError(f"Error: sequence must be list/tuple/str, but got {type(sequence)}")
|
103 |
+
|
104 |
+
def decode(self, tokens, rem_eos=True, return_str=True):
|
105 |
+
if tokens[-1] == self.eos_token_id and rem_eos:
|
106 |
+
tokens = tokens[:-1]
|
107 |
+
if return_str:
|
108 |
+
return "".join([ self.convert_id_to_token(tok) for tok in tokens ])
|
109 |
+
else:
|
110 |
+
return [ self.convert_id_to_token(tok) for tok in tokens ]
|
111 |
+
|
112 |
+
def tokenize(self, text, add_eos=True):
|
113 |
+
return self.encode(text, add_eos=add_eos)
|
114 |
+
|
115 |
+
def extend_vocab(self, tokens):
|
116 |
+
"""Extend the vocab with the list of tokens."""
|
117 |
+
for token in tokens:
|
118 |
+
if token not in self.vocab:
|
119 |
+
self.vocab[token] = len(self.vocab)
|
120 |
+
self.all_toks.append(token)
|
121 |
+
|
122 |
+
class ProteinTokenizer(PreTrainedTokenizer):
|
123 |
+
"""
|
124 |
+
Protein Tokenizer based on Residue level tokenizer
|
125 |
+
"""
|
126 |
+
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
vocab_file='xxx',
|
130 |
+
padding_side="right",
|
131 |
+
clean_up_tokenization_spaces=False,
|
132 |
+
encode_special_tokens=True,
|
133 |
+
**kwargs
|
134 |
+
):
|
135 |
+
self.name = "ProteinTokenizer"
|
136 |
+
self.vocab_file = vocab_file
|
137 |
+
self.tokenizer = ResidueLevelTokenizer()
|
138 |
+
self.special_tokens = self.tokenizer.special_tokens
|
139 |
+
self.encode_special_tokens = encode_special_tokens
|
140 |
+
|
141 |
+
super().__init__(
|
142 |
+
padding_side=padding_side,
|
143 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
144 |
+
**kwargs
|
145 |
+
)
|
146 |
+
|
147 |
+
def get_command(self, token):
|
148 |
+
if token in self.special_tokens:
|
149 |
+
return self.special_tokens[token]
|
150 |
+
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
|
151 |
+
return self.tokenizer.special_tokens[token]
|
152 |
+
|
153 |
+
@property
|
154 |
+
def unk_token(self) -> str:
|
155 |
+
return '[pad]'
|
156 |
+
|
157 |
+
@property
|
158 |
+
def pad_token(self) -> str:
|
159 |
+
return '[pad]'
|
160 |
+
|
161 |
+
@property
|
162 |
+
def eos_token(self) -> str:
|
163 |
+
return '</s>'
|
164 |
+
|
165 |
+
@property
|
166 |
+
def unk_token_id(self) -> int:
|
167 |
+
return '[pad]'
|
168 |
+
|
169 |
+
@property
|
170 |
+
def pad_token_id(self) -> int:
|
171 |
+
return self.tokenizer.pad_token_id
|
172 |
+
|
173 |
+
@property
|
174 |
+
def eos_token_id(self):
|
175 |
+
return self.tokenizer.eos_token_id
|
176 |
+
|
177 |
+
@property
|
178 |
+
def gMASK_token_id(self):
|
179 |
+
return self.tokenizer.gMASK_token_id
|
180 |
+
|
181 |
+
@property
|
182 |
+
def sop_token_id(self):
|
183 |
+
return self.tokenizer.sop_token_id
|
184 |
+
|
185 |
+
@property
|
186 |
+
def id_token_id(self):
|
187 |
+
return self.tokenizer.id_token_id
|
188 |
+
|
189 |
+
def IdToToken(self, id_):
|
190 |
+
return self.tokenizer.convert_id_to_token(id_)
|
191 |
+
|
192 |
+
def TokenToId(self, token):
|
193 |
+
return self.tokenizer.convert_token_to_id(token)
|
194 |
+
|
195 |
+
@unk_token.setter
|
196 |
+
def unk_token(self, value):
|
197 |
+
logger.warning("Setting unk_token is not supported, use the default one.")
|
198 |
+
|
199 |
+
@pad_token.setter
|
200 |
+
def pad_token(self, value):
|
201 |
+
logger.warning("Setting pad_token is not supported, use the default one.")
|
202 |
+
|
203 |
+
@eos_token.setter
|
204 |
+
def eos_token(self, value):
|
205 |
+
logger.warning("Setting eos_token is not supported, use the default one.")
|
206 |
+
|
207 |
+
@property
|
208 |
+
def vocab_size(self):
|
209 |
+
return len(self.tokenizer)
|
210 |
+
|
211 |
+
def encode(self, sequence, add_eos=True):
|
212 |
+
return self.tokenizer.encode(sequence, add_eos=add_eos)
|
213 |
+
|
214 |
+
def decode(self, token_ids, rem_eos=True, return_str=True):
|
215 |
+
return self.tokenizer.decode(token_ids, rem_eos=rem_eos, return_str=return_str)
|
216 |
+
|
217 |
+
def _convert_id_to_token(self, index):
|
218 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
219 |
+
return self.tokenizer.convert_id_to_token(index)
|
220 |
+
|
221 |
+
def get_vocab(self):
|
222 |
+
""" Returns vocab as a dict """
|
223 |
+
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
224 |
+
return vocab
|
225 |
+
|
226 |
+
@property
|
227 |
+
def eod(self):
|
228 |
+
return self.tokenizer.get_special_token('eos')
|
229 |
+
|
230 |
+
def detokenize(self, Ids, type_token=False):
|
231 |
+
new_tokens = self.tokenizer.decode(Ids)
|
232 |
+
return new_tokens
|
233 |
+
|
234 |
+
def tokenize(self, text):
|
235 |
+
ids = self.tokenizer.tokenize(text)
|
236 |
+
return ids
|
237 |
+
|
238 |
+
def extend_vocab(self, tokens):
|
239 |
+
""" Extend the vocab with the list of tokens """
|
240 |
+
self.tokenizer.extend_vocab(tokens)
|
241 |
+
|
242 |
+
def add_retriever_tokens(self):
|
243 |
+
retriever_tokens = ['MSA', 'ID'] + [ str(d) for d in range(0, 64) ]
|
244 |
+
self.tokenizer.extend_vocab(retriever_tokens)
|
245 |
+
self.tokenizer.command_token['[MSA]'] = 'MSA'
|
246 |
+
self.tokenizer.command_token['[ID]'] = 'ID'
|
247 |
+
|
248 |
+
def add_structure_tokens(self, codebook_size):
|
249 |
+
self.tokenizer.extend_vocab( [ str(i) for i in range(codebook_size) ] )
|
250 |
+
|
251 |
+
def build_chat_input(self, query):
|
252 |
+
input_ids = [ self.tokenizer.convert_token_to_id('gMASK'), self.tokenizer.convert_token_to_id('sop') ]
|
253 |
+
input_ids += [ self.tokenizer.convert_token_to_id(tok) for tok in query ]
|
254 |
+
input_ids += [ self.tokenizer.convert_token_to_id('ID') ]
|
255 |
+
# return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
|
256 |
+
|
257 |
+
position_ids = torch.stack([torch.zeros(len(input_ids)), torch.arange(len(input_ids))], axis=0).unsqueeze(0).long()
|
258 |
+
return {
|
259 |
+
'input_ids': torch.from_numpy(np.array([ input_ids ])).long(),
|
260 |
+
'attention_mask': None,
|
261 |
+
'position_ids': position_ids
|
262 |
+
}
|
263 |
+
|
264 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
265 |
+
vocab = self.get_vocab()
|
266 |
+
with open(f"{save_directory}/vocab.json", 'w') as f:
|
267 |
+
json.dump(vocab, f, indent=4)
|
268 |
+
return ( f"{save_directory}/vocab.json", )
|
tokenizer_config.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {},
|
3 |
+
"auto_map": {
|
4 |
+
"AutoTokenizer": [
|
5 |
+
"tokenization.ProteinTokenizer",
|
6 |
+
null
|
7 |
+
]
|
8 |
+
},
|
9 |
+
"clean_up_tokenization_spaces": false,
|
10 |
+
"do_lower_case": false,
|
11 |
+
"eos_token": "</s>",
|
12 |
+
"extra_special_tokens": {},
|
13 |
+
"model_max_length": 1000000000000000019884624838656,
|
14 |
+
"pad_token": "[pad]",
|
15 |
+
"padding_side": "right",
|
16 |
+
"remove_space": false,
|
17 |
+
"tokenizer_class": "ProteinTokenizer",
|
18 |
+
"unk_token": "[pad]"
|
19 |
+
}
|
vocab.json
ADDED
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"[pad]": 0,
|
3 |
+
"L": 1,
|
4 |
+
"A": 2,
|
5 |
+
"G": 3,
|
6 |
+
"V": 4,
|
7 |
+
"S": 5,
|
8 |
+
"E": 6,
|
9 |
+
"R": 7,
|
10 |
+
"T": 8,
|
11 |
+
"I": 9,
|
12 |
+
"D": 10,
|
13 |
+
"P": 11,
|
14 |
+
"K": 12,
|
15 |
+
"Q": 13,
|
16 |
+
"N": 14,
|
17 |
+
"F": 15,
|
18 |
+
"Y": 16,
|
19 |
+
"M": 17,
|
20 |
+
"H": 18,
|
21 |
+
"W": 19,
|
22 |
+
"C": 20,
|
23 |
+
"X": 21,
|
24 |
+
"B": 22,
|
25 |
+
"U": 23,
|
26 |
+
"Z": 24,
|
27 |
+
"O": 25,
|
28 |
+
".": 26,
|
29 |
+
"-": 27,
|
30 |
+
"[tMASK]": 28,
|
31 |
+
"[gMASK]": 29,
|
32 |
+
"[sMASK]": 30,
|
33 |
+
"[eod]": 31,
|
34 |
+
"[sop]": 32,
|
35 |
+
"[eop]": 33,
|
36 |
+
"[</s>]": 34,
|
37 |
+
"0": 35,
|
38 |
+
"1": 36,
|
39 |
+
"2": 37,
|
40 |
+
"3": 38,
|
41 |
+
"4": 39,
|
42 |
+
"5": 40,
|
43 |
+
"6": 41,
|
44 |
+
"7": 42,
|
45 |
+
"8": 43,
|
46 |
+
"9": 44,
|
47 |
+
"10": 45,
|
48 |
+
"11": 46,
|
49 |
+
"12": 47,
|
50 |
+
"13": 48,
|
51 |
+
"14": 49,
|
52 |
+
"15": 50,
|
53 |
+
"16": 51,
|
54 |
+
"17": 52,
|
55 |
+
"18": 53,
|
56 |
+
"19": 54,
|
57 |
+
"20": 55,
|
58 |
+
"21": 56,
|
59 |
+
"22": 57,
|
60 |
+
"23": 58,
|
61 |
+
"24": 59,
|
62 |
+
"25": 60,
|
63 |
+
"26": 61,
|
64 |
+
"27": 62,
|
65 |
+
"28": 63,
|
66 |
+
"29": 64,
|
67 |
+
"30": 65,
|
68 |
+
"31": 66,
|
69 |
+
"32": 67,
|
70 |
+
"33": 68,
|
71 |
+
"34": 69,
|
72 |
+
"35": 70,
|
73 |
+
"36": 71,
|
74 |
+
"37": 72,
|
75 |
+
"38": 73,
|
76 |
+
"39": 74,
|
77 |
+
"40": 75,
|
78 |
+
"41": 76,
|
79 |
+
"42": 77,
|
80 |
+
"43": 78,
|
81 |
+
"44": 79,
|
82 |
+
"45": 80,
|
83 |
+
"46": 81,
|
84 |
+
"47": 82,
|
85 |
+
"48": 83,
|
86 |
+
"49": 84,
|
87 |
+
"50": 85,
|
88 |
+
"51": 86,
|
89 |
+
"52": 87,
|
90 |
+
"53": 88,
|
91 |
+
"54": 89,
|
92 |
+
"55": 90,
|
93 |
+
"56": 91,
|
94 |
+
"57": 92,
|
95 |
+
"58": 93,
|
96 |
+
"59": 94,
|
97 |
+
"60": 95,
|
98 |
+
"61": 96,
|
99 |
+
"62": 97,
|
100 |
+
"63": 98,
|
101 |
+
"64": 99,
|
102 |
+
"65": 100,
|
103 |
+
"66": 101,
|
104 |
+
"67": 102,
|
105 |
+
"68": 103,
|
106 |
+
"69": 104,
|
107 |
+
"70": 105,
|
108 |
+
"71": 106,
|
109 |
+
"72": 107,
|
110 |
+
"73": 108,
|
111 |
+
"74": 109,
|
112 |
+
"75": 110,
|
113 |
+
"76": 111,
|
114 |
+
"77": 112,
|
115 |
+
"78": 113,
|
116 |
+
"79": 114,
|
117 |
+
"80": 115,
|
118 |
+
"81": 116,
|
119 |
+
"82": 117,
|
120 |
+
"83": 118,
|
121 |
+
"84": 119,
|
122 |
+
"85": 120,
|
123 |
+
"86": 121,
|
124 |
+
"87": 122,
|
125 |
+
"88": 123,
|
126 |
+
"89": 124,
|
127 |
+
"90": 125,
|
128 |
+
"91": 126,
|
129 |
+
"92": 127,
|
130 |
+
"93": 128,
|
131 |
+
"94": 129,
|
132 |
+
"95": 130,
|
133 |
+
"96": 131,
|
134 |
+
"97": 132,
|
135 |
+
"98": 133,
|
136 |
+
"99": 134,
|
137 |
+
"100": 135,
|
138 |
+
"101": 136,
|
139 |
+
"102": 137,
|
140 |
+
"103": 138,
|
141 |
+
"104": 139,
|
142 |
+
"105": 140,
|
143 |
+
"106": 141,
|
144 |
+
"107": 142,
|
145 |
+
"108": 143,
|
146 |
+
"109": 144,
|
147 |
+
"110": 145,
|
148 |
+
"111": 146,
|
149 |
+
"112": 147,
|
150 |
+
"113": 148,
|
151 |
+
"114": 149,
|
152 |
+
"115": 150,
|
153 |
+
"116": 151,
|
154 |
+
"117": 152,
|
155 |
+
"118": 153,
|
156 |
+
"119": 154,
|
157 |
+
"120": 155,
|
158 |
+
"121": 156,
|
159 |
+
"122": 157,
|
160 |
+
"123": 158,
|
161 |
+
"124": 159,
|
162 |
+
"125": 160,
|
163 |
+
"126": 161,
|
164 |
+
"127": 162,
|
165 |
+
"128": 163,
|
166 |
+
"129": 164,
|
167 |
+
"130": 165,
|
168 |
+
"131": 166,
|
169 |
+
"132": 167,
|
170 |
+
"133": 168,
|
171 |
+
"134": 169,
|
172 |
+
"135": 170,
|
173 |
+
"136": 171,
|
174 |
+
"137": 172,
|
175 |
+
"138": 173,
|
176 |
+
"139": 174,
|
177 |
+
"140": 175,
|
178 |
+
"141": 176,
|
179 |
+
"142": 177,
|
180 |
+
"143": 178,
|
181 |
+
"144": 179,
|
182 |
+
"145": 180,
|
183 |
+
"146": 181,
|
184 |
+
"147": 182,
|
185 |
+
"148": 183,
|
186 |
+
"149": 184,
|
187 |
+
"150": 185,
|
188 |
+
"151": 186,
|
189 |
+
"152": 187,
|
190 |
+
"153": 188,
|
191 |
+
"154": 189,
|
192 |
+
"155": 190,
|
193 |
+
"156": 191,
|
194 |
+
"157": 192,
|
195 |
+
"158": 193,
|
196 |
+
"159": 194,
|
197 |
+
"160": 195,
|
198 |
+
"161": 196,
|
199 |
+
"162": 197,
|
200 |
+
"163": 198,
|
201 |
+
"164": 199,
|
202 |
+
"165": 200,
|
203 |
+
"166": 201,
|
204 |
+
"167": 202,
|
205 |
+
"168": 203,
|
206 |
+
"169": 204,
|
207 |
+
"170": 205,
|
208 |
+
"171": 206,
|
209 |
+
"172": 207,
|
210 |
+
"173": 208,
|
211 |
+
"174": 209,
|
212 |
+
"175": 210,
|
213 |
+
"176": 211,
|
214 |
+
"177": 212,
|
215 |
+
"178": 213,
|
216 |
+
"179": 214,
|
217 |
+
"180": 215,
|
218 |
+
"181": 216,
|
219 |
+
"182": 217,
|
220 |
+
"183": 218,
|
221 |
+
"184": 219,
|
222 |
+
"185": 220,
|
223 |
+
"186": 221,
|
224 |
+
"187": 222,
|
225 |
+
"188": 223,
|
226 |
+
"189": 224,
|
227 |
+
"190": 225,
|
228 |
+
"191": 226,
|
229 |
+
"192": 227,
|
230 |
+
"193": 228,
|
231 |
+
"194": 229,
|
232 |
+
"195": 230,
|
233 |
+
"196": 231,
|
234 |
+
"197": 232,
|
235 |
+
"198": 233,
|
236 |
+
"199": 234,
|
237 |
+
"200": 235,
|
238 |
+
"201": 236,
|
239 |
+
"202": 237,
|
240 |
+
"203": 238,
|
241 |
+
"204": 239,
|
242 |
+
"205": 240,
|
243 |
+
"206": 241,
|
244 |
+
"207": 242,
|
245 |
+
"208": 243,
|
246 |
+
"209": 244,
|
247 |
+
"210": 245,
|
248 |
+
"211": 246,
|
249 |
+
"212": 247,
|
250 |
+
"213": 248,
|
251 |
+
"214": 249,
|
252 |
+
"215": 250,
|
253 |
+
"216": 251,
|
254 |
+
"217": 252,
|
255 |
+
"218": 253,
|
256 |
+
"219": 254,
|
257 |
+
"220": 255,
|
258 |
+
"221": 256,
|
259 |
+
"222": 257,
|
260 |
+
"223": 258,
|
261 |
+
"224": 259,
|
262 |
+
"225": 260,
|
263 |
+
"226": 261,
|
264 |
+
"227": 262,
|
265 |
+
"228": 263,
|
266 |
+
"229": 264,
|
267 |
+
"230": 265,
|
268 |
+
"231": 266,
|
269 |
+
"232": 267,
|
270 |
+
"233": 268,
|
271 |
+
"234": 269,
|
272 |
+
"235": 270,
|
273 |
+
"236": 271,
|
274 |
+
"237": 272,
|
275 |
+
"238": 273,
|
276 |
+
"239": 274,
|
277 |
+
"240": 275,
|
278 |
+
"241": 276,
|
279 |
+
"242": 277,
|
280 |
+
"243": 278,
|
281 |
+
"244": 279,
|
282 |
+
"245": 280,
|
283 |
+
"246": 281,
|
284 |
+
"247": 282,
|
285 |
+
"248": 283,
|
286 |
+
"249": 284,
|
287 |
+
"250": 285,
|
288 |
+
"251": 286,
|
289 |
+
"252": 287,
|
290 |
+
"253": 288,
|
291 |
+
"254": 289,
|
292 |
+
"255": 290,
|
293 |
+
"256": 291,
|
294 |
+
"257": 292,
|
295 |
+
"258": 293,
|
296 |
+
"259": 294,
|
297 |
+
"260": 295,
|
298 |
+
"261": 296,
|
299 |
+
"262": 297,
|
300 |
+
"263": 298,
|
301 |
+
"264": 299,
|
302 |
+
"265": 300,
|
303 |
+
"266": 301,
|
304 |
+
"267": 302,
|
305 |
+
"268": 303,
|
306 |
+
"269": 304,
|
307 |
+
"270": 305,
|
308 |
+
"271": 306,
|
309 |
+
"272": 307,
|
310 |
+
"273": 308,
|
311 |
+
"274": 309,
|
312 |
+
"275": 310,
|
313 |
+
"276": 311,
|
314 |
+
"277": 312,
|
315 |
+
"278": 313,
|
316 |
+
"279": 314,
|
317 |
+
"280": 315,
|
318 |
+
"281": 316,
|
319 |
+
"282": 317,
|
320 |
+
"283": 318,
|
321 |
+
"284": 319,
|
322 |
+
"285": 320,
|
323 |
+
"286": 321,
|
324 |
+
"287": 322,
|
325 |
+
"288": 323,
|
326 |
+
"289": 324,
|
327 |
+
"290": 325,
|
328 |
+
"291": 326,
|
329 |
+
"292": 327,
|
330 |
+
"293": 328,
|
331 |
+
"294": 329,
|
332 |
+
"295": 330,
|
333 |
+
"296": 331,
|
334 |
+
"297": 332,
|
335 |
+
"298": 333,
|
336 |
+
"299": 334,
|
337 |
+
"300": 335,
|
338 |
+
"301": 336,
|
339 |
+
"302": 337,
|
340 |
+
"303": 338,
|
341 |
+
"304": 339,
|
342 |
+
"305": 340,
|
343 |
+
"306": 341,
|
344 |
+
"307": 342,
|
345 |
+
"308": 343,
|
346 |
+
"309": 344,
|
347 |
+
"310": 345,
|
348 |
+
"311": 346,
|
349 |
+
"312": 347,
|
350 |
+
"313": 348,
|
351 |
+
"314": 349,
|
352 |
+
"315": 350,
|
353 |
+
"316": 351,
|
354 |
+
"317": 352,
|
355 |
+
"318": 353,
|
356 |
+
"319": 354,
|
357 |
+
"320": 355,
|
358 |
+
"321": 356,
|
359 |
+
"322": 357,
|
360 |
+
"323": 358,
|
361 |
+
"324": 359,
|
362 |
+
"325": 360,
|
363 |
+
"326": 361,
|
364 |
+
"327": 362,
|
365 |
+
"328": 363,
|
366 |
+
"329": 364,
|
367 |
+
"330": 365,
|
368 |
+
"331": 366,
|
369 |
+
"332": 367,
|
370 |
+
"333": 368,
|
371 |
+
"334": 369,
|
372 |
+
"335": 370,
|
373 |
+
"336": 371,
|
374 |
+
"337": 372,
|
375 |
+
"338": 373,
|
376 |
+
"339": 374,
|
377 |
+
"340": 375,
|
378 |
+
"341": 376,
|
379 |
+
"342": 377,
|
380 |
+
"343": 378,
|
381 |
+
"344": 379,
|
382 |
+
"345": 380,
|
383 |
+
"346": 381,
|
384 |
+
"347": 382,
|
385 |
+
"348": 383,
|
386 |
+
"349": 384,
|
387 |
+
"350": 385,
|
388 |
+
"351": 386,
|
389 |
+
"352": 387,
|
390 |
+
"353": 388,
|
391 |
+
"354": 389,
|
392 |
+
"355": 390,
|
393 |
+
"356": 391,
|
394 |
+
"357": 392,
|
395 |
+
"358": 393,
|
396 |
+
"359": 394,
|
397 |
+
"360": 395,
|
398 |
+
"361": 396,
|
399 |
+
"362": 397,
|
400 |
+
"363": 398,
|
401 |
+
"364": 399,
|
402 |
+
"365": 400,
|
403 |
+
"366": 401,
|
404 |
+
"367": 402,
|
405 |
+
"368": 403,
|
406 |
+
"369": 404,
|
407 |
+
"370": 405,
|
408 |
+
"371": 406,
|
409 |
+
"372": 407,
|
410 |
+
"373": 408,
|
411 |
+
"374": 409,
|
412 |
+
"375": 410,
|
413 |
+
"376": 411,
|
414 |
+
"377": 412,
|
415 |
+
"378": 413,
|
416 |
+
"379": 414,
|
417 |
+
"380": 415,
|
418 |
+
"381": 416,
|
419 |
+
"382": 417,
|
420 |
+
"383": 418,
|
421 |
+
"384": 419,
|
422 |
+
"385": 420,
|
423 |
+
"386": 421,
|
424 |
+
"387": 422,
|
425 |
+
"388": 423,
|
426 |
+
"389": 424,
|
427 |
+
"390": 425,
|
428 |
+
"391": 426,
|
429 |
+
"392": 427,
|
430 |
+
"393": 428,
|
431 |
+
"394": 429,
|
432 |
+
"395": 430,
|
433 |
+
"396": 431,
|
434 |
+
"397": 432,
|
435 |
+
"398": 433,
|
436 |
+
"399": 434,
|
437 |
+
"400": 435,
|
438 |
+
"401": 436,
|
439 |
+
"402": 437,
|
440 |
+
"403": 438,
|
441 |
+
"404": 439,
|
442 |
+
"405": 440,
|
443 |
+
"406": 441,
|
444 |
+
"407": 442,
|
445 |
+
"408": 443,
|
446 |
+
"409": 444,
|
447 |
+
"410": 445,
|
448 |
+
"411": 446,
|
449 |
+
"412": 447,
|
450 |
+
"413": 448,
|
451 |
+
"414": 449,
|
452 |
+
"415": 450,
|
453 |
+
"416": 451,
|
454 |
+
"417": 452,
|
455 |
+
"418": 453,
|
456 |
+
"419": 454,
|
457 |
+
"420": 455,
|
458 |
+
"421": 456,
|
459 |
+
"422": 457,
|
460 |
+
"423": 458,
|
461 |
+
"424": 459,
|
462 |
+
"425": 460,
|
463 |
+
"426": 461,
|
464 |
+
"427": 462,
|
465 |
+
"428": 463,
|
466 |
+
"429": 464,
|
467 |
+
"430": 465,
|
468 |
+
"431": 466,
|
469 |
+
"432": 467,
|
470 |
+
"433": 468,
|
471 |
+
"434": 469,
|
472 |
+
"435": 470,
|
473 |
+
"436": 471,
|
474 |
+
"437": 472,
|
475 |
+
"438": 473,
|
476 |
+
"439": 474,
|
477 |
+
"440": 475,
|
478 |
+
"441": 476,
|
479 |
+
"442": 477,
|
480 |
+
"443": 478,
|
481 |
+
"444": 479,
|
482 |
+
"445": 480,
|
483 |
+
"446": 481,
|
484 |
+
"447": 482,
|
485 |
+
"448": 483,
|
486 |
+
"449": 484,
|
487 |
+
"450": 485,
|
488 |
+
"451": 486,
|
489 |
+
"452": 487,
|
490 |
+
"453": 488,
|
491 |
+
"454": 489,
|
492 |
+
"455": 490,
|
493 |
+
"456": 491,
|
494 |
+
"457": 492,
|
495 |
+
"458": 493,
|
496 |
+
"459": 494,
|
497 |
+
"460": 495,
|
498 |
+
"461": 496,
|
499 |
+
"462": 497,
|
500 |
+
"463": 498,
|
501 |
+
"464": 499,
|
502 |
+
"465": 500,
|
503 |
+
"466": 501,
|
504 |
+
"467": 502,
|
505 |
+
"468": 503,
|
506 |
+
"469": 504,
|
507 |
+
"470": 505,
|
508 |
+
"471": 506,
|
509 |
+
"472": 507,
|
510 |
+
"473": 508,
|
511 |
+
"474": 509,
|
512 |
+
"475": 510,
|
513 |
+
"476": 511,
|
514 |
+
"477": 512,
|
515 |
+
"478": 513,
|
516 |
+
"479": 514,
|
517 |
+
"480": 515,
|
518 |
+
"481": 516,
|
519 |
+
"482": 517,
|
520 |
+
"483": 518,
|
521 |
+
"484": 519,
|
522 |
+
"485": 520,
|
523 |
+
"486": 521,
|
524 |
+
"487": 522,
|
525 |
+
"488": 523,
|
526 |
+
"489": 524,
|
527 |
+
"490": 525,
|
528 |
+
"491": 526,
|
529 |
+
"492": 527,
|
530 |
+
"493": 528,
|
531 |
+
"494": 529,
|
532 |
+
"495": 530,
|
533 |
+
"496": 531,
|
534 |
+
"497": 532,
|
535 |
+
"498": 533,
|
536 |
+
"499": 534,
|
537 |
+
"500": 535,
|
538 |
+
"501": 536,
|
539 |
+
"502": 537,
|
540 |
+
"503": 538,
|
541 |
+
"504": 539,
|
542 |
+
"505": 540,
|
543 |
+
"506": 541,
|
544 |
+
"507": 542,
|
545 |
+
"508": 543,
|
546 |
+
"509": 544,
|
547 |
+
"510": 545,
|
548 |
+
"511": 546
|
549 |
+
}
|