|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import, division, print_function |
|
|
|
import argparse |
|
from transformers import BertTokenizer, XLMTokenizer, XLMRobertaTokenizer |
|
import os |
|
from collections import defaultdict |
|
import csv |
|
import random |
|
import os |
|
import shutil |
|
import json |
|
|
|
|
|
TOKENIZERS = { |
|
'bert': BertTokenizer, |
|
'xlm': XLMTokenizer, |
|
'xlmr': XLMRobertaTokenizer, |
|
} |
|
|
|
def panx_tokenize_preprocess(args): |
|
def _preprocess_one_file(infile, outfile, idxfile, tokenizer, max_len): |
|
if not os.path.exists(infile): |
|
print(f'{infile} not exists') |
|
return 0 |
|
special_tokens_count = 3 if isinstance(tokenizer, XLMRobertaTokenizer) else 2 |
|
max_seq_len = max_len - special_tokens_count |
|
subword_len_counter = idx = 0 |
|
with open(infile, "rt") as fin, open(outfile, "w") as fout, open(idxfile, "w") as fidx: |
|
for line in fin: |
|
line = line.strip() |
|
if not line: |
|
fout.write('\n') |
|
fidx.write('\n') |
|
idx += 1 |
|
subword_len_counter = 0 |
|
continue |
|
|
|
items = line.split() |
|
token = items[0].strip() |
|
if len(items) == 2: |
|
label = items[1].strip() |
|
else: |
|
label = 'O' |
|
current_subwords_len = len(tokenizer.tokenize(token)) |
|
|
|
if (current_subwords_len == 0 or current_subwords_len > max_seq_len) and len(token) != 0: |
|
token = tokenizer.unk_token |
|
current_subwords_len = 1 |
|
|
|
if (subword_len_counter + current_subwords_len) > max_seq_len: |
|
fout.write(f"\n{token}\t{label}\n") |
|
fidx.write(f"\n{idx}\n") |
|
subword_len_counter = current_subwords_len |
|
else: |
|
fout.write(f"{token}\t{label}\n") |
|
fidx.write(f"{idx}\n") |
|
subword_len_counter += current_subwords_len |
|
return 1 |
|
|
|
model_type = args.model_type |
|
tokenizer = TOKENIZERS[model_type].from_pretrained(args.model_name_or_path, |
|
do_lower_case=args.do_lower_case, |
|
cache_dir=args.cache_dir if args.cache_dir else None) |
|
for lang in args.languages.split(','): |
|
out_dir = os.path.join(args.output_dir, lang) |
|
if not os.path.exists(out_dir): |
|
os.makedirs(out_dir) |
|
if lang == 'en': |
|
files = ['dev', 'test', 'train'] |
|
else: |
|
files = ['dev', 'test'] |
|
for file in files: |
|
infile = os.path.join(args.data_dir, f'{file}-{lang}.tsv') |
|
outfile = os.path.join(out_dir, "{}.{}".format(file, args.model_name_or_path)) |
|
idxfile = os.path.join(out_dir, "{}.{}.idx".format(file, args.model_name_or_path)) |
|
if os.path.exists(outfile) and os.path.exists(idxfile): |
|
print(f'{outfile} and {idxfile} exist') |
|
else: |
|
code = _preprocess_one_file(infile, outfile, idxfile, tokenizer, args.max_len) |
|
if code > 0: |
|
print(f'finish preprocessing {outfile}') |
|
|
|
|
|
def panx_preprocess(args): |
|
def _process_one_file(infile, outfile): |
|
lines = open(infile, 'r').readlines() |
|
if lines[-1].strip() == '': |
|
lines = lines[:-1] |
|
with open(outfile, 'w') as fout: |
|
for l in lines: |
|
items = l.strip().split('\t') |
|
if len(items) == 2: |
|
label = items[1].strip() |
|
idx = items[0].find(':') |
|
if idx != -1: |
|
token = items[0][idx+1:].strip() |
|
|
|
|
|
|
|
|
|
fout.write(f'{token}\t{label}\n') |
|
else: |
|
fout.write('\n') |
|
if not os.path.exists(args.output_dir): |
|
os.makedirs(args.output_dir) |
|
langs = 'ar he vi id jv ms tl eu ml ta te af nl en de el bn hi mr ur fa fr it pt es bg ru ja ka ko th sw yo my zh kk tr et fi hu'.split(' ') |
|
for lg in langs: |
|
for split in ['train', 'test', 'dev']: |
|
infile = os.path.join(args.data_dir, f'{lg}-{split}') |
|
outfile = os.path.join(args.output_dir, f'{split}-{lg}.tsv') |
|
_process_one_file(infile, outfile) |
|
|
|
def udpos_tokenize_preprocess(args): |
|
def _preprocess_one_file(infile, outfile, idxfile, tokenizer, max_len): |
|
if not os.path.exists(infile): |
|
print(f'{infile} does not exist') |
|
return |
|
subword_len_counter = idx = 0 |
|
special_tokens_count = 3 if isinstance(tokenizer, XLMRobertaTokenizer) else 2 |
|
max_seq_len = max_len - special_tokens_count |
|
with open(infile, "rt") as fin, open(outfile, "w") as fout, open(idxfile, "w") as fidx: |
|
for line in fin: |
|
line = line.strip() |
|
if len(line) == 0 or line == '': |
|
fout.write('\n') |
|
fidx.write('\n') |
|
idx += 1 |
|
subword_len_counter = 0 |
|
continue |
|
|
|
items = line.split() |
|
if len(items) == 2: |
|
label = items[1].strip() |
|
else: |
|
label = "X" |
|
token = items[0].strip() |
|
current_subwords_len = len(tokenizer.tokenize(token)) |
|
|
|
if (current_subwords_len == 0 or current_subwords_len > max_seq_len) and len(token) != 0: |
|
token = tokenizer.unk_token |
|
current_subwords_len = 1 |
|
|
|
if (subword_len_counter + current_subwords_len) > max_seq_len: |
|
fout.write(f"\n{token}\t{label}\n") |
|
fidx.write(f"\n{idx}\n") |
|
subword_len_counter = current_subwords_len |
|
else: |
|
fout.write(f"{token}\t{label}\n") |
|
fidx.write(f"{idx}\n") |
|
subword_len_counter += current_subwords_len |
|
|
|
model_type = args.model_type |
|
tokenizer = TOKENIZERS[model_type].from_pretrained(args.model_name_or_path, |
|
do_lower_case=args.do_lower_case, |
|
cache_dir=args.cache_dir if args.cache_dir else None) |
|
for lang in args.languages.split(','): |
|
out_dir = os.path.join(args.output_dir, lang) |
|
if not os.path.exists(out_dir): |
|
os.makedirs(out_dir) |
|
if lang == 'en': |
|
files = ['dev', 'test', 'train'] |
|
else: |
|
files = ['dev', 'test'] |
|
for file in files: |
|
infile = os.path.join(args.data_dir, "{}-{}.tsv".format(file, lang)) |
|
outfile = os.path.join(out_dir, "{}.{}".format(file, args.model_name_or_path)) |
|
idxfile = os.path.join(out_dir, "{}.{}.idx".format(file, args.model_name_or_path)) |
|
if os.path.exists(outfile) and os.path.exists(idxfile): |
|
print(f'{outfile} and {idxfile} exist') |
|
else: |
|
_preprocess_one_file(infile, outfile, idxfile, tokenizer, args.max_len) |
|
print(f'finish preprocessing {outfile}') |
|
|
|
def udpos_preprocess(args): |
|
def _read_one_file(file): |
|
data = [] |
|
sent, tag, lines = [], [], [] |
|
for line in open(file, 'r'): |
|
items = line.strip().split('\t') |
|
if len(items) != 10: |
|
empty = all(w == '_' for w in sent) |
|
num_empty = sum([int(w == '_') for w in sent]) |
|
if num_empty == 0 or num_empty < len(sent) - 1: |
|
data.append((sent, tag, lines)) |
|
sent, tag, lines = [], [], [] |
|
else: |
|
sent.append(items[1].strip()) |
|
tag.append(items[3].strip()) |
|
lines.append(line.strip()) |
|
assert len(sent) == int(items[0]), 'line={}, sent={}, tag={}'.format(line, sent, tag) |
|
return data |
|
|
|
def isfloat(value): |
|
try: |
|
float(value) |
|
return True |
|
except ValueError: |
|
return False |
|
|
|
def remove_empty_space(data): |
|
new_data = {} |
|
for split in data: |
|
new_data[split] = [] |
|
for sent, tag, lines in data[split]: |
|
new_sent = [''.join(w.replace('\u200c', '').split(' ')) for w in sent] |
|
lines = [line.replace('\u200c', '') for line in lines] |
|
assert len(" ".join(new_sent).split(' ')) == len(tag) |
|
new_data[split].append((new_sent, tag, lines)) |
|
return new_data |
|
|
|
def check_file(file): |
|
for i, l in enumerate(open(file)): |
|
items = l.strip().split('\t') |
|
assert len(items[0].split(' ')) == len(items[1].split(' ')), 'idx={}, line={}'.format(i, l) |
|
|
|
def _write_files(data, output_dir, lang, suffix): |
|
for split in data: |
|
if len(data[split]) > 0: |
|
prefix = os.path.join(output_dir, f'{split}-{lang}') |
|
if suffix == 'mt': |
|
with open(prefix + '.mt.tsv', 'w') as fout: |
|
for idx, (sent, tag, _) in enumerate(data[split]): |
|
newline = '\n' if idx != len(data[split]) - 1 else '' |
|
|
|
|
|
|
|
|
|
fout.write('{}\t{}{}'.format(' '.join(sent), ' '.join(tag), newline)) |
|
check_file(prefix + '.mt.tsv') |
|
print(' - finish checking ' + prefix + '.mt.tsv') |
|
elif suffix == 'tsv': |
|
with open(prefix + '.tsv', 'w') as fout: |
|
for sidx, (sent, tag, _) in enumerate(data[split]): |
|
for widx, (w, t) in enumerate(zip(sent, tag)): |
|
newline = '' if (sidx == len(data[split]) - 1) and (widx == len(sent) - 1) else '\n' |
|
|
|
|
|
|
|
|
|
fout.write('{}\t{}{}'.format(w, t, newline)) |
|
fout.write('\n') |
|
elif suffix == 'conll': |
|
with open(prefix + '.conll', 'w') as fout: |
|
for _, _, lines in data[split]: |
|
for l in lines: |
|
fout.write(l.strip() + '\n') |
|
fout.write('\n') |
|
print(f'finish writing file to {prefix}.{suffix}') |
|
|
|
if not os.path.exists(args.output_dir): |
|
os.makedirs(args.output_dir) |
|
|
|
languages = 'af ar bg de el en es et eu fa fi fr he hi hu id it ja kk ko mr nl pt ru ta te th tl tr ur vi yo zh'.split(' ') |
|
for root, dirs, files in os.walk(args.data_dir): |
|
lg = root.strip().split('/')[-1] |
|
if root == args.data_dir or lg not in languages: |
|
continue |
|
|
|
data = {k: [] for k in ['train', 'dev', 'test']} |
|
for f in sorted(files): |
|
if f.endswith('conll'): |
|
file = os.path.join(root, f) |
|
examples = _read_one_file(file) |
|
if 'train' in f: |
|
data['train'].extend(examples) |
|
elif 'dev' in f: |
|
data['dev'].extend(examples) |
|
elif 'test' in f: |
|
data['test'].extend(examples) |
|
else: |
|
print('split not found: ', file) |
|
print(' - finish reading {}, {}'.format(file, [(k, len(v)) for k,v in data.items()])) |
|
|
|
data = remove_empty_space(data) |
|
for sub in ['tsv']: |
|
_write_files(data, args.output_dir, lg, sub) |
|
|
|
def pawsx_preprocess(args): |
|
def _preprocess_one_file(infile, outfile, remove_label=False): |
|
data = [] |
|
for i, line in enumerate(open(infile, 'r')): |
|
if i == 0: |
|
continue |
|
items = line.strip().split('\t') |
|
sent1 = ' '.join(items[1].strip().split(' ')) |
|
sent2 = ' '.join(items[2].strip().split(' ')) |
|
label = items[3] |
|
data.append([sent1, sent2, label]) |
|
|
|
with open(outfile, 'w') as fout: |
|
writer = csv.writer(fout, delimiter='\t', quoting=csv.QUOTE_NONE, quotechar='') |
|
for sent1, sent2, label in data: |
|
|
|
|
|
|
|
|
|
writer.writerow([sent1, sent2, label]) |
|
|
|
if not os.path.exists(args.output_dir): |
|
os.makedirs(args.output_dir) |
|
|
|
split2file = {'train': 'train', 'test': 'test_2k', 'dev': 'dev_2k'} |
|
for lang in ['en', 'de', 'es', 'fr', 'ja', 'ko', 'zh']: |
|
for split in ['train', 'test', 'dev']: |
|
if split == 'train' and lang != 'en': |
|
continue |
|
file = split2file[split] |
|
infile = os.path.join(args.data_dir, lang, "{}.tsv".format(file)) |
|
outfile = os.path.join(args.output_dir, "{}-{}.tsv".format(split, lang)) |
|
_preprocess_one_file(infile, outfile, remove_label=(split == 'test')) |
|
print(f'finish preprocessing {outfile}') |
|
|
|
def xnli_preprocess(args): |
|
def _preprocess_file(infile, output_dir, split): |
|
all_langs = defaultdict(list) |
|
for i, line in enumerate(open(infile, 'r')): |
|
if i == 0: |
|
continue |
|
|
|
items = line.strip().split('\t') |
|
lang = items[0].strip() |
|
label = "contradiction" if items[1].strip() == "contradictory" else items[1].strip() |
|
sent1 = ' '.join(items[6].strip().split(' ')) |
|
sent2 = ' '.join(items[7].strip().split(' ')) |
|
all_langs[lang].append((sent1, sent2, label)) |
|
print(f'# langs={len(all_langs)}') |
|
for lang, pairs in all_langs.items(): |
|
outfile = os.path.join(output_dir, '{}-{}.tsv'.format(split, lang)) |
|
with open(outfile, 'w') as fout: |
|
writer = csv.writer(fout, delimiter='\t', quoting=csv.QUOTE_NONE, quotechar='') |
|
for (sent1, sent2, label) in pairs: |
|
|
|
|
|
|
|
|
|
writer.writerow([sent1, sent2, label]) |
|
print(f'finish preprocess {outfile}') |
|
|
|
def _preprocess_train_file(infile, outfile): |
|
with open(outfile, 'w') as fout: |
|
writer = csv.writer(fout, delimiter='\t', quoting=csv.QUOTE_NONE, quotechar='') |
|
for i, line in enumerate(open(infile, 'r')): |
|
if i == 0: |
|
continue |
|
|
|
items = line.strip().split('\t') |
|
sent1 = ' '.join(items[0].strip().split(' ')) |
|
sent2 = ' '.join(items[1].strip().split(' ')) |
|
label = "contradiction" if items[2].strip() == "contradictory" else items[2].strip() |
|
writer.writerow([sent1, sent2, label]) |
|
print(f'finish preprocess {outfile}') |
|
|
|
infile = os.path.join(args.data_dir, 'XNLI-MT-1.0/multinli/multinli.train.en.tsv') |
|
if not os.path.exists(args.output_dir): |
|
os.makedirs(args.output_dir) |
|
outfile = os.path.join(args.output_dir, 'train-en.tsv') |
|
_preprocess_train_file(infile, outfile) |
|
|
|
for split in ['test', 'dev']: |
|
infile = os.path.join(args.data_dir, 'XNLI-1.0/xnli.{}.tsv'.format(split)) |
|
print(f'reading file {infile}') |
|
_preprocess_file(infile, args.output_dir, split) |
|
|
|
|
|
def tatoeba_preprocess(args): |
|
lang3_dict = { |
|
'afr':'af', 'ara':'ar', 'bul':'bg', 'ben':'bn', |
|
'deu':'de', 'ell':'el', 'spa':'es', 'est':'et', |
|
'eus':'eu', 'pes':'fa', 'fin':'fi', 'fra':'fr', |
|
'heb':'he', 'hin':'hi', 'hun':'hu', 'ind':'id', |
|
'ita':'it', 'jpn':'ja', 'jav':'jv', 'kat':'ka', |
|
'kaz':'kk', 'kor':'ko', 'mal':'ml', 'mar':'mr', |
|
'nld':'nl', 'por':'pt', 'rus':'ru', 'swh':'sw', |
|
'tam':'ta', 'tel':'te', 'tha':'th', 'tgl':'tl', |
|
'tur':'tr', 'urd':'ur', 'vie':'vi', 'cmn':'zh', |
|
'eng':'en', |
|
} |
|
if not os.path.exists(args.output_dir): |
|
os.makedirs(args.output_dir) |
|
for sl3, sl2 in lang3_dict.items(): |
|
if sl3 != 'eng': |
|
src_file = f'{args.data_dir}/tatoeba.{sl3}-eng.{sl3}' |
|
tgt_file = f'{args.data_dir}/tatoeba.{sl3}-eng.eng' |
|
src_out = f'{args.output_dir}/{sl2}-en.{sl2}' |
|
tgt_out = f'{args.output_dir}/{sl2}-en.en' |
|
shutil.copy(src_file, src_out) |
|
tgts = [l.strip() for l in open(tgt_file)] |
|
idx = range(len(tgts)) |
|
data = zip(tgts, idx) |
|
with open(tgt_out, 'w') as ftgt: |
|
for t, i in sorted(data, key=lambda x: x[0]): |
|
ftgt.write(f'{t}\n') |
|
|
|
|
|
def xquad_preprocess(args): |
|
|
|
|
|
pass |
|
|
|
|
|
def mlqa_preprocess(args): |
|
|
|
|
|
pass |
|
|
|
|
|
def tydiqa_preprocess(args): |
|
LANG2ISO = {'arabic': 'ar', 'bengali': 'bn', 'english': 'en', 'finnish': 'fi', |
|
'indonesian': 'id', 'korean': 'ko', 'russian': 'ru', |
|
'swahili': 'sw', 'telugu': 'te'} |
|
assert os.path.exists(args.data_dir) |
|
train_file = os.path.join(args.data_dir, 'tydiqa-goldp-v1.1-train.json') |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
lang2data = defaultdict(list) |
|
with open(train_file, 'r') as f_in: |
|
data = json.load(f_in) |
|
version = data['version'] |
|
for doc in data['data']: |
|
for par in doc['paragraphs']: |
|
context = par['context'] |
|
for qa in par['qas']: |
|
question = qa['question'] |
|
question_id = qa['id'] |
|
example_lang = question_id.split('-')[0] |
|
q_id = question_id.split('-')[-1] |
|
for answer in qa['answers']: |
|
a_start, a_text = answer['answer_start'], answer['text'] |
|
a_end = a_start + len(a_text) |
|
assert context[a_start:a_end] == a_text |
|
lang2data[example_lang].append({'paragraphs': [{ |
|
'context': context, |
|
'qas': [{'answers': qa['answers'], |
|
'question': question, |
|
'id': q_id}]}]}) |
|
|
|
for lang, data in lang2data.items(): |
|
out_file = os.path.join( |
|
args.output_dir, 'tydiqa.%s.train.json' % LANG2ISO[lang]) |
|
with open(out_file, 'w') as f: |
|
json.dump({'data': data, 'version': version}, f) |
|
|
|
|
|
dev_dir = os.path.join(args.data_dir, 'tydiqa-goldp-v1.1-dev') |
|
assert os.path.exists(dev_dir) |
|
for lang, iso in LANG2ISO.items(): |
|
src_file = os.path.join(dev_dir, 'tydiqa-goldp-dev-%s.json' % lang) |
|
dst_file = os.path.join(dev_dir, 'tydiqa.%s.dev.json' % iso) |
|
os.rename(src_file, dst_file) |
|
|
|
|
|
|
|
|
|
|
|
def remove_qa_test_annotations(test_dir): |
|
assert os.path.exists(test_dir) |
|
for file_name in os.listdir(test_dir): |
|
new_data = [] |
|
test_file = os.path.join(test_dir, file_name) |
|
with open(test_file, 'r') as f: |
|
data = json.load(f) |
|
version = data['version'] |
|
for doc in data['data']: |
|
for par in doc['paragraphs']: |
|
context = par['context'] |
|
for qa in par['qas']: |
|
question = qa['question'] |
|
question_id = qa['id'] |
|
for answer in qa['answers']: |
|
a_start, a_text = answer['answer_start'], answer['text'] |
|
a_end = a_start + len(a_text) |
|
assert context[a_start:a_end] == a_text |
|
new_data.append({'paragraphs': [{ |
|
'context': context, |
|
'qas': [{'answers': [{'answer_start': 0, 'text': ''}], |
|
'question': question, |
|
'id': question_id}]}]}) |
|
with open(test_file, 'w') as f: |
|
json.dump({'data': new_data, 'version': version}, f) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument("--data_dir", default=None, type=str, required=True, |
|
help="The input data dir. Should contain the .tsv files (or other data files) for the task.") |
|
parser.add_argument("--output_dir", default=None, type=str, required=True, |
|
help="The output data dir where any processed files will be written to.") |
|
parser.add_argument("--task", default="panx", type=str, required=True, |
|
help="The task name") |
|
parser.add_argument("--model_name_or_path", default="bert-base-multilingual-cased", type=str, |
|
help="The pre-trained model") |
|
parser.add_argument("--model_type", default="bert", type=str, |
|
help="model type") |
|
parser.add_argument("--max_len", default=512, type=int, |
|
help="the maximum length of sentences") |
|
parser.add_argument("--do_lower_case", action='store_true', |
|
help="whether to do lower case") |
|
parser.add_argument("--cache_dir", default=None, type=str, |
|
help="cache directory") |
|
parser.add_argument("--languages", default="en", type=str, |
|
help="process language") |
|
parser.add_argument("--remove_last_token", action='store_true', |
|
help="whether to remove the last token") |
|
parser.add_argument("--remove_test_label", action='store_true', |
|
help="whether to remove test set label") |
|
args = parser.parse_args() |
|
|
|
if args.task == 'panx_tokenize': |
|
panx_tokenize_preprocess(args) |
|
if args.task == 'panx': |
|
panx_preprocess(args) |
|
if args.task == 'udpos_tokenize': |
|
udpos_tokenize_preprocess(args) |
|
if args.task == 'udpos': |
|
udpos_preprocess(args) |
|
if args.task == 'pawsx': |
|
pawsx_preprocess(args) |
|
if args.task == 'xnli': |
|
xnli_preprocess(args) |
|
if args.task == 'tatoeba': |
|
tatoeba_preprocess(args) |
|
if args.task == 'xquad': |
|
xquad_preprocess(args) |
|
if args.task == 'mlqa': |
|
mlqa_preprocess(args) |
|
if args.task == 'tydiqa': |
|
tydiqa_preprocess(args) |
|
|