Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- ag4masses-public.ipynb +91 -60
- download.sh +3 -1
- lm_inference.py +189 -0
ag4masses-public.ipynb
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"metadata": {
|
7 |
"executionInfo": {
|
8 |
"elapsed": 611,
|
@@ -14,17 +14,16 @@
|
|
14 |
},
|
15 |
"user_tz": 300
|
16 |
},
|
17 |
-
"id": "-IHoHd-t5sLP"
|
18 |
-
"trusted": true
|
19 |
},
|
20 |
"outputs": [],
|
21 |
"source": [
|
22 |
"import sys, os\n",
|
23 |
"\n",
|
24 |
-
"AG4MDIR='/home/user/ag4masses'\n",
|
25 |
-
"AGLIB=f'
|
26 |
-
"AGDIR=f\"{
|
27 |
-
"MELIAD_PATH=f\"{
|
28 |
"DATA=f\"{AGLIB}/ag_ckpt_vocab\"\n",
|
29 |
"TESTDIR=f\"/data/ag4mtest\""
|
30 |
]
|
@@ -41,9 +40,7 @@
|
|
41 |
{
|
42 |
"cell_type": "code",
|
43 |
"execution_count": null,
|
44 |
-
"metadata": {
|
45 |
-
"trusted": true
|
46 |
-
},
|
47 |
"outputs": [],
|
48 |
"source": [
|
49 |
"# Run this cell to refresh code and get the latest versions\n",
|
@@ -65,8 +62,7 @@
|
|
65 |
},
|
66 |
"user_tz": 300
|
67 |
},
|
68 |
-
"id": "GgR_vO8XX9Vr"
|
69 |
-
"trusted": true
|
70 |
},
|
71 |
"outputs": [],
|
72 |
"source": [
|
@@ -99,8 +95,7 @@
|
|
99 |
"user_tz": 300
|
100 |
},
|
101 |
"id": "gP4zAZh2MHcv",
|
102 |
-
"outputId": "4796397b-8952-411e-bd33-8fd813865735"
|
103 |
-
"trusted": true
|
104 |
},
|
105 |
"outputs": [],
|
106 |
"source": [
|
@@ -147,8 +142,7 @@
|
|
147 |
"user_tz": 300
|
148 |
},
|
149 |
"id": "X8Aj3G0neT6K",
|
150 |
-
"outputId": "9538ceba-8065-44d6-a32f-35127e5f2575"
|
151 |
-
"trusted": true
|
152 |
},
|
153 |
"outputs": [],
|
154 |
"source": [
|
@@ -174,8 +168,7 @@
|
|
174 |
"user_tz": 300
|
175 |
},
|
176 |
"id": "u9fuBSr2qEwN",
|
177 |
-
"outputId": "97bbce78-8b49-4d3b-a831-d188a4a9e536"
|
178 |
-
"trusted": true
|
179 |
},
|
180 |
"outputs": [],
|
181 |
"source": [
|
@@ -190,9 +183,7 @@
|
|
190 |
{
|
191 |
"cell_type": "code",
|
192 |
"execution_count": null,
|
193 |
-
"metadata": {
|
194 |
-
"trusted": true
|
195 |
-
},
|
196 |
"outputs": [],
|
197 |
"source": [
|
198 |
"# Linux packages for Nvidia gpu.\n",
|
@@ -206,8 +197,7 @@
|
|
206 |
"cell_type": "code",
|
207 |
"execution_count": null,
|
208 |
"metadata": {
|
209 |
-
"id": "fChy49CNhf01"
|
210 |
-
"trusted": true
|
211 |
},
|
212 |
"outputs": [],
|
213 |
"source": [
|
@@ -216,6 +206,51 @@
|
|
216 |
"!nvidia-smi"
|
217 |
]
|
218 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
{
|
220 |
"cell_type": "markdown",
|
221 |
"metadata": {
|
@@ -227,10 +262,8 @@
|
|
227 |
},
|
228 |
{
|
229 |
"cell_type": "code",
|
230 |
-
"execution_count":
|
231 |
-
"metadata": {
|
232 |
-
"trusted": true
|
233 |
-
},
|
234 |
"outputs": [],
|
235 |
"source": [
|
236 |
"#!! cannot have ' in the script, including in comments\n",
|
@@ -301,6 +334,7 @@
|
|
301 |
"\n",
|
302 |
"true \"==========================================\"\n",
|
303 |
"\n",
|
|
|
304 |
"python -m alphageometry \\\n",
|
305 |
"--alsologtostderr \\\n",
|
306 |
"--problems_file=$PROB_FILE \\\n",
|
@@ -310,7 +344,7 @@
|
|
310 |
"\"${SEARCH_ARGS[@]}\" \\\n",
|
311 |
"\"${LM_ARGS[@]}\" \\\n",
|
312 |
"--out_file=$OUTFILE \\\n",
|
313 |
-
"--n_workers=$NWORKERS 2>&1\n",
|
314 |
"\n",
|
315 |
"'''"
|
316 |
]
|
@@ -318,10 +352,18 @@
|
|
318 |
{
|
319 |
"cell_type": "code",
|
320 |
"execution_count": null,
|
321 |
-
"metadata": {
|
322 |
-
|
323 |
-
|
324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
"source": [
|
326 |
"os.environ[\"TESTDIR\"]=TESTDIR\n",
|
327 |
"os.environ[\"AG4MDIR\"]=AG4MDIR\n",
|
@@ -341,16 +383,16 @@
|
|
341 |
"# NWORKERS=2\n",
|
342 |
"# CUDA_VISIBLE_DEVICES=0,1\n",
|
343 |
"\n",
|
344 |
-
"os.environ[\"BATCH_SIZE\"]=\"
|
345 |
-
"os.environ[\"BEAM_SIZE\"]=\"
|
346 |
-
"os.environ[\"DEPTH\"]=\"
|
347 |
"os.environ[\"NWORKERS\"]=\"2\"\n",
|
348 |
"\n",
|
349 |
-
"
|
350 |
"\n",
|
351 |
"# test problems can be uploaded into a dataset, e.g. for dataset \"tmpfiles\", \"/kaggle/input/tmpfiles/test-problems.txt\"\n",
|
352 |
-
"os.environ[\"
|
353 |
-
"PROB=\"imo-2024-q4\"\n",
|
354 |
"os.environ[\"PROB\"]=PROB\n",
|
355 |
"# alphageometry|ddar\n",
|
356 |
"os.environ[\"MODEL\"]=\"alphageometry\"\n",
|
@@ -358,18 +400,15 @@
|
|
358 |
"# In an interactive Kaggle session, run the job in background, so we can do other things in the Notebook.\n",
|
359 |
"# For long jobs, commit the Notebook and run in Batch mode.\n",
|
360 |
"# An interactive session will be terminated after about 20 minutes of idle time.\n",
|
361 |
-
"if os.environ[\"KAGGLE_KERNEL_RUN_TYPE\"]==\"Batch\":\n",
|
362 |
-
"
|
363 |
-
"
|
364 |
-
" os.system(f\"echo '{jobScript}'|bash &\")\n"
|
365 |
]
|
366 |
},
|
367 |
{
|
368 |
"cell_type": "code",
|
369 |
-
"execution_count":
|
370 |
-
"metadata": {
|
371 |
-
"trusted": true
|
372 |
-
},
|
373 |
"outputs": [],
|
374 |
"source": [
|
375 |
"#!cat /kaggle/input/tmpfiles/test-problems.txt"
|
@@ -378,9 +417,7 @@
|
|
378 |
{
|
379 |
"cell_type": "code",
|
380 |
"execution_count": null,
|
381 |
-
"metadata": {
|
382 |
-
"trusted": true
|
383 |
-
},
|
384 |
"outputs": [],
|
385 |
"source": [
|
386 |
"# In an interactive Kaggle session, run this to see the log file. We can cancel this cell's execution\n",
|
@@ -392,9 +429,7 @@
|
|
392 |
{
|
393 |
"cell_type": "code",
|
394 |
"execution_count": null,
|
395 |
-
"metadata": {
|
396 |
-
"trusted": true
|
397 |
-
},
|
398 |
"outputs": [],
|
399 |
"source": [
|
400 |
"# Command to kill the background job in an interactive session\n",
|
@@ -409,9 +444,7 @@
|
|
409 |
{
|
410 |
"cell_type": "code",
|
411 |
"execution_count": null,
|
412 |
-
"metadata": {
|
413 |
-
"trusted": true
|
414 |
-
},
|
415 |
"outputs": [],
|
416 |
"source": [
|
417 |
"# Command to check progress of a running job in an interactive session\n",
|
@@ -421,9 +454,7 @@
|
|
421 |
{
|
422 |
"cell_type": "code",
|
423 |
"execution_count": null,
|
424 |
-
"metadata": {
|
425 |
-
"trusted": true
|
426 |
-
},
|
427 |
"outputs": [],
|
428 |
"source": [
|
429 |
"# In Batch run, after the job completes, list output files\n",
|
@@ -451,7 +482,7 @@
|
|
451 |
"sourceType": "notebook"
|
452 |
},
|
453 |
"kernelspec": {
|
454 |
-
"display_name": "Python 3",
|
455 |
"language": "python",
|
456 |
"name": "python3"
|
457 |
},
|
@@ -465,7 +496,7 @@
|
|
465 |
"name": "python",
|
466 |
"nbconvert_exporter": "python",
|
467 |
"pygments_lexer": "ipython3",
|
468 |
-
"version": "3.10.
|
469 |
}
|
470 |
},
|
471 |
"nbformat": 4,
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": 10,
|
6 |
"metadata": {
|
7 |
"executionInfo": {
|
8 |
"elapsed": 611,
|
|
|
14 |
},
|
15 |
"user_tz": 300
|
16 |
},
|
17 |
+
"id": "-IHoHd-t5sLP"
|
|
|
18 |
},
|
19 |
"outputs": [],
|
20 |
"source": [
|
21 |
"import sys, os\n",
|
22 |
"\n",
|
23 |
+
"AG4MDIR='/home/user/app/aglib/ag4masses'\n",
|
24 |
+
"AGLIB=f'/home/user/app/aglib/'\n",
|
25 |
+
"AGDIR=f\"{AG4MDIR}/alphageometry\"\n",
|
26 |
+
"MELIAD_PATH=f\"{AGLIB}/meliad\"\n",
|
27 |
"DATA=f\"{AGLIB}/ag_ckpt_vocab\"\n",
|
28 |
"TESTDIR=f\"/data/ag4mtest\""
|
29 |
]
|
|
|
40 |
{
|
41 |
"cell_type": "code",
|
42 |
"execution_count": null,
|
43 |
+
"metadata": {},
|
|
|
|
|
44 |
"outputs": [],
|
45 |
"source": [
|
46 |
"# Run this cell to refresh code and get the latest versions\n",
|
|
|
62 |
},
|
63 |
"user_tz": 300
|
64 |
},
|
65 |
+
"id": "GgR_vO8XX9Vr"
|
|
|
66 |
},
|
67 |
"outputs": [],
|
68 |
"source": [
|
|
|
95 |
"user_tz": 300
|
96 |
},
|
97 |
"id": "gP4zAZh2MHcv",
|
98 |
+
"outputId": "4796397b-8952-411e-bd33-8fd813865735"
|
|
|
99 |
},
|
100 |
"outputs": [],
|
101 |
"source": [
|
|
|
142 |
"user_tz": 300
|
143 |
},
|
144 |
"id": "X8Aj3G0neT6K",
|
145 |
+
"outputId": "9538ceba-8065-44d6-a32f-35127e5f2575"
|
|
|
146 |
},
|
147 |
"outputs": [],
|
148 |
"source": [
|
|
|
168 |
"user_tz": 300
|
169 |
},
|
170 |
"id": "u9fuBSr2qEwN",
|
171 |
+
"outputId": "97bbce78-8b49-4d3b-a831-d188a4a9e536"
|
|
|
172 |
},
|
173 |
"outputs": [],
|
174 |
"source": [
|
|
|
183 |
{
|
184 |
"cell_type": "code",
|
185 |
"execution_count": null,
|
186 |
+
"metadata": {},
|
|
|
|
|
187 |
"outputs": [],
|
188 |
"source": [
|
189 |
"# Linux packages for Nvidia gpu.\n",
|
|
|
197 |
"cell_type": "code",
|
198 |
"execution_count": null,
|
199 |
"metadata": {
|
200 |
+
"id": "fChy49CNhf01"
|
|
|
201 |
},
|
202 |
"outputs": [],
|
203 |
"source": [
|
|
|
206 |
"!nvidia-smi"
|
207 |
]
|
208 |
},
|
209 |
+
{
|
210 |
+
"cell_type": "code",
|
211 |
+
"execution_count": null,
|
212 |
+
"metadata": {},
|
213 |
+
"outputs": [],
|
214 |
+
"source": []
|
215 |
+
},
|
216 |
+
{
|
217 |
+
"cell_type": "markdown",
|
218 |
+
"metadata": {},
|
219 |
+
"source": [
|
220 |
+
"# AlphaGeometry\n",
|
221 |
+
"由DeepMind开源的AlphaGeometry用于几何解题工具。\n",
|
222 |
+
"\n",
|
223 |
+
"## 一.使用方法\n",
|
224 |
+
"\n",
|
225 |
+
"### 1. 上传题目\n",
|
226 |
+
"\n",
|
227 |
+
"双击左侧problems.txt,在末尾换行后添加新的题目,格式见第二部分。该文件已经有部分例子\n",
|
228 |
+
"\n",
|
229 |
+
"### 2. 修改配置\n",
|
230 |
+
"\n",
|
231 |
+
"在下方代码块中直接修改PROB的值,修改为题目名称。\n",
|
232 |
+
"\n",
|
233 |
+
"### 3. 运行\n",
|
234 |
+
"\n",
|
235 |
+
"从上之下依次点击代码块左侧的运行按钮即可,或者点击上方的双箭头按钮运行全部代码块。\n",
|
236 |
+
"\n",
|
237 |
+
"### 4. 查看结果\n",
|
238 |
+
"\n",
|
239 |
+
"运行结束后,双击打开左侧的ag4mtest文件夹,双击打开`题目名.out`文件。\n",
|
240 |
+
"\n",
|
241 |
+
"## 二.题目格式\n"
|
242 |
+
]
|
243 |
+
},
|
244 |
+
{
|
245 |
+
"cell_type": "code",
|
246 |
+
"execution_count": 13,
|
247 |
+
"metadata": {},
|
248 |
+
"outputs": [],
|
249 |
+
"source": [
|
250 |
+
"\n",
|
251 |
+
"PROB='imo-2024-q4'\n"
|
252 |
+
]
|
253 |
+
},
|
254 |
{
|
255 |
"cell_type": "markdown",
|
256 |
"metadata": {
|
|
|
262 |
},
|
263 |
{
|
264 |
"cell_type": "code",
|
265 |
+
"execution_count": 15,
|
266 |
+
"metadata": {},
|
|
|
|
|
267 |
"outputs": [],
|
268 |
"source": [
|
269 |
"#!! cannot have ' in the script, including in comments\n",
|
|
|
334 |
"\n",
|
335 |
"true \"==========================================\"\n",
|
336 |
"\n",
|
337 |
+
"cd $AG4MDIR\n",
|
338 |
"python -m alphageometry \\\n",
|
339 |
"--alsologtostderr \\\n",
|
340 |
"--problems_file=$PROB_FILE \\\n",
|
|
|
344 |
"\"${SEARCH_ARGS[@]}\" \\\n",
|
345 |
"\"${LM_ARGS[@]}\" \\\n",
|
346 |
"--out_file=$OUTFILE \\\n",
|
347 |
+
"--n_workers=$NWORKERS # 2>&1\n",
|
348 |
"\n",
|
349 |
"'''"
|
350 |
]
|
|
|
352 |
{
|
353 |
"cell_type": "code",
|
354 |
"execution_count": null,
|
355 |
+
"metadata": {},
|
356 |
+
"outputs": [
|
357 |
+
{
|
358 |
+
"name": "stderr",
|
359 |
+
"output_type": "stream",
|
360 |
+
"text": [
|
361 |
+
"+ OUTFILE=/data/ag4mtest/imo-2024-q4.out\n",
|
362 |
+
"+ ERRFILE=/data/ag4mtest/imo-2024-q4.log\n",
|
363 |
+
"+ exec\n"
|
364 |
+
]
|
365 |
+
}
|
366 |
+
],
|
367 |
"source": [
|
368 |
"os.environ[\"TESTDIR\"]=TESTDIR\n",
|
369 |
"os.environ[\"AG4MDIR\"]=AG4MDIR\n",
|
|
|
383 |
"# NWORKERS=2\n",
|
384 |
"# CUDA_VISIBLE_DEVICES=0,1\n",
|
385 |
"\n",
|
386 |
+
"os.environ[\"BATCH_SIZE\"]=\"2\"\n",
|
387 |
+
"os.environ[\"BEAM_SIZE\"]=\"2\"\n",
|
388 |
+
"os.environ[\"DEPTH\"]=\"2\"\n",
|
389 |
"os.environ[\"NWORKERS\"]=\"2\"\n",
|
390 |
"\n",
|
391 |
+
"# o# s.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0,1\"\n",
|
392 |
"\n",
|
393 |
"# test problems can be uploaded into a dataset, e.g. for dataset \"tmpfiles\", \"/kaggle/input/tmpfiles/test-problems.txt\"\n",
|
394 |
+
"os.environ[\"PROBFILE\"]=\"/data/problems.txt\"\n",
|
395 |
+
"# PROB=\"imo-2024-q4\"\n",
|
396 |
"os.environ[\"PROB\"]=PROB\n",
|
397 |
"# alphageometry|ddar\n",
|
398 |
"os.environ[\"MODEL\"]=\"alphageometry\"\n",
|
|
|
400 |
"# In an interactive Kaggle session, run the job in background, so we can do other things in the Notebook.\n",
|
401 |
"# For long jobs, commit the Notebook and run in Batch mode.\n",
|
402 |
"# An interactive session will be terminated after about 20 minutes of idle time.\n",
|
403 |
+
"# if os.environ[\"KAGGLE_KERNEL_RUN_TYPE\"]==\"Batch\":\n",
|
404 |
+
"os.system(f\"echo '{jobScript}'|bash\")\n",
|
405 |
+
"\n"
|
|
|
406 |
]
|
407 |
},
|
408 |
{
|
409 |
"cell_type": "code",
|
410 |
+
"execution_count": 6,
|
411 |
+
"metadata": {},
|
|
|
|
|
412 |
"outputs": [],
|
413 |
"source": [
|
414 |
"#!cat /kaggle/input/tmpfiles/test-problems.txt"
|
|
|
417 |
{
|
418 |
"cell_type": "code",
|
419 |
"execution_count": null,
|
420 |
+
"metadata": {},
|
|
|
|
|
421 |
"outputs": [],
|
422 |
"source": [
|
423 |
"# In an interactive Kaggle session, run this to see the log file. We can cancel this cell's execution\n",
|
|
|
429 |
{
|
430 |
"cell_type": "code",
|
431 |
"execution_count": null,
|
432 |
+
"metadata": {},
|
|
|
|
|
433 |
"outputs": [],
|
434 |
"source": [
|
435 |
"# Command to kill the background job in an interactive session\n",
|
|
|
444 |
{
|
445 |
"cell_type": "code",
|
446 |
"execution_count": null,
|
447 |
+
"metadata": {},
|
|
|
|
|
448 |
"outputs": [],
|
449 |
"source": [
|
450 |
"# Command to check progress of a running job in an interactive session\n",
|
|
|
454 |
{
|
455 |
"cell_type": "code",
|
456 |
"execution_count": null,
|
457 |
+
"metadata": {},
|
|
|
|
|
458 |
"outputs": [],
|
459 |
"source": [
|
460 |
"# In Batch run, after the job completes, list output files\n",
|
|
|
482 |
"sourceType": "notebook"
|
483 |
},
|
484 |
"kernelspec": {
|
485 |
+
"display_name": "Python 3 (ipykernel)",
|
486 |
"language": "python",
|
487 |
"name": "python3"
|
488 |
},
|
|
|
496 |
"name": "python",
|
497 |
"nbconvert_exporter": "python",
|
498 |
"pygments_lexer": "ipython3",
|
499 |
+
"version": "3.10.13"
|
500 |
}
|
501 |
},
|
502 |
"nbformat": 4,
|
download.sh
CHANGED
@@ -7,7 +7,9 @@ git clone https://github.com/tpgh24/ag4masses.git
|
|
7 |
pip cache purge
|
8 |
pip install --upgrade pip
|
9 |
pip install --upgrade packaging setuptools setuptools_scm wheel
|
|
|
10 |
pip install --require-hashes --no-deps -r /home/user/app/aglib/ag4masses/alphageometry/requirements.txt
|
|
|
11 |
|
12 |
# cd alphageometry
|
13 |
git clone https://github.com/google-research/meliad.git
|
@@ -21,5 +23,5 @@ export DATA=ag_ckpt_vocab
|
|
21 |
cd /home/user/app
|
22 |
# some patch for cpu
|
23 |
cp models.py /home/user/app/aglib/ag4masses/alphageometry/models.py
|
24 |
-
cp
|
25 |
cp ag4masses-public.ipynb /data/ag4masses-public.ipynb
|
|
|
7 |
pip cache purge
|
8 |
pip install --upgrade pip
|
9 |
pip install --upgrade packaging setuptools setuptools_scm wheel
|
10 |
+
# pip install typing_extensions==4.6.0
|
11 |
pip install --require-hashes --no-deps -r /home/user/app/aglib/ag4masses/alphageometry/requirements.txt
|
12 |
+
pip install typing_extensions==4.6.0
|
13 |
|
14 |
# cd alphageometry
|
15 |
git clone https://github.com/google-research/meliad.git
|
|
|
23 |
cd /home/user/app
|
24 |
# some patch for cpu
|
25 |
cp models.py /home/user/app/aglib/ag4masses/alphageometry/models.py
|
26 |
+
cp lm_inference.py /home/user/app/aglib/ag4masses/alphageometry/lm_inference.py
|
27 |
cp ag4masses-public.ipynb /data/ag4masses-public.ipynb
|
lm_inference.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Wrapper for language modeling inference implemented in Meliad."""
|
17 |
+
from typing import Any, Dict
|
18 |
+
|
19 |
+
import jax
|
20 |
+
import models # pylint: disable=unused-import
|
21 |
+
import t5.data
|
22 |
+
from transformer import inference_utils
|
23 |
+
|
24 |
+
|
25 |
+
np = jax.numpy
|
26 |
+
|
27 |
+
|
28 |
+
Trainer = inference_utils.Trainer
|
29 |
+
|
30 |
+
MetricsOutput = Dict[str, Any] # Metrics output by model.
|
31 |
+
|
32 |
+
|
33 |
+
parse_gin_configuration = inference_utils.parse_gin_configuration
|
34 |
+
|
35 |
+
|
36 |
+
class LanguageModelInference:
|
37 |
+
"""Meliad wrapper for LM inference."""
|
38 |
+
|
39 |
+
def __init__(self, vocab_path: str, load_dir: str, mode='beam_search'):
|
40 |
+
self.vocab = t5.data.SentencePieceVocabulary(vocab_path)
|
41 |
+
|
42 |
+
# This task won't be pulling from a dataset.
|
43 |
+
def null_iter_fn() -> None:
|
44 |
+
return None
|
45 |
+
|
46 |
+
process_summaries_f = inference_utils.models.process_summaries_function(
|
47 |
+
self.vocab
|
48 |
+
)
|
49 |
+
|
50 |
+
trainer = inference_utils.training_loop.Trainer(
|
51 |
+
get_training_dataset_iterator=null_iter_fn,
|
52 |
+
get_test_dataset_iterator=None,
|
53 |
+
pretty_print_input_function=None,
|
54 |
+
process_summaries_function=process_summaries_f,
|
55 |
+
load_dir=load_dir,
|
56 |
+
workdir='', # Don't log or save checkpoints.
|
57 |
+
replicate_mode=False,
|
58 |
+
) # Run on a single device at batch size 1.
|
59 |
+
self.trainer = trainer
|
60 |
+
|
61 |
+
# Create and initialize the model.
|
62 |
+
(tstate, _, imodel, prngs) = trainer.initialize_model()
|
63 |
+
self.imodel = imodel
|
64 |
+
self.batch_size = imodel.task_config.batch_size
|
65 |
+
|
66 |
+
self.n = imodel.num_heads
|
67 |
+
self.h = imodel.head_size
|
68 |
+
|
69 |
+
# Create an inference task.
|
70 |
+
writers = {}
|
71 |
+
self.task = trainer.create_training_task(mode, imodel, prngs, writers) # pylint: disable=too-many-function-args
|
72 |
+
|
73 |
+
# Register any additional actions.
|
74 |
+
# Actions are cleared first for use with colab.
|
75 |
+
inference_utils.training_loop.clear_interstep_callbacks()
|
76 |
+
inference_utils.training_loop.register_interstep_callbacks()
|
77 |
+
self.tstate = tstate
|
78 |
+
|
79 |
+
# some default parameters.
|
80 |
+
eos = [0] * 1024
|
81 |
+
for idx in self.encode_list(['.', ';']):
|
82 |
+
eos[idx] = 1
|
83 |
+
|
84 |
+
self.eos = np.array(eos, dtype=np.float32)
|
85 |
+
self.mask = jax.numpy.ones([1024], dtype=np.float32)
|
86 |
+
|
87 |
+
def decode(self, ids: list[int]) -> str:
|
88 |
+
return self.vocab.decode(ids)
|
89 |
+
|
90 |
+
def decode_list(self, tokens: list[int]) -> list[str]:
|
91 |
+
return [self.decode([tok]) for tok in tokens]
|
92 |
+
|
93 |
+
def encode(self, inputs_str: str) -> list[int]:
|
94 |
+
return self.vocab.encode(inputs_str)
|
95 |
+
|
96 |
+
def encode_list(self, inputs_strs: list[str]) -> list[int]:
|
97 |
+
result = [self.vocab.encode(x) for x in inputs_strs]
|
98 |
+
assert all([len(x) == 1 for x in result]), [
|
99 |
+
self.decode(x) for x in result if len(x) != 1
|
100 |
+
]
|
101 |
+
return [x[0] for x in result]
|
102 |
+
|
103 |
+
def call(
|
104 |
+
self,
|
105 |
+
inputs: np.ndarray,
|
106 |
+
dstate: tuple[dict[str, np.ndarray], ...] = None,
|
107 |
+
eos: np.ndarray = None,
|
108 |
+
mask: np.ndarray = None,
|
109 |
+
) -> MetricsOutput:
|
110 |
+
"""Call the meliad model."""
|
111 |
+
batch_size, length = inputs.shape
|
112 |
+
inputs = jax.numpy.pad(inputs, [(0, 0), (0, 1024 - length)])
|
113 |
+
|
114 |
+
if eos is None:
|
115 |
+
eos = self.eos
|
116 |
+
if mask is None:
|
117 |
+
mask = self.mask
|
118 |
+
|
119 |
+
x = {'targets': inputs, 'length': length, 'eos': eos, 'mask': mask}
|
120 |
+
|
121 |
+
if dstate is not None:
|
122 |
+
x['start_of_sequence'] = jax.numpy.array([False] * batch_size)
|
123 |
+
else:
|
124 |
+
dstate = tuple(
|
125 |
+
[{ # this dummy value will never be used.
|
126 |
+
'current_index': np.array([0] * batch_size, dtype=np.int32),
|
127 |
+
'keys': np.zeros(
|
128 |
+
(batch_size, 2048, self.n, self.h), dtype=np.float32
|
129 |
+
),
|
130 |
+
'values': np.zeros(
|
131 |
+
(batch_size, 2048, self.n, self.h), dtype=np.float32
|
132 |
+
),
|
133 |
+
'recurrent_kvq': None,
|
134 |
+
'relative_position_bias': np.zeros(
|
135 |
+
(batch_size, self.n, 1, 1024), dtype=np.float32
|
136 |
+
),
|
137 |
+
}]
|
138 |
+
* 12
|
139 |
+
)
|
140 |
+
x['start_of_sequence'] = jax.numpy.array([True] * batch_size)
|
141 |
+
|
142 |
+
x['dstate'] = dstate
|
143 |
+
_, metrics_np = self.task.run_step(self.tstate, x, 0)
|
144 |
+
return metrics_np
|
145 |
+
|
146 |
+
def beam_decode(
|
147 |
+
self,
|
148 |
+
inputs: str,
|
149 |
+
eos_tokens: np.ndarray = None,
|
150 |
+
mask_tokens: np.ndarray = None,
|
151 |
+
dstate: dict[str, np.ndarray] = None,
|
152 |
+
) -> MetricsOutput:
|
153 |
+
"""Beam search."""
|
154 |
+
inputs = jax.numpy.array([self.vocab.encode(inputs)] * self.batch_size)
|
155 |
+
|
156 |
+
eos = self.eos
|
157 |
+
if eos_tokens is not None:
|
158 |
+
eos_ids = self.encode_list(eos_tokens)
|
159 |
+
eos = np.array(
|
160 |
+
[1 if idx in eos_ids else 0 for idx in range(1024)], dtype=np.float32
|
161 |
+
).reshape((1, 1, 1024))
|
162 |
+
|
163 |
+
mask = self.mask
|
164 |
+
if mask_tokens is not None:
|
165 |
+
mask_ids = self.encode_list(mask_tokens)
|
166 |
+
mask = np.array(
|
167 |
+
[0 if idx in mask_ids else 1 for idx in range(1024)],
|
168 |
+
dtype=np.float32,
|
169 |
+
).reshape((1, 1, 1024))
|
170 |
+
|
171 |
+
metrics_np = self.call(inputs, dstate=dstate, eos=eos, mask=mask)
|
172 |
+
|
173 |
+
finished_seqs = metrics_np['finished_seqs']
|
174 |
+
finished_scores = metrics_np['finished_scores']
|
175 |
+
|
176 |
+
seqs = []
|
177 |
+
scores = []
|
178 |
+
for seq, score in zip(finished_seqs, finished_scores):
|
179 |
+
seq = self.decode(seq[1:])
|
180 |
+
seqs.append(seq)
|
181 |
+
scores.append(score)
|
182 |
+
|
183 |
+
return {
|
184 |
+
'finished_seqs': finished_seqs,
|
185 |
+
'finished_scores': finished_scores,
|
186 |
+
'seqs_str': seqs,
|
187 |
+
'scores': scores,
|
188 |
+
'dstate': metrics_np['dstate'],
|
189 |
+
}
|