|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""A simple invertible tokenizer. |
|
|
|
Converts from a unicode string to a list of tokens |
|
(represented as Unicode strings). |
|
|
|
This tokenizer has the following desirable properties: |
|
- It is invertible. |
|
- Alphanumeric characters are broken away from non-alphanumeric characters. |
|
- A single space between words does not produce an extra token. |
|
- The full Unicode punctuation and separator set is recognized. |
|
|
|
The tokenization algorithm is as follows: |
|
|
|
1. Split the text into a list of tokens, splitting at every boundary of an |
|
alphanumeric character and a non-alphanumeric character. This produces |
|
a list which alternates between "alphanumeric tokens" |
|
(strings of alphanumeric characters) and "non-alphanumeric tokens" |
|
(strings of non-alphanumeric characters). |
|
|
|
2. Remove every token consisting of a single space, unless it is |
|
the very first or very last token in the list. These tokens are now |
|
implied by the fact that there are two adjacent alphanumeric tokens. |
|
|
|
e.g. u"Dude - that's so cool." |
|
-> [u"Dude", u" - ", u"that", u"'", u"s", u"so", u"cool", u"."] |
|
""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import collections |
|
import sys |
|
import unicodedata |
|
import six |
|
import logging |
|
from six.moves import range |
|
|
|
import time |
|
import glob |
|
|
|
|
|
_native_to_unicode = (lambda s: s.decode("utf-8")) if six.PY2 else (lambda s: s) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
_ALPHANUMERIC_CHAR_SET = set( |
|
six.unichr(i) for i in range(sys.maxunicode) |
|
if (unicodedata.category(six.unichr(i)).startswith("L") or |
|
unicodedata.category(six.unichr(i)).startswith("N") or |
|
unicodedata.category(six.unichr(i)).startswith("P"))) |
|
|
|
|
|
|
|
def encode(text): |
|
"""Encode a unicode string as a list of tokens. |
|
|
|
Args: |
|
text: a unicode string |
|
Returns: |
|
a list of tokens as Unicode strings |
|
""" |
|
if not text: |
|
return [] |
|
ret = [] |
|
token_start = 0 |
|
|
|
is_alnum = [c in _ALPHANUMERIC_CHAR_SET for c in text] |
|
add_remaining = False |
|
for pos in range(1, len(text)): |
|
add_remaining = False |
|
if is_alnum[pos] != is_alnum[pos - 1]: |
|
if not is_alnum[pos]: |
|
token = text[token_start:pos] |
|
if token != u" " or token_start == 0: |
|
add_remaining = False |
|
ret.append(token) |
|
else: |
|
add_remaining = True |
|
token_start = pos |
|
|
|
final_token = text[token_start:] if text[-1] in _ALPHANUMERIC_CHAR_SET else text[token_start:-1] |
|
|
|
if add_remaining: |
|
ret.append(final_token) |
|
|
|
|
|
final_tokens = [] |
|
for token in ret: |
|
splitted_token = _run_split_on_punc(token) |
|
final_tokens.extend(splitted_token) |
|
|
|
return final_tokens |
|
|
|
|
|
def _run_split_on_punc(text, never_split=None): |
|
"""Splits punctuation on a piece of text.""" |
|
if never_split is not None and text in never_split: |
|
return [text] |
|
chars = list(text) |
|
i = 0 |
|
start_new_word = True |
|
output = [] |
|
while i < len(chars): |
|
char = chars[i] |
|
if _is_punctuation(char): |
|
output.append([char]) |
|
start_new_word = True |
|
else: |
|
if start_new_word: |
|
output.append([]) |
|
start_new_word = False |
|
output[-1].append(char) |
|
i += 1 |
|
|
|
return ["".join(x) for x in output] |
|
|
|
|
|
def _is_punctuation(char): |
|
"""Checks whether `chars` is a punctuation character.""" |
|
cp = ord(char) |
|
|
|
|
|
|
|
|
|
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): |
|
return True |
|
cat = unicodedata.category(char) |
|
if cat.startswith("P"): |
|
return True |
|
return False |
|
|
|
|
|
def decode(tokens): |
|
"""Decode a list of tokens to a unicode string. |
|
|
|
Args: |
|
tokens: a list of Unicode strings |
|
Returns: |
|
a unicode string |
|
""" |
|
token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens] |
|
ret = [] |
|
for i, token in enumerate(tokens): |
|
if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]: |
|
ret.append(u" ") |
|
ret.append(token) |
|
return "".join(ret) |
|
|
|
|
|
def _read_filepattern(filepattern, max_lines=None, split_on_newlines=True, do_lower_case=False): |
|
"""Reads files matching a wildcard pattern, yielding the contents. |
|
|
|
Args: |
|
filepattern: A wildcard pattern matching one or more files. |
|
max_lines: If set, stop reading after reading this many lines. |
|
split_on_newlines: A boolean. If true, then split files by lines and strip |
|
leading and trailing whitespace from each line. Otherwise, treat each |
|
file as a single string. |
|
|
|
Yields: |
|
The contents of the files as lines, if split_on_newlines is True, or |
|
the entire contents of each file if False. |
|
""" |
|
filenames = sorted(glob.glob(filepattern)) |
|
print(filenames, 'do lower case:', do_lower_case) |
|
lines_read = 0 |
|
for filename in filenames: |
|
start = time.time() |
|
with open(filename) as f: |
|
if split_on_newlines: |
|
for line in f: |
|
if do_lower_case: |
|
line = line.lower() |
|
yield line.strip() |
|
lines_read += 1 |
|
if max_lines and lines_read >= max_lines: |
|
return |
|
if lines_read % 100000 == 0: |
|
print("read", lines_read, "lines,", time.time() - start, "secs elapsed") |
|
|
|
else: |
|
if max_lines: |
|
doc = [] |
|
for line in f: |
|
if do_lower_case: |
|
line = line.lower() |
|
doc.append(line) |
|
lines_read += 1 |
|
if max_lines and lines_read >= max_lines: |
|
yield "".join(doc) |
|
return |
|
yield "".join(doc) |
|
|
|
else: |
|
yield f.read() |
|
|
|
print(time.time() - start, "for reading read file :", filename) |
|
|
|
|
|
def corpus_token_counts( |
|
text_filepattern, corpus_max_lines, split_on_newlines=True, additional_chars="", do_lower_case=False): |
|
"""Read the corpus and compute a dictionary of token counts. |
|
|
|
Args: |
|
text_filepattern: A pattern matching one or more files. |
|
corpus_max_lines: An integer; maximum total lines to read. |
|
split_on_newlines: A boolean. If true, then split files by lines and strip |
|
leading and trailing whitespace from each line. Otherwise, treat each |
|
file as a single string. |
|
additional_chars: A String. Each consisting characters will be treat as normal |
|
alphabets so that they will be included in each vocab. |
|
|
|
Returns: |
|
a dictionary mapping token to count. |
|
""" |
|
if additional_chars: |
|
_ALPHANUMERIC_CHAR_SET.add(additional_chars) |
|
|
|
counts = collections.Counter() |
|
for doc in _read_filepattern( |
|
text_filepattern, |
|
max_lines=corpus_max_lines, |
|
split_on_newlines=split_on_newlines, |
|
do_lower_case=do_lower_case): |
|
counts.update(encode(_native_to_unicode(doc))) |
|
print("read all files") |
|
return counts |
|
|
|
|
|
def vocab_token_counts(text_filepattern, max_lines, do_lower_case=False): |
|
"""Read a vocab file and return a dictionary of token counts. |
|
|
|
Reads a two-column CSV file of tokens and their frequency in a dataset. The |
|
tokens are presumed to be generated by encode() or the equivalent. |
|
|
|
Args: |
|
text_filepattern: A pattern matching one or more files. |
|
max_lines: An integer; maximum total lines to read. |
|
|
|
Returns: |
|
a dictionary mapping token to count. |
|
""" |
|
ret = {} |
|
for i, line in enumerate( |
|
_read_filepattern(text_filepattern, max_lines=max_lines)): |
|
if "," not in line: |
|
logger.warning("Malformed vocab line #%d '%s'", i, line) |
|
continue |
|
|
|
if do_lower_case: |
|
line = line.lower() |
|
|
|
token, count = line.rsplit(",", 1) |
|
ret[_native_to_unicode(token)] = int(count) |
|
|
|
return ret |
|
|