update test.py
Browse files
test.py
CHANGED
|
@@ -3,8 +3,8 @@ import os
|
|
| 3 |
import fire
|
| 4 |
import torch
|
| 5 |
from functools import partial
|
| 6 |
-
from transformers import
|
| 7 |
-
from transformers import
|
| 8 |
from pya0.preprocess import preprocess_for_transformer
|
| 9 |
|
| 10 |
|
|
@@ -25,14 +25,10 @@ def classifier_hook(tokenizer, tokens, topk, module, inputs, outputs):
|
|
| 25 |
str(tokenizer.convert_ids_to_tokens(top_cands)))
|
| 26 |
|
| 27 |
|
| 28 |
-
def test(
|
| 29 |
-
test_file='test.txt',
|
| 30 |
-
ckpt_bert='ckpt/bert-pretrained-for-math-7ep/6_3_1382',
|
| 31 |
-
ckpt_tokenizer='ckpt/bert-tokenizer-for-math'
|
| 32 |
-
):
|
| 33 |
|
| 34 |
-
tokenizer =
|
| 35 |
-
model =
|
| 36 |
tie_word_embeddings=True
|
| 37 |
)
|
| 38 |
with open(test_file, 'r') as fh:
|
|
|
|
| 3 |
import fire
|
| 4 |
import torch
|
| 5 |
from functools import partial
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
from transformers import AutoModelForPreTraining
|
| 8 |
from pya0.preprocess import preprocess_for_transformer
|
| 9 |
|
| 10 |
|
|
|
|
| 25 |
str(tokenizer.convert_ids_to_tokens(top_cands)))
|
| 26 |
|
| 27 |
|
| 28 |
+
def test(model_name_or_path, tokenizer_name_or_path, test_file='test.txt'):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
|
| 31 |
+
model = AutoModelForPreTraining.from_pretrained(model_name_or_path,
|
| 32 |
tie_word_embeddings=True
|
| 33 |
)
|
| 34 |
with open(test_file, 'r') as fh:
|