Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyright 2018 The Tensor2Tensor Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""A simple invertible tokenizer. | |
Converts from a unicode string to a list of tokens | |
(represented as Unicode strings). | |
This tokenizer has the following desirable properties: | |
- It is invertible. | |
- Alphanumeric characters are broken away from non-alphanumeric characters. | |
- A single space between words does not produce an extra token. | |
- The full Unicode punctuation and separator set is recognized. | |
The tokenization algorithm is as follows: | |
1. Split the text into a list of tokens, splitting at every boundary of an | |
alphanumeric character and a non-alphanumeric character. This produces | |
a list which alternates between "alphanumeric tokens" | |
(strings of alphanumeric characters) and "non-alphanumeric tokens" | |
(strings of non-alphanumeric characters). | |
2. Remove every token consisting of a single space, unless it is | |
the very first or very last token in the list. These tokens are now | |
implied by the fact that there are two adjacent alphanumeric tokens. | |
e.g. u"Dude - that's so cool." | |
-> [u"Dude", u" - ", u"that", u"'", u"s", u"so", u"cool", u"."] | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import collections | |
import sys | |
import unicodedata | |
import six | |
import logging | |
from six.moves import range # pylint: disable=redefined-builtin | |
# from tensor2tensor.utils import mlperf_log | |
import time | |
import glob | |
# Conversion between Unicode and UTF-8, if required (on Python2) | |
_native_to_unicode = (lambda s: s.decode("utf-8")) if six.PY2 else (lambda s: s) | |
logger = logging.getLogger(__name__) | |
# This set contains all letter and number characters. | |
_ALPHANUMERIC_CHAR_SET = set( | |
six.unichr(i) for i in range(sys.maxunicode) | |
if (unicodedata.category(six.unichr(i)).startswith("L") or | |
unicodedata.category(six.unichr(i)).startswith("N") or | |
unicodedata.category(six.unichr(i)).startswith("P"))) | |
# unicodedata.category(six.unichr(i)).startswith("S") | |
def encode(text): | |
"""Encode a unicode string as a list of tokens. | |
Args: | |
text: a unicode string | |
Returns: | |
a list of tokens as Unicode strings | |
""" | |
if not text: | |
return [] | |
ret = [] | |
token_start = 0 | |
# Classify each character in the input string | |
is_alnum = [c in _ALPHANUMERIC_CHAR_SET for c in text] | |
add_remaining = False | |
for pos in range(1, len(text)): | |
add_remaining = False | |
if is_alnum[pos] != is_alnum[pos - 1]: | |
if not is_alnum[pos]: | |
token = text[token_start:pos] | |
if token != u" " or token_start == 0: | |
add_remaining = False | |
ret.append(token) | |
else: | |
add_remaining = True | |
token_start = pos | |
final_token = text[token_start:] if text[-1] in _ALPHANUMERIC_CHAR_SET else text[token_start:-1] | |
if add_remaining: | |
ret.append(final_token) | |
# split on punctuation | |
final_tokens = [] | |
for token in ret: | |
splitted_token = _run_split_on_punc(token) | |
final_tokens.extend(splitted_token) | |
return final_tokens | |
def _run_split_on_punc(text, never_split=None): | |
"""Splits punctuation on a piece of text.""" | |
if never_split is not None and text in never_split: | |
return [text] | |
chars = list(text) | |
i = 0 | |
start_new_word = True | |
output = [] | |
while i < len(chars): | |
char = chars[i] | |
if _is_punctuation(char): | |
output.append([char]) | |
start_new_word = True | |
else: | |
if start_new_word: | |
output.append([]) | |
start_new_word = False | |
output[-1].append(char) | |
i += 1 | |
return ["".join(x) for x in output] | |
def _is_punctuation(char): | |
"""Checks whether `chars` is a punctuation character.""" | |
cp = ord(char) | |
# We treat all non-letter/number ASCII as punctuation. | |
# Characters such as "^", "$", and "`" are not in the Unicode | |
# Punctuation class but we treat them as punctuation anyways, for | |
# consistency. | |
if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): | |
return True | |
cat = unicodedata.category(char) | |
if cat.startswith("P"): | |
return True | |
return False | |
def decode(tokens): | |
"""Decode a list of tokens to a unicode string. | |
Args: | |
tokens: a list of Unicode strings | |
Returns: | |
a unicode string | |
""" | |
token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens] | |
ret = [] | |
for i, token in enumerate(tokens): | |
if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]: | |
ret.append(u" ") | |
ret.append(token) | |
return "".join(ret) | |
def _read_filepattern(filepattern, max_lines=None, split_on_newlines=True, do_lower_case=False): | |
"""Reads files matching a wildcard pattern, yielding the contents. | |
Args: | |
filepattern: A wildcard pattern matching one or more files. | |
max_lines: If set, stop reading after reading this many lines. | |
split_on_newlines: A boolean. If true, then split files by lines and strip | |
leading and trailing whitespace from each line. Otherwise, treat each | |
file as a single string. | |
Yields: | |
The contents of the files as lines, if split_on_newlines is True, or | |
the entire contents of each file if False. | |
""" | |
filenames = sorted(glob.glob(filepattern)) | |
print(filenames, 'do lower case:', do_lower_case) | |
lines_read = 0 | |
for filename in filenames: | |
start = time.time() | |
with open(filename) as f: | |
if split_on_newlines: | |
for line in f: | |
if do_lower_case: | |
line = line.lower() | |
yield line.strip() | |
lines_read += 1 | |
if max_lines and lines_read >= max_lines: | |
return | |
if lines_read % 100000 == 0: | |
print("read", lines_read, "lines,", time.time() - start, "secs elapsed") | |
else: | |
if max_lines: | |
doc = [] | |
for line in f: | |
if do_lower_case: | |
line = line.lower() | |
doc.append(line) | |
lines_read += 1 | |
if max_lines and lines_read >= max_lines: | |
yield "".join(doc) | |
return | |
yield "".join(doc) | |
else: | |
yield f.read() | |
print(time.time() - start, "for reading read file :", filename) | |
def corpus_token_counts( | |
text_filepattern, corpus_max_lines, split_on_newlines=True, additional_chars="", do_lower_case=False): | |
"""Read the corpus and compute a dictionary of token counts. | |
Args: | |
text_filepattern: A pattern matching one or more files. | |
corpus_max_lines: An integer; maximum total lines to read. | |
split_on_newlines: A boolean. If true, then split files by lines and strip | |
leading and trailing whitespace from each line. Otherwise, treat each | |
file as a single string. | |
additional_chars: A String. Each consisting characters will be treat as normal | |
alphabets so that they will be included in each vocab. | |
Returns: | |
a dictionary mapping token to count. | |
""" | |
if additional_chars: | |
_ALPHANUMERIC_CHAR_SET.add(additional_chars) | |
counts = collections.Counter() | |
for doc in _read_filepattern( | |
text_filepattern, | |
max_lines=corpus_max_lines, | |
split_on_newlines=split_on_newlines, | |
do_lower_case=do_lower_case): | |
counts.update(encode(_native_to_unicode(doc))) | |
print("read all files") | |
return counts | |
def vocab_token_counts(text_filepattern, max_lines, do_lower_case=False): | |
"""Read a vocab file and return a dictionary of token counts. | |
Reads a two-column CSV file of tokens and their frequency in a dataset. The | |
tokens are presumed to be generated by encode() or the equivalent. | |
Args: | |
text_filepattern: A pattern matching one or more files. | |
max_lines: An integer; maximum total lines to read. | |
Returns: | |
a dictionary mapping token to count. | |
""" | |
ret = {} | |
for i, line in enumerate( | |
_read_filepattern(text_filepattern, max_lines=max_lines)): | |
if "," not in line: | |
logger.warning("Malformed vocab line #%d '%s'", i, line) | |
continue | |
if do_lower_case: | |
line = line.lower() | |
token, count = line.rsplit(",", 1) | |
ret[_native_to_unicode(token)] = int(count) | |
return ret | |