|
import os |
|
import sys |
|
sys.path.append(os.getcwd()) |
|
import torch |
|
import torch.nn as nn |
|
import shutil |
|
import logging |
|
import torch.distributed as dist |
|
|
|
|
|
from transformers import ( |
|
BertTokenizer, |
|
RobertaTokenizer |
|
) |
|
|
|
from args import args |
|
from model import ( |
|
Layoutlmv1ForQuestionAnswering, |
|
Layoutlmv1Config, |
|
Layoutlmv1Config_roberta, |
|
Layoutlmv1ForQuestionAnswering_roberta |
|
) |
|
from util import set_seed, set_exp_folder, check_screen |
|
from trainer import train, evaluate |
|
|
|
from websrc import get_websrc_dataset |
|
|
|
def main(args): |
|
|
|
set_seed(args) |
|
set_exp_folder(args) |
|
|
|
|
|
logging.basicConfig(filename="{}/output/{}/log.txt".format(args.output_dir, args.exp_name), level=logging.INFO, |
|
format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') |
|
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) |
|
logging.info('Args '+str(args)) |
|
|
|
|
|
|
|
if args.model_type == 'bert': |
|
config_class, model_class, tokenizer_class = Layoutlmv1Config, Layoutlmv1ForQuestionAnswering, BertTokenizer |
|
elif args.model_type == 'roberta': |
|
config_class, model_class, tokenizer_class = Layoutlmv1Config_roberta, Layoutlmv1ForQuestionAnswering_roberta, RobertaTokenizer |
|
|
|
config = config_class.from_pretrained( |
|
args.model_name_or_path, cache_dir=args.cache_dir |
|
) |
|
config.add_linear = args.add_linear |
|
|
|
tokenizer = tokenizer_class.from_pretrained( |
|
args.model_name_or_path, cache_dir=args.cache_dir |
|
) |
|
|
|
|
|
model = model_class.from_pretrained( |
|
args.model_name_or_path, |
|
from_tf=bool(".ckpt" in args.model_name_or_path), |
|
config=config, |
|
cache_dir=args.cache_dir, |
|
) |
|
|
|
parameters = sum(p.numel() for p in model.parameters()) |
|
print("Total params: %.2fM" % (parameters/1e6)) |
|
|
|
|
|
|
|
if args.do_train: |
|
|
|
dataset_web = get_websrc_dataset(args, tokenizer) |
|
|
|
logging.info(f'Web dataset is successfully loaded. Length : {len(dataset_web)}') |
|
train(args, dataset_web, model, tokenizer) |
|
|
|
|
|
|
|
|
|
logging.info('Start evaluating') |
|
dataset_web, examples, features = get_websrc_dataset(args, tokenizer, evaluate=True, output_examples=True) |
|
logging.info(f'[Eval] Web dataset is successfully loaded. Length : {len(dataset_web)}') |
|
evaluate(args, dataset_web, examples, features, model, tokenizer) |
|
|
|
|
|
|
|
if args.do_test: |
|
pass |
|
|
|
|
|
if __name__ == '__main__': |
|
main(args) |
|
|
|
|
|
|