|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Encoders for text data. |
|
|
|
* TextEncoder: base class |
|
* SubwordTextEncoder: invertible |
|
""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import collections |
|
from itertools import chain |
|
import re |
|
import time |
|
import logging |
|
import six |
|
from six.moves import range |
|
|
|
|
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', |
|
datefmt='%m/%d/%Y %H:%M:%S', |
|
level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
PAD = "[PAD]" |
|
EOS = "[EOS]" |
|
UNK = "[UNK]" |
|
CLS = "[CLS]" |
|
SEP = "[SEP]" |
|
MASK = "[MASK]" |
|
RESERVED_TOKENS = [PAD, EOS, UNK, CLS, SEP, MASK] |
|
NUM_RESERVED_TOKENS = len(RESERVED_TOKENS) |
|
PAD_ID = RESERVED_TOKENS.index(PAD) |
|
EOS_ID = RESERVED_TOKENS.index(EOS) |
|
|
|
if six.PY2: |
|
RESERVED_TOKENS_BYTES = RESERVED_TOKENS |
|
else: |
|
RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")] |
|
|
|
|
|
|
|
|
|
|
|
_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);") |
|
_ESCAPE_CHARS = set(u"\\_u;0123456789") |
|
_SPECIAL_CHARS = set(u"!\"\'#$%&*()`+,-./:;<=>?@[]^_{}~|") |
|
|
|
|
|
def native_to_unicode(s): |
|
if is_unicode(s): |
|
return s |
|
try: |
|
return to_unicode(s) |
|
except UnicodeDecodeError: |
|
res = to_unicode(s, ignore_errors=True) |
|
logger.info("Ignoring Unicode error, outputting: %s" % res) |
|
return res |
|
|
|
|
|
def unicode_to_native(s): |
|
if six.PY2: |
|
return s.encode("utf-8") if is_unicode(s) else s |
|
else: |
|
return s |
|
|
|
|
|
def is_unicode(s): |
|
return isinstance(s, six.text_type) |
|
|
|
|
|
def to_unicode(s, ignore_errors=False): |
|
if is_unicode(s): |
|
return s |
|
error_mode = "ignore" if ignore_errors else "strict" |
|
return s.decode("utf-8", errors=error_mode) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TextEncoder(object): |
|
"""Base class for converting from ints to/from human readable strings.""" |
|
|
|
def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS): |
|
self._num_reserved_ids = num_reserved_ids |
|
|
|
@property |
|
def num_reserved_ids(self): |
|
return self._num_reserved_ids |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property |
|
def vocab_size(self): |
|
raise NotImplementedError() |
|
|
|
|
|
def _escape_token(token, alphabet): |
|
"""Escape away underscores and OOV characters and append '_'. |
|
|
|
This allows the token to be expressed as the concatenation of a list |
|
of subtokens from the vocabulary. The underscore acts as a sentinel |
|
which allows us to invertibly concatenate multiple such lists. |
|
|
|
Args: |
|
token: A unicode string to be escaped. |
|
alphabet: A set of all characters in the vocabulary's alphabet. |
|
|
|
Returns: |
|
escaped_token: An escaped unicode string. |
|
|
|
Raises: |
|
ValueError: If the provided token is not unicode. |
|
""" |
|
if not isinstance(token, six.text_type): |
|
raise ValueError("Expected string type for token, got %s" % type(token)) |
|
|
|
token = token.replace(u"\\", u"\\\\").replace(u"_", u"\\u") |
|
ret = [c if c in alphabet and c != u"\n" else r"\%d;" % ord(c) for c in token] |
|
return u"".join(ret) + "_" |
|
|
|
def _my_escape_token(token, alphabet): |
|
|
|
if not isinstance(token, six.text_type): |
|
raise ValueError("Expected string type for token, got %s" % type(token)) |
|
|
|
token = token.replace(u"\\", u"\\\\").replace(u"_", u"\\u") |
|
ret = [c if c in alphabet and c != u"\n" else r"\%d;" % ord(c) for c in token] |
|
return "_" + u"".join(ret) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SubwordTextEncoder(TextEncoder): |
|
"""Class for invertibly encoding text using a limited vocabulary. |
|
|
|
Invertibly encodes a native string as a sequence of subtokens from a limited |
|
vocabulary. |
|
|
|
A SubwordTextEncoder is built from a corpus (so it is tailored to the text in |
|
the corpus), and stored to a file. See text_encoder_build_subword.py. |
|
|
|
It can then be loaded and used to encode/decode any text. |
|
|
|
Encoding has four phases: |
|
|
|
1. Tokenize into a list of tokens. Each token is a unicode string of either |
|
all alphanumeric characters or all non-alphanumeric characters. We drop |
|
tokens consisting of a single space that are between two alphanumeric |
|
tokens. |
|
|
|
2. Escape each token. This escapes away special and out-of-vocabulary |
|
characters, and makes sure that each token ends with an underscore, and |
|
has no other underscores. |
|
|
|
3. Represent each escaped token as a the concatenation of a list of subtokens |
|
from the limited vocabulary. Subtoken selection is done greedily from |
|
beginning to end. That is, we construct the list in order, always picking |
|
the longest subtoken in our vocabulary that matches a prefix of the |
|
remaining portion of the encoded token. |
|
|
|
4. Concatenate these lists. This concatenation is invertible due to the |
|
fact that the trailing underscores indicate when one list is finished. |
|
|
|
""" |
|
|
|
def __init__(self, filename=None): |
|
"""Initialize and read from a file, if provided. |
|
|
|
Args: |
|
filename: filename from which to read vocab. If None, do not load a |
|
vocab |
|
""" |
|
self._alphabet = set() |
|
|
|
|
|
|
|
super(SubwordTextEncoder, self).__init__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property |
|
def vocab_size(self): |
|
"""The subtoken vocabulary size.""" |
|
return len(self._all_subtoken_strings) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _escaped_token_to_subtoken_strings(self, escaped_token): |
|
"""Converts an escaped token string to a list of subtoken strings. |
|
|
|
Args: |
|
escaped_token: An escaped token as a unicode string. |
|
Returns: |
|
A list of subtokens as unicode strings. |
|
""" |
|
|
|
|
|
ret = [] |
|
start = 0 |
|
token_len = len(escaped_token) |
|
while start < token_len: |
|
for end in range( |
|
min(token_len, start + self._max_subtoken_len), start, -1): |
|
subtoken = escaped_token[start:end] |
|
if subtoken in self._subtoken_string_to_id: |
|
ret.append(subtoken) |
|
start = end |
|
break |
|
|
|
else: |
|
|
|
|
|
|
|
assert False, "Token substring not found in subtoken vocabulary." |
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
def build_to_target_size(cls, |
|
target_size, |
|
token_counts, |
|
min_val, |
|
max_val, |
|
max_subtoken_length=None, |
|
reserved_tokens=None, |
|
num_iterations=4): |
|
"""Builds a SubwordTextEncoder that has `vocab_size` near `target_size`. |
|
|
|
Uses simple recursive binary search to find a minimum token count that most |
|
closely matches the `target_size`. |
|
|
|
Args: |
|
target_size: Desired vocab_size to approximate. |
|
token_counts: A dictionary of token counts, mapping string to int. |
|
min_val: An integer; lower bound for the minimum token count. |
|
max_val: An integer; upper bound for the minimum token count. |
|
max_subtoken_length: Maximum length of a subtoken. If this is not set, |
|
then the runtime and memory use of creating the vocab is quadratic in |
|
the length of the longest token. If this is set, then it is instead |
|
O(max_subtoken_length * length of longest token). |
|
reserved_tokens: List of reserved tokens. The global variable |
|
`RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this |
|
argument is `None`, it will use `RESERVED_TOKENS`. |
|
num_iterations: An integer; how many iterations of refinement. |
|
|
|
Returns: |
|
A SubwordTextEncoder instance. |
|
|
|
Raises: |
|
ValueError: If `min_val` is greater than `max_val`. |
|
""" |
|
if min_val > max_val: |
|
raise ValueError("Lower bound for the minimum token count " |
|
"is greater than the upper bound.") |
|
if target_size < 1: |
|
raise ValueError("Target size must be positive.") |
|
|
|
if reserved_tokens is None: |
|
reserved_tokens = RESERVED_TOKENS |
|
|
|
def bisect(min_val, max_val): |
|
"""Bisection to find the right size.""" |
|
present_count = (max_val + min_val) // 2 |
|
logger.info("Trying min_count %d" % present_count) |
|
subtokenizer = cls() |
|
subtokenizer.build_from_token_counts( |
|
token_counts, present_count, num_iterations, |
|
max_subtoken_length=max_subtoken_length, |
|
reserved_tokens=reserved_tokens) |
|
|
|
|
|
is_ok = abs(subtokenizer.vocab_size - target_size) * 100 < target_size |
|
|
|
if is_ok or min_val >= max_val or present_count < 2: |
|
return subtokenizer |
|
|
|
if subtokenizer.vocab_size > target_size: |
|
other_subtokenizer = bisect(present_count + 1, max_val) |
|
else: |
|
other_subtokenizer = bisect(min_val, present_count - 1) |
|
|
|
if other_subtokenizer is None: |
|
return subtokenizer |
|
|
|
if (abs(other_subtokenizer.vocab_size - target_size) < |
|
abs(subtokenizer.vocab_size - target_size)): |
|
return other_subtokenizer |
|
return subtokenizer |
|
|
|
return bisect(min_val, max_val) |
|
|
|
def build_from_token_counts(self, |
|
token_counts, |
|
min_count, |
|
num_iterations=4, |
|
reserved_tokens=None, |
|
max_subtoken_length=None): |
|
"""Train a SubwordTextEncoder based on a dictionary of word counts. |
|
|
|
Args: |
|
token_counts: a dictionary of Unicode strings to int. |
|
min_count: an integer - discard subtokens with lower counts. |
|
num_iterations: an integer. how many iterations of refinement. |
|
reserved_tokens: List of reserved tokens. The global variable |
|
`RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this |
|
argument is `None`, it will use `RESERVED_TOKENS`. |
|
max_subtoken_length: Maximum length of a subtoken. If this is not set, |
|
then the runtime and memory use of creating the vocab is quadratic in |
|
the length of the longest token. If this is set, then it is instead |
|
O(max_subtoken_length * length of longest token). |
|
|
|
Raises: |
|
ValueError: if reserved is not 0 or len(RESERVED_TOKENS). In this case, it |
|
is not clear what the space is being reserved for, or when it will be |
|
filled in. |
|
""" |
|
|
|
if reserved_tokens is None: |
|
reserved_tokens = RESERVED_TOKENS |
|
else: |
|
|
|
new_reserved_tokens = RESERVED_TOKENS |
|
for token in reserved_tokens: |
|
if token in new_reserved_tokens: |
|
continue |
|
new_reserved_tokens.append(token) |
|
reserved_tokens = new_reserved_tokens |
|
for default, proposed in zip(RESERVED_TOKENS, reserved_tokens): |
|
if default != proposed: |
|
raise ValueError("RESERVED_TOKENS must be a prefix of " |
|
"reserved_tokens.") |
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
alphabet_tokens = chain(six.iterkeys(token_counts), |
|
[native_to_unicode(t) for t in reserved_tokens[len(RESERVED_TOKENS):]]) |
|
|
|
|
|
self._init_alphabet_from_tokens(alphabet_tokens) |
|
|
|
|
|
|
|
self._init_subtokens_from_list(list(self._alphabet), |
|
reserved_tokens=reserved_tokens) |
|
|
|
|
|
|
|
|
|
if min_count < 1: |
|
min_count = 1 |
|
for i in range(num_iterations): |
|
|
|
|
|
|
|
subtoken_counts = collections.defaultdict(int) |
|
for token, count in six.iteritems(token_counts): |
|
iter_start_time = time.time() |
|
|
|
escaped_token = _my_escape_token(token, self._alphabet) |
|
subtokens = self._escaped_token_to_subtoken_strings(escaped_token) |
|
|
|
|
|
|
|
|
|
|
|
start = 0 |
|
for subtoken in subtokens: |
|
last_position = len(escaped_token) + 1 |
|
if max_subtoken_length is not None: |
|
last_position = min(last_position, start + max_subtoken_length) |
|
|
|
for end in range(start + 1, last_position): |
|
new_subtoken = escaped_token[start:end] |
|
subtoken_counts[new_subtoken] += count |
|
start += len(subtoken) |
|
|
|
iter_time_secs = time.time() - iter_start_time |
|
if iter_time_secs > 0.1: |
|
logger.info(u"Processing token [{0}] took {1} seconds, consider " |
|
"setting Text2TextProblem.max_subtoken_length to a " |
|
"smaller value.".format(token, iter_time_secs)) |
|
|
|
|
|
len_to_subtoken_strings = [] |
|
for subtoken_string, count in six.iteritems(subtoken_counts): |
|
lsub = len(subtoken_string) |
|
if count >= min_count: |
|
while len(len_to_subtoken_strings) <= lsub: |
|
len_to_subtoken_strings.append(set()) |
|
len_to_subtoken_strings[lsub].add(subtoken_string) |
|
|
|
|
|
|
|
new_subtoken_strings_with_count = [] |
|
for lsub in range(len(len_to_subtoken_strings) - 1, 0, -1): |
|
subtoken_strings = len_to_subtoken_strings[lsub] |
|
for subtoken_string in subtoken_strings: |
|
count = subtoken_counts[subtoken_string] |
|
if count >= min_count: |
|
|
|
|
|
if subtoken_string not in self._alphabet: |
|
new_subtoken_strings_with_count.append((count, subtoken_string)) |
|
for l in range(1, lsub): |
|
subtoken_counts[subtoken_string[:l]] -= count |
|
|
|
|
|
new_subtoken_strings_with_count.extend((subtoken_counts.get(a, 0), a) |
|
for a in self._alphabet) |
|
new_subtoken_strings_with_count.sort(reverse=True) |
|
|
|
|
|
new_subtoken_strings = [subtoken for _, subtoken in new_subtoken_strings_with_count] |
|
if reserved_tokens: |
|
|
|
|
|
|
|
|
|
|
|
new_subtoken_strings = reserved_tokens + new_subtoken_strings |
|
new_subtoken_strings = list(set(new_subtoken_strings)) |
|
self._init_subtokens_from_list(new_subtoken_strings) |
|
|
|
|
|
|
|
|
|
self.subtokens_with_counts = new_subtoken_strings_with_count |
|
|
|
|
|
|
|
new_subtoken_strings.remove("_") |
|
new_subtoken_strings.insert(len(new_subtoken_strings), "_") |
|
|
|
oov_list = [] |
|
for idx, subtoken in enumerate(new_subtoken_strings): |
|
if subtoken.startswith("_") and subtoken != "_": |
|
new_subtoken_strings[idx] = subtoken[1:] |
|
elif subtoken[0] in self._alphabet and subtoken not in reserved_tokens: |
|
new_subtoken_strings[idx] = "##" + subtoken |
|
else: |
|
oov_list.append(subtoken) |
|
new_subtoken_strings.extend(char for char in self._alphabet |
|
if char not in new_subtoken_strings) |
|
|
|
|
|
|
|
new_subtoken_strings = list(set(new_subtoken_strings)) |
|
self._init_subtokens_from_list(new_subtoken_strings) |
|
|
|
logger.info("total vocab size : {}, {} seconds elapsed ".format(self.vocab_size, time.time() - start_time)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_subtokens_from_list(self, subtoken_strings, reserved_tokens=None): |
|
"""Initialize token information from a list of subtoken strings. |
|
|
|
Args: |
|
subtoken_strings: a list of subtokens |
|
reserved_tokens: List of reserved tokens. We must have `reserved_tokens` |
|
as None or the empty list, or else the global variable `RESERVED_TOKENS` |
|
must be a prefix of `reserved_tokens`. |
|
|
|
Raises: |
|
ValueError: if reserved is not 0 or len(RESERVED_TOKENS). In this case, it |
|
is not clear what the space is being reserved for, or when it will be |
|
filled in. |
|
""" |
|
if reserved_tokens is None: |
|
reserved_tokens = [] |
|
|
|
if reserved_tokens: |
|
self._all_subtoken_strings = reserved_tokens + subtoken_strings |
|
else: |
|
self._all_subtoken_strings = subtoken_strings |
|
|
|
|
|
|
|
self._max_subtoken_len = max([len(s) for s in subtoken_strings]) |
|
self._subtoken_string_to_id = { |
|
s: i + len(reserved_tokens) |
|
for i, s in enumerate(subtoken_strings) if s |
|
} |
|
|
|
self._cache_size = 2 ** 20 |
|
self._cache = [(None, None)] * self._cache_size |
|
|
|
def _init_alphabet_from_tokens(self, tokens): |
|
"""Initialize alphabet from an iterable of token or subtoken strings.""" |
|
|
|
|
|
self._alphabet = {c for token in tokens for c in token} |
|
self._alphabet |= _ESCAPE_CHARS |
|
self._alphabet |= _SPECIAL_CHARS |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def store_to_file(self, filename, add_single_quotes=True): |
|
|
|
with open(filename, "w") as f: |
|
for subtoken_string in self._all_subtoken_strings: |
|
if add_single_quotes: |
|
f.write("'" + unicode_to_native(subtoken_string) + "'\n") |
|
else: |
|
f.write(unicode_to_native(subtoken_string) + "\n") |
|
|
|
def store_to_file_with_counts(self, filename): |
|
|
|
with open(filename, "w") as f: |
|
for subtoken_string, count in self.subtokens_with_counts: |
|
f.write(unicode_to_native(subtoken_string + "\t" + str(count)) + "\n") |
|
|