|
import argparse |
|
import os |
|
|
|
import torch |
|
from fairseq.data import (FairseqDataset, PrependTokenDataset, |
|
TokenBlockDataset, TruncateDataset, data_utils, StripTokenDataset, ConcatDataset, PrependTokenDataset, AppendTokenDataset) |
|
from fairseq.data.indexed_dataset import make_builder |
|
from tqdm import tqdm |
|
from transformers import AutoTokenizer |
|
|
|
from infoxlm.data.tlm_dataset import TLMDataset |
|
|
|
|
|
class IndexDataset(FairseqDataset): |
|
|
|
def __init__(self, indices): |
|
self.indices = indices |
|
self._sizes = [len(i) for i in indices] |
|
|
|
@property |
|
def sizes(self): |
|
return self._sizes |
|
|
|
def size(self, index): |
|
item = self.__getitem__(index) |
|
return len(item) |
|
|
|
def __getitem__(self, index): |
|
item = self.indices[index] |
|
item = torch.LongTensor(item) |
|
return item |
|
|
|
def __len__(self): |
|
return len(self.indices) |
|
|
|
def collater(self, samples): |
|
raise NotImplementedError |
|
|
|
|
|
def build_tokenizer(args): |
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
|
return tokenizer |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model_name", type=str, default="CZWin32768/xlm-align") |
|
parser.add_argument("--input_src", type=str, default="") |
|
parser.add_argument("--input_trg", type=str, default="") |
|
parser.add_argument("--output", type=str, default="") |
|
parser.add_argument("--max_pos", type=int, default=256) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def save_items(items, prefix, vocab_size): |
|
bin_fn = "%s.bin" % prefix |
|
idx_fn = "%s.idx" % prefix |
|
builder = make_builder(bin_fn, "mmap", vocab_size=vocab_size) |
|
print("builder: " + str(builder)) |
|
for item in items: builder.add_item(item) |
|
builder.finalize(idx_fn) |
|
|
|
|
|
def get_indices(input_fn, tokenizer): |
|
indices = [] |
|
with open(input_fn) as fp: |
|
for lid, line in tqdm(enumerate(fp)): |
|
|
|
|
|
line = line.strip() |
|
indices.append(tokenizer.encode(line)) |
|
print("tokenize finished.") |
|
return indices |
|
|
|
|
|
|
|
def main(args): |
|
tokenizer = build_tokenizer(args) |
|
src_indices = get_indices(args.input_src, tokenizer) |
|
trg_indices = get_indices(args.input_trg, tokenizer) |
|
|
|
src_dataset = IndexDataset(src_indices) |
|
trg_dataset = IndexDataset(trg_indices) |
|
|
|
eos = tokenizer.sep_token_id |
|
bos = tokenizer.cls_token_id |
|
max_pos = args.max_pos |
|
|
|
datasets = [] |
|
|
|
src_dataset = TruncateDataset( |
|
StripTokenDataset(src_dataset, eos), max_pos - 2,) |
|
trg_dataset = TruncateDataset( |
|
StripTokenDataset(trg_dataset, eos), max_pos - 2,) |
|
|
|
src_dataset = PrependTokenDataset(src_dataset, bos) |
|
trg_dataset = PrependTokenDataset(trg_dataset, bos) |
|
|
|
src_dataset = AppendTokenDataset(src_dataset, eos) |
|
trg_dataset = AppendTokenDataset(trg_dataset, eos) |
|
|
|
print("| get all items ...") |
|
|
|
items = [] |
|
for t1, t2 in tqdm(zip(src_dataset, trg_dataset)): |
|
items.append(t1) |
|
items.append(t2) |
|
|
|
print("| writing binary file ...") |
|
prefix = os.path.join(args.output, "train.0") |
|
save_items(items, prefix, len(tokenizer)) |
|
|
|
|
|
if __name__ == "__main__": |
|
args = get_args() |
|
main(args) |
|
|