|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from text_encoder import SubwordTextEncoder |
|
import tokenizer |
|
import os |
|
import tempfile |
|
|
|
import tensorflow as tf |
|
|
|
tf.flags.DEFINE_string('output_filename', '/tmp/my.subword_text_encoder', |
|
'where to store the SubwordTextEncoder') |
|
tf.flags.DEFINE_string('corpus_filepattern', '', |
|
'Corpus of one or more text files') |
|
tf.flags.DEFINE_string('vocab_filepattern', '', 'One or more vocabulary files ' |
|
'(one word per line as "word,count")') |
|
tf.flags.DEFINE_integer('min_count', 5, 'Minimum subtoken count in corpus') |
|
tf.flags.DEFINE_integer('vocab_size', 30000, 'The final vocab size. It will produce a vocab with a near vocab size') |
|
tf.flags.DEFINE_integer('corpus_max_lines', None, |
|
'How many lines of corpus to read') |
|
tf.flags.DEFINE_integer('num_iterations', 5, 'Number of iterations') |
|
tf.flags.DEFINE_bool('split_on_newlines', True, 'Break corpus into lines.') |
|
tf.flags.DEFINE_string('additional_chars', "", 'Set special characters to be included in vocab. ex : "~", "/".') |
|
tf.flags.DEFINE_integer('max_subtoken_length', None, 'Max subtoken length') |
|
tf.flags.DEFINE_string('raw_vocab', None, 'Raw bert vovab file') |
|
tf.flags.DEFINE_bool('do_lower_case', False, 'Whether or not to lowercase the input corpus') |
|
FLAGS = tf.flags.FLAGS |
|
|
|
|
|
def merge_output_file_with_bert_vocab(output_filename, bert_vocab, temp_path): |
|
writer = open(output_filename, 'w', encoding='utf-8') |
|
_set = set() |
|
with open(bert_vocab, 'r', encoding='utf-8') as reader: |
|
for line in reader: |
|
writer.write(line) |
|
_set.add(line.strip()) |
|
print(temp_path) |
|
with open(temp_path, 'r', encoding='utf-8') as reader: |
|
for line in reader: |
|
if line.strip() not in _set: |
|
writer.write(line) |
|
|
|
writer.close() |
|
|
|
|
|
def main(unused_argv): |
|
if FLAGS.corpus_filepattern and FLAGS.vocab_filepattern: |
|
raise ValueError( |
|
'Must only provide one of --corpus_filepattern or --vocab_filepattern') |
|
|
|
elif FLAGS.corpus_filepattern: |
|
token_counts = tokenizer.corpus_token_counts( |
|
FLAGS.corpus_filepattern, |
|
FLAGS.corpus_max_lines, |
|
split_on_newlines=FLAGS.split_on_newlines, additional_chars=FLAGS.additional_chars, do_lower_case=FLAGS.do_lower_case) |
|
|
|
elif FLAGS.vocab_filepattern: |
|
token_counts = tokenizer.vocab_token_counts(FLAGS.vocab_filepattern, |
|
FLAGS.corpus_max_lines, FLAGS.do_lower_case) |
|
|
|
else: |
|
raise ValueError( |
|
'Must provide one of --corpus_filepattern or --vocab_filepattern') |
|
reserved_tokens = None |
|
if FLAGS.raw_vocab: |
|
lines = open(FLAGS.raw_vocab, 'r', encoding='utf-8').readlines() |
|
lines = [s.strip() for s in lines if len(s) > 0] |
|
reserved_tokens = lines |
|
|
|
print(len(token_counts)) |
|
print(len(reserved_tokens)) |
|
target_size = FLAGS.vocab_size |
|
if target_size <= len(reserved_tokens): |
|
raise ValueError("The vocab_size must be larger than the origin vocab's size ") |
|
if target_size >= len(token_counts): |
|
raise ValueError("The vocab_size is too large. Please set it smaller or prepare more corpus.") |
|
min_val = 1 |
|
max_val = len(token_counts) // (target_size ** 0.5) |
|
fd, temp_path = tempfile.mkstemp() |
|
encoder = SubwordTextEncoder.build_to_target_size(target_size,token_counts,min_val, max_val, num_iterations=FLAGS.num_iterations, |
|
reserved_tokens=reserved_tokens, max_subtoken_length=FLAGS.max_subtoken_length) |
|
|
|
|
|
|
|
encoder.store_to_file(temp_path, add_single_quotes=False) |
|
merge_output_file_with_bert_vocab(FLAGS.output_filename, FLAGS.raw_vocab, temp_path) |
|
|
|
if __name__ == '__main__': |
|
tf.app.run() |
|
|