huolongguo10 commited on
Commit
c3a0684
·
verified ·
1 Parent(s): f7f3e51

Upload 5 files

Browse files
Files changed (3) hide show
  1. ag4masses-public.ipynb +91 -60
  2. download.sh +3 -1
  3. lm_inference.py +189 -0
ag4masses-public.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": null,
6
  "metadata": {
7
  "executionInfo": {
8
  "elapsed": 611,
@@ -14,17 +14,16 @@
14
  },
15
  "user_tz": 300
16
  },
17
- "id": "-IHoHd-t5sLP",
18
- "trusted": true
19
  },
20
  "outputs": [],
21
  "source": [
22
  "import sys, os\n",
23
  "\n",
24
- "AG4MDIR='/home/user/ag4masses'\n",
25
- "AGLIB=f'{AG4MDIR}/aglib'\n",
26
- "AGDIR=f\"{AGLIB}/alphageometry\"\n",
27
- "MELIAD_PATH=f\"{AGDIR}/meliad\"\n",
28
  "DATA=f\"{AGLIB}/ag_ckpt_vocab\"\n",
29
  "TESTDIR=f\"/data/ag4mtest\""
30
  ]
@@ -41,9 +40,7 @@
41
  {
42
  "cell_type": "code",
43
  "execution_count": null,
44
- "metadata": {
45
- "trusted": true
46
- },
47
  "outputs": [],
48
  "source": [
49
  "# Run this cell to refresh code and get the latest versions\n",
@@ -65,8 +62,7 @@
65
  },
66
  "user_tz": 300
67
  },
68
- "id": "GgR_vO8XX9Vr",
69
- "trusted": true
70
  },
71
  "outputs": [],
72
  "source": [
@@ -99,8 +95,7 @@
99
  "user_tz": 300
100
  },
101
  "id": "gP4zAZh2MHcv",
102
- "outputId": "4796397b-8952-411e-bd33-8fd813865735",
103
- "trusted": true
104
  },
105
  "outputs": [],
106
  "source": [
@@ -147,8 +142,7 @@
147
  "user_tz": 300
148
  },
149
  "id": "X8Aj3G0neT6K",
150
- "outputId": "9538ceba-8065-44d6-a32f-35127e5f2575",
151
- "trusted": true
152
  },
153
  "outputs": [],
154
  "source": [
@@ -174,8 +168,7 @@
174
  "user_tz": 300
175
  },
176
  "id": "u9fuBSr2qEwN",
177
- "outputId": "97bbce78-8b49-4d3b-a831-d188a4a9e536",
178
- "trusted": true
179
  },
180
  "outputs": [],
