Junde commited on
Commit
840328b
·
verified ·
1 Parent(s): e31c91f

Upload folder using huggingface_hub

Browse files
model-00001-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe3791a5f2f0773e38adf0c8d1da47e2f73a7f391b87427af2c13840660c38fa
3
+ size 4840397588
model-00002-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:486db20e6ae740b78145eaacdd773a9f412c8cf2d218c22b53fd1b73a4ad685d
3
+ size 4838824240
model-00003-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfdf23318e2bb71f62222dcd9e4916b96e802b869a62f80c691acca9e5a1b23a
3
+ size 4838824296
model-00004-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3fa234ad7f37850b16954b2add5b5a766e74a3e6ade87da851dbb9459a71065
3
+ size 4838824296
model-00005-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a59d7c1f22faa38f0c00ca0df5fa00bbdb3a45fa2439a9b7dcbcea04b7cc7fe
3
+ size 4838824296
model-00006-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36469f8d727a9017afce8172c25852ae6e996ac9c8f4c7b68b9b712826c13652
3
+ size 1692142452
model.safetensors.index.json ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 25887801592
4
+ },
5
+ "weight_map": {
6
+ "lm_head.bias": "model-00006-of-00006.safetensors",
7
+ "lm_head.weight": "model-00006-of-00006.safetensors",
8
+ "transformer.h.0.attn.bias": "model-00001-of-00006.safetensors",
9
+ "transformer.h.0.attn.masked_bias": "model-00001-of-00006.safetensors",
10
+ "transformer.h.0.attn.out_proj.weight": "model-00001-of-00006.safetensors",
11
+ "transformer.h.0.attn.qkv_proj.weight": "model-00001-of-00006.safetensors",
12
+ "transformer.h.0.ln_1.bias": "model-00001-of-00006.safetensors",
13
+ "transformer.h.0.ln_1.weight": "model-00001-of-00006.safetensors",
14
+ "transformer.h.0.mlp.fc_in.bias": "model-00001-of-00006.safetensors",
15
+ "transformer.h.0.mlp.fc_in.weight": "model-00001-of-00006.safetensors",
16
+ "transformer.h.0.mlp.fc_out.bias": "model-00001-of-00006.safetensors",
17
+ "transformer.h.0.mlp.fc_out.weight": "model-00001-of-00006.safetensors",
18
+ "transformer.h.1.attn.bias": "model-00001-of-00006.safetensors",
19
+ "transformer.h.1.attn.masked_bias": "model-00001-of-00006.safetensors",
20
+ "transformer.h.1.attn.out_proj.weight": "model-00001-of-00006.safetensors",
21
+ "transformer.h.1.attn.qkv_proj.weight": "model-00001-of-00006.safetensors",
22
+ "transformer.h.1.ln_1.bias": "model-00001-of-00006.safetensors",
23
+ "transformer.h.1.ln_1.weight": "model-00001-of-00006.safetensors",
24
+ "transformer.h.1.mlp.fc_in.bias": "model-00001-of-00006.safetensors",
25
+ "transformer.h.1.mlp.fc_in.weight": "model-00001-of-00006.safetensors",
26
+ "transformer.h.1.mlp.fc_out.bias": "model-00001-of-00006.safetensors",
27
+ "transformer.h.1.mlp.fc_out.weight": "model-00001-of-00006.safetensors",
28
+ "transformer.h.10.attn.bias": "model-00002-of-00006.safetensors",
29
+ "transformer.h.10.attn.masked_bias": "model-00002-of-00006.safetensors",
30
+ "transformer.h.10.attn.out_proj.weight": "model-00002-of-00006.safetensors",
31
+ "transformer.h.10.attn.qkv_proj.weight": "model-00002-of-00006.safetensors",
32
+ "transformer.h.10.ln_1.bias": "model-00002-of-00006.safetensors",
33
+ "transformer.h.10.ln_1.weight": "model-00002-of-00006.safetensors",
34
+ "transformer.h.10.mlp.fc_in.bias": "model-00002-of-00006.safetensors",
35
+ "transformer.h.10.mlp.fc_in.weight": "model-00002-of-00006.safetensors",
36
+ "transformer.h.10.mlp.fc_out.bias": "model-00002-of-00006.safetensors",
37
+ "transformer.h.10.mlp.fc_out.weight": "model-00002-of-00006.safetensors",
38
+ "transformer.h.11.attn.bias": "model-00002-of-00006.safetensors",
39
+ "transformer.h.11.attn.masked_bias": "model-00002-of-00006.safetensors",
40
+ "transformer.h.11.attn.out_proj.weight": "model-00002-of-00006.safetensors",
41
+ "transformer.h.11.attn.qkv_proj.weight": "model-00002-of-00006.safetensors",
42
+ "transformer.h.11.ln_1.bias": "model-00002-of-00006.safetensors",
43
+ "transformer.h.11.ln_1.weight": "model-00002-of-00006.safetensors",
44
+ "transformer.h.11.mlp.fc_in.bias": "model-00002-of-00006.safetensors",
45
+ "transformer.h.11.mlp.fc_in.weight": "model-00002-of-00006.safetensors",
46
+ "transformer.h.11.mlp.fc_out.bias": "model-00002-of-00006.safetensors",
47
+ "transformer.h.11.mlp.fc_out.weight": "model-00002-of-00006.safetensors",
48
+ "transformer.h.12.attn.bias": "model-00002-of-00006.safetensors",
49
+ "transformer.h.12.attn.masked_bias": "model-00002-of-00006.safetensors",
50
+ "transformer.h.12.attn.out_proj.weight": "model-00003-of-00006.safetensors",
51
+ "transformer.h.12.attn.qkv_proj.weight": "model-00003-of-00006.safetensors",
52
+ "transformer.h.12.ln_1.bias": "model-00002-of-00006.safetensors",
53
+ "transformer.h.12.ln_1.weight": "model-00002-of-00006.safetensors",
54
+ "transformer.h.12.mlp.fc_in.bias": "model-00003-of-00006.safetensors",
55
+ "transformer.h.12.mlp.fc_in.weight": "model-00003-of-00006.safetensors",
56
+ "transformer.h.12.mlp.fc_out.bias": "model-00003-of-00006.safetensors",
57
+ "transformer.h.12.mlp.fc_out.weight": "model-00003-of-00006.safetensors",
58
+ "transformer.h.13.attn.bias": "model-00003-of-00006.safetensors",
59
+ "transformer.h.13.attn.masked_bias": "model-00003-of-00006.safetensors",
60
+ "transformer.h.13.attn.out_proj.weight": "model-00003-of-00006.safetensors",
61
+ "transformer.h.13.attn.qkv_proj.weight": "model-00003-of-00006.safetensors",
62
+ "transformer.h.13.ln_1.bias": "model-00003-of-00006.safetensors",
63
+ "transformer.h.13.ln_1.weight": "model-00003-of-00006.safetensors",
64
+ "transformer.h.13.mlp.fc_in.bias": "model-00003-of-00006.safetensors",
65
+ "transformer.h.13.mlp.fc_in.weight": "model-00003-of-00006.safetensors",
66
+ "transformer.h.13.mlp.fc_out.bias": "model-00003-of-00006.safetensors",
67
+ "transformer.h.13.mlp.fc_out.weight": "model-00003-of-00006.safetensors",
68
+ "transformer.h.14.attn.bias": "model-00003-of-00006.safetensors",
69
+ "transformer.h.14.attn.masked_bias": "model-00003-of-00006.safetensors",
70
+ "transformer.h.14.attn.out_proj.weight": "model-00003-of-00006.safetensors",
71
+ "transformer.h.14.attn.qkv_proj.weight": "model-00003-of-00006.safetensors",
72
+ "transformer.h.14.ln_1.bias": "model-00003-of-00006.safetensors",
73
+ "transformer.h.14.ln_1.weight": "model-00003-of-00006.safetensors",
74
+ "transformer.h.14.mlp.fc_in.bias": "model-00003-of-00006.safetensors",
75
+ "transformer.h.14.mlp.fc_in.weight": "model-00003-of-00006.safetensors",
76
+ "transformer.h.14.mlp.fc_out.bias": "model-00003-of-00006.safetensors",
77
+ "transformer.h.14.mlp.fc_out.weight": "model-00003-of-00006.safetensors",
78
+ "transformer.h.15.attn.bias": "model-00003-of-00006.safetensors",
79
+ "transformer.h.15.attn.masked_bias": "model-00003-of-00006.safetensors",
80
+ "transformer.h.15.attn.out_proj.weight": "model-00003-of-00006.safetensors",
81
+ "transformer.h.15.attn.qkv_proj.weight": "model-00003-of-00006.safetensors",
82
+ "transformer.h.15.ln_1.bias": "model-00003-of-00006.safetensors",
83
+ "transformer.h.15.ln_1.weight": "model-00003-of-00006.safetensors",
84
+ "transformer.h.15.mlp.fc_in.bias": "model-00003-of-00006.safetensors",
85
+ "transformer.h.15.mlp.fc_in.weight": "model-00003-of-00006.safetensors",
86
+ "transformer.h.15.mlp.fc_out.bias": "model-00003-of-00006.safetensors",
87
+ "transformer.h.15.mlp.fc_out.weight": "model-00003-of-00006.safetensors",
88
+ "transformer.h.16.attn.bias": "model-00003-of-00006.safetensors",
89
+ "transformer.h.16.attn.masked_bias": "model-00003-of-00006.safetensors",
90
+ "transformer.h.16.attn.out_proj.weight": "model-00003-of-00006.safetensors",
91
+ "transformer.h.16.attn.qkv_proj.weight": "model-00003-of-00006.safetensors",
92
+ "transformer.h.16.ln_1.bias": "model-00003-of-00006.safetensors",
93
+ "transformer.h.16.ln_1.weight": "model-00003-of-00006.safetensors",
94
+ "transformer.h.16.mlp.fc_in.bias": "model-00003-of-00006.safetensors",
95
+ "transformer.h.16.mlp.fc_in.weight": "model-00003-of-00006.safetensors",
96
+ "transformer.h.16.mlp.fc_out.bias": "model-00003-of-00006.safetensors",
97
+ "transformer.h.16.mlp.fc_out.weight": "model-00003-of-00006.safetensors",
98
+ "transformer.h.17.attn.bias": "model-00003-of-00006.safetensors",
99
+ "transformer.h.17.attn.masked_bias": "model-00003-of-00006.safetensors",
100
+ "transformer.h.17.attn.out_proj.weight": "model-00003-of-00006.safetensors",
101
+ "transformer.h.17.attn.qkv_proj.weight": "model-00003-of-00006.safetensors",
102
+ "transformer.h.17.ln_1.bias": "model-00003-of-00006.safetensors",
103
+ "transformer.h.17.ln_1.weight": "model-00003-of-00006.safetensors",
104
+ "transformer.h.17.mlp.fc_in.bias": "model-00003-of-00006.safetensors",
105
+ "transformer.h.17.mlp.fc_in.weight": "model-00003-of-00006.safetensors",
106
+ "transformer.h.17.mlp.fc_out.bias": "model-00003-of-00006.safetensors",
107
+ "transformer.h.17.mlp.fc_out.weight": "model-00003-of-00006.safetensors",
108
+ "transformer.h.18.attn.bias": "model-00003-of-00006.safetensors",
109
+ "transformer.h.18.attn.masked_bias": "model-00003-of-00006.safetensors",
110
+ "transformer.h.18.attn.out_proj.weight": "model-00004-of-00006.safetensors",
111
+ "transformer.h.18.attn.qkv_proj.weight": "model-00004-of-00006.safetensors",
112
+ "transformer.h.18.ln_1.bias": "model-00003-of-00006.safetensors",
113
+ "transformer.h.18.ln_1.weight": "model-00003-of-00006.safetensors",
114
+ "transformer.h.18.mlp.fc_in.bias": "model-00004-of-00006.safetensors",
115
+ "transformer.h.18.mlp.fc_in.weight": "model-00004-of-00006.safetensors",
116
+ "transformer.h.18.mlp.fc_out.bias": "model-00004-of-00006.safetensors",
117
+ "transformer.h.18.mlp.fc_out.weight": "model-00004-of-00006.safetensors",
118
+ "transformer.h.19.attn.bias": "model-00004-of-00006.safetensors",
119
+ "transformer.h.19.attn.masked_bias": "model-00004-of-00006.safetensors",
120
+ "transformer.h.19.attn.out_proj.weight": "model-00004-of-00006.safetensors",
121
+ "transformer.h.19.attn.qkv_proj.weight": "model-00004-of-00006.safetensors",
122
+ "transformer.h.19.ln_1.bias": "model-00004-of-00006.safetensors",
123
+ "transformer.h.19.ln_1.weight": "model-00004-of-00006.safetensors",
124
+ "transformer.h.19.mlp.fc_in.bias": "model-00004-of-00006.safetensors",
125
+ "transformer.h.19.mlp.fc_in.weight": "model-00004-of-00006.safetensors",
126
+ "transformer.h.19.mlp.fc_out.bias": "model-00004-of-00006.safetensors",
127
+ "transformer.h.19.mlp.fc_out.weight": "model-00004-of-00006.safetensors",
128
+ "transformer.h.2.attn.bias": "model-00001-of-00006.safetensors",
129
+ "transformer.h.2.attn.masked_bias": "model-00001-of-00006.safetensors",
130
+ "transformer.h.2.attn.out_proj.weight": "model-00001-of-00006.safetensors",
131
+ "transformer.h.2.attn.qkv_proj.weight": "model-00001-of-00006.safetensors",
132
+ "transformer.h.2.ln_1.bias": "model-00001-of-00006.safetensors",
133
+ "transformer.h.2.ln_1.weight": "model-00001-of-00006.safetensors",
134
+ "transformer.h.2.mlp.fc_in.bias": "model-00001-of-00006.safetensors",
135
+ "transformer.h.2.mlp.fc_in.weight": "model-00001-of-00006.safetensors",
136
+ "transformer.h.2.mlp.fc_out.bias": "model-00001-of-00006.safetensors",
137
+ "transformer.h.2.mlp.fc_out.weight": "model-00001-of-00006.safetensors",
138
+ "transformer.h.20.attn.bias": "model-00004-of-00006.safetensors",
139
+ "transformer.h.20.attn.masked_bias": "model-00004-of-00006.safetensors",
140
+ "transformer.h.20.attn.out_proj.weight": "model-00004-of-00006.safetensors",
141
+ "transformer.h.20.attn.qkv_proj.weight": "model-00004-of-00006.safetensors",
142
+ "transformer.h.20.ln_1.bias": "model-00004-of-00006.safetensors",
143
+ "transformer.h.20.ln_1.weight": "model-00004-of-00006.safetensors",
144
+ "transformer.h.20.mlp.fc_in.bias": "model-00004-of-00006.safetensors",
145
+ "transformer.h.20.mlp.fc_in.weight": "model-00004-of-00006.safetensors",
146
+ "transformer.h.20.mlp.fc_out.bias": "model-00004-of-00006.safetensors",
147
+ "transformer.h.20.mlp.fc_out.weight": "model-00004-of-00006.safetensors",
148
+ "transformer.h.21.attn.bias": "model-00004-of-00006.safetensors",
149
+ "transformer.h.21.attn.masked_bias": "model-00004-of-00006.safetensors",
150
+ "transformer.h.21.attn.out_proj.weight": "model-00004-of-00006.safetensors",
151
+ "transformer.h.21.attn.qkv_proj.weight": "model-00004-of-00006.safetensors",
152
+ "transformer.h.21.ln_1.bias": "model-00004-of-00006.safetensors",
153
+ "transformer.h.21.ln_1.weight": "model-00004-of-00006.safetensors",
154
+ "transformer.h.21.mlp.fc_in.bias": "model-00004-of-00006.safetensors",
155
+ "transformer.h.21.mlp.fc_in.weight": "model-00004-of-00006.safetensors",
156
+ "transformer.h.21.mlp.fc_out.bias": "model-00004-of-00006.safetensors",
157
+ "transformer.h.21.mlp.fc_out.weight": "model-00004-of-00006.safetensors",
158
+ "transformer.h.22.attn.bias": "model-00004-of-00006.safetensors",
159
+ "transformer.h.22.attn.masked_bias": "model-00004-of-00006.safetensors",
160
+ "transformer.h.22.attn.out_proj.weight": "model-00004-of-00006.safetensors",
161
+ "transformer.h.22.attn.qkv_proj.weight": "model-00004-of-00006.safetensors",
162
+ "transformer.h.22.ln_1.bias": "model-00004-of-00006.safetensors",
163
+ "transformer.h.22.ln_1.weight": "model-00004-of-00006.safetensors",
164
+ "transformer.h.22.mlp.fc_in.bias": "model-00004-of-00006.safetensors",
165
+ "transformer.h.22.mlp.fc_in.weight": "model-00004-of-00006.safetensors",
166
+ "transformer.h.22.mlp.fc_out.bias": "model-00004-of-00006.safetensors",
167
+ "transformer.h.22.mlp.fc_out.weight": "model-00004-of-00006.safetensors",
168
+ "transformer.h.23.attn.bias": "model-00004-of-00006.safetensors",
169
+ "transformer.h.23.attn.masked_bias": "model-00004-of-00006.safetensors",
170
+ "transformer.h.23.attn.out_proj.weight": "model-00004-of-00006.safetensors",
171
+ "transformer.h.23.attn.qkv_proj.weight": "model-00004-of-00006.safetensors",
172
+ "transformer.h.23.ln_1.bias": "model-00004-of-00006.safetensors",
173
+ "transformer.h.23.ln_1.weight": "model-00004-of-00006.safetensors",
174
+ "transformer.h.23.mlp.fc_in.bias": "model-00004-of-00006.safetensors",
175
+ "transformer.h.23.mlp.fc_in.weight": "model-00004-of-00006.safetensors",
176
+ "transformer.h.23.mlp.fc_out.bias": "model-00004-of-00006.safetensors",
177
+ "transformer.h.23.mlp.fc_out.weight": "model-00004-of-00006.safetensors",
178
+ "transformer.h.24.attn.bias": "model-00004-of-00006.safetensors",
179
+ "transformer.h.24.attn.masked_bias": "model-00004-of-00006.safetensors",
180
+ "transformer.h.24.attn.out_proj.weight": "model-00005-of-00006.safetensors",
181
+ "transformer.h.24.attn.qkv_proj.weight": "model-00005-of-00006.safetensors",
182
+ "transformer.h.24.ln_1.bias": "model-00004-of-00006.safetensors",
183
+ "transformer.h.24.ln_1.weight": "model-00004-of-00006.safetensors",
184
+ "transformer.h.24.mlp.fc_in.bias": "model-00005-of-00006.safetensors",
185
+ "transformer.h.24.mlp.fc_in.weight": "model-00005-of-00006.safetensors",
186
+ "transformer.h.24.mlp.fc_out.bias": "model-00005-of-00006.safetensors",
187
+ "transformer.h.24.mlp.fc_out.weight": "model-00005-of-00006.safetensors",
188
+ "transformer.h.25.attn.bias": "model-00005-of-00006.safetensors",
189
+ "transformer.h.25.attn.masked_bias": "model-00005-of-00006.safetensors",
190
+ "transformer.h.25.attn.out_proj.weight": "model-00005-of-00006.safetensors",
191
+ "transformer.h.25.attn.qkv_proj.weight": "model-00005-of-00006.safetensors",
192
+ "transformer.h.25.ln_1.bias": "model-00005-of-00006.safetensors",
193
+ "transformer.h.25.ln_1.weight": "model-00005-of-00006.safetensors",
194
+ "transformer.h.25.mlp.fc_in.bias": "model-00005-of-00006.safetensors",
195
+ "transformer.h.25.mlp.fc_in.weight": "model-00005-of-00006.safetensors",
196
+ "transformer.h.25.mlp.fc_out.bias": "model-00005-of-00006.safetensors",
197
+ "transformer.h.25.mlp.fc_out.weight": "model-00005-of-00006.safetensors",
198
+ "transformer.h.26.attn.bias": "model-00005-of-00006.safetensors",
199
+ "transformer.h.26.attn.masked_bias": "model-00005-of-00006.safetensors",
200
+ "transformer.h.26.attn.out_proj.weight": "model-00005-of-00006.safetensors",
201
+ "transformer.h.26.attn.qkv_proj.weight": "model-00005-of-00006.safetensors",
202
+ "transformer.h.26.ln_1.bias": "model-00005-of-00006.safetensors",
203
+ "transformer.h.26.ln_1.weight": "model-00005-of-00006.safetensors",
204
+ "transformer.h.26.mlp.fc_in.bias": "model-00005-of-00006.safetensors",
205
+ "transformer.h.26.mlp.fc_in.weight": "model-00005-of-00006.safetensors",
206
+ "transformer.h.26.mlp.fc_out.bias": "model-00005-of-00006.safetensors",
207
+ "transformer.h.26.mlp.fc_out.weight": "model-00005-of-00006.safetensors",
208
+ "transformer.h.27.attn.bias": "model-00005-of-00006.safetensors",
209
+ "transformer.h.27.attn.masked_bias": "model-00005-of-00006.safetensors",
210
+ "transformer.h.27.attn.out_proj.weight": "model-00005-of-00006.safetensors",
211
+ "transformer.h.27.attn.qkv_proj.weight": "model-00005-of-00006.safetensors",
212
+ "transformer.h.27.ln_1.bias": "model-00005-of-00006.safetensors",
213
+ "transformer.h.27.ln_1.weight": "model-00005-of-00006.safetensors",
214
+ "transformer.h.27.mlp.fc_in.bias": "model-00005-of-00006.safetensors",
215
+ "transformer.h.27.mlp.fc_in.weight": "model-00005-of-00006.safetensors",
216
+ "transformer.h.27.mlp.fc_out.bias": "model-00005-of-00006.safetensors",
217
+ "transformer.h.27.mlp.fc_out.weight": "model-00005-of-00006.safetensors",
218
+ "transformer.h.28.attn.bias": "model-00005-of-00006.safetensors",
219
+ "transformer.h.28.attn.masked_bias": "model-00005-of-00006.safetensors",
220
+ "transformer.h.28.attn.out_proj.weight": "model-00005-of-00006.safetensors",
221
+ "transformer.h.28.attn.qkv_proj.weight": "model-00005-of-00006.safetensors",
222
+ "transformer.h.28.ln_1.bias": "model-00005-of-00006.safetensors",
223
+ "transformer.h.28.ln_1.weight": "model-00005-of-00006.safetensors",
224
+ "transformer.h.28.mlp.fc_in.bias": "model-00005-of-00006.safetensors",
225
+ "transformer.h.28.mlp.fc_in.weight": "model-00005-of-00006.safetensors",
226
+ "transformer.h.28.mlp.fc_out.bias": "model-00005-of-00006.safetensors",
227
+ "transformer.h.28.mlp.fc_out.weight": "model-00005-of-00006.safetensors",
228
+ "transformer.h.29.attn.bias": "model-00005-of-00006.safetensors",
229
+ "transformer.h.29.attn.masked_bias": "model-00005-of-00006.safetensors",
230
+ "transformer.h.29.attn.out_proj.weight": "model-00005-of-00006.safetensors",
231
+ "transformer.h.29.attn.qkv_proj.weight": "model-00005-of-00006.safetensors",
232
+ "transformer.h.29.ln_1.bias": "model-00005-of-00006.safetensors",
233
+ "transformer.h.29.ln_1.weight": "model-00005-of-00006.safetensors",
234
+ "transformer.h.29.mlp.fc_in.bias": "model-00005-of-00006.safetensors",
235
+ "transformer.h.29.mlp.fc_in.weight": "model-00005-of-00006.safetensors",
236
+ "transformer.h.29.mlp.fc_out.bias": "model-00005-of-00006.safetensors",
237
+ "transformer.h.29.mlp.fc_out.weight": "model-00005-of-00006.safetensors",
238
+ "transformer.h.3.attn.bias": "model-00001-of-00006.safetensors",
239
+ "transformer.h.3.attn.masked_bias": "model-00001-of-00006.safetensors",
240
+ "transformer.h.3.attn.out_proj.weight": "model-00001-of-00006.safetensors",
241
+ "transformer.h.3.attn.qkv_proj.weight": "model-00001-of-00006.safetensors",
242
+ "transformer.h.3.ln_1.bias": "model-00001-of-00006.safetensors",
243
+ "transformer.h.3.ln_1.weight": "model-00001-of-00006.safetensors",
244
+ "transformer.h.3.mlp.fc_in.bias": "model-00001-of-00006.safetensors",
245
+ "transformer.h.3.mlp.fc_in.weight": "model-00001-of-00006.safetensors",
246
+ "transformer.h.3.mlp.fc_out.bias": "model-00001-of-00006.safetensors",
247
+ "transformer.h.3.mlp.fc_out.weight": "model-00001-of-00006.safetensors",
248
+ "transformer.h.30.attn.bias": "model-00005-of-00006.safetensors",
249
+ "transformer.h.30.attn.masked_bias": "model-00005-of-00006.safetensors",
250
+ "transformer.h.30.attn.out_proj.weight": "model-00006-of-00006.safetensors",
251
+ "transformer.h.30.attn.qkv_proj.weight": "model-00006-of-00006.safetensors",
252
+ "transformer.h.30.ln_1.bias": "model-00005-of-00006.safetensors",
253
+ "transformer.h.30.ln_1.weight": "model-00005-of-00006.safetensors",
254
+ "transformer.h.30.mlp.fc_in.bias": "model-00006-of-00006.safetensors",
255
+ "transformer.h.30.mlp.fc_in.weight": "model-00006-of-00006.safetensors",
256
+ "transformer.h.30.mlp.fc_out.bias": "model-00006-of-00006.safetensors",
257
+ "transformer.h.30.mlp.fc_out.weight": "model-00006-of-00006.safetensors",
258
+ "transformer.h.31.attn.bias": "model-00006-of-00006.safetensors",
259
+ "transformer.h.31.attn.masked_bias": "model-00006-of-00006.safetensors",
260
+ "transformer.h.31.attn.out_proj.weight": "model-00006-of-00006.safetensors",
261
+ "transformer.h.31.attn.qkv_proj.weight": "model-00006-of-00006.safetensors",
262
+ "transformer.h.31.ln_1.bias": "model-00006-of-00006.safetensors",
263
+ "transformer.h.31.ln_1.weight": "model-00006-of-00006.safetensors",
264
+ "transformer.h.31.mlp.fc_in.bias": "model-00006-of-00006.safetensors",
265
+ "transformer.h.31.mlp.fc_in.weight": "model-00006-of-00006.safetensors",
266
+ "transformer.h.31.mlp.fc_out.bias": "model-00006-of-00006.safetensors",
267
+ "transformer.h.31.mlp.fc_out.weight": "model-00006-of-00006.safetensors",
268
+ "transformer.h.4.attn.bias": "model-00001-of-00006.safetensors",
269
+ "transformer.h.4.attn.masked_bias": "model-00001-of-00006.safetensors",
270
+ "transformer.h.4.attn.out_proj.weight": "model-00001-of-00006.safetensors",
271
+ "transformer.h.4.attn.qkv_proj.weight": "model-00001-of-00006.safetensors",
272
+ "transformer.h.4.ln_1.bias": "model-00001-of-00006.safetensors",
273
+ "transformer.h.4.ln_1.weight": "model-00001-of-00006.safetensors",
274
+ "transformer.h.4.mlp.fc_in.bias": "model-00001-of-00006.safetensors",
275
+ "transformer.h.4.mlp.fc_in.weight": "model-00001-of-00006.safetensors",
276
+ "transformer.h.4.mlp.fc_out.bias": "model-00001-of-00006.safetensors",
277
+ "transformer.h.4.mlp.fc_out.weight": "model-00001-of-00006.safetensors",
278
+ "transformer.h.5.attn.bias": "model-00001-of-00006.safetensors",
279
+ "transformer.h.5.attn.masked_bias": "model-00001-of-00006.safetensors",
280
+ "transformer.h.5.attn.out_proj.weight": "model-00001-of-00006.safetensors",
281
+ "transformer.h.5.attn.qkv_proj.weight": "model-00001-of-00006.safetensors",
282
+ "transformer.h.5.ln_1.bias": "model-00001-of-00006.safetensors",
283
+ "transformer.h.5.ln_1.weight": "model-00001-of-00006.safetensors",
284
+ "transformer.h.5.mlp.fc_in.bias": "model-00001-of-00006.safetensors",
285
+ "transformer.h.5.mlp.fc_in.weight": "model-00001-of-00006.safetensors",
286
+ "transformer.h.5.mlp.fc_out.bias": "model-00001-of-00006.safetensors",
287
+ "transformer.h.5.mlp.fc_out.weight": "model-00001-of-00006.safetensors",
288
+ "transformer.h.6.attn.bias": "model-00001-of-00006.safetensors",
289
+ "transformer.h.6.attn.masked_bias": "model-00001-of-00006.safetensors",
290
+ "transformer.h.6.attn.out_proj.weight": "model-00002-of-00006.safetensors",
291
+ "transformer.h.6.attn.qkv_proj.weight": "model-00002-of-00006.safetensors",
292
+ "transformer.h.6.ln_1.bias": "model-00001-of-00006.safetensors",
293
+ "transformer.h.6.ln_1.weight": "model-00001-of-00006.safetensors",
294
+ "transformer.h.6.mlp.fc_in.bias": "model-00002-of-00006.safetensors",
295
+ "transformer.h.6.mlp.fc_in.weight": "model-00002-of-00006.safetensors",
296
+ "transformer.h.6.mlp.fc_out.bias": "model-00002-of-00006.safetensors",
297
+ "transformer.h.6.mlp.fc_out.weight": "model-00002-of-00006.safetensors",
298
+ "transformer.h.7.attn.bias": "model-00002-of-00006.safetensors",
299
+ "transformer.h.7.attn.masked_bias": "model-00002-of-00006.safetensors",
300
+ "transformer.h.7.attn.out_proj.weight": "model-00002-of-00006.safetensors",
301
+ "transformer.h.7.attn.qkv_proj.weight": "model-00002-of-00006.safetensors",
302
+ "transformer.h.7.ln_1.bias": "model-00002-of-00006.safetensors",
303
+ "transformer.h.7.ln_1.weight": "model-00002-of-00006.safetensors",
304
+ "transformer.h.7.mlp.fc_in.bias": "model-00002-of-00006.safetensors",
305
+ "transformer.h.7.mlp.fc_in.weight": "model-00002-of-00006.safetensors",
306
+ "transformer.h.7.mlp.fc_out.bias": "model-00002-of-00006.safetensors",
307
+ "transformer.h.7.mlp.fc_out.weight": "model-00002-of-00006.safetensors",
308
+ "transformer.h.8.attn.bias": "model-00002-of-00006.safetensors",
309
+ "transformer.h.8.attn.masked_bias": "model-00002-of-00006.safetensors",
310
+ "transformer.h.8.attn.out_proj.weight": "model-00002-of-00006.safetensors",
311
+ "transformer.h.8.attn.qkv_proj.weight": "model-00002-of-00006.safetensors",
312
+ "transformer.h.8.ln_1.bias": "model-00002-of-00006.safetensors",
313
+ "transformer.h.8.ln_1.weight": "model-00002-of-00006.safetensors",
314
+ "transformer.h.8.mlp.fc_in.bias": "model-00002-of-00006.safetensors",
315
+ "transformer.h.8.mlp.fc_in.weight": "model-00002-of-00006.safetensors",
316
+ "transformer.h.8.mlp.fc_out.bias": "model-00002-of-00006.safetensors",
317
+ "transformer.h.8.mlp.fc_out.weight": "model-00002-of-00006.safetensors",
318
+ "transformer.h.9.attn.bias": "model-00002-of-00006.safetensors",
319
+ "transformer.h.9.attn.masked_bias": "model-00002-of-00006.safetensors",
320
+ "transformer.h.9.attn.out_proj.weight": "model-00002-of-00006.safetensors",
321
+ "transformer.h.9.attn.qkv_proj.weight": "model-00002-of-00006.safetensors",
322
+ "transformer.h.9.ln_1.bias": "model-00002-of-00006.safetensors",
323
+ "transformer.h.9.ln_1.weight": "model-00002-of-00006.safetensors",
324
+ "transformer.h.9.mlp.fc_in.bias": "model-00002-of-00006.safetensors",
325
+ "transformer.h.9.mlp.fc_in.weight": "model-00002-of-00006.safetensors",
326
+ "transformer.h.9.mlp.fc_out.bias": "model-00002-of-00006.safetensors",
327
+ "transformer.h.9.mlp.fc_out.weight": "model-00002-of-00006.safetensors",
328
+ "transformer.ln_f.bias": "model-00006-of-00006.safetensors",
329
+ "transformer.ln_f.weight": "model-00006-of-00006.safetensors",
330
+ "transformer.structure.attn_pool.mlp.0.bias": "model-00006-of-00006.safetensors",
331
+ "transformer.structure.attn_pool.mlp.0.weight": "model-00006-of-00006.safetensors",
332
+ "transformer.structure.attn_pool.mlp.1.bias": "model-00006-of-00006.safetensors",
333
+ "transformer.structure.attn_pool.mlp.1.weight": "model-00006-of-00006.safetensors",
334
+ "transformer.structure.attn_pool.mlp.3.bias": "model-00006-of-00006.safetensors",
335
+ "transformer.structure.attn_pool.mlp.3.weight": "model-00006-of-00006.safetensors",
336
+ "transformer.wte.weight": "model-00001-of-00006.safetensors"
337
+ }
338
+ }
modeling_InstructProGen.py ADDED
@@ -0,0 +1,847 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+
5
+ import torch
6
+ import torch.utils.checkpoint
7
+ from torch import nn
8
+ from torch.nn import CrossEntropyLoss
9
+
10
+ from transformers.activations import ACT2FN
11
+ from transformers.generation.configuration_utils import GenerationConfig
12
+ from transformers.generation.logits_process import LogitsProcessorList
13
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
14
+ from transformers.generation.streamers import BaseStreamer
15
+ from transformers.generation.utils import GenerateOutput
16
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
17
+ from transformers.modeling_utils import PreTrainedModel
18
+ from transformers.utils import logging
19
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
20
+ from .configuration_progen import ProGenConfig
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ from .structure import StructureTransformer
25
+ # from .structure_ligand import StructureTransformer
26
+
27
+
28
+ import math
29
+
30
+ # Inverse dim formula to find dim based on number of rotations
31
+ def _yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
32
+ return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base))
33
+
34
+ # Find dim range bounds based on rotations
35
+ def _yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
36
+ low = math.floor(_yarn_find_correction_dim(
37
+ low_rot, dim, base, max_position_embeddings))
38
+ high = math.ceil(_yarn_find_correction_dim(
39
+ high_rot, dim, base, max_position_embeddings))
40
+ return max(low, 0), min(high, dim-1) # Clamp values just in case
41
+
42
+ def _yarn_linear_ramp_mask(min, max, dim):
43
+ if min == max:
44
+ max += 0.001 # Prevent singularity
45
+
46
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
47
+ ramp_func = torch.clamp(linear_func, 0, 1)
48
+ return ramp_func
49
+
50
+ def _yarn_get_mscale(scale=1):
51
+ if scale <= 1:
52
+ return 1.0
53
+ return 0.1 * math.log(scale) + 1.0
54
+
55
+ def yarn(x, seq_len=None, seq_dim=1, base=10000, scale=4, original_max_position_embeddings=1024, extrapolation_factor=1, attn_factor=1, beta_fast=4, beta_slow=0.5):
56
+ dim = x.shape[-1]
57
+ pos_freqs = base ** (torch.arange(0, dim, 2).float().to(x.device) / dim)
58
+ inv_freq_extrapolation = 1.0 / pos_freqs
59
+ inv_freq_interpolation = 1.0 / (scale * pos_freqs)
60
+
61
+ low, high = _yarn_find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)
62
+
63
+ inv_freq_mask = (1 - _yarn_linear_ramp_mask(low, high, dim // 2).float().to(x.device)) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
64
+ inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
65
+
66
+ mscale = float(_yarn_get_mscale(scale) * attn_factor) # Get n-d magnitude scaling corrected for interpolation
67
+
68
+ # Build here to make `torch.jit.trace` work.
69
+ if seq_len is None:
70
+ seq_len = x.shape[seq_dim]
71
+ t = torch.arange(seq_len, device=x.device, dtype=x.dtype)
72
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
73
+ # torch.save(freqs, 'yarn_freq.pt')
74
+ return torch.sin(freqs)*mscale, torch.cos(freqs)*mscale
75
+
76
+ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
77
+ dim = x.shape[-1]
78
+ if seq_len is None:
79
+ seq_len = x.shape[seq_dim]
80
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
81
+ sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq).to(x.device).float()
82
+ # return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
83
+ return torch.sin(sinusoid_inp).to(x.dtype), torch.cos(sinusoid_inp).to(x.dtype)
84
+
85
+ def rotate_every_two(x):
86
+ x1 = x[:, :, :, ::2]
87
+ x2 = x[:, :, :, 1::2]
88
+ x = torch.stack((-x2, x1), axis=-1)
89
+ return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
90
+
91
+
92
+ def apply_rotary_pos_emb(x, sincos, offset=0, position_ids=None):
93
+ # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
94
+ if position_ids is not None:
95
+ sin, cos = map(lambda t: t[None, :, None, :].repeat_interleave(2, 3), sincos)
96
+ sin = sin.take_along_dim(position_ids.unsqueeze(-1).unsqueeze(-1), dim=1)
97
+ cos = cos.take_along_dim(position_ids.unsqueeze(-1).unsqueeze(-1), dim=1)
98
+ else:
99
+ sin, cos = map(lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(2, 3), sincos)
100
+ return (x * cos) + (rotate_every_two(x) * sin)
101
+
102
+
103
+ class ProGenAttention(nn.Module):
104
+ def __init__(self, config):
105
+ super().__init__()
106
+
107
+ max_positions = config.max_position_embeddings
108
+ self.yarn = config.yarn
109
+ self.register_buffer(
110
+ "bias",
111
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
112
+ 1, 1, max_positions, max_positions
113
+ ),
114
+ )
115
+ self.register_buffer("masked_bias", torch.tensor(-1e9))
116
+
117
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
118
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
119
+
120
+ self.embed_dim = config.hidden_size
121
+ self.num_attention_heads = config.num_attention_heads
122
+ self.head_dim = self.embed_dim // self.num_attention_heads
123
+ if self.head_dim * self.num_attention_heads != self.embed_dim:
124
+ raise ValueError(
125
+ f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})."
126
+ )
127
+ self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
128
+ self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)
129
+
130
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
131
+ self.rotary_dim = None
132
+ if config.rotary_dim is not None:
133
+ self.rotary_dim = config.rotary_dim
134
+
135
+ def _split_heads(self, x, n_head, dim_head, mp_num):
136
+ reshaped = x.reshape(x.shape[:-1] + (n_head//mp_num, dim_head))
137
+ reshaped = reshaped.reshape(x.shape[:-2] + (-1, ) + reshaped.shape[-1:])
138
+ return reshaped
139
+
140
+ def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
141
+ """
142
+ Merges attn_head_size dim and num_attn_heads dim into n_ctx
143
+ """
144
+ if len(tensor.shape) == 5:
145
+ tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
146
+ elif len(tensor.shape) == 4:
147
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
148
+ else:
149
+ raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
150
+ new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
151
+ return tensor.view(new_shape)
152
+
153
+ def _attn(
154
+ self,
155
+ query,
156
+ key,
157
+ value,
158
+ attention_mask=None,
159
+ head_mask=None,
160
+ ):
161
+
162
+ # compute causal mask from causal mask buffer
163
+ query_length, key_length = query.size(-2), key.size(-2)
164
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
165
+
166
+ # Keep the attention weights computation in fp32 to avoid overflow issues
167
+ query = query.to(torch.float32)
168
+ key = key.to(torch.float32)
169
+
170
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
171
+
172
+ attn_weights = attn_weights / self.scale_attn
173
+ attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
174
+
175
+ if attention_mask is not None:
176
+ # Apply the attention mask
177
+ attn_weights = attn_weights + attention_mask
178
+
179
+ attn_weights = nn.Softmax(dim=-1)(attn_weights)
180
+ attn_weights = attn_weights.to(value.dtype)
181
+ attn_weights = self.attn_dropout(attn_weights)
182
+
183
+ # Mask heads if we want to
184
+ if head_mask is not None:
185
+ attn_weights = attn_weights * head_mask
186
+
187
+ attn_output = torch.matmul(attn_weights, value)
188
+
189
+ return attn_output, attn_weights
190
+
191
+ def forward(
192
+ self,
193
+ hidden_states,
194
+ attention_mask=None,
195
+ layer_past=None,
196
+ head_mask=None,
197
+ use_cache=False,
198
+ output_attentions=False,
199
+ position_ids= None
200
+ ):
201
+
202
+ qkv = self.qkv_proj(hidden_states)
203
+ # TODO(enijkamp): factor out number of logical TPU-v3/v4 cores or make forward pass agnostic
204
+ # mp_num = 4
205
+ mp_num = 8
206
+ qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
207
+
208
+ local_dim = self.head_dim * self.num_attention_heads // mp_num
209
+ query, value, key = torch.split(qkv_split, local_dim, dim=-1)
210
+ query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
211
+ key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)
212
+
213
+ value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
214
+ value = value.permute(0, 2, 1, 3)
215
+
216
+ seq_len = key.shape[1] * 2 # for position ids greater than current
217
+ offset = 0
218
+
219
+ if layer_past is not None:
220
+ offset = layer_past[0].shape[-2]
221
+ seq_len += offset
222
+
223
+ if self.rotary_dim is not None:
224
+ k_rot = key[:, :, :, : self.rotary_dim]
225
+ k_pass = key[:, :, :, self.rotary_dim :]
226
+
227
+ q_rot = query[:, :, :, : self.rotary_dim]
228
+ q_pass = query[:, :, :, self.rotary_dim :]
229
+ if self.yarn:
230
+ sincos = yarn(k_rot, seq_dim=1, seq_len=seq_len)
231
+ else:
232
+ sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
233
+ k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset, position_ids=position_ids)
234
+ q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset, position_ids=position_ids)
235
+
236
+ key = torch.cat([k_rot, k_pass], dim=-1)
237
+ query = torch.cat([q_rot, q_pass], dim=-1)
238
+ else:
239
+ if self.yarn:
240
+ sincos = yarn(k_rot, seq_dim=1, seq_len=seq_len)
241
+ else:
242
+ sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
243
+ key = apply_rotary_pos_emb(key, sincos, offset=offset, position_ids=position_ids)
244
+ query = apply_rotary_pos_emb(query, sincos, offset=offset, position_ids=position_ids)
245
+
246
+ key = key.permute(0, 2, 1, 3)
247
+ query = query.permute(0, 2, 1, 3)
248
+
249
+ if layer_past is not None:
250
+ past_key = layer_past[0]
251
+ past_value = layer_past[1]
252
+ key = torch.cat((past_key, key), dim=-2)
253
+ value = torch.cat((past_value, value), dim=-2)
254
+
255
+ if use_cache is True:
256
+ present = (key, value)
257
+ else:
258
+ present = None
259
+
260
+ # compute self-attention: V x Softmax(QK^T)
261
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
262
+
263
+ attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
264
+
265
+ attn_output = self.out_proj(attn_output)
266
+ attn_output = self.resid_dropout(attn_output)
267
+
268
+ outputs = (attn_output, present)
269
+ if output_attentions:
270
+ outputs += (attn_weights,)
271
+
272
+ return outputs # a, present, (attentions)
273
+
274
+
275
+ class ProGenMLP(nn.Module):
276
+ def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim
277
+ super().__init__()
278
+ embed_dim = config.n_embd
279
+
280
+ self.fc_in = nn.Linear(embed_dim, intermediate_size)
281
+ self.fc_out = nn.Linear(intermediate_size, embed_dim)
282
+
283
+ self.act = ACT2FN[config.activation_function]
284
+ self.dropout = nn.Dropout(config.resid_pdrop)
285
+
286
+ def forward(self, hidden_states):
287
+ hidden_states = self.fc_in(hidden_states)
288
+ hidden_states = self.act(hidden_states)
289
+ hidden_states = self.fc_out(hidden_states)
290
+ hidden_states = self.dropout(hidden_states)
291
+ return hidden_states
292
+
293
+
294
+ class ProGenBlock(nn.Module):
295
+ def __init__(self, config):
296
+ super().__init__()
297
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
298
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
299
+ self.attn = ProGenAttention(config)
300
+ self.mlp = ProGenMLP(inner_dim, config)
301
+
302
+ def forward(
303
+ self,
304
+ hidden_states,
305
+ layer_past=None,
306
+ attention_mask=None,
307
+ head_mask=None,
308
+ use_cache=False,
309
+ output_attentions=False,
310
+ position_ids=None,
311
+ ):
312
+ residual = hidden_states
313
+ hidden_states = self.ln_1(hidden_states)
314
+ attn_outputs = self.attn(
315
+ hidden_states,
316
+ layer_past=layer_past,
317
+ attention_mask=attention_mask,
318
+ head_mask=head_mask,
319
+ use_cache=use_cache,
320
+ output_attentions=output_attentions,
321
+ position_ids=position_ids
322
+ )
323
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
324
+ outputs = attn_outputs[1:]
325
+
326
+ feed_forward_hidden_states = self.mlp(hidden_states)
327
+ hidden_states = attn_output + feed_forward_hidden_states + residual
328
+
329
+ if use_cache:
330
+ outputs = (hidden_states,) + outputs
331
+ else:
332
+ outputs = (hidden_states,) + outputs[1:]
333
+
334
+ return outputs # hidden_states, present, (attentions)
335
+
336
+
337
+ class ProGenPreTrainedModel(PreTrainedModel):
338
+ """
339
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
340
+ models.
341
+ """
342
+
343
+ config_class = ProGenConfig
344
+ base_model_prefix = "transformer"
345
+ supports_gradient_checkpointing = True
346
+ is_parallelizable = True
347
+ _no_split_modules = ["ProGenBlock"]
348
+ def __init__(self, *inputs, **kwargs):
349
+ super().__init__(*inputs, **kwargs)
350
+
351
+ def _init_weights(self, module):
352
+ """Initialize the weights."""
353
+ if isinstance(module, (nn.Linear,)):
354
+ # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization
355
+ # cf https://github.com/pytorch/pytorch/pull/5617
356
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
357
+ if module.bias is not None:
358
+ module.bias.data.zero_()
359
+ elif isinstance(module, nn.Embedding):
360
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
361
+ if module.padding_idx is not None:
362
+ module.weight.data[module.padding_idx].zero_()
363
+ elif isinstance(module, nn.LayerNorm):
364
+ module.bias.data.zero_()
365
+ module.weight.data.fill_(1.0)
366
+
367
+ def _set_gradient_checkpointing(self, module, value=False):
368
+ if isinstance(module, ProGenModel):
369
+ module.gradient_checkpointing = value
370
+
371
+ class ProGenModel(ProGenPreTrainedModel):
372
+ def __init__(self, config):
373
+ super().__init__(config)
374
+
375
+ self.embed_dim = config.n_embd
376
+ self.vocab_size = config.vocab_size
377
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
378
+ self.drop = nn.Dropout(config.embd_pdrop)
379
+ self.h = nn.ModuleList([ProGenBlock(config) for _ in range(config.n_layer)])
380
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
381
+ self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)
382
+
383
+ self.gradient_checkpointing = False
384
+ self.structure = StructureTransformer(**config.structure)
385
+
386
+ self.init_weights()
387
+
388
+ # Model parallel
389
+ self.model_parallel = False
390
+ self.device_map = None
391
+
392
+
393
+ def parallelize(self, device_map=None):
394
+ # Check validity of device_map
395
+ self.device_map = (
396
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
397
+ )
398
+ assert_device_map(self.device_map, len(self.h))
399
+ self.model_parallel = True
400
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
401
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
402
+ self.wte = self.wte.to(self.first_device)
403
+ # Load onto devices
404
+ for k, v in self.device_map.items():
405
+ for block in v:
406
+ cuda_device = "cuda:" + str(k)
407
+ self.h[block] = self.h[block].to(cuda_device)
408
+ # ln_f to last
409
+ self.ln_f = self.ln_f.to(self.last_device)
410
+
411
+
412
+ def deparallelize(self):
413
+ self.model_parallel = False
414
+ self.device_map = None
415
+ self.first_device = "cpu"
416
+ self.last_device = "cpu"
417
+ self.wte = self.wte.to("cpu")
418
+ for index in range(len(self.h)):
419
+ self.h[index] = self.h[index].to("cpu")
420
+ self.ln_f = self.ln_f.to("cpu")
421
+ torch.cuda.empty_cache()
422
+
423
+ def get_input_embeddings(self):
424
+ return self.wte
425
+
426
+ def set_input_embeddings(self, new_embeddings):
427
+ self.wte = new_embeddings
428
+
429
+ def forward(
430
+ self,
431
+ input_ids=None,
432
+ past_key_values=None,
433
+ attention_mask=None,
434
+ token_type_ids=None,
435
+ position_ids=None,
436
+ head_mask=None,
437
+ inputs_embeds=None,
438
+ query_embeds=None,
439
+ use_cache=None,
440
+ output_attentions=None,
441
+ output_hidden_states=None,
442
+ return_dict=None,
443
+ ):
444
+ if past_key_values is None:
445
+ # structure encode will check if input_ids contains valid
446
+ # structure_embs: Tensor with size of batchsize * maxlen * width
447
+ # structure_mask: Tensor with size of batchsize * maxlen * 1
448
+ # input_ids: new ids without structure path (ascii code)
449
+ # clone for keeping original input_ids not change
450
+ (structure_embs, structure_mask), input_ids = self.structure.encode(input_ids.clone())
451
+
452
+ else:
453
+ structure_embs = None
454
+
455
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
456
+ output_hidden_states = (
457
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
458
+ )
459
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
460
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
461
+
462
+ if input_ids is not None and inputs_embeds is not None:
463
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
464
+ elif input_ids is not None:
465
+ input_shape = input_ids.size()
466
+ input_ids = input_ids.view(-1, input_shape[-1])
467
+ batch_size = input_ids.shape[0]
468
+ elif inputs_embeds is not None:
469
+ input_shape = inputs_embeds.size()[:-1]
470
+ batch_size = inputs_embeds.shape[0]
471
+ else:
472
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
473
+
474
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
475
+
476
+ # if token_type_ids is not None:
477
+ # token_type_ids = token_type_ids.view(-1, input_shape[-1])
478
+
479
+ if past_key_values is None:
480
+ past_length = 0
481
+ past_key_values = tuple([None] * len(self.h))
482
+ else:
483
+ past_length = past_key_values[0][0].size(-2)
484
+
485
+
486
+
487
+ # Attention mask.
488
+ if attention_mask is not None:
489
+ assert batch_size > 0, "batch_size has to be defined and > 0"
490
+ attention_mask = attention_mask.view(batch_size, -1)
491
+ # We create a 3D attention mask from a 2D tensor mask.
492
+ # Sizes are [batch_size, 1, 1, to_seq_length]
493
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
494
+ # this attention mask is more simple than the triangular masking of causal attention
495
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
496
+ attention_mask = attention_mask[:, None, None, :]
497
+
498
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
499
+ # masked positions, this operation will create a tensor which is 0.0 for
500
+ # positions we want to attend and -10000.0 for masked positions.
501
+ # Since we are adding it to the raw scores before the softmax, this is
502
+ # effectively the same as removing these entirely.
503
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
504
+ attention_mask = (1.0 - attention_mask) * -10000.0
505
+
506
+ # Prepare head mask if needed
507
+ # 1.0 in head_mask indicate we keep the head
508
+ # attention_probs has shape bsz x num_attention_heads x N x N
509
+ # head_mask has shape n_layer x batch x num_attention_heads x N x N
510
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
511
+
512
+ if inputs_embeds is None:
513
+ inputs_embeds = self.wte(input_ids.int())
514
+
515
+ if query_embeds is not None:
516
+ inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
517
+ input_shape = inputs_embeds.size()[:-1]
518
+
519
+ if structure_embs is not None:
520
+ try:
521
+ inputs_embeds_mask = (input_ids == 1).unsqueeze(-1) # bos to as structure placeholder
522
+ inputs_embeds = inputs_embeds.to(structure_embs.dtype) # change dtype manually, autocast won't change it.
523
+ inputs_embeds = inputs_embeds.masked_scatter(inputs_embeds_mask, structure_embs.masked_select(structure_mask.unsqueeze(-1)))
524
+ input_shape = inputs_embeds.size()[:-1]
525
+ except:
526
+ torch.save(input_ids, f'input_ids_{inputs_embeds.device}.pt')
527
+ torch.save(inputs_embeds, f'inputs_embeds_{inputs_embeds.device}.pt')
528
+ torch.save(structure_embs, f'structure_embs_{inputs_embeds.device}.pt')
529
+
530
+
531
+ if position_ids is not None:
532
+ position_ids = position_ids.view(-1, input_shape[-1])
533
+
534
+ if position_ids is None:
535
+ if any(attention_mask[:, ..., 0] < 0): # padding left
536
+ position_ids = ((attention_mask >= 0).cumsum(dim=-1).squeeze() - 1).clamp(min=0)
537
+ position_ids = position_ids[:, past_length:past_length + input_shape[-1]]
538
+ else:
539
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
540
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
541
+
542
+ hidden_states = inputs_embeds
543
+
544
+ # disable token_type_ids
545
+ # if token_type_ids is not None:
546
+ # token_type_embeds = self.wte(token_type_ids)
547
+ # hidden_states = hidden_states + token_type_embeds
548
+
549
+ hidden_states = self.drop(hidden_states)
550
+
551
+ output_shape = input_shape + (hidden_states.size(-1),)
552
+
553
+ presents = () if use_cache else None
554
+ all_self_attentions = () if output_attentions else None
555
+ all_hidden_states = () if output_hidden_states else None
556
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
557
+
558
+ # Model parallel
559
+ if self.model_parallel:
560
+ torch.cuda.set_device(hidden_states.device)
561
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
562
+ if layer_past is not None:
563
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
564
+ # Ensure that attention_mask is always on the same device as hidden_states
565
+ if attention_mask is not None:
566
+ attention_mask = attention_mask.to(hidden_states.device)
567
+ if isinstance(head_mask, torch.Tensor):
568
+ head_mask = head_mask.to(hidden_states.device)
569
+ if output_hidden_states:
570
+ all_hidden_states = all_hidden_states + (hidden_states,)
571
+
572
+ if self.gradient_checkpointing and self.training:
573
+
574
+ if use_cache:
575
+ # logger.warning(
576
+ # "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
577
+ # "`use_cache=False`..."
578
+ # )
579
+ use_cache = False
580
+
581
+ def create_custom_forward(module):
582
+ def custom_forward(*inputs):
583
+ # None for past_key_value
584
+ return module(*inputs, use_cache, output_attentions)
585
+
586
+ return custom_forward
587
+
588
+ outputs = torch.utils.checkpoint.checkpoint(
589
+ create_custom_forward(block),
590
+ hidden_states,
591
+ None,
592
+ attention_mask,
593
+ head_mask[i],
594
+ )
595
+ else:
596
+ outputs = block(
597
+ hidden_states,
598
+ layer_past=layer_past,
599
+ attention_mask=attention_mask,
600
+ head_mask=head_mask[i],
601
+ use_cache=use_cache,
602
+ output_attentions=output_attentions,
603
+ position_ids=position_ids
604
+ )
605
+
606
+ hidden_states = outputs[0]
607
+ if use_cache is True:
608
+ presents = presents + (outputs[1],)
609
+
610
+ if output_attentions:
611
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
612
+
613
+ # Model Parallel: If it's the last layer for that device, put things on the next device
614
+ if self.model_parallel:
615
+ for k, v in self.device_map.items():
616
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
617
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
618
+
619
+ hidden_states = self.ln_f(hidden_states)
620
+
621
+ hidden_states = hidden_states.view(*output_shape)
622
+ # Add last hidden state
623
+ if output_hidden_states:
624
+ all_hidden_states = all_hidden_states + (hidden_states,)
625
+
626
+ if not return_dict:
627
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
628
+
629
+ return BaseModelOutputWithPast(
630
+ last_hidden_state=hidden_states,
631
+ past_key_values=presents,
632
+ hidden_states=all_hidden_states,
633
+ attentions=all_self_attentions,
634
+ )
635
+
636
+
637
+ class ProGenForCausalLM(ProGenPreTrainedModel):
638
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]
639
+
640
+ def __init__(self, config):
641
+ super().__init__(config)
642
+ self.transformer = ProGenModel(config)
643
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
644
+ self.init_weights()
645
+
646
+ # Model parallel
647
+ self.model_parallel = False
648
+ self.device_map = None
649
+
650
+ def parallelize(self, device_map=None):
651
+ self.device_map = (
652
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
653
+ if device_map is None
654
+ else device_map
655
+ )
656
+ assert_device_map(self.device_map, len(self.transformer.h))
657
+ self.transformer.parallelize(self.device_map)
658
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
659
+ self.model_parallel = True
660
+
661
+ def deparallelize(self):
662
+ self.transformer.deparallelize()
663
+ self.transformer = self.transformer.to("cpu")
664
+ self.lm_head = self.lm_head.to("cpu")
665
+ self.model_parallel = False
666
+ torch.cuda.empty_cache()
667
+
668
+ def get_output_embeddings(self):
669
+ return self.lm_head
670
+
671
+ def set_output_embeddings(self, new_embeddings):
672
+ self.lm_head = new_embeddings
673
+
674
+ def _update_model_kwargs_for_generation(self, *args, **kwargs):
675
+ model_kwargs = super()._update_model_kwargs_for_generation(*args, **kwargs)
676
+
677
+ position_ids = model_kwargs.get("position_ids", None)
678
+ attention_mask = model_kwargs.get("attention_mask", None)
679
+ use_cache = model_kwargs.get("use_cache", False)
680
+
681
+ if attention_mask is not None and position_ids is not None:
682
+ # create position_ids on the fly for batch generation
683
+ position_ids = position_ids
684
+
685
+ # if past_key_values:
686
+ last_position = position_ids[:, -1:] # all position in a batch should be the sanme
687
+ current_position = last_position + 1
688
+ if not use_cache:
689
+ position_ids = torch.concat([position_ids, current_position], dim=-1)
690
+ else:
691
+ position_ids = current_position
692
+
693
+ else:
694
+ position_ids = None
695
+ model_kwargs['position_ids'] = position_ids
696
+ return model_kwargs
697
+
698
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
699
+ token_type_ids = kwargs.get("token_type_ids", None)
700
+ use_cache = kwargs.get("use_cache", False)
701
+ # only last token for inputs_ids if past is defined in kwargs
702
+ if past_key_values:
703
+ input_ids = input_ids[:, -1].unsqueeze(-1)
704
+ if token_type_ids is not None:
705
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
706
+
707
+ attention_mask = kwargs.get("attention_mask", None)
708
+ position_ids = kwargs.get("position_ids", None)
709
+
710
+ # if attention_mask is not None and position_ids is not None:
711
+ # # create position_ids on the fly for batch generation
712
+ # position_ids = position_ids
713
+ # if input_ids[0, -1] != 3:
714
+ # # if past_key_values:
715
+ # last_position = position_ids[:, -1:] # all position in a batch should be the sanme
716
+ # current_position = last_position + 1
717
+ # if not use_cache:
718
+ # position_ids = torch.concat([position_ids, current_position], dim=-1)
719
+ # else:
720
+ # position_ids = current_position
721
+
722
+ # else:
723
+ # position_ids = None
724
+ return {
725
+ "input_ids": input_ids,
726
+ "past_key_values": past_key_values,
727
+ "use_cache": use_cache,
728
+ "position_ids": position_ids,
729
+ "attention_mask": attention_mask,
730
+ "token_type_ids": token_type_ids,
731
+ }
732
+
733
+ # def prepare_inputs_for_generation(
734
+ # self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
735
+ # ):
736
+ # if past_key_values:
737
+ # input_ids = input_ids[:, -1:]
738
+
739
+ # position_ids = kwargs.get("position_ids")
740
+ # if position_ids is not None:
741
+ # last_position = position_ids[:, -1:] # all position in a batch should be the sanme
742
+ # current_position = last_position + 1
743
+ # position_ids = torch.concat([position_ids, current_position], dim=-1)
744
+
745
+ # # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
746
+ # if inputs_embeds is not None and past_key_values is None:
747
+ # model_inputs = {"inputs_embeds": inputs_embeds}
748
+ # else:
749
+ # model_inputs = {"input_ids": input_ids}
750
+
751
+ # model_inputs.update(
752
+ # {
753
+ # "past_key_values": past_key_values,
754
+ # "use_cache": kwargs.get("use_cache"),
755
+ # "attention_mask": attention_mask,
756
+ # "position_ids": position_ids
757
+ # }
758
+ # )
759
+ # return model_inputs
760
+
761
+ def forward(
762
+ self,
763
+ input_ids=None,
764
+ past_key_values=None,
765
+ attention_mask=None,
766
+ token_type_ids=None,
767
+ position_ids=None,
768
+ head_mask=None,
769
+ inputs_embeds=None,
770
+ labels=None,
771
+ use_cache=None,
772
+ query_embeds = None,
773
+ output_attentions=None,
774
+ output_hidden_states=None,
775
+ return_dict=None,
776
+ ):
777
+ r"""
778
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
779
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
780
+ ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
781
+ ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
782
+ """
783
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
784
+
785
+ transformer_outputs = self.transformer(
786
+ input_ids,
787
+ past_key_values=past_key_values,
788
+ attention_mask=attention_mask,
789
+ token_type_ids=token_type_ids,
790
+ position_ids=position_ids,
791
+ head_mask=head_mask,
792
+ inputs_embeds=inputs_embeds,
793
+ query_embeds=query_embeds,
794
+ use_cache=use_cache,
795
+ output_attentions=output_attentions,
796
+ output_hidden_states=output_hidden_states,
797
+ return_dict=return_dict,
798
+ )
799
+ hidden_states = transformer_outputs[0]
800
+
801
+ # Set device for model parallelism
802
+ if self.model_parallel:
803
+ torch.cuda.set_device(self.transformer.first_device)
804
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
805
+
806
+ # make sure sampling in fp16 works correctly and
807
+ # compute loss in fp32 to match with mesh-tf version
808
+ # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
809
+ lm_logits = self.lm_head(hidden_states).to(torch.float32)
810
+
811
+ loss = None
812
+ if labels is not None:
813
+ # Shift so that tokens < n predict n
814
+ shift_logits = lm_logits[..., :-1, :].contiguous()
815
+ shift_labels = labels[..., 1:].contiguous()
816
+ # Flatten the tokens
817
+ loss_fct = CrossEntropyLoss()
818
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
819
+
820
+ loss = loss.to(hidden_states.dtype)
821
+
822
+ if not return_dict:
823
+ output = (lm_logits,) + transformer_outputs[1:]
824
+ return ((loss,) + output) if loss is not None else output
825
+
826
+ return CausalLMOutputWithPast(
827
+ loss=loss,
828
+ logits=lm_logits,
829
+ past_key_values=transformer_outputs.past_key_values,
830
+ hidden_states=transformer_outputs.hidden_states,
831
+ attentions=transformer_outputs.attentions,
832
+ )
833
+
834
+ @staticmethod
835
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
836
+ """
837
+ This function is used to re-order the :obj:`past_key_values` cache if
838
+ :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
839
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
840
+ """
841
+ return tuple(
842
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
843
+ for layer_past in past
844
+ )
845
+
846
+ # def generate(self, inputs: Tensor | None = None, generation_config: GenerationConfig | None = None, logits_processor: LogitsProcessorList | None = None, stopping_criteria: StoppingCriteriaList | None = None, prefix_allowed_tokens_fn: Callable[[int, Tensor], List[int]] | None = None, synced_gpus: bool | None = None, assistant_model: PreTrainedModel | None = None, streamer: BaseStreamer | None = None, negative_prompt_ids: Tensor | None = None, negative_prompt_attention_mask: Tensor | None = None, **kwargs) -> GenerateOutput | LongTensor:
847
+ # return super().generate(inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
structure.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import OrderedDict
7
+ import math
8
+ import requests
9
+ from io import BytesIO
10
+ from functools import partial
11
+ import pickle
12
+ from typing import Callable, Optional, Sequence, Tuple, List
13
+ import numpy as np
14
+ import os
15
+ import torch
16
+ from torch import nn
17
+ from torch.nn import functional as F
18
+ from torch.nn.init import trunc_normal_
19
+ from torchvision import transforms
20
+ from torchvision.transforms import InterpolationMode
21
+
22
+ class GLU(nn.Module):
23
+ def __init__(self,hidden_size):
24
+ super().__init__()
25
+ self.linear_proj = nn.Linear(hidden_size,hidden_size,bias=False)
26
+ self.norm1 = nn.LayerNorm(hidden_size)
27
+ self.act1 = nn.GELU()
28
+ self.act2 = nn.functional.silu
29
+ self.dense_h_to_4h = nn.Linear(hidden_size,hidden_size*4,bias=False)
30
+ self.gate_proj = nn.Linear(hidden_size,hidden_size*4,bias=False)
31
+ self.dense_4h_to_h = nn.Linear(hidden_size*4,hidden_size,bias=False)
32
+
33
+ def forward(self,x):
34
+ x = self.linear_proj(x)
35
+ x = self.act1(self.norm1(x))
36
+ x = self.act2(self.gate_proj(x))*self.dense_h_to_4h(x)
37
+ x = self.dense_4h_to_h(x)
38
+ return x
39
+ def swiglu(x):
40
+ x = torch.chunk(x, 2, dim=-1)
41
+ return nn.functional.silu(x[0]) * x[1]
42
+
43
+ class GLU_new(nn.Module):
44
+ def __init__(self,hidden_size, dropout=0.1):
45
+ super().__init__()
46
+ intermediate_size = int((4 * hidden_size * 2 / 3) / 64) * 64
47
+ intermediate_size = 1280
48
+
49
+ self.act = swiglu
50
+ self.dense_h_to_4h = nn.Linear(hidden_size, intermediate_size * 2, bias=False)
51
+ self.dense_4h_to_h = nn.Linear(intermediate_size, hidden_size, bias=False)
52
+ self.dropout = nn.Dropout(p=dropout)
53
+
54
+ def forward(self,x):
55
+ x = self.dense_h_to_4h(x)
56
+ x = self.act(x)
57
+ x = self.dense_4h_to_h(x)
58
+ x = self.dropout(x)
59
+ return x
60
+
61
+
62
+ n_queries = 32
63
+ def get_abs_pos(abs_pos, tgt_size):
64
+ # abs_pos: L, C
65
+ # tgt_size: M
66
+ # return: M, C
67
+ src_size = int(math.sqrt(abs_pos.size(0)))
68
+ tgt_size = int(math.sqrt(tgt_size))
69
+ dtype = abs_pos.dtype
70
+
71
+ if src_size != tgt_size:
72
+ return F.interpolate(
73
+ abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
74
+ size=(tgt_size, tgt_size),
75
+ mode="bicubic",
76
+ align_corners=False,
77
+ ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
78
+ else:
79
+ return abs_pos
80
+
81
+ from einops import rearrange, repeat
82
+
83
+ def get_1d_sincos_pos_embed(embed_dim, pos):
84
+ """
85
+ embed_dim: output dimension for each position
86
+ pos: a list of positions to be encoded: size (M,)
87
+ out: (M, D)
88
+ """
89
+ assert embed_dim % 2 == 0
90
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
91
+ omega /= embed_dim / 2.
92
+ omega = 1. / 10000**omega # (D/2,)
93
+
94
+ pos = pos.reshape(-1) # (M,)
95
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
96
+
97
+ emb_sin = np.sin(out) # (M, D/2)
98
+ emb_cos = np.cos(out) # (M, D/2)
99
+
100
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
101
+ return emb
102
+
103
+ class Resampler(nn.Module):
104
+ def __init__(
105
+ self,
106
+ kv_dim,
107
+ embed_dim,
108
+ num_heads=8,
109
+ n_queries=64,
110
+ max_seqlen=1024,
111
+ perceiver_resampler_positional_emb=True,
112
+ use_GLU=False,
113
+ bos_init=False,
114
+ dropout=0.0
115
+ ):
116
+ super().__init__()
117
+ self.perceiver_resampler_positional_emb = perceiver_resampler_positional_emb
118
+
119
+ if self.perceiver_resampler_positional_emb:
120
+ assert n_queries <= max_seqlen
121
+ self.stride = max_seqlen // n_queries
122
+ # self.nan_emb = nn.Parameter(torch.randn(1, kv_dim))
123
+ # nn.init.trunc_normal_(self.nan_emb, std=.02)
124
+ pos = np.arange(max_seqlen, dtype=np.float32)
125
+ self.register_buffer(
126
+ "pos_embed",
127
+ torch.from_numpy(get_1d_sincos_pos_embed(embed_dim, pos)).float()
128
+ )
129
+ self.latents = nn.Parameter(torch.randn(n_queries, embed_dim))
130
+ if bos_init:
131
+ self.latents.load('')
132
+ else:
133
+ nn.init.trunc_normal_(self.latents, std=1e-3)
134
+
135
+ self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
136
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, dropout=dropout)
137
+ self.ln_q = nn.LayerNorm(embed_dim)
138
+ self.ln_kv = nn.LayerNorm(embed_dim)
139
+ self.ln_post = nn.LayerNorm(embed_dim)
140
+ if use_GLU:
141
+ print('GLU *********************************')
142
+ self.proj = GLU_new(embed_dim, dropout=dropout)
143
+ else:
144
+ self.proj = nn.Linear(embed_dim, embed_dim, bias=False)
145
+
146
+ self.apply(self._init_weights)
147
+
148
+ def _init_weights(self, m):
149
+ if isinstance(m, nn.Linear):
150
+ nn.init.trunc_normal_(m.weight, std=1e-3)
151
+ if isinstance(m, nn.Linear) and m.bias is not None:
152
+ nn.init.constant_(m.bias, 0)
153
+ elif isinstance(m, nn.LayerNorm):
154
+ nn.init.constant_(m.bias, 0)
155
+ nn.init.constant_(m.weight, 1.0)
156
+
157
+ def forward(self, struc_x):
158
+ """
159
+ Args:
160
+ x (torch.Tensor): protein structure features
161
+ shape (B, L, C)
162
+ Returns:
163
+ shape (B, n, C) where n is self.num_latents
164
+ """
165
+ x = struc_x["encoder_out"]
166
+ mask = struc_x["encoder_padding_mask"]
167
+
168
+
169
+ nan_mask = torch.isnan(x)
170
+ if nan_mask.any():
171
+ x = x.masked_fill(nan_mask, 0.0)
172
+ # nan_mask = nan_mask.sum(dim=-1).bool()
173
+ # x[nan_mask] += self.nan_emb
174
+
175
+ x = self.kv_proj(x)
176
+ x = self.ln_kv(x)
177
+
178
+ b, seqlen = x.shape[:2]
179
+
180
+ latents = self.ln_q(self.latents)
181
+ if self.perceiver_resampler_positional_emb:
182
+ # TODO: interpolate
183
+ latents = latents + self.pos_embed[::self.stride].contiguous().to(latents.device)
184
+ pos_emb = self.pos_embed[:seqlen].unsqueeze(0).to(latents.device)
185
+ x = x + pos_emb.contiguous()
186
+
187
+ # blocks
188
+ latents = repeat(latents, "n d -> b n d", b=b)
189
+ out = self.attn(latents, x, x, key_padding_mask=~mask)[0]
190
+
191
+ out = self.ln_post(out)
192
+ out = self.proj(out)
193
+
194
+ return out
195
+
196
+ class mlp(nn.Module):
197
+ def __init__(self, width, output_dim, **kwargs):
198
+ super().__init__(**kwargs)
199
+
200
+ self.mlp = nn.Sequential(
201
+ nn.Linear(width, output_dim),
202
+ nn.LayerNorm(output_dim),
203
+ nn.GELU(),
204
+ nn.Linear(output_dim, output_dim)
205
+ )
206
+
207
+ def forward(self, struc_x):
208
+ x = struc_x["encoder_out"]
209
+ mask = struc_x["encoder_padding_mask"]
210
+ return self.mlp(x), mask
211
+
212
+ class StructureTransformer(nn.Module):
213
+
214
+ def __init__(
215
+ self,
216
+ width: int = 640,
217
+ n_queries: int = 32,
218
+ output_dim: int = 4096,
219
+ embedding_keys=set(["mpnn_emb"]),
220
+ max_seqlen: int=1024,
221
+ num_heads: int=8,
222
+ structure_emb_path_prefix='structure_emb',
223
+ projector='mlp',
224
+ **kwargs
225
+ ):
226
+ super().__init__()
227
+
228
+ self.structure_emb_path_prefix = structure_emb_path_prefix
229
+ # self.transformer = None # replace None with a pretrained strucure encoder
230
+ self.embedding_keys = embedding_keys
231
+ self.max_seqlen = max_seqlen
232
+ self.width = width
233
+ self.n_queries = n_queries
234
+ if projector == 'mlp':
235
+ self.attn_pool = mlp(
236
+ width=width,
237
+ output_dim=output_dim,
238
+ **kwargs
239
+ )
240
+ else:
241
+ self.attn_pool = Resampler(
242
+ embed_dim=output_dim,
243
+ kv_dim=width,
244
+ n_queries=n_queries,
245
+ max_seqlen=max_seqlen,
246
+ num_heads=num_heads,
247
+ **kwargs
248
+ )
249
+
250
+ def prepare_structure(self, sample):
251
+ emb_pad = torch.zeros((self.max_seqlen, self.width))
252
+ emb_mask = torch.zeros((self.max_seqlen), dtype=bool)
253
+
254
+
255
+ ### domians ###
256
+ emb = []
257
+ for ek in self.embedding_keys:
258
+ if ek in sample:
259
+ if isinstance( sample[ek], List):
260
+ emb.append(torch.cat(sample[ek]))
261
+ else:
262
+ emb.append(sample[ek].squeeze())
263
+ # emb = [sample[ek] for ek in self.embedding_keys if ek in sample]
264
+ emb = torch.cat(emb, dim=-1)
265
+
266
+ emb_pad[:len(emb)] = emb
267
+ emb_mask[:len(emb)] = 1
268
+ return emb_pad, emb_mask
269
+
270
+ def forward(self, x):
271
+
272
+ # x = self.transformer(x)
273
+ x = self.attn_pool(x)
274
+
275
+ return x
276
+
277
+ def encode(self, input_ids: List[torch.Tensor]):
278
+ structure_embs = []
279
+ structure_mask = []
280
+
281
+ for structure_path in input_ids:
282
+
283
+ if structure_path[0] == 1: # bos token for bypassing DPO trainer
284
+ structure_path[0] = 0
285
+ if structure_path[0] == 0: # left padding
286
+ structure_path = structure_path[structure_path > 0]
287
+
288
+ path_length = (structure_path>32).sum() # structure path should greater than 32 in ascii
289
+ structure_path = [chr(s) for s in structure_path[:path_length].int().tolist() if s > 0]
290
+
291
+ structure_path = os.path.join(self.structure_emb_path_prefix, ''.join(structure_path))
292
+
293
+ if not os.path.exists(structure_path):
294
+ print(structure_path)
295
+ print('no structure found')
296
+ return None
297
+
298
+ with open(structure_path, 'rb') as f:
299
+ structure, struc_mask = self.prepare_structure(pickle.load(f))
300
+
301
+
302
+ structure_embs.append(structure)
303
+ structure_mask.append(struc_mask)
304
+
305
+ input_ids[input_ids > 32] = 1 # change ascii code back to <|bos|>
306
+ structure_embs = torch.stack(structure_embs, dim=0).to(
307
+ device=next(self.attn_pool.parameters()).device,
308
+ dtype=next(self.attn_pool.parameters()).dtype)
309
+ structure_mask = torch.stack(structure_mask, dim=0).to(
310
+ device=next(self.attn_pool.parameters()).device)
311
+
312
+ return self({
313
+ 'encoder_out': structure_embs,
314
+ 'encoder_padding_mask': structure_mask
315
+ }), input_ids
tokenization_iPLM.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union
2
+ from transformers import PreTrainedTokenizerFast
3
+ from tokenizers.processors import TemplateProcessing
4
+ from tokenizers import Tokenizer
5
+ from transformers.tokenization_utils_base import BatchEncoding, EncodedInput, PreTokenizedInput, TextInput, TruncationStrategy
6
+ from transformers.utils import PaddingStrategy, TensorType
7
+ import torch
8
+ import numpy as np
9
+
10
+ def create_tokenizer_custom(file):
11
+ with open(file, 'r') as f:
12
+ return Tokenizer.from_str(f.read())
13
+
14
+
15
+ class iPLMTokenizer(PreTrainedTokenizerFast):
16
+ def __init__(self, parallel=False, **kwargs):
17
+ super().__init__(tokenizer_object=create_tokenizer_custom(kwargs.get('tokenizer_file')), **kwargs)
18
+ self.add_special_tokens({'pad_token': '<|pad|>'})
19
+ self.parallel = parallel
20
+ def __call__(
21
+ self,
22
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
23
+ n_queries = -1, # -1 for vary-length prompt, int with larger than 0 for fix-length, 0 for no prompt
24
+ text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
25
+ text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
26
+ text_pair_target: Optional[
27
+ Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
28
+ ] = None,
29
+ add_special_tokens: bool = True,
30
+ padding: Union[bool, str, PaddingStrategy] = False,
31
+ truncation: Union[bool, str, TruncationStrategy] = None,
32
+ max_length: Optional[int] = None,
33
+ stride: int = 0,
34
+ is_split_into_words: bool = False,
35
+ pad_to_multiple_of: Optional[int] = None,
36
+ return_tensors: Optional[Union[str, TensorType]] = None,
37
+ return_token_type_ids: Optional[bool] = None,
38
+ return_attention_mask: Optional[bool] = None,
39
+ return_overflowing_tokens: bool = False,
40
+ return_special_tokens_mask: bool = False,
41
+ return_offsets_mapping: bool = False,
42
+ return_length: bool = False,
43
+ verbose: bool = True,
44
+ **kwargs,
45
+ ) -> BatchEncoding:
46
+
47
+ if not isinstance(text, list):
48
+ text = [text]
49
+ batching = False
50
+ else:
51
+ batching = True
52
+
53
+ # add prompt
54
+ text_with_prompt = []
55
+ for t in text:
56
+ prompt_length = 0
57
+ assert '|' in t, 'prompt not found'
58
+
59
+ raw_text = t.split('|')[-1]
60
+
61
+ if n_queries > 0: # use fix length prompt
62
+ prompt_length = n_queries
63
+ elif n_queries < 0:
64
+ prompt_length = len(raw_text.replace('1', '').replace('2', ''))
65
+
66
+ text_with_prompt.append('<|bos|>' * prompt_length + raw_text)
67
+
68
+ batch = super().__call__(
69
+ text=text_with_prompt,
70
+ text_pair=text_pair,
71
+ text_target=text_target,
72
+ text_pair_target=text_pair_target,
73
+ add_special_tokens=add_special_tokens,
74
+ padding=padding,
75
+ truncation= truncation,
76
+ max_length=max_length,
77
+ stride=stride,
78
+ is_split_into_words=is_split_into_words,
79
+ pad_to_multiple_of=pad_to_multiple_of,
80
+ padding_side=None,
81
+ return_tensors=return_tensors,
82
+ return_token_type_ids=return_token_type_ids,
83
+ return_attention_mask=return_attention_mask,
84
+ return_overflowing_tokens=return_overflowing_tokens,
85
+ return_special_tokens_mask=return_special_tokens_mask,
86
+ return_offsets_mapping=return_offsets_mapping,
87
+ return_length=return_length,
88
+ verbose=verbose,
89
+ **kwargs
90
+ )
91
+
92
+ # add structure ids
93
+ for i in range(len(text)):
94
+ if '|' not in text[i]:
95
+ continue
96
+
97
+ structure_ids = text[i].split('|')[0]
98
+ if return_tensors is None:
99
+ for j in range(len(structure_ids)):
100
+ batch['input_ids'][i][j] = ord(structure_ids[j])
101
+ else:
102
+ batch['input_ids'][i, :len(structure_ids)] = torch.tensor([ord(c) for c in structure_ids])
103
+
104
+ if "token_type_ids" in batch:
105
+ del batch["token_type_ids"]
106
+
107
+ if batching:
108
+ return batch
109
+ else:
110
+ return {k:v[0] for k, v in batch.items()}
tokenizer.json ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0",
3
+ "truncation": null,
4
+ "padding": null,
5
+ "added_tokens": [
6
+ {
7
+ "id": 0,
8
+ "content": "<|pad|>",
9
+ "single_word": false,
10
+ "lstrip": false,
11
+ "rstrip": false,
12
+ "normalized": false,
13
+ "special": true
14
+ },
15
+ {
16
+ "id": 1,
17
+ "content": "<|bos|>",
18
+ "single_word": false,
19
+ "lstrip": false,
20
+ "rstrip": false,
21
+ "normalized": false,
22
+ "special": true
23
+ },
24
+ {
25
+ "id": 2,
26
+ "content": "<|eos|>",
27
+ "single_word": false,
28
+ "lstrip": false,
29
+ "rstrip": false,
30
+ "normalized": false,
31
+ "special": true
32
+ }
33
+ ],
34
+ "normalizer": null,
35
+ "pre_tokenizer": {
36
+ "type": "ByteLevel",
37
+ "add_prefix_space": false,
38
+ "trim_offsets": true,
39
+ "use_regex": true
40
+ },
41
+ "post_processor": {
42
+ "type": "ByteLevel",
43
+ "add_prefix_space": true,
44
+ "trim_offsets": true,
45
+ "use_regex": true
46
+ },
47
+ "decoder": {
48
+ "type": "ByteLevel",
49
+ "add_prefix_space": true,
50
+ "trim_offsets": true,
51
+ "use_regex": true
52
+ },
53
+ "model": {
54
+ "type": "BPE",
55
+ "dropout": null,
56
+ "unk_token": null,
57
+ "continuing_subword_prefix": null,
58
+ "end_of_word_suffix": null,
59
+ "fuse_unk": false,
60
+ "byte_fallback": false,
61
+ "vocab": {
62
+ "<|pad|>": 0,
63
+ "<|bos|>": 1,
64
+ "<|eos|>": 2,
65
+ "1": 3,
66
+ "2": 4,
67
+ "A": 5,
68
+ "B": 6,
69
+ "C": 7,
70
+ "D": 8,
71
+ "E": 9,
72
+ "F": 10,
73
+ "G": 11,
74
+ "H": 12,
75
+ "I": 13,
76
+ "K": 14,
77
+ "L": 15,
78
+ "M": 16,
79
+ "N": 17,
80
+ "O": 18,
81
+ "P": 19,
82
+ "Q": 20,
83
+ "R": 21,
84
+ "S": 22,
85
+ "T": 23,
86
+ "U": 24,
87
+ "V": 25,
88
+ "W": 26,
89
+ "X": 27,
90
+ "Y": 28,
91
+ "Z": 29
92
+ },
93
+ "merges": []
94
+ }
95
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_queries": 256,
3
+ "use_structure": true,
4
+ "tokenizer_class": "iPLMTokenizer",
5
+ "auto_map": {
6
+ "AutoTokenizer": [
7
+ "tokenization_iPLM.iPLMTokenizer",
8
+ null
9
+ ]
10
+ }
11
+ }