pan-li commited on
Commit
d7174d3
·
verified ·
1 Parent(s): 826f039

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ performance.png filter=lfs diff=lfs merge=lfs -text
37
+ proteinmoe_architecture.png filter=lfs diff=lfs merge=lfs -text
LICENSE CHANGED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GENBIO AI COMMUNITY LICENSE AGREEMENT
2
+
3
+ This GenBio AI Community License Agreement (the “License”) constitutes an agreement between you or the legal entity you represent (“you” or “your”) and GENBIO.AI, INC. (“GenBio”), governing your use of the GenBio Materials. If you are using the GenBio Materials on behalf of a legal entity, you represent and warrant to GenBio that you have full legal authority to act on behalf of that legal entity as applicable under the License. If you do not have the authority to accept this License or if you disagree with any or all of the License, you shall not use the GenBio Materials in any manner. By using or distributing any portion or element of the GenBio Materials, you imply your agreement to be bound by the License.
4
+
5
+ “GenBio Materials” means any datasets, code, model weights or any other materials provided by GenBio at the following GitHub Page https://github.com/genbio-ai or Hugging Face Page https://huggingface.co/genbio-ai, including any updates or modifications made from time to time, whether in Source or Object form, and is made available to you under this License.
6
+
7
+
8
+ 1. License Grant.
9
+ 1.1 License Scope. Subject to the terms of this License, GenBio grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable and royalty-free limited license under GenBio’s intellectual property or other rights owned by GenBio embodied in the GenBio Materials to use, reproduce, distribute, and create Derivative Works of, and make modifications to, the GenBio Materials for any Non-Commercial Purposes.
10
+ 1.2 Use Restrictions. Restricted activities in relation to the License or use of GenBio Materials include:
11
+ 1.2.1 You shall use the GenBio Materials, Contributions, Derivative Works, Outputs and Output Derivatives (as defined below) solely for Non-Commercial Purposes;
12
+ 1.2.2 You shall not, directly or indirectly: (a) use or provide access to any Outputs or Output Derivatives to train, optimize, improve, or otherwise enhance the functionality or performance of any machine learning models or related technologies that are similar to the GenBio Materials; (b) engage in any form of model distillation or other methods that would achieve the purposes described in subsection (a) above. Notwithstanding the foregoing, you may use Outputs and Output Derivatives to train, optimize, improve, or enhance the functionality or performance of: (i) The GenBio Materials itself; and (ii) downstream Derivative Works of the GenBio Materials;
13
+ 1.2.3 Your use of the GenBio Materials shall be subject to any additional terms and conditions that: (a) GenBio provides to you separately; or (b) GenBio otherwise makes available to you.
14
+
15
+ 2. Sharing and Distribution.
16
+ 2.1 Subject to Section 1, if you distribute or make available the GenBio Materials or a Derivative Work to a third party for your Non-Commercial Purposes, in Source or Object form, you shall:
17
+ 2.1.1 provide a copy of this License to that third party;
18
+ 2.1.2 retain the following attribution notice within a “Notice” text file distributed as a part of such copies: “This is licensed under the GenBio AI Community License Agreement, Copyright © GENBIO.AI, INC. All Rights Reserved”; and
19
+ 2.1.3 prominently display “Powered by GenBio AI” on a related website, user interface, blogpost, about page, or product documentation.
20
+ 2.2 If You create a Derivative Work, you may add your own attribution notice(s) to the “Notice” text file included with that Derivative Work, provided that you clearly indicate which attributions apply to the GenBio Materials and state in the “Notice” text file that you changed the GenBio Materials and how it was modified.
21
+
22
+ 3. Submission of Contribution.
23
+ Unless you explicitly state otherwise, any Contribution intentionally submitted for inclusion in the GenBio Materials by you to GenBio shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with GenBio regarding such Contributions.
24
+
25
+ 4. Export Control.
26
+ You shall comply with the applicable U.S. Foreign Corrupt Practices Act and all applicable export laws, restrictions and regulations of the U.S. Department of Commerce, and any other applicable U.S. and foreign authority.
27
+
28
+ 5. Disclaimer of Warranty.
29
+ GENBIO MATERIALS PROVIDED BY GENBIO OR ANY OUTPUT YOU RECEIVED ARE PROVIDED “AS IS.” EXCEPT TO THE EXTENT PROHIBITED BY LAW. GENBIO MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND, WHETHER EXPRESS, IMPLIED OR OTHERWISE, REGARDING THE ACCURACY, COMPLETENESS OR PERFORMANCE OF THE SERVICES AND YOUR OUTPUT, OR WITH RESPECT TO SATISFACTORY QUALITY, FITNESS FOR A PARTICULAR PURPOSE OR NON-INFRINGEMENT.
30
+
31
+ 6. Limitation of Liability.
32
+ In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the GenBio Materials (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
33
+
34
+ 7. General Terms.
35
+ 7.1 Relationship of Parties. You and GenBio are independent contractors, and nothing herein shall be deemed to constitute either party as the agent or representative of the other or both parties as joint venturers or partners for any purpose.
36
+ 7.2 Assignment. This License and the rights and obligations herein may not be assigned or transferred, in whole or in part, by You without the prior written consent of GenBio. Any assignment in violation of this provision is void. GenBio may freely assign or transfer this License, in whole or in part. This License shall be binding upon, and inure to the benefit of, the successors and permitted assigns of the parties.
37
+ 7.3 Governing Law. This License shall be governed, construed and interpreted in accordance with the laws of the State of California, without giving effect to principles of conflicts of law. Each of the parties to this License consents to the exclusive jurisdiction and venue of the courts of the state and federal courts of California.
38
+ 7.4 Severability. If any provision of this License is held to be invalid, illegal or unenforceable in any respect, that provision shall be limited or eliminated to the minimum extent necessary so that this License otherwise remains in full force and effect and enforceable.
39
+
40
+ 8. Definitions.
41
+ 8.1 “Commercial Entity” means any entity engaged in any activity intended for or directed toward commercial advantage or monetary compensation, including, without limitation, the development of any product or service intended to be sold or made available for a fee. For the purpose of this License, references to a Commercial Entity expressly exclude any universities, non-profit organizations, not-for-profit entities, research institutes and educational and government bodies.
42
+ 8.2 “Contribution” means any work of authorship, including the original version of the GenBio Materials and any modifications or additions to that GenBio Materials or Derivative Works thereof, that is intentionally submitted to GenBio for inclusion in the GenBio Materials by the copyright owner or by an individual or legal entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to GenBio or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, GenBio for the purpose of discussing and improving the GenBio Materials, but excluding Outputs and all communications that are conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution”.
43
+ 8.3 “Contributor” means GenBio and any individual or legal entity on behalf of whom a Contribution has been received by GenBio and subsequently incorporated within the GenBio Materials.
44
+ 8.4 “Derivative Work” means any work, whether in Source or Object form, that is based on (or derived from) the GenBio Materials and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the GenBio Materials and Derivative Works thereof.
45
+ 8.5 “Non-Commercial Purposes” means uses not intended for or directed toward commercial advantage or monetary compensation, or the facilitation of development of any product or service to be sold or made available for a fee. For the avoidance of doubt, the provision of Outputs as a service is not a Non-Commercial Purpose.
46
+ 8.6 “Object” means any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
47
+ 8.7 “Output” means any output, including any protein sequence, structure prediction, functional annotation, molecule, descriptions of a molecule, model, sequence, text, and/or image that is elicited directly or indirectly by, or otherwise made available to, you in connection with your use of the GenBio Materials, including, but not limited to, the use of AI-Powered Technology. For the avoidance of doubt, it includes any intermediate results, such as activations across model layers, intermediate outputs from model layers (e.g., attention maps), as well as gradients and embeddings produced by the GenBio Materials.
48
+ 8.8 “Output Derivatives” means any enhancements, modifications and derivative works of Outputs (including, but not limited to, any derivative sequences or molecules).
49
+ 8.9 “Source” means the preferred form for making modifications, including but not limited to GenBio Materials source code, documentation source, and configuration files.
README.md CHANGED
@@ -1,5 +1,167 @@
1
  ---
2
  license: other
3
- license_name: genbio.ai-community-license
4
- license_link: LICENSE
5
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: other
 
 
3
  ---
4
+
5
+ # AIDO.Protein-RAG-16B-proteingym-dms-zeroshot
6
+
7
+ AIDO.Protein-RAG-16B-proteingym-dms-zeroshot is a multimodal protein language model that integrates Multiple Sequence Alignment (MSA) and structural data, building upon the [AIDO.Protein-16B](https://huggingface.co/genbio-ai/AIDO.Protein-16B) foundation. The training process comprises three main stages:
8
+
9
+ 1. 2D RoPE encoding fine-tuning
10
+ 2. Initial training on 100 billion tokens from UniRef50/UniClust30 MSA data
11
+ 3. Subsequent training on 23 billion tokens from AlphaFold Database MSA and structural data
12
+
13
+ ## Model Architecture
14
+
15
+ AIDO.Protein-RAG-16B-proteingym-dms-zeroshot employs a transformer encoder-only architecture featuring sparse Mixture-of-Experts (MoE) layers that replace dense MLP layers in each transformer block. Utilizing single amino acid tokenization and optimized through masked language modeling (MLM), the model activates 2 experts per token via top-2 routing mechanisms.
16
+
17
+ <center><img src="proteinmoe_architecture.png" alt="An Overview of AIDO.Protein" style="width:70%; height:auto;" /></center>
18
+
19
+ More architecture details are shown below:
20
+
21
+ | Model Arch Component | Value |
22
+ | ----------------------- | :---: |
23
+ | Num Attention Head | 36 |
24
+ | Num Hidden Layer | 36 |
25
+ | Hidden Size | 2304 |
26
+ | FFN Hidden Size | 7680 |
27
+ | Num MoE Layer per Block | 8 |
28
+ | Num MoE Layer per Token | 2 |
29
+ | Vocab Size | 44 |
30
+ | Context Length | 2048 |
31
+
32
+ ## Pre-training of AIDO.Protein-RAG-16B-proteingym-dms-zeroshot
33
+
34
+ Here we briefly introduce the details of pre-training of AIDO.Protein-RAG-16B-proteingym-dms-zeroshot. Mainly divided into three stages: (1) 1D -> 2D RoPE encoding fine-tuning; (2) UniRef50/Uniclust30 MSA fine-tuning; (3) AlphaFold Database MSA & Structure tokens fine-tuning
35
+
36
+ ### Data
37
+
38
+ **UniRef50/Uniclust30 MSA dataset**: We utilized sequences from UniRef50 as queries to search for homologous sequences in UniClust30, subsequently constructing multiple sequence alignments (MSAs). UniRef50 comprises a total of 53.6 million sequences. Using HHblits, we searched all sequences, identifying over 25 homologous sequences for 23.7 million of them. This dataset was directly used as the training set, referred to as `HHblits_MSA`. The remaining 29.9 million sequences were input into MSA Retriever, resulting in 7.7 million sequences with more than 25 homologous sequences. This dataset was designated as `Retriever_MSA`. During training, RAGPLM randomly sampled from the two datasets with probabilities of 0.75 and 0.25. Refer to AIDO.Protein-RAG-3B paper ([link](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1)) for more information.
39
+
40
+ **AlphaFold Database MSA & Structure dataset**: We downloaded all structural data from the AlphaFold Database and kept only those where more than 40% of amino acids had a pLDDT score > 70. The remaining sequences were clustered using `mmseqs` (`seq id=0.5`), and one representative per cluster was retained, resulting in 46.9 million sequence/structure pairs. For each structure, we used [genbio-ai/AIDO.StructureTokenizer](https://huggingface.co/genbio-ai/AIDO.StructureTokenizer) to obtain structure tokens and embeddings. [MSA Retriever](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1) was used to obtain the corresponding MSA.
41
+
42
+ ### Training Details
43
+
44
+ Model training is divided into three stages:
45
+
46
+ #### (1) 1D -> 2D RoPE Encoding Fine-tuning
47
+
48
+ Same training data as [AIDO.Protein-16B](https://huggingface.co/genbio-ai/AIDO.Protein-16B), but with [2D rotary position embedding](https://arxiv.org/abs/2406.05347) for token encoding.
49
+
50
+ #### (2) UniRef50/UniClust30 MSA Fine-tuning
51
+
52
+ The model from Stage 1 is further fine-tuned on the UniRef50/Uniclust30 MSA dataset. See the [AIDO.Protein-RAG-3B paper](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1) for more.
53
+
54
+ #### (3) AlphaFold Database MSA & Structure Fine-tuning
55
+
56
+ We fine-tuned the model with concatenated query and homologous sequences. Structure embeddings (dim = 384) are linearly mapped to 2304 and added to the query token embeddings.
57
+
58
+ ##### Sequence Masking
59
+
60
+ * Randomly sample `0.05 × L` span positions from a query of length `L`. Span lengths follow a geometric distribution (`p=0.2`), capped at length 10. On average, ~15% of query tokens are masked.
61
+
62
+ * When a residue is selected, its aligned residues across all sequences (MSA column) are also masked.
63
+
64
+ * For masked MSA columns: 80% are replaced with `<MASK>`, 10% with random amino acids, and 10% left unchanged.
65
+
66
+ ##### Structure Masking
67
+
68
+ * In 20% of cases, structure embeddings are replaced with 0.
69
+
70
+ * In 80% of cases, a number of amino acids is sampled using the BetaLinear30 distribution and corresponding embeddings are zeroed. (BetaLinear30 = 20% Uniform(0,1) + 80% Beta(3,9)).
71
+
72
+ ##### Positional Embedding
73
+
74
+ We use [2D rotary position embedding](https://arxiv.org/abs/2406.05347) to help the model distinguish token chain identities and residue indices. See AIDO.Protein-RAG-3B paper ([link](https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1)) for more information.
75
+
76
+ ##### Loss Function
77
+
78
+ Total loss is a weighted sum of sequence loss (weight 1.0) and structure loss (weight 0.025).
79
+
80
+ * **Sequence loss**: CrossEntropy loss for masked token prediction.
81
+
82
+ * **Structure loss**: CrossEntropy loss for masked structure token prediction.
83
+
84
+ | Hyper-params | (1) 1D -> 2D fine-tuning | (2) UniRef50/Uniclust30 MSA fine-tuning | (3) AFDB MSA & Structure tokens fine-tuning |
85
+ | --------------------------- | :---------------------: | :------------------------------------: | :----------------------------------------: |
86
+ | Initialized parameters | AIDO.Protein-16B | Stage (1) | Stage (2) |
87
+ | Data | ColabFoldDB, UniRef | HHblits_MSA, Retriever_MSA | AFDB MSA & Structure tokens |
88
+ | Global Batch Size | 512 | 256 | 256 |
89
+ | Sequence length | 2048 | 12800 | 12800 |
90
+ | Per Device Micro Batch Size | 1 | 1 | 1 |
91
+ | Precision | Mixed FP32-FP16 | Mixed FP32-FP16 | Mixed FP32-FP16 |
92
+ | LR | [5e-6,5e-5] | [1e-6, 1e-5] | 1e-5 |
93
+ | Num Tokens | 10 billion | 100 billion | 23 billion |
94
+ | Structural loss | N/A | N/A | 0.025 |
95
+
96
+ ### Tokenization
97
+
98
+ We encode protein sequence with single amino acid resolution with 44 vocabularies, where 24 tokens represent amino acid types and 20 are special tokens. Sequences were also suffixed with a `[SEP]` token as hooks for downstream tasks.
99
+
100
+ ## Results
101
+
102
+ ### Zero-shot DMS score
103
+
104
+ <center><img src="performance.png" alt="performance" style="width:100%; height:auto;" /></center>
105
+
106
+ ## How to Run
107
+
108
+ ### Load the model and tokenizer
109
+
110
+ ```python
111
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModelForMaskedLM
112
+
113
+ tokenizer = AutoTokenizer.from_pretrained("genbio-ai/AIDO.Protein-RAG-16B-proteingym-dms-zeroshot", trust_remote_code=True)
114
+ model = AutoModelForCausalLM.from_pretrained("genbio-ai/AIDO.Protein-RAG-16B-proteingym-dms-zeroshot", trust_remote_code=True, torch_dtype=torch.bfloat16)
115
+ model = model.bfloat16().eval().to('cuda:0')
116
+ ```
117
+
118
+ ### Clone the github respository and install environment **TODO**
119
+
120
+ Please read introduction of [github respository](https://gitlab.genbio.ai/pan.li/ragplm_zeroshot/-/tree/master) to get the detail of installing environment and running method.
121
+
122
+ ```bash
123
+ conda create -n ragplm python=3.11 -y
124
+ conda activate ragplm
125
+
126
+ pip install tabulate seaborn deepspeed
127
+ pip install git+https://github.com/genbio-ai/ModelGenerator.git
128
+
129
+ git clone [email protected]:pan.li/ragplm_zeroshot.git
130
+ cd ragplm_zeroshot
131
+
132
+ tar xf dms_data.tar.gz
133
+ tar xf struc_data.tar.gz
134
+ mkdir output
135
+ ```
136
+
137
+ ### Run zero-shot
138
+
139
+ ```bash
140
+ python compute_fitness.py --dms_ids PTEN_HUMAN_Mighell_2018
141
+ ```
142
+
143
+ # Citation
144
+
145
+ Please cite AIDO.Protein-RAG-16B-proteingym-dms-zeroshot using the following BibTex code:
146
+
147
+ ```
148
+ @inproceedings{sun_mixture_2024,
149
+ title = {Mixture of Experts Enable Efficient and Effective Protein Understanding and Design},
150
+ url = {https://www.biorxiv.org/content/10.1101/2024.11.29.625425v1},
151
+ doi = {10.1101/2024.11.29.625425},
152
+ publisher = {bioRxiv},
153
+ author = {Sun, Ning and Zou, Shuxian and Tao, Tianhua and Mahbub, Sazan and Li, Dian and Zhuang, Yonghao and Wang, Hongyi and Cheng, Xingyi and Song, Le and Xing, Eric P.},
154
+ year = {2024},
155
+ booktitle={NeurIPS 2024 Workshop on AI for New Drug Modalities},
156
+ }
157
+
158
+ @article {Li2024.12.02.626519,
159
+ author = {Li, Pan and Cheng, Xingyi and Song, Le and Xing, Eric},
160
+ title = {Retrieval Augmented Protein Language Models for Protein Structure Prediction},
161
+ url = {https://www.biorxiv.org/content/10.1101/2024.12.02.626519v1},
162
+ year = {2024},
163
+ doi = {10.1101/2024.12.02.626519},
164
+ publisher = {bioRxiv},
165
+ booktitle={NeurIPS 2024 Workshop on Machine Learning in Structural Biology},
166
+ }
167
+ ```
config.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "Protein/RAGPLM",
3
+ "add_bias_linear": true,
4
+ "add_qkv_bias": true,
5
+ "add_seq_emb_ln": false,
6
+ "add_str_emb_ln": false,
7
+ "apply_query_key_layer_scaling": true,
8
+ "apply_residual_connection_post_layernorm": false,
9
+ "architectures": [
10
+ "RAGPLMForConditionalGeneration"
11
+ ],
12
+ "attention_dropout": 0.0,
13
+ "attention_softmax_in_fp32": true,
14
+ "auto_map": {
15
+ "AutoConfig": "configuration_ragplm.RAGPLMConfig",
16
+ "AutoModel": "modeling_ragplm.RAGPLMModel",
17
+ "AutoModelForCausalLM": "modeling_ragplm.RAGPLMForConditionalGeneration",
18
+ "AutoModelForSeq2SeqLM": "modeling_ragplm.RAGPLMForConditionalGeneration"
19
+ },
20
+ "bias_dropout_fusion": true,
21
+ "classifier_dropout": null,
22
+ "deepnorm": false,
23
+ "eos_token_id": 34,
24
+ "experts_per_token": 2,
25
+ "ffn_hidden_size": 7680,
26
+ "fp32_residual_connection": false,
27
+ "glu_activation": "swiglu",
28
+ "hidden_dropout": 0.0,
29
+ "hidden_size": 2304,
30
+ "is_causal": false,
31
+ "kv_channels": 64,
32
+ "layernorm_epsilon": 1e-05,
33
+ "lora": false,
34
+ "lora_alpha": 16,
35
+ "lora_before_position": false,
36
+ "lora_dropout": 0,
37
+ "lora_r": 8,
38
+ "mlp_lora": false,
39
+ "model_type": "ragplm",
40
+ "moe": true,
41
+ "multi_query_attention": false,
42
+ "multi_query_group_num": 2,
43
+ "num_attention_heads": 36,
44
+ "num_experts": 8,
45
+ "num_layers": 36,
46
+ "original_rope": true,
47
+ "pad_token_id": 0,
48
+ "padded_vocab_size": 640,
49
+ "post_layer_norm": true,
50
+ "qseq_output_dim": null,
51
+ "quantization_bit": 0,
52
+ "rmsnorm": true,
53
+ "rotary_embedding_2d": true,
54
+ "rotary_freq_base": 10000,
55
+ "seq_length": 2048,
56
+ "str_input_dim": 384,
57
+ "str_output_dim": 512,
58
+ "str_vocab_size": null,
59
+ "tie_word_embeddings": false,
60
+ "torch_dtype": "torch.bfloat16",
61
+ "transformers_version": "4.48.3",
62
+ "use_cache": true,
63
+ "use_pytorch_sdpa": true,
64
+ "vocab_size": 128
65
+ }
configuration_ragplm.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ import torch
3
+
4
+ class RAGPLMConfig(PretrainedConfig):
5
+ model_type = "ragplm"
6
+ def __init__(
7
+ self,
8
+ num_layers=28,
9
+ padded_vocab_size=65024,
10
+ hidden_size=4096,
11
+ ffn_hidden_size=13696,
12
+ kv_channels=128,
13
+ num_attention_heads=32,
14
+
15
+ add_str_emb_ln=False, # Add layer norm to the structure embedding layer
16
+ add_seq_emb_ln=False, # Add layer norm to the sequence embedding layer
17
+ str_vocab_size=None,
18
+ str_input_dim=None,
19
+ str_output_dim=None,
20
+ qseq_output_dim=None,
21
+
22
+ seq_length=2048,
23
+ hidden_dropout=0.0,
24
+ classifier_dropout=None,
25
+ attention_dropout=0.0,
26
+ layernorm_epsilon=1e-5,
27
+ glu_activation='geglu',
28
+ torch_dtype=torch.bfloat16,
29
+ rmsnorm=True,
30
+ deepnorm=True,
31
+ apply_residual_connection_post_layernorm=False,
32
+ post_layer_norm=True,
33
+ add_bias_linear=False,
34
+ add_qkv_bias=False,
35
+ bias_dropout_fusion=True,
36
+ multi_query_attention=False,
37
+ multi_query_group_num=1,
38
+ apply_query_key_layer_scaling=True,
39
+ attention_softmax_in_fp32=True,
40
+ fp32_residual_connection=False,
41
+ quantization_bit=0,
42
+ # pre_seq_len=None,
43
+ # prefix_projection=False,
44
+ rotary_embedding_2d=True,
45
+ rotary_freq_base=10000,
46
+ lora=False,
47
+ mlp_lora=False,
48
+ lora_before_position=False, ### Default the QKV LoRA is after the position encoding
49
+ lora_r=8,
50
+ lora_alpha=16,
51
+ lora_dropout=0,
52
+ use_pytorch_sdpa=True,
53
+ is_causal=True,
54
+ moe=False,
55
+ num_experts=16,
56
+ experts_per_token=2,
57
+ **kwargs
58
+ ):
59
+
60
+ if not deepnorm and apply_residual_connection_post_layernorm:
61
+ print(f"Warning: deepnorm is False and apply_residual_connection_post_layernorm is True")
62
+
63
+ self.num_layers = num_layers
64
+ self.vocab_size = padded_vocab_size
65
+ self.padded_vocab_size = padded_vocab_size
66
+ self.hidden_size = hidden_size
67
+ self.ffn_hidden_size = ffn_hidden_size
68
+ self.kv_channels = kv_channels
69
+ self.num_attention_heads = num_attention_heads
70
+ self.add_str_emb_ln = add_str_emb_ln
71
+ self.add_seq_emb_ln = add_seq_emb_ln
72
+ self.str_vocab_size = str_vocab_size
73
+ self.str_input_dim = str_input_dim
74
+ self.str_output_dim = str_output_dim
75
+ self.qseq_output_dim = qseq_output_dim
76
+ self.seq_length = seq_length
77
+ self.hidden_dropout = hidden_dropout
78
+ self.classifier_dropout = classifier_dropout
79
+ self.attention_dropout = attention_dropout
80
+ self.layernorm_epsilon = layernorm_epsilon
81
+ self.torch_dtype = torch_dtype
82
+ self.glu_activation = glu_activation
83
+ self.rmsnorm = rmsnorm
84
+ self.deepnorm = deepnorm
85
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
86
+ self.post_layer_norm = post_layer_norm
87
+ self.add_bias_linear = add_bias_linear
88
+ self.add_qkv_bias = add_qkv_bias
89
+ self.bias_dropout_fusion = bias_dropout_fusion
90
+ self.multi_query_attention = multi_query_attention
91
+ self.multi_query_group_num = multi_query_group_num
92
+ self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
93
+ self.attention_softmax_in_fp32 = attention_softmax_in_fp32
94
+ self.fp32_residual_connection = fp32_residual_connection
95
+ self.quantization_bit = quantization_bit
96
+ #self.pre_seq_len = pre_seq_len
97
+ #self.prefix_projection = prefix_projection
98
+ self.rotary_embedding_2d = rotary_embedding_2d
99
+ self.rotary_freq_base = rotary_freq_base
100
+ self.is_causal = is_causal
101
+ self.lora = lora
102
+ self.mlp_lora = mlp_lora
103
+ self.lora_before_position = lora_before_position
104
+ self.lora_r = lora_r
105
+ self.lora_alpha = lora_alpha
106
+ self.lora_dropout = lora_dropout
107
+ self.use_pytorch_sdpa = use_pytorch_sdpa
108
+ self.moe = moe
109
+ self.num_experts = num_experts
110
+ self.experts_per_token = experts_per_token
111
+
112
+ super().__init__(**kwargs)
113
+
114
+ if isinstance(torch_dtype, str):
115
+ if torch_dtype.startswith('torch.'):
116
+ self.torch_dtype = eval(torch_dtype)
117
+ else:
118
+ self.torch_dtype = eval(f"torch.{torch_dtype}")
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "eos_token_id": 34,
4
+ "pad_token_id": 0,
5
+ "transformers_version": "4.48.3"
6
+ }
modeling_ragplm.py ADDED
@@ -0,0 +1,1260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch AIDO.Protein-DMS-16B model. """
2
+
3
+ import math
4
+ import copy
5
+ import warnings
6
+ import re
7
+ import sys
8
+ import os
9
+ import pathlib
10
+ import time
11
+ import argparse
12
+ import random
13
+ import numpy as np
14
+ from tqdm.auto import tqdm, trange
15
+ from functools import partial
16
+
17
+ import torch, deepspeed
18
+ import torch.utils.checkpoint
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
22
+ from torch.nn.utils import skip_init
23
+ from typing import Optional, Tuple, Union, List, Callable, Dict, Any
24
+ from copy import deepcopy
25
+ from collections import namedtuple
26
+
27
+ from transformers.modeling_outputs import (
28
+ BaseModelOutputWithPast,
29
+ CausalLMOutputWithPast,
30
+ SequenceClassifierOutputWithPast,
31
+ )
32
+ from transformers.modeling_utils import PreTrainedModel
33
+ from transformers.utils import logging
34
+ from transformers.generation.logits_process import LogitsProcessor
35
+ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
36
+
37
+ from .configuration_ragplm import RAGPLMConfig
38
+
39
+ def get_checkpoint_fn():
40
+ if deepspeed.checkpointing.is_configured():
41
+ # checkpoint = deepspeed.checkpointing.non_reentrant_checkpoint
42
+ checkpoint = deepspeed.checkpointing.checkpoint
43
+ else:
44
+ checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
45
+ # checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=True)
46
+ return checkpoint
47
+
48
+ # flags required to enable jit fusion kernels
49
+
50
+ if sys.platform != 'darwin':
51
+ torch._C._jit_set_profiling_mode(False)
52
+ torch._C._jit_set_profiling_executor(False)
53
+ torch._C._jit_override_can_fuse_on_cpu(True)
54
+ torch._C._jit_override_can_fuse_on_gpu(True)
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+ _CHECKPOINT_FOR_DOC = "Protein/Protein_RAGPLM"
59
+ _CONFIG_FOR_DOC = "RAGPLMConfig"
60
+
61
+
62
+ def default_init(cls, *args, **kwargs):
63
+ return cls(*args, **kwargs)
64
+
65
+ DeepNormCoefficients = namedtuple("DeepNormCoefficients", ["alpha", "beta"])
66
+
67
+ def get_deepnorm_coefficients(config: RAGPLMConfig):
68
+ """
69
+ DeepNorm coefficients from : https://kexue.fm/archives/8978
70
+ """
71
+ num_layers = config.num_layers
72
+ return DeepNormCoefficients(alpha=(2 * num_layers) ** 0.5, beta=(2 * num_layers) ** -0.5)
73
+
74
+
75
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
76
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
77
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
78
+ scores.zero_()
79
+ scores[..., 5] = 5e4
80
+ return scores
81
+
82
+ def split_tensor_along_last_dim(
83
+ tensor: torch.Tensor,
84
+ num_partitions: int,
85
+ contiguous_split_chunks: bool = False,
86
+ ) -> List[torch.Tensor]:
87
+ """Split a tensor along its last dimension.
88
+
89
+ Arguments:
90
+ tensor: input tensor.
91
+ num_partitions: number of partitions to split the tensor
92
+ contiguous_split_chunks: If True, make each chunk contiguous
93
+ in memory.
94
+
95
+ Returns:
96
+ A list of Tensors
97
+ """
98
+ # Get the size and dimension.
99
+ last_dim = tensor.dim() - 1
100
+ last_dim_size = tensor.size()[last_dim] // num_partitions
101
+ # Split.
102
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
103
+ # Note: torch.split does not create contiguous tensors by default.
104
+ if contiguous_split_chunks:
105
+ return tuple(chunk.contiguous() for chunk in tensor_list)
106
+
107
+ return tensor_list
108
+
109
+ class RotaryEmbedding(torch.nn.Module):
110
+
111
+ def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
112
+ super().__init__()
113
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)).to(precision)
114
+ self.dim = dim
115
+ self.base = base
116
+ self.learnable = learnable
117
+ if learnable:
118
+ self.inv_freq = torch.nn.Parameter(inv_freq)
119
+ self.max_seq_len_cached = None
120
+ else:
121
+ self.register_buffer('inv_freq', inv_freq)
122
+ self.max_seq_len_cached = None
123
+ self.cos_cached = None
124
+ self.sin_cached = None
125
+ self.precision = precision
126
+
127
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
128
+ # import pdb; pdb.set_trace();
129
+ if f'{prefix}inv_freq' in state_dict:
130
+ super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
131
+ else:
132
+ self.inv_freq.copy_(1. / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)).to(self.precision))
133
+
134
+ def forward(self, x, seq_dim=1, seq_len=None):
135
+
136
+ # self.inv_freq = 1. / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)).to(x.device)
137
+ if seq_len is None:
138
+ seq_len = x.shape[seq_dim]
139
+ if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
140
+ self.max_seq_len_cached = None if self.learnable else seq_len
141
+ t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
142
+ # import pdb; pdb.set_trace();
143
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq.to(x.device))
144
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
145
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
146
+ if self.precision == torch.bfloat16 or self.precision == torch.half:
147
+ emb = emb.float()
148
+ # [sx, 1 (b * np), hn]
149
+ cos_cached = emb.cos()[:, None, :]
150
+ sin_cached = emb.sin()[:, None, :]
151
+ if self.precision == torch.bfloat16:
152
+ cos_cached = cos_cached.bfloat16()
153
+ sin_cached = sin_cached.bfloat16()
154
+ elif self.precision == torch.half:
155
+ cos_cached = cos_cached.half()
156
+ sin_cached = sin_cached.half()
157
+ if self.learnable:
158
+ return cos_cached, sin_cached
159
+ self.cos_cached, self.sin_cached = cos_cached, sin_cached
160
+ return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
161
+
162
+ def rotate_half(x):
163
+ x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
164
+ return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions
165
+
166
+ def assert_dim_check(tensor, ndim=None, shape=None):
167
+ if ndim is not None:
168
+ assert tensor.ndim == ndim, f"Exepct tensor.ndim={ndim}. gut got tensor.shape={tensor.shape}"
169
+ if shape is not None:
170
+ assert list(tensor.shape) == list(shape), f"Exepct tensor.shape={shape}. gut got tensor.shape={tensor.shape}"
171
+
172
+ def apply_rotary_pos_emb_index_torch(q, k, cos, sin, position_id): # jitting fails with bf16
173
+ # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
174
+ cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
175
+ F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
176
+ q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
177
+ return q, k
178
+
179
+ try:
180
+ # raise 'Errror'
181
+ from apex.normalization import MixedFusedRMSNorm
182
+ from apex.normalization import FusedLayerNorm
183
+ print(f"{__file__}: Use apex.normalization.MixedFusedRMSNorm as RMSNorm")
184
+
185
+ class RMSNorm(MixedFusedRMSNorm):
186
+ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, memory_efficient=False):
187
+ super(RMSNorm, self).__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient)
188
+
189
+ def forward(self, input):
190
+ dtype = input.dtype
191
+ with torch.autocast('cuda', enabled=True, dtype=torch.float32, cache_enabled=None):
192
+ output = super().forward(input)
193
+ return output.to(dtype)
194
+
195
+ class LayerNorm(FusedLayerNorm):
196
+ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, memory_efficient=False):
197
+ super(LayerNorm, self).__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient)
198
+
199
+ def forward(self, input):
200
+ dtype = input.dtype
201
+ with torch.autocast('cuda', enabled=True, dtype=torch.float32, cache_enabled=None):
202
+ output = super().forward(input)
203
+ return output.to(dtype)
204
+
205
+ except:
206
+ class RMSNorm(torch.nn.Module):
207
+ def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
208
+ super().__init__()
209
+ self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
210
+ self.eps = eps
211
+ @torch.jit.export
212
+ def forward(self, hidden_states: torch.Tensor):
213
+ input_dtype = hidden_states.dtype
214
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
215
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
216
+ return (self.weight * hidden_states).to(input_dtype)
217
+ print(f"{__file__}: Use custom RMSNorm")
218
+
219
+ class CoreAttention(torch.nn.Module):
220
+ def __init__(self, config: RAGPLMConfig, layer_number):
221
+ super(CoreAttention, self).__init__()
222
+
223
+ self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
224
+ self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
225
+ if self.apply_query_key_layer_scaling:
226
+ self.attention_softmax_in_fp32 = True
227
+ self.layer_number = max(1, layer_number)
228
+
229
+ projection_size = config.kv_channels * config.num_attention_heads
230
+
231
+ # Per attention head and per partition values.
232
+ self.hidden_size_per_partition = projection_size
233
+ self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
234
+ self.num_attention_heads_per_partition = config.num_attention_heads
235
+
236
+ coeff = None
237
+ self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
238
+ if self.apply_query_key_layer_scaling:
239
+ coeff = self.layer_number
240
+ self.norm_factor *= coeff
241
+ self.coeff = coeff
242
+
243
+ self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
244
+
245
+ self.is_causal = config.is_causal
246
+ self.use_pytorch_sdpa = config.use_pytorch_sdpa
247
+
248
+ def forward(self, query_layer, key_layer, value_layer, attention_mask):
249
+ # query_layer, key_layer, value_layer: [seq_len, batch_size, num_heads, head_dim]
250
+ # import pdb; pdb.set_trace();
251
+ pytorch_major_version = int(torch.__version__.split('.')[0])
252
+ # assert pytorch_major_version >= 2, f"Expect PyTorch version > 2.0"
253
+ if pytorch_major_version >= 2 and self.use_pytorch_sdpa:
254
+ dropout_p = self.attention_dropout.p if self.training else 0
255
+ # [seq_len, batch_size, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim]
256
+ query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
257
+ # import pdb; pdb.set_trace();
258
+ if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
259
+ # context_layer: [batch_size, num_heads, seq_len, head_dim]
260
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, is_causal=self.is_causal, dropout_p=dropout_p)
261
+ #print(f"torch.nn.functional.scaled_dot_product_attention")
262
+ else:
263
+ if (attention_mask is not None) and (attention_mask.dtype == torch.bool):
264
+ attention_mask = attention_mask.logical_not() ## DO NOT inplace operation!!!!
265
+ #print(f"attention_mask.shape={attention_mask.shape}, attention_mask={attention_mask}")
266
+ else:
267
+ pass
268
+ # print(f"query_layer.shape={query_layer.shape}, key_layer.shape={key_layer.shape}, attention_mask={attention_mask}")
269
+ context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, attention_mask, dropout_p=dropout_p)
270
+ # [batch_size, num_heads, seq_len, head_dim] -> [seq_len, batch_size, num_heads, head_dim]
271
+ context_layer = context_layer.permute(2, 0, 1, 3)
272
+ # [seq_len, batch_size, 2560]
273
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
274
+ context_layer = context_layer.reshape(*new_context_layer_shape)
275
+ else:
276
+ # Raw attention scores
277
+
278
+ # [b, np, sq, sk]
279
+ output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
280
+
281
+ # [sq, b, np, hn] -> [sq, b * np, hn]
282
+ query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
283
+ # [sk, b, np, hn] -> [sk, b * np, hn]
284
+ key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
285
+
286
+ # preallocting input tensor: [b * np, sq, sk]
287
+ matmul_input_buffer = torch.empty(
288
+ output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
289
+ device=query_layer.device
290
+ )
291
+
292
+ # Raw attention scores. [b * np, sq, sk]
293
+ matmul_result = torch.baddbmm(
294
+ matmul_input_buffer,
295
+ query_layer.transpose(0, 1), # [b * np, sq, hn]
296
+ key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
297
+ beta=0.0,
298
+ alpha=(1.0 / self.norm_factor),
299
+ )
300
+
301
+ # change view to [b, np, sq, sk]
302
+ attention_scores = matmul_result.view(*output_size)
303
+
304
+ # ===========================
305
+ # Attention probs and dropout
306
+ # ===========================
307
+
308
+ # attention scores and attention mask [b, np, sq, sk]
309
+ if self.attention_softmax_in_fp32:
310
+ attention_scores = attention_scores.float()
311
+ if self.coeff is not None:
312
+ attention_scores = attention_scores * self.coeff
313
+ if self.is_causal and attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
314
+ attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
315
+ device=attention_scores.device, dtype=torch.bool)
316
+ attention_mask.tril_()
317
+ attention_mask = ~attention_mask
318
+ if attention_mask is not None:
319
+ attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
320
+ attention_probs = F.softmax(attention_scores, dim=-1)
321
+ attention_probs = attention_probs.type_as(value_layer)
322
+
323
+ # This is actually dropping out entire tokens to attend to, which might
324
+ # seem a bit unusual, but is taken from the original Transformer paper.
325
+ attention_probs = self.attention_dropout(attention_probs)
326
+ # =========================
327
+ # Context layer. [sq, b, hp]
328
+ # =========================
329
+
330
+ # value_layer -> context layer.
331
+ # [sk, b, np, hn] --> [b, np, sq, hn]
332
+
333
+ # context layer shape: [b, np, sq, hn]
334
+ output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
335
+ # change view [sk, b * np, hn]
336
+ value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
337
+ # change view [b * np, sq, sk]
338
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
339
+ # matmul: [b * np, sq, hn]
340
+ context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
341
+ # change view [b, np, sq, hn]
342
+ context_layer = context_layer.view(*output_size)
343
+ # [b, np, sq, hn] --> [sq, b, np, hn]
344
+ context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
345
+ # [sq, b, np, hn] --> [sq, b, hp]
346
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
347
+ context_layer = context_layer.view(*new_context_layer_shape)
348
+
349
+ return context_layer
350
+
351
+
352
+ class SelfAttention(torch.nn.Module):
353
+ """Parallel self-attention layer abstract class.
354
+
355
+ Self-attention layer takes input with size [s, b, h]
356
+ and returns output of the same size.
357
+ """
358
+
359
+ def __init__(self, config: RAGPLMConfig, layer_number, device=None):
360
+ super(SelfAttention, self).__init__()
361
+ self.layer_number = max(1, layer_number)
362
+
363
+ self.projection_size = config.kv_channels * config.num_attention_heads
364
+
365
+ # Per attention head and per partition values.
366
+ self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
367
+ self.num_attention_heads_per_partition = config.num_attention_heads
368
+
369
+ self.multi_query_attention = config.multi_query_attention
370
+ self.qkv_hidden_size = 3 * self.projection_size
371
+ if self.multi_query_attention:
372
+ self.num_multi_query_groups_per_partition = config.multi_query_group_num
373
+ self.qkv_hidden_size = (
374
+ self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
375
+ )
376
+ self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
377
+ bias=config.add_bias_linear or config.add_qkv_bias,
378
+ device=device, **_config_to_kwargs(config)
379
+ )
380
+
381
+ self.core_attention = CoreAttention(config, self.layer_number)
382
+
383
+ # Output.
384
+ self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, device=device, **_config_to_kwargs(config))
385
+
386
+ self.rotary_embedding_2d = config.rotary_embedding_2d
387
+ # dim, base=10000, precision=torch.half, learnable=False
388
+ self.rotary_emb = RotaryEmbedding(self.hidden_size_per_attention_head // 2 if self.rotary_embedding_2d else self.hidden_size_per_attention_head,
389
+ base=config.rotary_freq_base, precision=config.torch_dtype, learnable=False)
390
+
391
+ ##### LoRA
392
+ self.lora = config.lora
393
+ if config.lora:
394
+ self.lora_linear = torch.nn.ModuleDict()
395
+ self.lora_dropout = torch.nn.Dropout(config.lora_dropout)
396
+ self.lora_alpha = config.lora_alpha
397
+ self.lora_r = config.lora_r
398
+ self.lora_before_position = config.lora_before_position
399
+ for name in ('Q', 'K', 'V', 'O'):
400
+ self.lora_linear[f'{name}_A'] = torch.nn.Linear(config.hidden_size, config.lora_r, bias=False)
401
+ self.lora_linear[f'{name}_B'] = torch.nn.Linear(config.lora_r, config.hidden_size, bias=False)
402
+ torch.nn.init.kaiming_uniform_(self.lora_linear[f"{name}_A"].weight, a=math.sqrt(5))
403
+ torch.nn.init.zeros_(self.lora_linear[f'{name}_B'].weight)
404
+
405
+ def forward(
406
+ self, hidden_states, attention_mask, position_ids, kv_cache=None, use_cache=True
407
+ ):
408
+
409
+ # =================================================
410
+ # Pre-allocate memory for key-values for inference.
411
+ # =================================================
412
+ # =====================
413
+ # Query, Key, and Value
414
+ # =====================
415
+
416
+ # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
417
+ mixed_x_layer = self.query_key_value(hidden_states) # [12800, 1, 6912]
418
+
419
+ if self.multi_query_attention:
420
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
421
+ [
422
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
423
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
424
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
425
+ ],
426
+ dim=-1,
427
+ )
428
+ query_layer = query_layer.view(
429
+ query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
430
+ )
431
+ key_layer = key_layer.view(
432
+ key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
433
+ )
434
+ value_layer = value_layer.view(
435
+ value_layer.size()[:-1]
436
+ + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
437
+ )
438
+ else:
439
+ new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head) # [12800, 1, 36, 192]
440
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [12800, 1, 36, 192]
441
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
442
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
443
+
444
+ if self.lora and self.lora_before_position:
445
+ scaling = self.lora_alpha / self.lora_r
446
+ query_layer = query_layer + ( self.lora_linear['Q_B'](self.lora_linear['Q_A'](self.lora_dropout(hidden_states))) * scaling ).reshape(query_layer.shape)
447
+ key_layer = key_layer + ( self.lora_linear['K_B'](self.lora_linear['K_A'](self.lora_dropout(hidden_states))) * scaling ).reshape(key_layer.shape)
448
+ value_layer = value_layer + ( self.lora_linear['V_B'](self.lora_linear['V_A'](self.lora_dropout(hidden_states))) * scaling ).reshape(value_layer.shape)
449
+
450
+ # apply relative positional encoding (rotary embedding)
451
+ if position_ids is not None: # [seq_len, 2, batch_size, 32, 2]
452
+
453
+
454
+ if self.rotary_embedding_2d:
455
+ q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) # 32
456
+ k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
457
+ # import pdb; pdb.set_trace();
458
+ cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) # 32
459
+ position_ids, block_position_ids = \
460
+ position_ids[:, 0, :].transpose(0, 1).contiguous(), \
461
+ position_ids[:, 1, :].transpose(0, 1).contiguous()
462
+ q1, k1 = apply_rotary_pos_emb_index_torch(q1, k1, cos, sin, position_ids)
463
+ q2, k2 = apply_rotary_pos_emb_index_torch(q2, k2, cos, sin, block_position_ids)
464
+ query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
465
+ key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))
466
+ else:
467
+ # [b, sq] -> [sq, b]
468
+ position_ids = position_ids.transpose(0, 1)
469
+ cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1)
470
+ query_layer, key_layer = apply_rotary_pos_emb_index_torch(query_layer, key_layer, cos, sin, position_ids)
471
+
472
+
473
+ if self.lora and not self.lora_before_position:
474
+ # query_layer = query_layer + lora_layer["Q_B"](lora_layer["Q_A"](self.lora_dropout(hidden_states)))* self.scaling
475
+ scaling = self.lora_alpha / self.lora_r
476
+ query_layer = query_layer + ( self.lora_linear['Q_B'](self.lora_linear['Q_A'](self.lora_dropout(hidden_states))) * scaling ).reshape(query_layer.shape)
477
+ key_layer = key_layer + ( self.lora_linear['K_B'](self.lora_linear['K_A'](self.lora_dropout(hidden_states))) * scaling ).reshape(key_layer.shape)
478
+ value_layer = value_layer + ( self.lora_linear['V_B'](self.lora_linear['V_A'](self.lora_dropout(hidden_states))) * scaling ).reshape(value_layer.shape)
479
+
480
+ # adjust key and value for inference
481
+ if kv_cache is not None:
482
+ cache_k, cache_v = kv_cache
483
+ key_layer = torch.cat((cache_k, key_layer), dim=0)
484
+ value_layer = torch.cat((cache_v, value_layer), dim=0)
485
+ if use_cache:
486
+ kv_cache = (key_layer, value_layer)
487
+ else:
488
+ kv_cache = None
489
+
490
+ if self.multi_query_attention:
491
+ key_layer = key_layer.unsqueeze(-2)
492
+ key_layer = key_layer.expand(-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1)
493
+ key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))
494
+ value_layer = value_layer.unsqueeze(-2)
495
+ value_layer = value_layer.expand(-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1)
496
+ value_layer = value_layer.contiguous().view(value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))
497
+
498
+ # ==================================
499
+ # core attention computation
500
+ # ==================================
501
+
502
+ context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) # context_layer: [seq_len, batch_size, num_heads*head_dim]
503
+ output = self.dense(context_layer)
504
+ if self.lora:
505
+ scaling = self.lora_alpha / self.lora_r
506
+ output = output + self.lora_linear['O_B'](self.lora_linear['O_A'](self.lora_dropout(context_layer))) * scaling
507
+
508
+ # =================
509
+ # Output. [sq, b, h]
510
+ # =================
511
+
512
+ # output = context_layer @ self.dense.weight.T + self.dense.bias
513
+ return output, kv_cache
514
+
515
+
516
+ def _config_to_kwargs(args):
517
+ common_kwargs = {
518
+ "dtype": args.torch_dtype,
519
+ }
520
+ return common_kwargs
521
+
522
+
523
+ class MLP(torch.nn.Module):
524
+ """MLP.
525
+
526
+ MLP will take the input with h hidden state, project it to 4*h
527
+ hidden dimension, perform nonlinear transformation, and project the
528
+ state back into h hidden dimension.
529
+ """
530
+
531
+ def __init__(self, config: RAGPLMConfig, device=None):
532
+ super(MLP, self).__init__()
533
+
534
+ self.add_bias = config.add_bias_linear
535
+ self.moe = config.moe
536
+ self.mlp_lora = config.mlp_lora
537
+ self.num_experts = config.num_experts
538
+ self.experts_per_token = config.experts_per_token # 2
539
+
540
+ if self.moe is True and self.mlp_lora is True:
541
+ raise NotImplementedError(f"moe and mlp_lora are both enabled")
542
+
543
+ # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
544
+ self.dense_h_to_4h = nn.Linear(
545
+ config.hidden_size,
546
+ config.ffn_hidden_size * 2,
547
+ bias=self.add_bias,
548
+ device=device,
549
+ **_config_to_kwargs(config)
550
+ )
551
+
552
+ def swiglu(x):
553
+ x = torch.chunk(x, 2, dim=-1)
554
+ return x[0] * F.silu(x[1])
555
+
556
+ def geglu(x):
557
+ x = torch.chunk(x, 2, dim=-1)
558
+ return x[0] * F.gelu(x[1])
559
+
560
+ if config.glu_activation == 'geglu':
561
+ self.activation_func = geglu
562
+ elif config.glu_activation == 'swiglu':
563
+ self.activation_func = swiglu
564
+ else:
565
+ assert RuntimeError(f"Unsupported glu_activation: {config.glu_activation}")
566
+
567
+ # Project back to h.
568
+ self.dense_4h_to_h = nn.Linear(
569
+ config.ffn_hidden_size,
570
+ config.hidden_size,
571
+ bias=self.add_bias,
572
+ device=device,
573
+ **_config_to_kwargs(config)
574
+ )
575
+
576
+ if self.moe:
577
+ assert self.num_experts > 1
578
+ del self.dense_h_to_4h
579
+ del self.dense_4h_to_h
580
+ self.router = nn.Linear(
581
+ config.hidden_size,
582
+ config.num_experts,
583
+ bias=False,
584
+ device=device,
585
+ dtype=torch.float32
586
+ )
587
+ for i in range(0, self.num_experts):
588
+ self.register_module(f"dense_h_to_4h_{i}", nn.Linear(
589
+ config.hidden_size,
590
+ config.ffn_hidden_size * 2,
591
+ bias=self.add_bias,
592
+ device=device,
593
+ **_config_to_kwargs(config)
594
+ ))
595
+ self.register_module(f"dense_4h_to_h_{i}", nn.Linear(
596
+ config.ffn_hidden_size,
597
+ config.hidden_size,
598
+ bias=self.add_bias,
599
+ device=device,
600
+ **_config_to_kwargs(config)
601
+ ))
602
+
603
+ if self.mlp_lora:
604
+ self.lora_linear = torch.nn.ModuleDict()
605
+ self.lora_dropout = torch.nn.Dropout(config.lora_dropout)
606
+ self.lora_alpha = config.lora_alpha
607
+ self.lora_r = config.lora_r
608
+ for name in ('dense_h_to_4h', 'dense_4h_to_h'):
609
+ if name == 'dense_h_to_4h':
610
+ self.lora_linear[f'{name}_A'] = torch.nn.Linear(config.hidden_size, config.lora_r, bias=False)
611
+ self.lora_linear[f'{name}_B'] = torch.nn.Linear(config.lora_r, config.ffn_hidden_size * 2, bias=False)
612
+ elif name == 'dense_4h_to_h':
613
+ self.lora_linear[f'{name}_A'] = torch.nn.Linear(config.ffn_hidden_size, config.lora_r, bias=False)
614
+ self.lora_linear[f'{name}_B'] = torch.nn.Linear(config.lora_r, config.hidden_size, bias=False)
615
+ torch.nn.init.kaiming_uniform_(self.lora_linear[f"{name}_A"].weight, a=math.sqrt(5))
616
+ torch.nn.init.zeros_(self.lora_linear[f'{name}_B'].weight)
617
+
618
+ def moe_forward(self, hidden_states, expert_idx):
619
+ # hidden_states: torch.Size([503, 1920])
620
+ # import pdb; pdb.set_trace();
621
+ intermediate_parallel = getattr(self, f"dense_h_to_4h_{expert_idx}")(hidden_states) # torch.Size([503, 20480])
622
+ intermediate_parallel = self.activation_func(intermediate_parallel) # torch.Size([503, 10240])
623
+ output = getattr(self, f"dense_4h_to_h_{expert_idx}")(intermediate_parallel) # torch.Size([503, 1920])
624
+ return output
625
+
626
+ def forward(self, hidden_states):
627
+ if self.moe:
628
+ # import pdb; pdb.set_trace();
629
+ s, b, n = hidden_states.shape
630
+ dtype = hidden_states.dtype
631
+ hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h]
632
+
633
+ route = self.router(hidden_states).to(dtype)
634
+
635
+ weights, selected_experts = torch.topk(route, self.experts_per_token)
636
+ weights = F.softmax(weights, dim=1, dtype=torch.float).to(hidden_states.dtype)
637
+ output = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
638
+ for expert_idx in range(self.num_experts):
639
+ batch_idx, nth_expert = torch.where(selected_experts == expert_idx)
640
+ if nth_expert.shape[0] == 0:
641
+ continue
642
+ cur_out = self.moe_forward(hidden_states[batch_idx], expert_idx)
643
+ output[batch_idx] += weights[batch_idx, nth_expert, None] * cur_out
644
+ output = output.reshape(s, b, n)
645
+ else:
646
+ # [s, b, 4hp]
647
+ #intermediate_parallel = hidden_states @ self.dense_h_to_4h.weight.T + self.dense_h_to_4h.bias
648
+ intermediate_parallel = self.dense_h_to_4h(hidden_states)
649
+ if self.mlp_lora:
650
+ scaling = self.lora_alpha / self.lora_r
651
+ intermediate_parallel = intermediate_parallel + ( self.lora_linear['dense_h_to_4h_B'](self.lora_linear['dense_h_to_4h_A'](self.lora_dropout(hidden_states))) * scaling )
652
+
653
+ intermediate_parallel = self.activation_func(intermediate_parallel)
654
+ # [s, b, h]
655
+ output = self.dense_4h_to_h(intermediate_parallel)
656
+ if self.mlp_lora:
657
+ output = output + ( self.lora_linear['dense_4h_to_h_B'](self.lora_linear['dense_4h_to_h_A'](self.lora_dropout(intermediate_parallel))) * scaling )# .reshape(output.shape)
658
+
659
+ #output = intermediate_parallel @ self.dense_4h_to_h.weight.T + self.dense_4h_to_h.bias # self.dense_4h_to_h(intermediate_parallel)
660
+ return output
661
+
662
+ class RAGPLMBlock(torch.nn.Module):
663
+ """A single transformer layer.
664
+
665
+ Transformer layer takes input with size [s, b, h] and returns an
666
+ output of the same size.
667
+ """
668
+
669
+ def __init__(self, config: RAGPLMConfig, layer_number, device=None):
670
+ super(RAGPLMBlock, self).__init__()
671
+ self.layer_number = layer_number
672
+
673
+ self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
674
+
675
+ self.fp32_residual_connection = config.fp32_residual_connection
676
+
677
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
678
+ # Layernorm on the input data.
679
+ self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon)
680
+
681
+ # Self attention.
682
+ self.self_attention = SelfAttention(config, layer_number, device=device)
683
+ self.hidden_dropout = config.hidden_dropout
684
+
685
+ # Layernorm on the attention output
686
+ self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon)
687
+
688
+ # MLP
689
+ self.mlp = MLP(config, device=device)
690
+
691
+ self.deepnorm_coeff = get_deepnorm_coefficients(config) if config.deepnorm else None
692
+
693
+ def forward(
694
+ self, hidden_states, attention_mask, position_ids, kv_cache=None, use_cache=True,
695
+ ):
696
+ # hidden_states: [s, b, h]
697
+
698
+ layernorm_output = self.input_layernorm(hidden_states)
699
+ # Self attention.
700
+
701
+ attention_output, kv_cache = self.self_attention(
702
+ layernorm_output,
703
+ attention_mask,
704
+ position_ids, # [batch_size, 2, seq_len, 32, 2]
705
+ kv_cache=kv_cache,
706
+ use_cache=use_cache
707
+ )
708
+
709
+ # Residual connection.
710
+ if self.apply_residual_connection_post_layernorm:
711
+ residual = layernorm_output
712
+ else:
713
+ residual = hidden_states
714
+
715
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
716
+ if self.deepnorm_coeff is not None:
717
+ layernorm_input = residual*self.deepnorm_coeff.alpha + layernorm_input
718
+ else:
719
+ layernorm_input = residual + layernorm_input
720
+
721
+ # Layer norm post the self attention.
722
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
723
+
724
+
725
+ # MLP.
726
+ mlp_output = self.mlp(layernorm_output)
727
+
728
+ # Second residual connection.
729
+ if self.apply_residual_connection_post_layernorm:
730
+ residual = layernorm_output
731
+ else:
732
+ residual = layernorm_input
733
+
734
+
735
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
736
+ if self.deepnorm_coeff is not None:
737
+ output = residual*self.deepnorm_coeff.alpha + output
738
+ else:
739
+ output = residual + output
740
+
741
+ return output, kv_cache
742
+
743
+
744
+ class RAGPLMTransformer(torch.nn.Module):
745
+ """Transformer class."""
746
+
747
+ def __init__(self, config: RAGPLMConfig, device=None):
748
+ super(RAGPLMTransformer, self).__init__()
749
+
750
+ self.config = config
751
+ self.fp32_residual_connection = config.fp32_residual_connection
752
+ self.post_layer_norm = config.post_layer_norm
753
+
754
+ # Number of layers.
755
+ self.num_layers = config.num_layers
756
+
757
+ # Transformer layers.
758
+ def build_layer(layer_number):
759
+ return RAGPLMBlock(config, layer_number, device=device)
760
+
761
+ self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])
762
+
763
+ if self.post_layer_norm:
764
+ LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
765
+ # Final layer norm before output.
766
+ self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon)
767
+
768
+ self.gradient_checkpointing = False
769
+ # Introduce a gradient checkpointing for per num_checkpoint layers
770
+ # For example: num_checkpoint=1 will checkpoint all layers, num_checkpoint=2 will checkpoint half of layers
771
+ self.num_checkpoint = 1
772
+
773
+ def _get_layer(self, layer_number):
774
+ return self.layers[layer_number]
775
+
776
+ def forward(
777
+ self, hidden_states, attention_mask, position_ids, kv_caches=None,
778
+ use_cache: Optional[bool] = True,
779
+ output_hidden_states: Optional[bool] = False,
780
+ ):
781
+ if not kv_caches:
782
+ kv_caches = [None for _ in range(self.num_layers)]
783
+ presents = () if use_cache else None
784
+ if self.gradient_checkpointing and self.training:
785
+ if use_cache:
786
+ logger.warning_once(
787
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
788
+ )
789
+ use_cache = False
790
+
791
+ all_self_attentions = None
792
+ all_hidden_states = () if output_hidden_states else None
793
+
794
+ for index in range(self.num_layers):
795
+ if output_hidden_states:
796
+ all_hidden_states = all_hidden_states + (hidden_states,)
797
+ layer = self._get_layer(index)
798
+ if self.gradient_checkpointing and self.training and torch.is_grad_enabled() and index % self.num_checkpoint == 0:
799
+ #### A trick to enable gradient to avoid gradient checkpointing error
800
+ if hidden_states.requires_grad is False and deepspeed.checkpointing.is_configured() and (self.config.lora or self.config.mlp_lora):
801
+ # print(f"index={index}, set hidden_states.requires_grad = True")
802
+ hidden_states = hidden_states.clone()
803
+ hidden_states.requires_grad = True
804
+ layer_ret = get_checkpoint_fn()(
805
+ layer,
806
+ hidden_states,
807
+ attention_mask,
808
+ position_ids,
809
+ kv_caches[index],
810
+ use_cache
811
+ )
812
+ else:
813
+ layer_ret = layer(
814
+ hidden_states,
815
+ attention_mask,
816
+ position_ids,
817
+ kv_cache=kv_caches[index],
818
+ use_cache=use_cache
819
+ )
820
+
821
+ hidden_states, kv_cache = layer_ret
822
+ if use_cache:
823
+ presents = presents + (kv_cache,)
824
+
825
+ if output_hidden_states:
826
+ all_hidden_states = all_hidden_states + (hidden_states,)
827
+
828
+ # Final layer norm.
829
+ if self.post_layer_norm:
830
+ hidden_states = self.final_layernorm(hidden_states)
831
+
832
+ return hidden_states, presents, all_hidden_states, all_self_attentions
833
+
834
+
835
+ class RAGPLMPreTrainedModel(PreTrainedModel):
836
+ """
837
+ An abstract class to handle weights initialization and
838
+ a simple interface for downloading and loading pretrained models.
839
+ """
840
+
841
+ is_parallelizable = False
842
+ supports_gradient_checkpointing = True
843
+ config_class = RAGPLMConfig
844
+ base_model_prefix = "transformer"
845
+ _no_split_modules = ["RAGPLMBlock"]
846
+
847
+ def _init_weights(self, module: nn.Module):
848
+ """Initialize the weights."""
849
+ return
850
+
851
+ def get_masks(self, input_ids, past_key_values, padding_mask=None):
852
+ batch_size, seq_length = input_ids.shape
853
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
854
+ full_attention_mask.tril_()
855
+ past_length = 0
856
+ if past_key_values:
857
+ past_length = past_key_values[0][0].shape[0]
858
+ if past_length:
859
+ full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
860
+ device=input_ids.device), full_attention_mask), dim=-1)
861
+ if padding_mask is not None:
862
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
863
+ if not past_length and padding_mask is not None:
864
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
865
+ full_attention_mask = (full_attention_mask < 0.5).bool()
866
+ full_attention_mask.unsqueeze_(1)
867
+ return full_attention_mask
868
+
869
+ def get_position_ids(self, input_ids, device):
870
+ batch_size, seq_length = input_ids.shape
871
+ position_ids_1 = torch.zeros( seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
872
+ position_ids_2 = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
873
+ position_ids = torch.stack([position_ids_1, position_ids_2], axis=1) # [batch_size, 2, seq_len]
874
+ return position_ids
875
+
876
+ def _set_gradient_checkpointing(self, module, value=False):
877
+ if isinstance(module, RAGPLMTransformer):
878
+ module.gradient_checkpointing = value
879
+
880
+ class Embedding(torch.nn.Module):
881
+ """Language model embeddings."""
882
+
883
+ def __init__(self, config: RAGPLMConfig, device=None):
884
+ super(Embedding, self).__init__()
885
+
886
+ self.hidden_size = config.hidden_size
887
+ # Word embeddings (parallel).
888
+ self.word_embeddings = nn.Embedding(
889
+ config.padded_vocab_size,
890
+ self.hidden_size,
891
+ dtype=config.torch_dtype,
892
+ device=device
893
+ )
894
+ self.fp32_residual_connection = config.fp32_residual_connection
895
+
896
+ def forward(self, input_ids):
897
+ # Embeddings.
898
+ words_embeddings = self.word_embeddings(input_ids)
899
+ embeddings = words_embeddings
900
+ # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
901
+ embeddings = embeddings.transpose(0, 1).contiguous()
902
+ # If the input flag for fp32 residual connection is set, convert for float.
903
+ if self.fp32_residual_connection:
904
+ embeddings = embeddings.float()
905
+ return embeddings
906
+
907
+ class RAGPLMModel(RAGPLMPreTrainedModel):
908
+ def __init__(self, config: RAGPLMConfig, device=None, empty_init=True):
909
+ super().__init__(config)
910
+ if empty_init:
911
+ init_method = skip_init
912
+ else:
913
+ init_method = default_init
914
+ init_kwargs = {}
915
+ if device is not None:
916
+ init_kwargs["device"] = device
917
+ self.embedding = init_method(Embedding, config, **init_kwargs)
918
+ self.num_layers = config.num_layers
919
+ self.multi_query_group_num = config.multi_query_group_num
920
+ self.kv_channels = config.kv_channels
921
+
922
+ self.str_emb_transform = None
923
+ self.seq_ln = None
924
+ self.str_ln = None
925
+ self.str_embedding = None
926
+
927
+ self.add_str_emb_ln = config.add_str_emb_ln
928
+ self.add_seq_emb_ln = config.add_seq_emb_ln
929
+ if config.str_input_dim is not None and config.str_input_dim > 0:
930
+ # Structure input as codebook: str_input_dim given, str_output_dim given
931
+ self.str_emb_transform = torch.nn.Linear(config.str_input_dim, config.hidden_size, bias=False)
932
+ if config.add_seq_emb_ln:
933
+ self.seq_ln = torch.nn.LayerNorm(config.hidden_size)
934
+ if config.add_str_emb_ln:
935
+ self.str_ln = torch.nn.LayerNorm(config.hidden_size)
936
+
937
+ # if config.str_input_dim is None and config.str_output_dim is not None and config.str_output_dim > 0:
938
+ if config.str_input_dim is None and config.str_vocab_size is not None and config.str_vocab_size > 0:
939
+ # Structure input as index: str_input_dim not given, str_output_dim is the vocab size, the structure embedding will be nn.Embedding(str_output_dim+1, hidden_size)
940
+ self.str_embedding = torch.nn.Embedding(config.str_vocab_size+1, config.hidden_size)
941
+
942
+ # Rotary positional embeddings
943
+ self.seq_length = config.seq_length
944
+ rotary_dim = (
945
+ config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
946
+ )
947
+
948
+ # self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, base=10000, precision=config.torch_dtype, learnable=False)
949
+ self.encoder = init_method(RAGPLMTransformer, config, **init_kwargs)
950
+ self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
951
+ dtype=config.torch_dtype, **init_kwargs)
952
+
953
+ if config.str_output_dim is not None and config.str_output_dim > 0:
954
+ self.output_layer_str = init_method(nn.Linear, config.hidden_size, config.str_output_dim, bias=False,
955
+ dtype=config.torch_dtype, **init_kwargs)
956
+ else:
957
+ self.output_layer_str = None
958
+
959
+ if config.qseq_output_dim is not None and config.qseq_output_dim > 0:
960
+ self.output_layer_qseq = init_method(nn.Linear, config.hidden_size, config.qseq_output_dim, bias=False,
961
+ dtype=config.torch_dtype, **init_kwargs)
962
+ else:
963
+ self.output_layer_qseq = None
964
+
965
+ def init_lora_modules(self):
966
+ for name, param in self.named_parameters():
967
+ if 'lora_linear' in name:
968
+ if '_A' in name:
969
+ torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5))
970
+ elif '_B' in name:
971
+ torch.nn.init.zeros_(param)
972
+
973
+ def get_input_embeddings(self):
974
+ return self.embedding.word_embeddings
975
+
976
+ def forward(
977
+ self,
978
+ input_ids,
979
+ position_ids: Optional[torch.Tensor] = None, # position_ids: [batch_size, 2, seq_len]
980
+ attention_mask: Optional[torch.BoolTensor] = None,
981
+ full_attention_mask: Optional[torch.BoolTensor] = None,
982
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
983
+ inputs_embeds: Optional[torch.Tensor] = None,
984
+ inputs_str_ids: Optional[torch.Tensor] = None,
985
+ inputs_str_embeds: Optional[torch.Tensor] = None,
986
+ use_cache: Optional[bool] = None,
987
+ output_hidden_states: Optional[bool] = None,
988
+ return_dict: Optional[bool] = None,
989
+ ):
990
+ output_hidden_states = (
991
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
992
+ )
993
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
994
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
995
+
996
+ batch_size, seq_length = input_ids.shape
997
+
998
+ if inputs_embeds is None:
999
+ inputs_embeds = self.embedding(input_ids) # [L, B, E]
1000
+
1001
+ if self.str_emb_transform is not None and inputs_str_embeds is not None:
1002
+ # inputs_embeds: torch.Size([12800, 1, 2304]), inputs_str_embeds: torch.Size([1, 337, 384])
1003
+ assert inputs_str_embeds.ndim == 3, f"inputs_embeds: {inputs_embeds.shape}, inputs_str_embeds: {inputs_str_embeds.shape}"
1004
+ assert inputs_str_embeds.shape[0] == inputs_embeds.shape[1], f"inputs_embeds: {inputs_embeds.shape}, inputs_str_embeds: {inputs_str_embeds.shape}"
1005
+ inputs_str_embeds = inputs_str_embeds.transpose(0, 1) # [L, B, E]
1006
+
1007
+ num_res, num_batch, num_dim = inputs_str_embeds.shape
1008
+ # inputs_embeds: [L, B, E]
1009
+ padding = inputs_embeds.shape[0] - num_res
1010
+ inputs_str_embeds = F.pad(inputs_str_embeds, [0, 0, 0, 0, 0, padding], value=0)
1011
+ str_embs = self.str_emb_transform(inputs_str_embeds)
1012
+
1013
+ if self.add_str_emb_ln:
1014
+ str_embs = self.str_ln(str_embs)
1015
+
1016
+ if self.add_seq_emb_ln:
1017
+ # seq_ln only apply to the query sequence part
1018
+ inputs_embeds = torch.cat([ self.seq_ln(inputs_embeds[:num_res]), inputs_embeds[num_res:] ], dim=0)
1019
+
1020
+ inputs_embeds = inputs_embeds + str_embs
1021
+
1022
+ #if self.add_emb_ln:
1023
+ # inputs_embeds = self.seq_ln(inputs_embeds) + self.str_ln(self.str_emb_transform(inputs_str_embeds))
1024
+ #else:
1025
+ # inputs_embeds = inputs_embeds + self.str_emb_transform(inputs_str_embeds)
1026
+
1027
+ if self.str_embedding is not None and inputs_str_ids is not None:
1028
+
1029
+
1030
+ str_embedding_weight = self.str_embedding.weight # [513, 2304]
1031
+ # Add a dimension represent the padding token
1032
+ str_embedding_weight = F.pad(str_embedding_weight, (0, 0, 0, 1)) # [514, 2304]
1033
+
1034
+ assert inputs_str_ids.max() < str_embedding_weight.shape[0], f"inputs_str_ids.max()={inputs_str_ids.max()}, str_embedding_weight.shape[0]={str_embedding_weight.shape[0]}"
1035
+
1036
+ str_embs = str_embedding_weight[inputs_str_ids] # [B, L, E]
1037
+ str_embs = str_embs.permute([1, 0, 2]) # [L, B, E]
1038
+ num_res, num_batch, num_dim = str_embs.shape
1039
+ padding = inputs_embeds.shape[0] - num_res
1040
+ str_embs = F.pad(str_embs, [0, 0, 0, 0, 0, padding], value=0)
1041
+ inputs_embeds = inputs_embeds + str_embs
1042
+
1043
+ if full_attention_mask is None:
1044
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
1045
+ full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
1046
+
1047
+ # Run encoder.
1048
+ hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
1049
+ inputs_embeds, full_attention_mask, position_ids=position_ids,
1050
+ kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
1051
+ )
1052
+
1053
+ if not return_dict:
1054
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
1055
+
1056
+ return BaseModelOutputWithPast(
1057
+ last_hidden_state=hidden_states,
1058
+ past_key_values=presents,
1059
+ hidden_states=all_hidden_states,
1060
+ attentions=all_self_attentions,
1061
+ )
1062
+
1063
+ class RAGPLMForConditionalGeneration(RAGPLMPreTrainedModel):
1064
+ def __init__(self, config: RAGPLMConfig, empty_init=True, device=None):
1065
+ super().__init__(config)
1066
+
1067
+ self.max_sequence_length = config.max_length
1068
+ self.transformer = RAGPLMModel(config, empty_init=empty_init, device=device)
1069
+ self.config = config
1070
+
1071
+ def _update_model_kwargs_for_generation(
1072
+ self,
1073
+ outputs: ModelOutput,
1074
+ model_kwargs: Dict[str, Any],
1075
+ is_encoder_decoder: bool = False,
1076
+ standardize_cache_format: bool = False,
1077
+ ) -> Dict[str, Any]:
1078
+
1079
+ # update past_key_values
1080
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
1081
+ outputs, standardize_cache_format=standardize_cache_format
1082
+ )
1083
+
1084
+ # update attention mask
1085
+ if "attention_mask" in model_kwargs:
1086
+ attention_mask = model_kwargs["attention_mask"]
1087
+ model_kwargs["attention_mask"] = torch.cat(
1088
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
1089
+ )
1090
+
1091
+ if 'full_attention_mask' in model_kwargs:
1092
+ raise NotImplementedError(f"full_attention_mask...")
1093
+ model_kwargs['full_attention_mask'] = F.pad(model_kwargs['full_attention_mask'], [0, 1, 0, 1])
1094
+ if self.config.is_causal:
1095
+ model_kwargs['full_attention_mask'][..., -1] = True
1096
+
1097
+ # update position ids
1098
+ if "position_ids" in model_kwargs:
1099
+ position_ids = model_kwargs["position_ids"]
1100
+ new_position_id = position_ids[..., -1:].clone() # [batch_size, 2, 1]
1101
+ if self.config.rotary_embedding_2d:
1102
+ new_position_id[:, 1] += 1 # Only update the 2nd dimension
1103
+ else:
1104
+ new_position_id[:] += 1
1105
+ model_kwargs["position_ids"] = torch.cat(
1106
+ [position_ids, new_position_id], dim=-1
1107
+ ) # [batch_size, 2, seq_len+1]
1108
+
1109
+ model_kwargs["is_first_forward"] = False
1110
+ return model_kwargs
1111
+
1112
+ def prepare_inputs_for_generation(
1113
+ self,
1114
+ input_ids: torch.LongTensor,
1115
+ past_key_values: Optional[torch.Tensor] = None,
1116
+ attention_mask: Optional[torch.Tensor] = None,
1117
+ full_attention_mask: Optional[torch.Tensor] = None,
1118
+ position_ids: Optional[torch.Tensor] = None,
1119
+ use_cache: Optional[bool] = None,
1120
+ is_first_forward: bool = True,
1121
+ **kwargs
1122
+ ) -> dict:
1123
+ # only last token for input_ids if past is not None
1124
+ if position_ids is None:
1125
+ position_ids = self.get_position_ids(input_ids, device=input_ids.device) # position_ids: [batch_size, 2, seq_len]
1126
+ if not is_first_forward:
1127
+ if past_key_values is not None:
1128
+ position_ids = position_ids[..., -1:]
1129
+ input_ids = input_ids[:, -1:]
1130
+ return {
1131
+ "input_ids": input_ids,
1132
+ "past_key_values": past_key_values,
1133
+ "position_ids": position_ids,
1134
+ "attention_mask": attention_mask,
1135
+ "full_attention_mask": full_attention_mask,
1136
+ "return_last_logit": True,
1137
+ "use_cache": use_cache
1138
+ }
1139
+
1140
+ def forward(
1141
+ self,
1142
+ input_ids: Optional[torch.Tensor] = None,
1143
+ position_ids: Optional[torch.Tensor] = None,
1144
+ attention_mask: Optional[torch.Tensor] = None,
1145
+ full_attention_mask: Optional[torch.Tensor] = None,
1146
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
1147
+ inputs_embeds: Optional[torch.Tensor] = None,
1148
+ labels: Optional[torch.Tensor] = None,
1149
+ use_cache: Optional[bool] = None,
1150
+ output_attentions: Optional[bool] = None,
1151
+ output_hidden_states: Optional[bool] = None,
1152
+ return_dict: Optional[bool] = None,
1153
+ return_last_logit: Optional[bool] = False,
1154
+ ):
1155
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1156
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1157
+
1158
+ transformer_outputs = self.transformer(
1159
+ input_ids=input_ids,
1160
+ position_ids=position_ids, # position_ids: [batch_size, 2, seq_len]
1161
+ attention_mask=attention_mask,
1162
+ full_attention_mask=full_attention_mask,
1163
+ past_key_values=past_key_values,
1164
+ inputs_embeds=inputs_embeds,
1165
+ use_cache=use_cache,
1166
+ output_hidden_states=output_hidden_states,
1167
+ return_dict=return_dict,
1168
+ )
1169
+
1170
+ hidden_states = transformer_outputs[0]
1171
+ if return_last_logit:
1172
+ hidden_states = hidden_states[-1:]
1173
+ lm_logits = self.transformer.output_layer(hidden_states)
1174
+ # output_layer_str
1175
+ lm_logits = lm_logits.transpose(0, 1).contiguous()
1176
+
1177
+ loss = None
1178
+ if labels is not None:
1179
+ lm_logits = lm_logits.to(torch.float32)
1180
+
1181
+ # Shift so that tokens < n predict n
1182
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1183
+ shift_labels = labels[..., 1:].contiguous()
1184
+ # Flatten the tokens
1185
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
1186
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1187
+
1188
+ lm_logits = lm_logits.to(hidden_states.dtype)
1189
+ loss = loss.to(hidden_states.dtype)
1190
+
1191
+ if not return_dict:
1192
+ output = (lm_logits,) + transformer_outputs[1:]
1193
+ return ((loss,) + output) if loss is not None else output
1194
+
1195
+ return CausalLMOutputWithPast(
1196
+ loss=loss,
1197
+ logits=lm_logits,
1198
+ past_key_values=transformer_outputs.past_key_values,
1199
+ hidden_states=transformer_outputs.hidden_states,
1200
+ attentions=transformer_outputs.attentions,
1201
+ )
1202
+
1203
+ @staticmethod
1204
+ def _reorder_cache(
1205
+ past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
1206
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
1207
+ """
1208
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1209
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1210
+ beam_idx at every generation step.
1211
+
1212
+ Output shares the same memory storage as `past`.
1213
+ """
1214
+ return tuple(
1215
+ (
1216
+ layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
1217
+ layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
1218
+ )
1219
+ for layer_past in past
1220
+ )
1221
+
1222
+ def process_response(self, output, history):
1223
+ content = ""
1224
+ history = deepcopy(history)
1225
+ for response in output.split("<|assistant|>"):
1226
+ if "\n" in response:
1227
+ metadata, content = response.split("\n", maxsplit=1)
1228
+ else:
1229
+ metadata, content = "", response
1230
+ if not metadata.strip():
1231
+ content = content.strip()
1232
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
1233
+ content = content.replace("[[训练时间]]", "2023年")
1234
+ else:
1235
+ history.append({"role": "assistant", "metadata": metadata, "content": content})
1236
+ if history[0]["role"] == "system" and "tools" in history[0]:
1237
+ content = "\n".join(content.split("\n")[1:-1])
1238
+ def tool_call(**kwargs):
1239
+ return kwargs
1240
+ parameters = eval(content)
1241
+ content = {"name": metadata.strip(), "parameters": parameters}
1242
+ else:
1243
+ content = {"name": metadata.strip(), "content": content}
1244
+ return content, history
1245
+
1246
+ @torch.inference_mode()
1247
+ def chat(self, tokenizer, query: str, max_length: int = 2048, num_beams=1,
1248
+ do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, **kwargs):
1249
+ if logits_processor is None:
1250
+ logits_processor = LogitsProcessorList()
1251
+ logits_processor.append(InvalidScoreLogitsProcessor())
1252
+ gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1253
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1254
+ inputs = tokenizer.build_chat_input(query)
1255
+ inputs = inputs.to(self.device)
1256
+ eos_token_id = [tokenizer.eos_token_id]
1257
+ outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1258
+ outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
1259
+ response = tokenizer.decode(outputs)
1260
+ return response
performance.png ADDED

Git LFS Details

  • SHA256: 1419e63831d60aabbed4476453f0b2fcd8ca235da3e525a3a6481b582b0a8fcb
  • Pointer size: 131 Bytes
  • Size of remote file: 142 kB
proteinmoe_architecture.png ADDED

Git LFS Details

  • SHA256: 670762ddcb58e41cea704d9fddb6acf9bb109216baff6eaadea401699b56552f
  • Pointer size: 131 Bytes
  • Size of remote file: 450 kB
pytorch_model-00001-of-00007.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2442d42059efa680a6dfcb0098629bb1661582a0b61b1eed4e6ed472dbf1d7b9
3
+ size 4932953760
pytorch_model-00002-of-00007.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb9e891095ef8df462d4c192accdf3650764acecf8f3512568c8d25c66ea60ea
3
+ size 4999044744
pytorch_model-00003-of-00007.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae85274ab0bc8fcd790fdb55b0993d6bafffe153511d82047b7c12bca28f225e
3
+ size 4991905042
pytorch_model-00004-of-00007.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d798373023367553ae425c7e57170ee7dce7ac516aec6edabd8c894f6842731
3
+ size 4963629472
pytorch_model-00005-of-00007.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:900893052b8f4b2a5e939c777a3c9a07b1cc1e52efdb7f688231a8cd650ef89a
3
+ size 4991904870
pytorch_model-00006-of-00007.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2c940f8773468bc21e771d98f35be0250a4ba2ae77e5429f56cd1cbe8b55c84
3
+ size 4999045148
pytorch_model-00007-of-00007.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a45ba9f924d8ab3f664f8365f07c40b1247f824acbf662629b324a2022af928
3
+ size 2249880965
pytorch_model.bin.index.json ADDED
The diff for this file is too large to render. See raw diff
 
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
tokenization.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence, Tuple, List, Union, Optional
2
+ from abc import ABC
3
+ from abc import abstractmethod
4
+ # from .tokenizer import AbstractTokenizer
5
+ import logging
6
+ import itertools
7
+ from transformers import PreTrainedTokenizer
8
+ import torch
9
+ import json
10
+ import numpy as np
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class ResidueLevelTokenizer(object):
15
+ """
16
+ Tokenizer for Protein Residue Level Tokenization.
17
+ """
18
+ def __init__(self, **kwargs):
19
+ super(ResidueLevelTokenizer, self).__init__()
20
+
21
+ ### Set normal tokens
22
+ self.all_toks = ['[pad]', 'L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-']
23
+
24
+ ### Set special tokens
25
+ _special_tokens = ['tMASK', 'gMASK', 'sMASK', 'eod', 'sop', 'eop', '</s>' ] # + ['MSA', 'ID'] + [ str(d) for d in range(0, 64) ]
26
+ self.special_tokens = { tok: len(self.all_toks)+i for i,tok in enumerate(_special_tokens) }
27
+ self.special_tokens_decoder = { v:k for k, v in self.special_tokens.items() }
28
+ self.special_tokens['eos'] = self.special_tokens['</s>']
29
+ self.all_toks.extend(_special_tokens)
30
+
31
+ self.vocab = {tok:idx for idx,tok in enumerate(self.all_toks)}
32
+ self.command_token = {'[MASK]':'MASK', '[gMASK]': 'gMASK', '[sMASK]':'sMASK'} # , '[MSA]':'MSA', '[ID]':'ID'}
33
+
34
+ self.gMASK_token_id = self.convert_token_to_id('gMASK')
35
+ self.sop_token_id = self.convert_token_to_id('sop')
36
+ self.eos_token_id = self.convert_token_to_id('</s>')
37
+ # self.id_token_id = self.convert_token_to_id('ID')
38
+ self.pad_token_id = self.convert_token_to_id('[pad]')
39
+
40
+ def __len__(self):
41
+ return len(self.vocab)
42
+
43
+ def get_special_token(self, token):
44
+ return self.special_tokens[token]
45
+
46
+ def get_vocab(self):
47
+ return self.vocab
48
+
49
+ def convert_id_to_token(self, idx):
50
+ idx = int(idx)
51
+ if idx == 0:
52
+ return '[pad]'
53
+ elif idx in self.special_tokens_decoder:
54
+ return f"[{self.special_tokens_decoder[idx]}]"
55
+ else:
56
+ return self.all_toks[idx]
57
+
58
+ def convert_token_to_id(self, token):
59
+ if token == '[pad]':
60
+ return 0
61
+ elif token in self.special_tokens:
62
+ return self.special_tokens[token]
63
+ else:
64
+ return self.vocab[token]
65
+
66
+ def encode(self, sequence, add_eos=True):
67
+ """
68
+ Encode string or list of tokens into array
69
+
70
+ Examples
71
+ ----------------
72
+ encode('[pad]ABDWOKOAKOQA[pad][MSA][3][2][19][3]')
73
+ encode(['A', 'B', 'D', 'MSA', '34', '2'])
74
+ """
75
+
76
+ all_toks = set(self.all_toks)
77
+ a2b = {f"[{t}]":t for t in self.special_tokens if t[0] not in ('[', )}
78
+ all_toks.update( set(a2b.keys()) )
79
+
80
+ if isinstance(sequence, (tuple, list)):
81
+ if sequence[-1] != '</s>' and add_eos:
82
+ sequence = sequence + ['</s>']
83
+ sequence = [ a2b.get(tok, tok) for tok in sequence ]
84
+ return np.array([ self.convert_token_to_id(t) for t in sequence ])
85
+ elif isinstance(sequence, str):
86
+ if not sequence.endswith('</s>') and add_eos:
87
+ sequence = sequence + '</s>'
88
+ s = 0
89
+ e = 1
90
+ tok_list = []
91
+ while s < len(sequence):
92
+ while sequence[s:e] not in all_toks and e < len(sequence):
93
+ e += 1
94
+ assert sequence[s:e] in all_toks, f"Error: sub sequence {sequence[s:]} cannot be parsed"
95
+ tok = sequence[s:e]
96
+ tok = a2b.get(tok, tok) # [gMASK], [sMASK] ...
97
+ tok_id = self.convert_token_to_id(tok)
98
+ tok_list.append(tok_id)
99
+ s = e
100
+ return np.array(tok_list)
101
+ else:
102
+ raise RuntimeError(f"Error: sequence must be list/tuple/str, but got {type(sequence)}")
103
+
104
+ def decode(self, tokens, rem_eos=True, return_str=True):
105
+ if tokens[-1] == self.eos_token_id and rem_eos:
106
+ tokens = tokens[:-1]
107
+ if return_str:
108
+ return "".join([ self.convert_id_to_token(tok) for tok in tokens ])
109
+ else:
110
+ return [ self.convert_id_to_token(tok) for tok in tokens ]
111
+
112
+ def tokenize(self, text, add_eos=True):
113
+ return self.encode(text, add_eos=add_eos)
114
+
115
+ def extend_vocab(self, tokens):
116
+ """Extend the vocab with the list of tokens."""
117
+ for token in tokens:
118
+ if token not in self.vocab:
119
+ self.vocab[token] = len(self.vocab)
120
+ self.all_toks.append(token)
121
+
122
+ class ProteinTokenizer(PreTrainedTokenizer):
123
+ """
124
+ Protein Tokenizer based on Residue level tokenizer
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ vocab_file='xxx',
130
+ padding_side="right",
131
+ clean_up_tokenization_spaces=False,
132
+ encode_special_tokens=True,
133
+ **kwargs
134
+ ):
135
+ self.name = "ProteinTokenizer"
136
+ self.vocab_file = vocab_file
137
+ self.tokenizer = ResidueLevelTokenizer()
138
+ self.special_tokens = self.tokenizer.special_tokens
139
+ self.encode_special_tokens = encode_special_tokens
140
+
141
+ super().__init__(
142
+ padding_side=padding_side,
143
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
144
+ **kwargs
145
+ )
146
+
147
+ def get_command(self, token):
148
+ if token in self.special_tokens:
149
+ return self.special_tokens[token]
150
+ assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
151
+ return self.tokenizer.special_tokens[token]
152
+
153
+ @property
154
+ def unk_token(self) -> str:
155
+ return '[pad]'
156
+
157
+ @property
158
+ def pad_token(self) -> str:
159
+ return '[pad]'
160
+
161
+ @property
162
+ def eos_token(self) -> str:
163
+ return '</s>'
164
+
165
+ @property
166
+ def unk_token_id(self) -> int:
167
+ return '[pad]'
168
+
169
+ @property
170
+ def pad_token_id(self) -> int:
171
+ return self.tokenizer.pad_token_id
172
+
173
+ @property
174
+ def eos_token_id(self):
175
+ return self.tokenizer.eos_token_id
176
+
177
+ @property
178
+ def gMASK_token_id(self):
179
+ return self.tokenizer.gMASK_token_id
180
+
181
+ @property
182
+ def sop_token_id(self):
183
+ return self.tokenizer.sop_token_id
184
+
185
+ @property
186
+ def id_token_id(self):
187
+ return self.tokenizer.id_token_id
188
+
189
+ def IdToToken(self, id_):
190
+ return self.tokenizer.convert_id_to_token(id_)
191
+
192
+ def TokenToId(self, token):
193
+ return self.tokenizer.convert_token_to_id(token)
194
+
195
+ @unk_token.setter
196
+ def unk_token(self, value):
197
+ logger.warning("Setting unk_token is not supported, use the default one.")
198
+
199
+ @pad_token.setter
200
+ def pad_token(self, value):
201
+ logger.warning("Setting pad_token is not supported, use the default one.")
202
+
203
+ @eos_token.setter
204
+ def eos_token(self, value):
205
+ logger.warning("Setting eos_token is not supported, use the default one.")
206
+
207
+ @property
208
+ def vocab_size(self):
209
+ return len(self.tokenizer)
210
+
211
+ def encode(self, sequence, add_eos=True):
212
+ return self.tokenizer.encode(sequence, add_eos=add_eos)
213
+
214
+ def decode(self, token_ids, rem_eos=True, return_str=True):
215
+ return self.tokenizer.decode(token_ids, rem_eos=rem_eos, return_str=return_str)
216
+
217
+ def _convert_id_to_token(self, index):
218
+ """Converts an index (integer) in a token (str) using the vocab."""
219
+ return self.tokenizer.convert_id_to_token(index)
220
+
221
+ def get_vocab(self):
222
+ """ Returns vocab as a dict """
223
+ vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
224
+ return vocab
225
+
226
+ @property
227
+ def eod(self):
228
+ return self.tokenizer.get_special_token('eos')
229
+
230
+ def detokenize(self, Ids, type_token=False):
231
+ new_tokens = self.tokenizer.decode(Ids)
232
+ return new_tokens
233
+
234
+ def tokenize(self, text):
235
+ ids = self.tokenizer.tokenize(text)
236
+ return ids
237
+
238
+ def extend_vocab(self, tokens):
239
+ """ Extend the vocab with the list of tokens """
240
+ self.tokenizer.extend_vocab(tokens)
241
+
242
+ def add_retriever_tokens(self):
243
+ retriever_tokens = ['MSA', 'ID'] + [ str(d) for d in range(0, 64) ]
244
+ self.tokenizer.extend_vocab(retriever_tokens)
245
+ self.tokenizer.command_token['[MSA]'] = 'MSA'
246
+ self.tokenizer.command_token['[ID]'] = 'ID'
247
+
248
+ def add_structure_tokens(self, codebook_size):
249
+ self.tokenizer.extend_vocab( [ str(i) for i in range(codebook_size) ] )
250
+
251
+ def build_chat_input(self, query):
252
+ input_ids = [ self.tokenizer.convert_token_to_id('gMASK'), self.tokenizer.convert_token_to_id('sop') ]
253
+ input_ids += [ self.tokenizer.convert_token_to_id(tok) for tok in query ]
254
+ input_ids += [ self.tokenizer.convert_token_to_id('ID') ]
255
+ # return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
256
+
257
+ position_ids = torch.stack([torch.zeros(len(input_ids)), torch.arange(len(input_ids))], axis=0).unsqueeze(0).long()
258
+ return {
259
+ 'input_ids': torch.from_numpy(np.array([ input_ids ])).long(),
260
+ 'attention_mask': None,
261
+ 'position_ids': position_ids
262
+ }
263
+
264
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
265
+ vocab = self.get_vocab()
266
+ with open(f"{save_directory}/vocab.json", 'w') as f:
267
+ json.dump(vocab, f, indent=4)
268
+ return ( f"{save_directory}/vocab.json", )
tokenizer_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {},
3
+ "auto_map": {
4
+ "AutoTokenizer": [
5
+ "tokenization.ProteinTokenizer",
6
+ null
7
+ ]
8
+ },
9
+ "clean_up_tokenization_spaces": false,
10
+ "do_lower_case": false,
11
+ "eos_token": "</s>",
12
+ "extra_special_tokens": {},
13
+ "model_max_length": 1000000000000000019884624838656,
14
+ "pad_token": "[pad]",
15
+ "padding_side": "right",
16
+ "remove_space": false,
17
+ "tokenizer_class": "ProteinTokenizer",
18
+ "unk_token": "[pad]"
19
+ }
vocab.json ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[pad]": 0,
3
+ "L": 1,
4
+ "A": 2,
5
+ "G": 3,
6
+ "V": 4,
7
+ "S": 5,
8
+ "E": 6,
9
+ "R": 7,
10
+ "T": 8,
11
+ "I": 9,
12
+ "D": 10,
13
+ "P": 11,
14
+ "K": 12,
15
+ "Q": 13,
16
+ "N": 14,
17
+ "F": 15,
18
+ "Y": 16,
19
+ "M": 17,
20
+ "H": 18,
21
+ "W": 19,
22
+ "C": 20,
23
+ "X": 21,
24
+ "B": 22,
25
+ "U": 23,
26
+ "Z": 24,
27
+ "O": 25,
28
+ ".": 26,
29
+ "-": 27,
30
+ "[tMASK]": 28,
31
+ "[gMASK]": 29,
32
+ "[sMASK]": 30,
33
+ "[eod]": 31,
34
+ "[sop]": 32,
35
+ "[eop]": 33,
36
+ "[</s>]": 34,
37
+ "0": 35,
38
+ "1": 36,
39
+ "2": 37,
40
+ "3": 38,
41
+ "4": 39,
42
+ "5": 40,
43
+ "6": 41,
44
+ "7": 42,
45
+ "8": 43,
46
+ "9": 44,
47
+ "10": 45,
48
+ "11": 46,
49
+ "12": 47,
50
+ "13": 48,
51
+ "14": 49,
52
+ "15": 50,
53
+ "16": 51,
54
+ "17": 52,
55
+ "18": 53,
56
+ "19": 54,
57
+ "20": 55,
58
+ "21": 56,
59
+ "22": 57,
60
+ "23": 58,
61
+ "24": 59,
62
+ "25": 60,
63
+ "26": 61,
64
+ "27": 62,
65
+ "28": 63,
66
+ "29": 64,
67
+ "30": 65,
68
+ "31": 66,
69
+ "32": 67,
70
+ "33": 68,
71
+ "34": 69,
72
+ "35": 70,
73
+ "36": 71,
74
+ "37": 72,
75
+ "38": 73,
76
+ "39": 74,
77
+ "40": 75,
78
+ "41": 76,
79
+ "42": 77,
80
+ "43": 78,
81
+ "44": 79,
82
+ "45": 80,
83
+ "46": 81,
84
+ "47": 82,
85
+ "48": 83,
86
+ "49": 84,
87
+ "50": 85,
88
+ "51": 86,
89
+ "52": 87,
90
+ "53": 88,
91
+ "54": 89,
92
+ "55": 90,
93
+ "56": 91,
94
+ "57": 92,
95
+ "58": 93,
96
+ "59": 94,
97
+ "60": 95,
98
+ "61": 96,
99
+ "62": 97,
100
+ "63": 98,
101
+ "64": 99,
102
+ "65": 100,
103
+ "66": 101,
104
+ "67": 102,
105
+ "68": 103,
106
+ "69": 104,
107
+ "70": 105,
108
+ "71": 106,
109
+ "72": 107,
110
+ "73": 108,
111
+ "74": 109,
112
+ "75": 110,
113
+ "76": 111,
114
+ "77": 112,
115
+ "78": 113,
116
+ "79": 114,
117
+ "80": 115,
118
+ "81": 116,
119
+ "82": 117,
120
+ "83": 118,
121
+ "84": 119,
122
+ "85": 120,
123
+ "86": 121,
124
+ "87": 122,
125
+ "88": 123,
126
+ "89": 124,
127
+ "90": 125,
128
+ "91": 126,
129
+ "92": 127,
130
+ "93": 128,
131
+ "94": 129,
132
+ "95": 130,
133
+ "96": 131,
134
+ "97": 132,
135
+ "98": 133,
136
+ "99": 134,
137
+ "100": 135,
138
+ "101": 136,
139
+ "102": 137,
140
+ "103": 138,
141
+ "104": 139,
142
+ "105": 140,
143
+ "106": 141,
144
+ "107": 142,
145
+ "108": 143,
146
+ "109": 144,
147
+ "110": 145,
148
+ "111": 146,
149
+ "112": 147,
150
+ "113": 148,
151
+ "114": 149,
152
+ "115": 150,
153
+ "116": 151,
154
+ "117": 152,
155
+ "118": 153,
156
+ "119": 154,
157
+ "120": 155,
158
+ "121": 156,
159
+ "122": 157,
160
+ "123": 158,
161
+ "124": 159,
162
+ "125": 160,
163
+ "126": 161,
164
+ "127": 162,
165
+ "128": 163,
166
+ "129": 164,
167
+ "130": 165,
168
+ "131": 166,
169
+ "132": 167,
170
+ "133": 168,
171
+ "134": 169,
172
+ "135": 170,
173
+ "136": 171,
174
+ "137": 172,
175
+ "138": 173,
176
+ "139": 174,
177
+ "140": 175,
178
+ "141": 176,
179
+ "142": 177,
180
+ "143": 178,
181
+ "144": 179,
182
+ "145": 180,
183
+ "146": 181,
184
+ "147": 182,
185
+ "148": 183,
186
+ "149": 184,
187
+ "150": 185,
188
+ "151": 186,
189
+ "152": 187,
190
+ "153": 188,
191
+ "154": 189,
192
+ "155": 190,
193
+ "156": 191,
194
+ "157": 192,
195
+ "158": 193,
196
+ "159": 194,
197
+ "160": 195,
198
+ "161": 196,
199
+ "162": 197,
200
+ "163": 198,
201
+ "164": 199,
202
+ "165": 200,
203
+ "166": 201,
204
+ "167": 202,
205
+ "168": 203,
206
+ "169": 204,
207
+ "170": 205,
208
+ "171": 206,
209
+ "172": 207,
210
+ "173": 208,
211
+ "174": 209,
212
+ "175": 210,
213
+ "176": 211,
214
+ "177": 212,
215
+ "178": 213,
216
+ "179": 214,
217
+ "180": 215,
218
+ "181": 216,
219
+ "182": 217,
220
+ "183": 218,
221
+ "184": 219,
222
+ "185": 220,
223
+ "186": 221,
224
+ "187": 222,
225
+ "188": 223,
226
+ "189": 224,
227
+ "190": 225,
228
+ "191": 226,
229
+ "192": 227,
230
+ "193": 228,
231
+ "194": 229,
232
+ "195": 230,
233
+ "196": 231,
234
+ "197": 232,
235
+ "198": 233,
236
+ "199": 234,
237
+ "200": 235,
238
+ "201": 236,
239
+ "202": 237,
240
+ "203": 238,
241
+ "204": 239,
242
+ "205": 240,
243
+ "206": 241,
244
+ "207": 242,
245
+ "208": 243,
246
+ "209": 244,
247
+ "210": 245,
248
+ "211": 246,
249
+ "212": 247,
250
+ "213": 248,
251
+ "214": 249,
252
+ "215": 250,
253
+ "216": 251,
254
+ "217": 252,
255
+ "218": 253,
256
+ "219": 254,
257
+ "220": 255,
258
+ "221": 256,
259
+ "222": 257,
260
+ "223": 258,
261
+ "224": 259,
262
+ "225": 260,
263
+ "226": 261,
264
+ "227": 262,
265
+ "228": 263,
266
+ "229": 264,
267
+ "230": 265,
268
+ "231": 266,
269
+ "232": 267,
270
+ "233": 268,
271
+ "234": 269,
272
+ "235": 270,
273
+ "236": 271,
274
+ "237": 272,
275
+ "238": 273,
276
+ "239": 274,
277
+ "240": 275,
278
+ "241": 276,
279
+ "242": 277,
280
+ "243": 278,
281
+ "244": 279,
282
+ "245": 280,
283
+ "246": 281,
284
+ "247": 282,
285
+ "248": 283,
286
+ "249": 284,
287
+ "250": 285,
288
+ "251": 286,
289
+ "252": 287,
290
+ "253": 288,
291
+ "254": 289,
292
+ "255": 290,
293
+ "256": 291,
294
+ "257": 292,
295
+ "258": 293,
296
+ "259": 294,
297
+ "260": 295,
298
+ "261": 296,
299
+ "262": 297,
300
+ "263": 298,
301
+ "264": 299,
302
+ "265": 300,
303
+ "266": 301,
304
+ "267": 302,
305
+ "268": 303,
306
+ "269": 304,
307
+ "270": 305,
308
+ "271": 306,
309
+ "272": 307,
310
+ "273": 308,
311
+ "274": 309,
312
+ "275": 310,
313
+ "276": 311,
314
+ "277": 312,
315
+ "278": 313,
316
+ "279": 314,
317
+ "280": 315,
318
+ "281": 316,
319
+ "282": 317,
320
+ "283": 318,
321
+ "284": 319,
322
+ "285": 320,
323
+ "286": 321,
324
+ "287": 322,
325
+ "288": 323,
326
+ "289": 324,
327
+ "290": 325,
328
+ "291": 326,
329
+ "292": 327,
330
+ "293": 328,
331
+ "294": 329,
332
+ "295": 330,
333
+ "296": 331,
334
+ "297": 332,
335
+ "298": 333,
336
+ "299": 334,
337
+ "300": 335,
338
+ "301": 336,
339
+ "302": 337,
340
+ "303": 338,
341
+ "304": 339,
342
+ "305": 340,
343
+ "306": 341,
344
+ "307": 342,
345
+ "308": 343,
346
+ "309": 344,
347
+ "310": 345,
348
+ "311": 346,
349
+ "312": 347,
350
+ "313": 348,
351
+ "314": 349,
352
+ "315": 350,
353
+ "316": 351,
354
+ "317": 352,
355
+ "318": 353,
356
+ "319": 354,
357
+ "320": 355,
358
+ "321": 356,
359
+ "322": 357,
360
+ "323": 358,
361
+ "324": 359,
362
+ "325": 360,
363
+ "326": 361,
364
+ "327": 362,
365
+ "328": 363,
366
+ "329": 364,
367
+ "330": 365,
368
+ "331": 366,
369
+ "332": 367,
370
+ "333": 368,
371
+ "334": 369,
372
+ "335": 370,
373
+ "336": 371,
374
+ "337": 372,
375
+ "338": 373,
376
+ "339": 374,
377
+ "340": 375,
378
+ "341": 376,
379
+ "342": 377,
380
+ "343": 378,
381
+ "344": 379,
382
+ "345": 380,
383
+ "346": 381,
384
+ "347": 382,
385
+ "348": 383,
386
+ "349": 384,
387
+ "350": 385,
388
+ "351": 386,
389
+ "352": 387,
390
+ "353": 388,
391
+ "354": 389,
392
+ "355": 390,
393
+ "356": 391,
394
+ "357": 392,
395
+ "358": 393,
396
+ "359": 394,
397
+ "360": 395,
398
+ "361": 396,
399
+ "362": 397,
400
+ "363": 398,
401
+ "364": 399,
402
+ "365": 400,
403
+ "366": 401,
404
+ "367": 402,
405
+ "368": 403,
406
+ "369": 404,
407
+ "370": 405,
408
+ "371": 406,
409
+ "372": 407,
410
+ "373": 408,
411
+ "374": 409,
412
+ "375": 410,
413
+ "376": 411,
414
+ "377": 412,
415
+ "378": 413,
416
+ "379": 414,
417
+ "380": 415,
418
+ "381": 416,
419
+ "382": 417,
420
+ "383": 418,
421
+ "384": 419,
422
+ "385": 420,
423
+ "386": 421,
424
+ "387": 422,
425
+ "388": 423,
426
+ "389": 424,
427
+ "390": 425,
428
+ "391": 426,
429
+ "392": 427,
430
+ "393": 428,
431
+ "394": 429,
432
+ "395": 430,
433
+ "396": 431,
434
+ "397": 432,
435
+ "398": 433,
436
+ "399": 434,
437
+ "400": 435,
438
+ "401": 436,
439
+ "402": 437,
440
+ "403": 438,
441
+ "404": 439,
442
+ "405": 440,
443
+ "406": 441,
444
+ "407": 442,
445
+ "408": 443,
446
+ "409": 444,
447
+ "410": 445,
448
+ "411": 446,
449
+ "412": 447,
450
+ "413": 448,
451
+ "414": 449,
452
+ "415": 450,
453
+ "416": 451,
454
+ "417": 452,
455
+ "418": 453,
456
+ "419": 454,
457
+ "420": 455,
458
+ "421": 456,
459
+ "422": 457,
460
+ "423": 458,
461
+ "424": 459,
462
+ "425": 460,
463
+ "426": 461,
464
+ "427": 462,
465
+ "428": 463,
466
+ "429": 464,
467
+ "430": 465,
468
+ "431": 466,
469
+ "432": 467,
470
+ "433": 468,
471
+ "434": 469,
472
+ "435": 470,
473
+ "436": 471,
474
+ "437": 472,
475
+ "438": 473,
476
+ "439": 474,
477
+ "440": 475,
478
+ "441": 476,
479
+ "442": 477,
480
+ "443": 478,
481
+ "444": 479,
482
+ "445": 480,
483
+ "446": 481,
484
+ "447": 482,
485
+ "448": 483,
486
+ "449": 484,
487
+ "450": 485,
488
+ "451": 486,
489
+ "452": 487,
490
+ "453": 488,
491
+ "454": 489,
492
+ "455": 490,
493
+ "456": 491,
494
+ "457": 492,
495
+ "458": 493,
496
+ "459": 494,
497
+ "460": 495,
498
+ "461": 496,
499
+ "462": 497,
500
+ "463": 498,
501
+ "464": 499,
502
+ "465": 500,
503
+ "466": 501,
504
+ "467": 502,
505
+ "468": 503,
506
+ "469": 504,
507
+ "470": 505,
508
+ "471": 506,
509
+ "472": 507,
510
+ "473": 508,
511
+ "474": 509,
512
+ "475": 510,
513
+ "476": 511,
514
+ "477": 512,
515
+ "478": 513,
516
+ "479": 514,
517
+ "480": 515,
518
+ "481": 516,
519
+ "482": 517,
520
+ "483": 518,
521
+ "484": 519,
522
+ "485": 520,
523
+ "486": 521,
524
+ "487": 522,
525
+ "488": 523,
526
+ "489": 524,
527
+ "490": 525,
528
+ "491": 526,
529
+ "492": 527,
530
+ "493": 528,
531
+ "494": 529,
532
+ "495": 530,
533
+ "496": 531,
534
+ "497": 532,
535
+ "498": 533,
536
+ "499": 534,
537
+ "500": 535,
538
+ "501": 536,
539
+ "502": 537,
540
+ "503": 538,
541
+ "504": 539,
542
+ "505": 540,
543
+ "506": 541,
544
+ "507": 542,
545
+ "508": 543,
546
+ "509": 544,
547
+ "510": 545,
548
+ "511": 546
549
+ }