181
  "source": [
@@ -190,9 +183,7 @@
190
  {
191
  "cell_type": "code",
192
  "execution_count": null,
193
- "metadata": {
194
- "trusted": true
195
- },
196
  "outputs": [],
197
  "source": [
198
  "# Linux packages for Nvidia gpu.\n",
@@ -206,8 +197,7 @@
206
  "cell_type": "code",
207
  "execution_count": null,
208
  "metadata": {
209
- "id": "fChy49CNhf01",
210
- "trusted": true
211
  },
212
  "outputs": [],
213
  "source": [
@@ -216,6 +206,51 @@
216
  "!nvidia-smi"
217
  ]
218
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  {
220
  "cell_type": "markdown",
221
  "metadata": {
@@ -227,10 +262,8 @@
227
  },
228
  {
229
  "cell_type": "code",
230
- "execution_count": null,
231
- "metadata": {
232
- "trusted": true
233
- },
234
  "outputs": [],
235
  "source": [
236
  "#!! cannot have ' in the script, including in comments\n",
@@ -301,6 +334,7 @@
301
  "\n",
302
  "true \"==========================================\"\n",
303
  "\n",
 
304
  "python -m alphageometry \\\n",
305
  "--alsologtostderr \\\n",
306
  "--problems_file=$PROB_FILE \\\n",
@@ -310,7 +344,7 @@
310
  "\"${SEARCH_ARGS[@]}\" \\\n",
311
  "\"${LM_ARGS[@]}\" \\\n",
312
  "--out_file=$OUTFILE \\\n",
313
- "--n_workers=$NWORKERS 2>&1\n",
314
  "\n",
315
  "'''"
316
  ]
@@ -318,10 +352,18 @@
318
  {
319
  "cell_type": "code",
320
  "execution_count": null,
321
- "metadata": {
322
- "trusted": true
323
- },
324
- "outputs": [],
 
 
 
 
 
 
 
 
325
  "source": [
326
  "os.environ[\"TESTDIR\"]=TESTDIR\n",
327
  "os.environ[\"AG4MDIR\"]=AG4MDIR\n",
@@ -341,16 +383,16 @@
341
  "# NWORKERS=2\n",
342
  "# CUDA_VISIBLE_DEVICES=0,1\n",
343
  "\n",
344
- "os.environ[\"BATCH_SIZE\"]=\"16\"\n",
345
- "os.environ[\"BEAM_SIZE\"]=\"64\"\n",
346
- "os.environ[\"DEPTH\"]=\"8\"\n",
347
  "os.environ[\"NWORKERS\"]=\"2\"\n",
348
  "\n",
349
- "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0,1\"\n",
350
  "\n",
351
  "# test problems can be uploaded into a dataset, e.g. for dataset \"tmpfiles\", \"/kaggle/input/tmpfiles/test-problems.txt\"\n",
352
- "os.environ[\"PROB_FILE\"]=f\"{AG4MDIR}/data/ag4m_problems.txt\"\n",
353
- "PROB=\"imo-2024-q4\"\n",
354
  "os.environ[\"PROB\"]=PROB\n",
355
  "# alphageometry|ddar\n",
356
  "os.environ[\"MODEL\"]=\"alphageometry\"\n",
@@ -358,18 +400,15 @@
358
  "# In an interactive Kaggle session, run the job in background, so we can do other things in the Notebook.\n",
359
  "# For long jobs, commit the Notebook and run in Batch mode.\n",
360
  "# An interactive session will be terminated after about 20 minutes of idle time.\n",
361
- "if os.environ[\"KAGGLE_KERNEL_RUN_TYPE\"]==\"Batch\":\n",
362
- " os.system(f\"echo '{jobScript}'|bash\")\n",
363
- "else:\n",
364
- " os.system(f\"echo '{jobScript}'|bash &\")\n"
365
  ]
366
  },
367
  {
368
  "cell_type": "code",
369
- "execution_count": null,
370
- "metadata": {
371
- "trusted": true
372
- },
373
  "outputs": [],
374
  "source": [
375
  "#!cat /kaggle/input/tmpfiles/test-problems.txt"
@@ -378,9 +417,7 @@
378
  {
379
  "cell_type": "code",
380
  "execution_count": null,
381
- "metadata": {
382
- "trusted": true
383
- },
384
  "outputs": [],
385
  "source": [
386
  "# In an interactive Kaggle session, run this to see the log file. We can cancel this cell's execution\n",
@@ -392,9 +429,7 @@
392
  {
393
  "cell_type": "code",
394
  "execution_count": null,
395
- "metadata": {
396
- "trusted": true
397
- },
398
  "outputs": [],
399
  "source": [
400
  "# Command to kill the background job in an interactive session\n",
@@ -409,9 +444,7 @@
409
  {
410
  "cell_type": "code",
411
  "execution_count": null,
412
- "metadata": {
413
- "trusted": true
414
- },
415
  "outputs": [],
416
  "source": [
417
  "# Command to check progress of a running job in an interactive session\n",
@@ -421,9 +454,7 @@
421
  {
422
  "cell_type": "code",
423
  "execution_count": null,
424
- "metadata": {
425
- "trusted": true
426
- },
427
  "outputs": [],
428
  "source": [
429
  "# In Batch run, after the job completes, list output files\n",
@@ -451,7 +482,7 @@
451
  "sourceType": "notebook"
452
  },
453
  "kernelspec": {
454
- "display_name": "Python 3",
455
  "language": "python",
456
  "name": "python3"
457
  },
@@ -465,7 +496,7 @@
465
  "name": "python",
466
  "nbconvert_exporter": "python",
467
  "pygments_lexer": "ipython3",
468
- "version": "3.10.12"
469
  }
470
  },
471
  "nbformat": 4,
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 10,
6
  "metadata": {
7
  "executionInfo": {
8
  "elapsed": 611,
 
14
  },
15
  "user_tz": 300
16
  },
17
+ "id": "-IHoHd-t5sLP"
 
18
  },
19
  "outputs": [],
20
  "source": [
21
  "import sys, os\n",
22
  "\n",
23
+ "AG4MDIR='/home/user/app/aglib/ag4masses'\n",
24
+ "AGLIB=f'/home/user/app/aglib/'\n",
25
+ "AGDIR=f\"{AG4MDIR}/alphageometry\"\n",
26
+ "MELIAD_PATH=f\"{AGLIB}/meliad\"\n",
27
  "DATA=f\"{AGLIB}/ag_ckpt_vocab\"\n",
28
  "TESTDIR=f\"/data/ag4mtest\""
29
  ]
 
40
  {
41
  "cell_type": "code",
42
  "execution_count": null,
43
+ "metadata": {},
 
 
44
  "outputs": [],
45
  "source": [
46
  "# Run this cell to refresh code and get the latest versions\n",
 
62
  },
63
  "user_tz": 300
64
  },
65
+ "id": "GgR_vO8XX9Vr"
 
66
  },
67
  "outputs": [],
68
  "source": [
 
95
  "user_tz": 300
96
  },
97
  "id": "gP4zAZh2MHcv",
98
+ "outputId": "4796397b-8952-411e-bd33-8fd813865735"
 
99
  },
100
  "outputs": [],
101
  "source": [
 
142
  "user_tz": 300
143
  },
144
  "id": "X8Aj3G0neT6K",
145
+ "outputId": "9538ceba-8065-44d6-a32f-35127e5f2575"
 
146
  },
147
  "outputs": [],
148
  "source": [
 
168
  "user_tz": 300
169
  },
170
  "id": "u9fuBSr2qEwN",
171
+ "outputId": "97bbce78-8b49-4d3b-a831-d188a4a9e536"
 
172
  },
173
  "outputs": [],
174
  "source": [
 
183
  {
184
  "cell_type": "code",
185
  "execution_count": null,
186
+ "metadata": {},
 
 
187
  "outputs": [],
188
  "source": [
189
  "# Linux packages for Nvidia gpu.\n",
 
197
  "cell_type": "code",
198
  "execution_count": null,
199
  "metadata": {
200
+ "id": "fChy49CNhf01"
 
201
  },
202
  "outputs": [],
203
  "source": [
 
206
  "!nvidia-smi"
207
  ]
208
  },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": null,
212
+ "metadata": {},
213
+ "outputs": [],
214
+ "source": []
215
+ },
216
+ {
217
+ "cell_type": "markdown",
218
+ "metadata": {},
219
+ "source": [
220
+ "# AlphaGeometry\n",
221
+ "由DeepMind开源的AlphaGeometry用于几何解题工具。\n",
222
+ "\n",
223
+ "## 一.使用方法\n",
224
+ "\n",
225
+ "### 1. 上传题目\n",
226
+ "\n",
227
+ "双击左侧problems.txt,在末尾换行后添加新的题目,格式见第二部分。该文件已经有部分例子\n",
228
+ "\n",
229
+ "### 2. 修改配置\n",
230
+ "\n",
231
+ "在下方代码块中直接修改PROB的值,修改为题目名称。\n",
232
+ "\n",
233
+ "### 3. 运行\n",
234
+ "\n",
235
+ "从上之下依次点击代码块左侧的运行按钮即可,或者点击上方的双箭头按钮运行全部代码块。\n",
236
+ "\n",
237
+ "### 4. 查看结果\n",
238
+ "\n",
239
+ "运行结束后,双击打开左侧的ag4mtest文件夹,双击打开`题目名.out`文件。\n",
240
+ "\n",
241
+ "## 二.题目格式\n"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": 13,
247
+ "metadata": {},
248
+ "outputs": [],
249
+ "source": [
250
+ "\n",
251
+ "PROB='imo-2024-q4'\n"
252
+ ]
253
+ },
254
  {
255
  "cell_type": "markdown",
256
  "metadata": {
 
262
  },
263
  {
264
  "cell_type": "code",
265
+ "execution_count": 15,
266
+ "metadata": {},
 
 
267
  "outputs": [],
268
  "source": [
269
  "#!! cannot have ' in the script, including in comments\n",
 
334
  "\n",
335
  "true \"==========================================\"\n",
336
  "\n",
337
+ "cd $AG4MDIR\n",
338
  "python -m alphageometry \\\n",
339
  "--alsologtostderr \\\n",
340
  "--problems_file=$PROB_FILE \\\n",
 
344
  "\"${SEARCH_ARGS[@]}\" \\\n",
345
  "\"${LM_ARGS[@]}\" \\\n",
346
  "--out_file=$OUTFILE \\\n",
347
+ "--n_workers=$NWORKERS # 2>&1\n",
348
  "\n",
349
  "'''"
350
  ]
 
352
  {
353
  "cell_type": "code",
354
  "execution_count": null,
355
+ "metadata": {},
356
+ "outputs": [
357
+ {
358
+ "name": "stderr",
359
+ "output_type": "stream",
360
+ "text": [
361
+ "+ OUTFILE=/data/ag4mtest/imo-2024-q4.out\n",
362
+ "+ ERRFILE=/data/ag4mtest/imo-2024-q4.log\n",
363
+ "+ exec\n"
364
+ ]
365
+ }
366
+ ],
367
  "source": [
368
  "os.environ[\"TESTDIR\"]=TESTDIR\n",
369
  "os.environ[\"AG4MDIR\"]=AG4MDIR\n",
 
383
  "# NWORKERS=2\n",
384
  "# CUDA_VISIBLE_DEVICES=0,1\n",
385
  "\n",
386
+ "os.environ[\"BATCH_SIZE\"]=\"2\"\n",
387
+ "os.environ[\"BEAM_SIZE\"]=\"2\"\n",
388
+ "os.environ[\"DEPTH\"]=\"2\"\n",
389
  "os.environ[\"NWORKERS\"]=\"2\"\n",
390
  "\n",
391
+ "# o# s.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0,1\"\n",
392
  "\n",
393
  "# test problems can be uploaded into a dataset, e.g. for dataset \"tmpfiles\", \"/kaggle/input/tmpfiles/test-problems.txt\"\n",
394
+ "os.environ[\"PROBFILE\"]=\"/data/problems.txt\"\n",
395
+ "# PROB=\"imo-2024-q4\"\n",
396
  "os.environ[\"PROB\"]=PROB\n",
397
  "# alphageometry|ddar\n",
398
  "os.environ[\"MODEL\"]=\"alphageometry\"\n",
 
400
  "# In an interactive Kaggle session, run the job in background, so we can do other things in the Notebook.\n",
401
  "# For long jobs, commit the Notebook and run in Batch mode.\n",
402
  "# An interactive session will be terminated after about 20 minutes of idle time.\n",
403
+ "# if os.environ[\"KAGGLE_KERNEL_RUN_TYPE\"]==\"Batch\":\n",
404
+ "os.system(f\"echo '{jobScript}'|bash\")\n",
405
+ "\n"
 
406
  ]
407
  },
408
  {
409
  "cell_type": "code",
410
+ "execution_count": 6,
411
+ "metadata": {},
 
 
412
  "outputs": [],
413
  "source": [
414
  "#!cat /kaggle/input/tmpfiles/test-problems.txt"
 
417
  {
418
  "cell_type": "code",
419
  "execution_count": null,
420
+ "metadata": {},
 
 
421
  "outputs": [],
422
  "source": [
423
  "# In an interactive Kaggle session, run this to see the log file. We can cancel this cell's execution\n",
 
429
  {
430
  "cell_type": "code",
431
  "execution_count": null,
432
+ "metadata": {},
 
 
433
  "outputs": [],
434
  "source": [
435
  "# Command to kill the background job in an interactive session\n",
 
444
  {
445
  "cell_type": "code",
446
  "execution_count": null,
447
+ "metadata": {},
 
 
448
  "outputs": [],
449
  "source": [
450
  "# Command to check progress of a running job in an interactive session\n",
 
454
  {
455
  "cell_type": "code",
456
  "execution_count": null,
457
+ "metadata": {},
 
 
458
  "outputs": [],
459
  "source": [
460
  "# In Batch run, after the job completes, list output files\n",
 
482
  "sourceType": "notebook"
483
  },
484
  "kernelspec": {
485
+ "display_name": "Python 3 (ipykernel)",
486
  "language": "python",
487
  "name": "python3"
488
  },
 
496
  "name": "python",
497
  "nbconvert_exporter": "python",
498
  "pygments_lexer": "ipython3",
499
+ "version": "3.10.13"
500
  }
501
  },
502
  "nbformat": 4,
download.sh CHANGED
@@ -7,7 +7,9 @@ git clone https://github.com/tpgh24/ag4masses.git
7
  pip cache purge
8
  pip install --upgrade pip
9
  pip install --upgrade packaging setuptools setuptools_scm wheel
 
10
  pip install --require-hashes --no-deps -r /home/user/app/aglib/ag4masses/alphageometry/requirements.txt
 
11
 
12
  # cd alphageometry
13
  git clone https://github.com/google-research/meliad.git
@@ -21,5 +23,5 @@ export DATA=ag_ckpt_vocab
21
  cd /home/user/app
22
  # some patch for cpu
23
  cp models.py /home/user/app/aglib/ag4masses/alphageometry/models.py
24
- cp alphageometry.py /home/user/app/aglib/ag4masses/alphageometry/alphageometry.py
25
  cp ag4masses-public.ipynb /data/ag4masses-public.ipynb
 
7
  pip cache purge
8
  pip install --upgrade pip
9
  pip install --upgrade packaging setuptools setuptools_scm wheel
10
+ # pip install typing_extensions==4.6.0
11
  pip install --require-hashes --no-deps -r /home/user/app/aglib/ag4masses/alphageometry/requirements.txt
12
+ pip install typing_extensions==4.6.0
13
 
14
  # cd alphageometry
15
  git clone https://github.com/google-research/meliad.git
 
23
  cd /home/user/app
24
  # some patch for cpu
25
  cp models.py /home/user/app/aglib/ag4masses/alphageometry/models.py
26
+ cp lm_inference.py /home/user/app/aglib/ag4masses/alphageometry/lm_inference.py
27
  cp ag4masses-public.ipynb /data/ag4masses-public.ipynb
lm_inference.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Wrapper for language modeling inference implemented in Meliad."""
17
+ from typing import Any, Dict
18
+
19
+ import jax
20
+ import models # pylint: disable=unused-import
21
+ import t5.data
22
+ from transformer import inference_utils
23
+
24
+
25
+ np = jax.numpy
26
+
27
+
28
+ Trainer = inference_utils.Trainer
29
+
30
+ MetricsOutput = Dict[str, Any] # Metrics output by model.
31
+
32
+
33
+ parse_gin_configuration = inference_utils.parse_gin_configuration
34
+
35
+
36
+ class LanguageModelInference:
37
+ """Meliad wrapper for LM inference."""
38
+
39
+ def __init__(self, vocab_path: str, load_dir: str, mode='beam_search'):
40
+ self.vocab = t5.data.SentencePieceVocabulary(vocab_path)
41
+
42
+ # This task won't be pulling from a dataset.
43
+ def null_iter_fn() -> None:
44
+ return None
45
+
46
+ process_summaries_f = inference_utils.models.process_summaries_function(
47
+ self.vocab
48
+ )
49
+
50
+ trainer = inference_utils.training_loop.Trainer(
51
+ get_training_dataset_iterator=null_iter_fn,
52
+ get_test_dataset_iterator=None,
53
+ pretty_print_input_function=None,
54
+ process_summaries_function=process_summaries_f,
55
+ load_dir=load_dir,
56
+ workdir='', # Don't log or save checkpoints.
57
+ replicate_mode=False,
58
+ ) # Run on a single device at batch size 1.
59
+ self.trainer = trainer
60
+
61
+ # Create and initialize the model.
62
+ (tstate, _, imodel, prngs) = trainer.initialize_model()
63
+ self.imodel = imodel
64
+ self.batch_size = imodel.task_config.batch_size
65
+
66
+ self.n = imodel.num_heads
67
+ self.h = imodel.head_size
68
+
69
+ # Create an inference task.
70
+ writers = {}
71
+ self.task = trainer.create_training_task(mode, imodel, prngs, writers) # pylint: disable=too-many-function-args
72
+
73
+ # Register any additional actions.
74
+ # Actions are cleared first for use with colab.
75
+ inference_utils.training_loop.clear_interstep_callbacks()
76
+ inference_utils.training_loop.register_interstep_callbacks()
77
+ self.tstate = tstate
78
+
79
+ # some default parameters.
80
+ eos = [0] * 1024
81
+ for idx in self.encode_list(['.', ';']):
82
+ eos[idx] = 1
83
+
84
+ self.eos = np.array(eos, dtype=np.float32)
85
+ self.mask = jax.numpy.ones([1024], dtype=np.float32)
86
+
87
+ def decode(self, ids: list[int]) -> str:
88
+ return self.vocab.decode(ids)
89
+
90
+ def decode_list(self, tokens: list[int]) -> list[str]:
91
+ return [self.decode([tok]) for tok in tokens]
92
+
93
+ def encode(self, inputs_str: str) -> list[int]:
94
+ return self.vocab.encode(inputs_str)
95
+
96
+ def encode_list(self, inputs_strs: list[str]) -> list[int]:
97
+ result = [self.vocab.encode(x) for x in inputs_strs]
98
+ assert all([len(x) == 1 for x in result]), [
99
+ self.decode(x) for x in result if len(x) != 1
100
+ ]
101
+ return [x[0] for x in result]
102
+
103
+ def call(
104
+ self,
105
+ inputs: np.ndarray,
106
+ dstate: tuple[dict[str, np.ndarray], ...] = None,
107
+ eos: np.ndarray = None,
108
+ mask: np.ndarray = None,
109
+ ) -> MetricsOutput:
110
+ """Call the meliad model."""
111
+ batch_size, length = inputs.shape
112
+ inputs = jax.numpy.pad(inputs, [(0, 0), (0, 1024 - length)])
113
+
114
+ if eos is None:
115
+ eos = self.eos
116
+ if mask is None:
117
+ mask = self.mask
118
+
119
+ x = {'targets': inputs, 'length': length, 'eos': eos, 'mask': mask}
120
+
121
+ if dstate is not None:
122
+ x['start_of_sequence'] = jax.numpy.array([False] * batch_size)
123
+ else:
124
+ dstate = tuple(
125
+ [{ # this dummy value will never be used.
126
+ 'current_index': np.array([0] * batch_size, dtype=np.int32),
127
+ 'keys': np.zeros(
128
+ (batch_size, 2048, self.n, self.h), dtype=np.float32
129
+ ),
130
+ 'values': np.zeros(
131
+ (batch_size, 2048, self.n, self.h), dtype=np.float32
132
+ ),
133
+ 'recurrent_kvq': None,
134
+ 'relative_position_bias': np.zeros(
135
+ (batch_size, self.n, 1, 1024), dtype=np.float32
136
+ ),
137
+ }]
138
+ * 12
139
+ )
140
+ x['start_of_sequence'] = jax.numpy.array([True] * batch_size)
141
+
142
+ x['dstate'] = dstate
143
+ _, metrics_np = self.task.run_step(self.tstate, x, 0)
144
+ return metrics_np
145
+
146
+ def beam_decode(
147
+ self,
148
+ inputs: str,
149
+ eos_tokens: np.ndarray = None,
150
+ mask_tokens: np.ndarray = None,
151
+ dstate: dict[str, np.ndarray] = None,
152
+ ) -> MetricsOutput:
153
+ """Beam search."""
154
+ inputs = jax.numpy.array([self.vocab.encode(inputs)] * self.batch_size)
155
+
156
+ eos = self.eos
157
+ if eos_tokens is not None:
158
+ eos_ids = self.encode_list(eos_tokens)
159
+ eos = np.array(
160
+ [1 if idx in eos_ids else 0 for idx in range(1024)], dtype=np.float32
161
+ ).reshape((1, 1, 1024))
162
+
163
+ mask = self.mask
164
+ if mask_tokens is not None:
165
+ mask_ids = self.encode_list(mask_tokens)
166
+ mask = np.array(
167
+ [0 if idx in mask_ids else 1 for idx in range(1024)],
168
+ dtype=np.float32,
169
+ ).reshape((1, 1, 1024))
170
+
171
+ metrics_np = self.call(inputs, dstate=dstate, eos=eos, mask=mask)
172
+
173
+ finished_seqs = metrics_np['finished_seqs']
174
+ finished_scores = metrics_np['finished_scores']
175
+
176
+ seqs = []
177
+ scores = []
178
+ for seq, score in zip(finished_seqs, finished_scores):
179
+ seq = self.decode(seq[1:])
180
+ seqs.append(seq)
181
+ scores.append(score)
182
+
183
+ return {
184
+ 'finished_seqs': finished_seqs,
185
+ 'finished_scores': finished_scores,
186
+ 'seqs_str': seqs,
187
+ 'scores': scores,
188
+ 'dstate': metrics_np['dstate'],
189
+ }