Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyright 2018 The Tensor2Tensor Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Encoders for text data. | |
* TextEncoder: base class | |
* SubwordTextEncoder: invertible | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import collections | |
from itertools import chain | |
import re | |
import time | |
import logging | |
import six | |
from six.moves import range # pylint: disable=redefined-builtin | |
# from tensor2tensor.data_generators import tokenizer | |
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | |
datefmt='%m/%d/%Y %H:%M:%S', | |
level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Reserved tokens for things like padding and EOS symbols. | |
PAD = "[PAD]" | |
EOS = "[EOS]" | |
UNK = "[UNK]" | |
CLS = "[CLS]" | |
SEP = "[SEP]" | |
MASK = "[MASK]" | |
RESERVED_TOKENS = [PAD, EOS, UNK, CLS, SEP, MASK] | |
NUM_RESERVED_TOKENS = len(RESERVED_TOKENS) | |
PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0 | |
EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1 | |
if six.PY2: | |
RESERVED_TOKENS_BYTES = RESERVED_TOKENS | |
else: | |
RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")] | |
# Regular expression for unescaping token strings. | |
# '\u' is converted to '_' | |
# '\\' is converted to '\' | |
# '\213;' is converted to unichr(213) | |
_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);") | |
_ESCAPE_CHARS = set(u"\\_u;0123456789") | |
_SPECIAL_CHARS = set(u"!\"\'#$%&*()`+,-./:;<=>?@[]^_{}~|") | |
# Unicode utility functions that work with Python 2 and 3 | |
def native_to_unicode(s): | |
if is_unicode(s): | |
return s | |
try: | |
return to_unicode(s) | |
except UnicodeDecodeError: | |
res = to_unicode(s, ignore_errors=True) | |
logger.info("Ignoring Unicode error, outputting: %s" % res) | |
return res | |
def unicode_to_native(s): | |
if six.PY2: | |
return s.encode("utf-8") if is_unicode(s) else s | |
else: | |
return s | |
def is_unicode(s): | |
return isinstance(s, six.text_type) | |
def to_unicode(s, ignore_errors=False): | |
if is_unicode(s): | |
return s | |
error_mode = "ignore" if ignore_errors else "strict" | |
return s.decode("utf-8", errors=error_mode) | |
# def to_unicode_ignore_errors(s): | |
# return to_unicode(s, ignore_errors=True) | |
# def to_unicode_utf8(s): | |
# return unicode(s, "utf-8") if six.PY2 else s.decode("utf-8") | |
# def strip_ids(ids, ids_to_strip): | |
# """Strip ids_to_strip from the end ids.""" | |
# ids = list(ids) | |
# while ids and ids[-1] in ids_to_strip: | |
# ids.pop() | |
# return ids | |
class TextEncoder(object): | |
"""Base class for converting from ints to/from human readable strings.""" | |
def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS): | |
self._num_reserved_ids = num_reserved_ids | |
def num_reserved_ids(self): | |
return self._num_reserved_ids | |
# def encode(self, s): | |
# """Transform a human-readable string into a sequence of int ids. | |
# | |
# The ids should be in the range [num_reserved_ids, vocab_size). Ids [0, | |
# num_reserved_ids) are reserved. | |
# | |
# EOS is not appended. | |
# | |
# Args: | |
# s: human-readable string to be converted. | |
# | |
# Returns: | |
# ids: list of integers | |
# """ | |
# return [int(w) + self._num_reserved_ids for w in s.split()] | |
# | |
# def decode(self, ids, strip_extraneous=False): | |
# """Transform a sequence of int ids into a human-readable string. | |
# | |
# EOS is not expected in ids. | |
# | |
# Args: | |
# ids: list of integers to be converted. | |
# strip_extraneous: bool, whether to strip off extraneous tokens | |
# (EOS and PAD). | |
# | |
# Returns: | |
# s: human-readable string. | |
# """ | |
# if strip_extraneous: | |
# ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) | |
# return " ".join(self.decode_list(ids)) | |
# | |
# def decode_list(self, ids): | |
# """Transform a sequence of int ids into a their string versions. | |
# | |
# This method supports transforming individual input/output ids to their | |
# string versions so that sequence to/from text conversions can be visualized | |
# in a human readable format. | |
# | |
# Args: | |
# ids: list of integers to be converted. | |
# | |
# Returns: | |
# strs: list of human-readable string. | |
# """ | |
# decoded_ids = [] | |
# for id_ in ids: | |
# if 0 <= id_ < self._num_reserved_ids: | |
# decoded_ids.append(RESERVED_TOKENS[int(id_)]) | |
# else: | |
# decoded_ids.append(id_ - self._num_reserved_ids) | |
# return [str(d) for d in decoded_ids] | |
def vocab_size(self): | |
raise NotImplementedError() | |
def _escape_token(token, alphabet): | |
"""Escape away underscores and OOV characters and append '_'. | |
This allows the token to be expressed as the concatenation of a list | |
of subtokens from the vocabulary. The underscore acts as a sentinel | |
which allows us to invertibly concatenate multiple such lists. | |
Args: | |
token: A unicode string to be escaped. | |
alphabet: A set of all characters in the vocabulary's alphabet. | |
Returns: | |
escaped_token: An escaped unicode string. | |
Raises: | |
ValueError: If the provided token is not unicode. | |
""" | |
if not isinstance(token, six.text_type): | |
raise ValueError("Expected string type for token, got %s" % type(token)) | |
token = token.replace(u"\\", u"\\\\").replace(u"_", u"\\u") | |
ret = [c if c in alphabet and c != u"\n" else r"\%d;" % ord(c) for c in token] | |
return u"".join(ret) + "_" | |
def _my_escape_token(token, alphabet): | |
if not isinstance(token, six.text_type): | |
raise ValueError("Expected string type for token, got %s" % type(token)) | |
token = token.replace(u"\\", u"\\\\").replace(u"_", u"\\u") | |
ret = [c if c in alphabet and c != u"\n" else r"\%d;" % ord(c) for c in token] | |
return "_" + u"".join(ret) | |
# def _unescape_token(escaped_token): | |
# """Inverse of _escape_token(). | |
# | |
# Args: | |
# escaped_token: a unicode string | |
# | |
# Returns: | |
# token: a unicode string | |
# """ | |
# | |
# def match(m): | |
# if m.group(1) is None: | |
# return u"_" if m.group(0) == u"\\u" else u"\\" | |
# | |
# try: | |
# return six.unichr(int(m.group(1))) | |
# except (ValueError, OverflowError) as _: | |
# return u"\u3013" # Unicode for undefined character. | |
# | |
# trimmed = escaped_token[:-1] if escaped_token.endswith("_") else escaped_token | |
# return _UNESCAPE_REGEX.sub(match, trimmed) | |
class SubwordTextEncoder(TextEncoder): | |
"""Class for invertibly encoding text using a limited vocabulary. | |
Invertibly encodes a native string as a sequence of subtokens from a limited | |
vocabulary. | |
A SubwordTextEncoder is built from a corpus (so it is tailored to the text in | |
the corpus), and stored to a file. See text_encoder_build_subword.py. | |
It can then be loaded and used to encode/decode any text. | |
Encoding has four phases: | |
1. Tokenize into a list of tokens. Each token is a unicode string of either | |
all alphanumeric characters or all non-alphanumeric characters. We drop | |
tokens consisting of a single space that are between two alphanumeric | |
tokens. | |
2. Escape each token. This escapes away special and out-of-vocabulary | |
characters, and makes sure that each token ends with an underscore, and | |
has no other underscores. | |
3. Represent each escaped token as a the concatenation of a list of subtokens | |
from the limited vocabulary. Subtoken selection is done greedily from | |
beginning to end. That is, we construct the list in order, always picking | |
the longest subtoken in our vocabulary that matches a prefix of the | |
remaining portion of the encoded token. | |
4. Concatenate these lists. This concatenation is invertible due to the | |
fact that the trailing underscores indicate when one list is finished. | |
""" | |
def __init__(self, filename=None): | |
"""Initialize and read from a file, if provided. | |
Args: | |
filename: filename from which to read vocab. If None, do not load a | |
vocab | |
""" | |
self._alphabet = set() | |
# self.filename = filename | |
# if filename is not None: | |
# self._load_from_file(filename) | |
super(SubwordTextEncoder, self).__init__() | |
# def encode(self, s): | |
# """Converts a native string to a list of subtoken ids. | |
# | |
# Args: | |
# s: a native string. | |
# Returns: | |
# a list of integers in the range [0, vocab_size) | |
# """ | |
# return self._tokens_to_subtoken_ids( | |
# tokenizer.encode(native_to_unicode(s))) | |
# | |
# def encode_without_tokenizing(self, token_text): | |
# """Converts string to list of subtoken ids without calling tokenizer. | |
# | |
# This treats `token_text` as a single token and directly converts it | |
# to subtoken ids. This may be useful when the default tokenizer doesn't | |
# do what we want (e.g., when encoding text with tokens composed of lots of | |
# nonalphanumeric characters). It is then up to the caller to make sure that | |
# raw text is consistently converted into tokens. Only use this if you are | |
# sure that `encode` doesn't suit your needs. | |
# | |
# Args: | |
# token_text: A native string representation of a single token. | |
# Returns: | |
# A list of subword token ids; i.e., integers in the range [0, vocab_size). | |
# """ | |
# return self._tokens_to_subtoken_ids([native_to_unicode(token_text)]) | |
# def decode(self, ids, strip_extraneous=False): | |
# """Converts a sequence of subtoken ids to a native string. | |
# | |
# Args: | |
# ids: a list of integers in the range [0, vocab_size) | |
# strip_extraneous: bool, whether to strip off extraneous tokens | |
# (EOS and PAD). | |
# | |
# Returns: | |
# a native string | |
# """ | |
# if strip_extraneous: | |
# ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) | |
# return unicode_to_native( | |
# tokenizer.decode(self._subtoken_ids_to_tokens(ids))) | |
# def decode_list(self, ids): | |
# return [self._subtoken_id_to_subtoken_string(s) for s in ids] | |
def vocab_size(self): | |
"""The subtoken vocabulary size.""" | |
return len(self._all_subtoken_strings) | |
# def _tokens_to_subtoken_ids(self, tokens): | |
# """Converts a list of tokens to a list of subtoken ids. | |
# | |
# Args: | |
# tokens: a list of strings. | |
# Returns: | |
# a list of integers in the range [0, vocab_size) | |
# """ | |
# ret = [] | |
# for token in tokens: | |
# ret.extend(self._token_to_subtoken_ids(token)) | |
# return ret | |
# def _token_to_subtoken_ids(self, token): | |
# """Converts token to a list of subtoken ids. | |
# | |
# Args: | |
# token: a string. | |
# Returns: | |
# a list of integers in the range [0, vocab_size) | |
# """ | |
# cache_location = hash(token) % self._cache_size | |
# cache_key, cache_value = self._cache[cache_location] | |
# if cache_key == token: | |
# return cache_value | |
# ret = self._escaped_token_to_subtoken_ids( | |
# _escape_token(token, self._alphabet)) | |
# self._cache[cache_location] = (token, ret) | |
# return ret | |
# def _subtoken_ids_to_tokens(self, subtokens): | |
# """Converts a list of subtoken ids to a list of tokens. | |
# | |
# Args: | |
# subtokens: a list of integers in the range [0, vocab_size) | |
# Returns: | |
# a list of strings. | |
# """ | |
# concatenated = "".join( | |
# [self._subtoken_id_to_subtoken_string(s) for s in subtokens]) | |
# split = concatenated.split("_") | |
# ret = [] | |
# for t in split: | |
# if t: | |
# unescaped = _unescape_token(t + "_") | |
# if unescaped: | |
# ret.append(unescaped) | |
# return ret | |
# def _subtoken_id_to_subtoken_string(self, subtoken): | |
# """Converts a subtoken integer ID to a subtoken string.""" | |
# if 0 <= subtoken < self.vocab_size: | |
# return self._all_subtoken_strings[subtoken] | |
# return u"" | |
def _escaped_token_to_subtoken_strings(self, escaped_token): | |
"""Converts an escaped token string to a list of subtoken strings. | |
Args: | |
escaped_token: An escaped token as a unicode string. | |
Returns: | |
A list of subtokens as unicode strings. | |
""" | |
# NOTE: This algorithm is greedy; it won't necessarily produce the "best" | |
# list of subtokens. | |
ret = [] | |
start = 0 | |
token_len = len(escaped_token) | |
while start < token_len: | |
for end in range( | |
min(token_len, start + self._max_subtoken_len), start, -1): | |
subtoken = escaped_token[start:end] | |
if subtoken in self._subtoken_string_to_id: | |
ret.append(subtoken) | |
start = end | |
break | |
else: # Did not break | |
# If there is no possible encoding of the escaped token then one of the | |
# characters in the token is not in the alphabet. This should be | |
# impossible and would be indicative of a bug. | |
assert False, "Token substring not found in subtoken vocabulary." | |
return ret | |
# def _escaped_token_to_subtoken_ids(self, escaped_token): | |
# """Converts an escaped token string to a list of subtoken IDs. | |
# | |
# Args: | |
# escaped_token: An escaped token as a unicode string. | |
# Returns: | |
# A list of subtoken IDs as integers. | |
# """ | |
# return [ | |
# self._subtoken_string_to_id[subtoken] | |
# for subtoken in self._escaped_token_to_subtoken_strings(escaped_token) | |
# ] | |
# @classmethod | |
# def build_from_generator(cls, | |
# generator, | |
# target_size, | |
# max_subtoken_length=None, | |
# reserved_tokens=None): | |
# """Builds a SubwordTextEncoder from the generated text. | |
# | |
# Args: | |
# generator: yields text. | |
# target_size: int, approximate vocabulary size to create. | |
# max_subtoken_length: Maximum length of a subtoken. If this is not set, | |
# then the runtime and memory use of creating the vocab is quadratic in | |
# the length of the longest token. If this is set, then it is instead | |
# O(max_subtoken_length * length of longest token). | |
# reserved_tokens: List of reserved tokens. The global variable | |
# `RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this | |
# argument is `None`, it will use `RESERVED_TOKENS`. | |
# | |
# Returns: | |
# SubwordTextEncoder with `vocab_size` approximately `target_size`. | |
# """ | |
# token_counts = collections.defaultdict(int) | |
# for item in generator: | |
# for tok in tokenizer.encode(native_to_unicode(item)): | |
# token_counts[tok] += 1 | |
# encoder = cls.build_to_target_size( | |
# target_size, token_counts, 1, 1e3, | |
# max_subtoken_length=max_subtoken_length, | |
# reserved_tokens=reserved_tokens) | |
# return encoder | |
# | |
def build_to_target_size(cls, | |
target_size, | |
token_counts, | |
min_val, | |
max_val, | |
max_subtoken_length=None, | |
reserved_tokens=None, | |
num_iterations=4): | |
"""Builds a SubwordTextEncoder that has `vocab_size` near `target_size`. | |
Uses simple recursive binary search to find a minimum token count that most | |
closely matches the `target_size`. | |
Args: | |
target_size: Desired vocab_size to approximate. | |
token_counts: A dictionary of token counts, mapping string to int. | |
min_val: An integer; lower bound for the minimum token count. | |
max_val: An integer; upper bound for the minimum token count. | |
max_subtoken_length: Maximum length of a subtoken. If this is not set, | |
then the runtime and memory use of creating the vocab is quadratic in | |
the length of the longest token. If this is set, then it is instead | |
O(max_subtoken_length * length of longest token). | |
reserved_tokens: List of reserved tokens. The global variable | |
`RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this | |
argument is `None`, it will use `RESERVED_TOKENS`. | |
num_iterations: An integer; how many iterations of refinement. | |
Returns: | |
A SubwordTextEncoder instance. | |
Raises: | |
ValueError: If `min_val` is greater than `max_val`. | |
""" | |
if min_val > max_val: | |
raise ValueError("Lower bound for the minimum token count " | |
"is greater than the upper bound.") | |
if target_size < 1: | |
raise ValueError("Target size must be positive.") | |
if reserved_tokens is None: | |
reserved_tokens = RESERVED_TOKENS | |
def bisect(min_val, max_val): | |
"""Bisection to find the right size.""" | |
present_count = (max_val + min_val) // 2 | |
logger.info("Trying min_count %d" % present_count) | |
subtokenizer = cls() | |
subtokenizer.build_from_token_counts( | |
token_counts, present_count, num_iterations, | |
max_subtoken_length=max_subtoken_length, | |
reserved_tokens=reserved_tokens) | |
# Being within 1% of the target size is ok. | |
is_ok = abs(subtokenizer.vocab_size - target_size) * 100 < target_size | |
# If min_val == max_val, we can't do any better than this. | |
if is_ok or min_val >= max_val or present_count < 2: | |
return subtokenizer | |
if subtokenizer.vocab_size > target_size: | |
other_subtokenizer = bisect(present_count + 1, max_val) | |
else: | |
other_subtokenizer = bisect(min_val, present_count - 1) | |
if other_subtokenizer is None: | |
return subtokenizer | |
if (abs(other_subtokenizer.vocab_size - target_size) < | |
abs(subtokenizer.vocab_size - target_size)): | |
return other_subtokenizer | |
return subtokenizer | |
return bisect(min_val, max_val) | |
def build_from_token_counts(self, | |
token_counts, | |
min_count, | |
num_iterations=4, | |
reserved_tokens=None, | |
max_subtoken_length=None): | |
"""Train a SubwordTextEncoder based on a dictionary of word counts. | |
Args: | |
token_counts: a dictionary of Unicode strings to int. | |
min_count: an integer - discard subtokens with lower counts. | |
num_iterations: an integer. how many iterations of refinement. | |
reserved_tokens: List of reserved tokens. The global variable | |
`RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this | |
argument is `None`, it will use `RESERVED_TOKENS`. | |
max_subtoken_length: Maximum length of a subtoken. If this is not set, | |
then the runtime and memory use of creating the vocab is quadratic in | |
the length of the longest token. If this is set, then it is instead | |
O(max_subtoken_length * length of longest token). | |
Raises: | |
ValueError: if reserved is not 0 or len(RESERVED_TOKENS). In this case, it | |
is not clear what the space is being reserved for, or when it will be | |
filled in. | |
""" | |
# import pudb; pu.db | |
if reserved_tokens is None: | |
reserved_tokens = RESERVED_TOKENS | |
else: | |
# There is not complete freedom in replacing RESERVED_TOKENS. | |
new_reserved_tokens = RESERVED_TOKENS | |
for token in reserved_tokens: | |
if token in new_reserved_tokens: | |
continue | |
new_reserved_tokens.append(token) | |
reserved_tokens = new_reserved_tokens | |
for default, proposed in zip(RESERVED_TOKENS, reserved_tokens): | |
if default != proposed: | |
raise ValueError("RESERVED_TOKENS must be a prefix of " | |
"reserved_tokens.") | |
start_time = time.time() | |
#import pudb; pu.db | |
# Initialize the alphabet. Note, this must include reserved tokens or it can | |
# result in encoding failures. Remove RESERVED_TOKENS. | |
alphabet_tokens = chain(six.iterkeys(token_counts), | |
[native_to_unicode(t) for t in reserved_tokens[len(RESERVED_TOKENS):]]) | |
# all alphabets in tokens | |
self._init_alphabet_from_tokens(alphabet_tokens) | |
# Bootstrap the initial list of subtokens with the characters from the | |
# alphabet plus the escaping characters. | |
self._init_subtokens_from_list(list(self._alphabet), | |
reserved_tokens=reserved_tokens) | |
# We build iteratively. On each iteration, we segment all the words, | |
# then count the resulting potential subtokens, keeping the ones | |
# with high enough counts for our new vocabulary. | |
if min_count < 1: | |
min_count = 1 | |
for i in range(num_iterations): | |
#logger.info("Iteration {0}".format(i)) | |
# Collect all substrings of the encoded token that break along current | |
# subtoken boundaries. | |
subtoken_counts = collections.defaultdict(int) | |
for token, count in six.iteritems(token_counts): | |
iter_start_time = time.time() | |
# escaped_token = _escape_token(token, self._alphabet) # added "_" at the end | |
escaped_token = _my_escape_token(token, self._alphabet) | |
subtokens = self._escaped_token_to_subtoken_strings(escaped_token) | |
# print(escaped_token) | |
# print(subtokens) | |
# excaped_token '_1234' -> subtoknes ['_12', '34'] (ex) | |
# '_1234':100 -> '_', '_1', '_12', '_123', '_1234','3', '34' :+= 100, | |
start = 0 | |
for subtoken in subtokens: | |
last_position = len(escaped_token) + 1 | |
if max_subtoken_length is not None: | |
last_position = min(last_position, start + max_subtoken_length) | |
for end in range(start + 1, last_position): | |
new_subtoken = escaped_token[start:end] | |
subtoken_counts[new_subtoken] += count | |
start += len(subtoken) | |
iter_time_secs = time.time() - iter_start_time | |
if iter_time_secs > 0.1: | |
logger.info(u"Processing token [{0}] took {1} seconds, consider " | |
"setting Text2TextProblem.max_subtoken_length to a " | |
"smaller value.".format(token, iter_time_secs)) | |
# print(len(subtoken_counts)) | |
# Array of sets of candidate subtoken strings, by length. | |
len_to_subtoken_strings = [] | |
for subtoken_string, count in six.iteritems(subtoken_counts): | |
lsub = len(subtoken_string) | |
if count >= min_count: | |
while len(len_to_subtoken_strings) <= lsub: | |
len_to_subtoken_strings.append(set()) | |
len_to_subtoken_strings[lsub].add(subtoken_string) | |
# Consider the candidates longest to shortest, so that if we accept | |
# a longer subtoken string, we can decrement the counts of its prefixes. | |
new_subtoken_strings_with_count = [] | |
for lsub in range(len(len_to_subtoken_strings) - 1, 0, -1): | |
subtoken_strings = len_to_subtoken_strings[lsub] | |
for subtoken_string in subtoken_strings: | |
count = subtoken_counts[subtoken_string] | |
if count >= min_count: | |
# Exclude alphabet tokens here, as they must be included later, | |
# explicitly, regardless of count. | |
if subtoken_string not in self._alphabet: | |
new_subtoken_strings_with_count.append((count, subtoken_string)) | |
for l in range(1, lsub): | |
subtoken_counts[subtoken_string[:l]] -= count | |
# Include the alphabet explicitly to guarantee all strings are encodable. | |
new_subtoken_strings_with_count.extend((subtoken_counts.get(a, 0), a) | |
for a in self._alphabet) | |
new_subtoken_strings_with_count.sort(reverse=True) | |
# Reinitialize to the candidate vocabulary. | |
new_subtoken_strings = [subtoken for _, subtoken in new_subtoken_strings_with_count] | |
if reserved_tokens: | |
# escaped_reserved_tokens = [ | |
# _escape_token(native_to_unicode(t), self._alphabet) | |
# for t in reserved_tokens | |
# ] | |
# new_subtoken_strings = escaped_reserved_tokens + new_subtoken_strings | |
new_subtoken_strings = reserved_tokens + new_subtoken_strings | |
new_subtoken_strings = list(set(new_subtoken_strings)) | |
self._init_subtokens_from_list(new_subtoken_strings) | |
#logger.info("vocab_size = %d" % self.vocab_size) | |
# print("vocab_size = %d" % self.vocab_size) | |
# print(self.vocab_size) | |
self.subtokens_with_counts = new_subtoken_strings_with_count | |
# Frequency of "_" is high. | |
# So remove from current position and add to the last. | |
new_subtoken_strings.remove("_") | |
new_subtoken_strings.insert(len(new_subtoken_strings), "_") | |
oov_list = [] | |
for idx, subtoken in enumerate(new_subtoken_strings): | |
if subtoken.startswith("_") and subtoken != "_": | |
new_subtoken_strings[idx] = subtoken[1:] | |
elif subtoken[0] in self._alphabet and subtoken not in reserved_tokens: | |
new_subtoken_strings[idx] = "##" + subtoken | |
else: | |
oov_list.append(subtoken) | |
new_subtoken_strings.extend(char for char in self._alphabet | |
if char not in new_subtoken_strings) | |
# print(new_subtoken_strings) | |
# print(oov_list) | |
new_subtoken_strings = list(set(new_subtoken_strings)) | |
self._init_subtokens_from_list(new_subtoken_strings) | |
#logger.info("vocab_size = %d" % self.vocab_size) | |
logger.info("total vocab size : {}, {} seconds elapsed ".format(self.vocab_size, time.time() - start_time)) | |
# @property | |
# def all_subtoken_strings(self): | |
# return tuple(self._all_subtoken_strings) | |
# | |
# def dump(self): | |
# """Debugging dump of the current subtoken vocabulary.""" | |
# subtoken_strings = [(i, s) | |
# for s, i in six.iteritems(self._subtoken_string_to_id)] | |
# print(u", ".join(u"{0} : '{1}'".format(i, s) | |
# for i, s in sorted(subtoken_strings))) | |
def _init_subtokens_from_list(self, subtoken_strings, reserved_tokens=None): | |
"""Initialize token information from a list of subtoken strings. | |
Args: | |
subtoken_strings: a list of subtokens | |
reserved_tokens: List of reserved tokens. We must have `reserved_tokens` | |
as None or the empty list, or else the global variable `RESERVED_TOKENS` | |
must be a prefix of `reserved_tokens`. | |
Raises: | |
ValueError: if reserved is not 0 or len(RESERVED_TOKENS). In this case, it | |
is not clear what the space is being reserved for, or when it will be | |
filled in. | |
""" | |
if reserved_tokens is None: | |
reserved_tokens = [] | |
if reserved_tokens: | |
self._all_subtoken_strings = reserved_tokens + subtoken_strings | |
else: | |
self._all_subtoken_strings = subtoken_strings | |
# we remember the maximum length of any subtoken to avoid having to | |
# check arbitrarily long strings. | |
self._max_subtoken_len = max([len(s) for s in subtoken_strings]) | |
self._subtoken_string_to_id = { | |
s: i + len(reserved_tokens) | |
for i, s in enumerate(subtoken_strings) if s | |
} | |
# Initialize the cache to empty. | |
self._cache_size = 2 ** 20 | |
self._cache = [(None, None)] * self._cache_size | |
def _init_alphabet_from_tokens(self, tokens): | |
"""Initialize alphabet from an iterable of token or subtoken strings.""" | |
# Include all characters from all tokens in the alphabet to guarantee that | |
# any token can be encoded. Additionally, include all escaping characters. | |
self._alphabet = {c for token in tokens for c in token} | |
self._alphabet |= _ESCAPE_CHARS | |
self._alphabet |= _SPECIAL_CHARS | |
# def _load_from_file_object(self, f): | |
# """Load from a file object. | |
# | |
# Args: | |
# f: File object to load vocabulary from | |
# """ | |
# subtoken_strings = [] | |
# for line in f: | |
# s = line.strip() | |
# # Some vocab files wrap words in single quotes, but others don't | |
# if ((s.startswith("'") and s.endswith("'")) or | |
# (s.startswith("\"") and s.endswith("\""))): | |
# s = s[1:-1] | |
# subtoken_strings.append(native_to_unicode(s)) | |
# self._init_subtokens_from_list(subtoken_strings) | |
# self._init_alphabet_from_tokens(subtoken_strings) | |
# | |
# def _load_from_file(self, filename): | |
# """Load from a vocab file.""" | |
# if not tf.gfile.Exists(filename): | |
# raise ValueError("File %s not found" % filename) | |
# with tf.gfile.Open(filename) as f: | |
# self._load_from_file_object(f) | |
def store_to_file(self, filename, add_single_quotes=True): | |
#with tf.gfile.Open(filename, "w") as f: | |
with open(filename, "w") as f: | |
for subtoken_string in self._all_subtoken_strings: | |
if add_single_quotes: | |
f.write("'" + unicode_to_native(subtoken_string) + "'\n") | |
else: | |
f.write(unicode_to_native(subtoken_string) + "\n") | |
def store_to_file_with_counts(self, filename): | |
# with tf.gfile.Open(filename, "w") as f: | |
with open(filename, "w") as f: | |
for subtoken_string, count in self.subtokens_with_counts: | |
f.write(unicode_to_native(subtoken_string + "\t" + str(count)) + "\n") | |