|
import torch |
|
import numpy as np |
|
import copy |
|
from collections import OrderedDict |
|
import json |
|
from datasets import ClassLabel |
|
import random |
|
import math |
|
from functools import lru_cache |
|
from matplotlib import font_manager |
|
from colorama import Fore, Style, init |
|
|
|
|
|
class BaseQuantizer: |
|
@property |
|
def ignore_tokens(self): |
|
if self.num_mask_tokens > 0: |
|
if self.mask_type == 'cm3': |
|
return [self.predict_start_token] + self.mask_tokens |
|
elif self.mask_type == 'mask_aug': |
|
return [self.mask_aug_token] |
|
else: |
|
raise ValueError(f'Invalid mask type {self.mask_type}') |
|
else: |
|
return [] |
|
|
|
def __init__(self, simplify_json=False, mask_all=False, |
|
num_mask_tokens=0, mask_type='cm3', **kwargs): |
|
self.simplify_json=simplify_json |
|
self.io_ignore_replace_tokens = ['<split-text>'] |
|
self.mask_all = mask_all |
|
self.num_mask_tokens = num_mask_tokens |
|
self.mask_type = mask_type |
|
if self.mask_type == 'mask_aug': |
|
self.mask_aug_token = '<mask-aug>' |
|
elif self.mask_type == 'cm3': |
|
self.predict_start_token = '<pred-start>' |
|
else: |
|
raise ValueError(f'Invalid mask type {self.mask_type}') |
|
|
|
def get_additional_mask_tokens(self): |
|
if self.mask_type == 'cm3': |
|
self.mask_tokens = ['<mask-%d>' % i for i in range(self.num_mask_tokens)] |
|
return [self.predict_start_token] + self.mask_tokens |
|
elif self.mask_type == 'mask_aug': |
|
return [self.mask_aug_token] |
|
else: |
|
raise ValueError(f'Invalid mask type {self.mask_type}') |
|
|
|
def dump2json(self, json_example): |
|
if self.simplify_json: |
|
content = json.dumps(json_example, separators=(',',':')) |
|
for token in self.additional_special_tokens: |
|
content = content.replace(f'"{token}"', token) |
|
else: |
|
content = json.dumps(json_example) |
|
return content |
|
|
|
def load_json(self, content): |
|
replace_tokens = set(self.additional_special_tokens) - set(self.io_ignore_replace_tokens) |
|
if self.simplify_json: |
|
for token in replace_tokens: |
|
content = content.replace(token, f'"{token}"') |
|
return json.loads(content) |
|
|
|
def apply_masking(self, |
|
json_example, |
|
mask_all=None, |
|
return_meta=False, |
|
target_keys=['width', 'height', 'left', 'top'], |
|
target_element_types=None |
|
): |
|
if mask_all is None: |
|
mask_all = self.mask_all |
|
json_example = copy.deepcopy(json_example) |
|
target_keys = set(target_keys) |
|
target_tokens = [] |
|
for shape_i, shape in enumerate(json_example['layers']['textlayer']): |
|
|
|
|
|
|
|
|
|
for key_i, key in enumerate(shape.keys()): |
|
if key in target_keys: |
|
target_tokens.append((shape_i, key_i, key, shape[key])) |
|
if not mask_all: |
|
target_num_mask_tokens = random.randint(1, self.num_mask_tokens) |
|
if len(target_tokens) > target_num_mask_tokens: |
|
random.shuffle(target_tokens) |
|
target_tokens = target_tokens[:target_num_mask_tokens] |
|
|
|
target_tokens = sorted(target_tokens, key=lambda x: x[0]*100+x[1]) |
|
else: |
|
if len(target_tokens) > self.num_mask_tokens: |
|
|
|
target_tokens = target_tokens[-self.num_mask_tokens:] |
|
|
|
tuples = [] |
|
meta_infos = [] |
|
for mask_i, (shape_i, key_i, key, value) in enumerate(target_tokens): |
|
if self.mask_type == 'cm3': |
|
mask_token = self.mask_tokens[mask_i] |
|
elif self.mask_type == 'mask_aug': |
|
mask_token = self.mask_aug_token |
|
else: |
|
raise ValueError(f'Invalid mask type {self.mask_type}') |
|
|
|
if '<' in value: |
|
num_token = value.count('<') |
|
else: |
|
num_token = value.count(' ') |
|
json_example['layers']['textlayer'][shape_i][key] = mask_token |
|
tuples.append((mask_token, value, num_token)) |
|
meta_infos.append((shape_i,key)) |
|
if return_meta: |
|
return json_example, tuples, meta_infos |
|
else: |
|
return json_example, tuples |
|
|
|
def make_prediction_postfix(self, tuples): |
|
postfix = self.predict_start_token |
|
for mask_token, value, num_token in tuples: |
|
postfix = postfix+ f'{mask_token}{value}' |
|
return postfix |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
specs={ |
|
"width":"size", |
|
"height":"size", |
|
"x":"pos", |
|
"y":"pos", |
|
"color":"color", |
|
"font":"font" |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
min_max_bins = { |
|
'size': (0,1,256), |
|
'pos': (0,1,256), |
|
'color': (0,137,138), |
|
'font': (0,511,512) |
|
} |
|
|
|
import numpy as np |
|
|
|
|
|
def get_keys_and_multipliers(pre_decimal=3, post_decimal=2): |
|
pre_keys = ['one', 'ten', 'hundred', 'thousand'] |
|
pre_multiplers = [1, 10, 100, 1000] |
|
assert pre_decimal <= len(pre_keys) |
|
pre_keys = pre_keys[:pre_decimal][::-1] |
|
pre_multiplers = pre_multiplers[:pre_decimal][::-1] |
|
|
|
post_keys = [f'decimal{x}' for x in range(post_decimal)] |
|
post_multiplers = [10 ** -(x+1) for x in range(post_decimal)] |
|
|
|
keys = pre_keys + post_keys |
|
multiplers = pre_multiplers + post_multiplers |
|
return keys, multiplers |
|
|
|
class DecimalQuantizer: |
|
def __init__(self, max_pre_decimal=3, max_post_decimal=2): |
|
self.max_pre_decimal = max_pre_decimal |
|
self.max_post_decimal = max_post_decimal |
|
self.keys, self.multiplers = get_keys_and_multipliers(max_pre_decimal, max_post_decimal) |
|
self.symbols = { |
|
-1: '<symbol-1>', |
|
1: '<symbol-0>', |
|
} |
|
|
|
def get_vocab(self): |
|
special_tokens = [*self.symbols.values()] |
|
for key in self.keys: |
|
special_tokens.extend([f'<{key}-{i}>' for i in range(10)]) |
|
return special_tokens |
|
|
|
def check_valid(self, token): |
|
prefix = token.lstrip('<').split('-')[0] |
|
if prefix =='symbol' or prefix in self.keys: |
|
return True |
|
else: |
|
return False |
|
|
|
|
|
def __call__(self, val, pre_decimal=None, post_decimal=None, need_symbol=False): |
|
if pre_decimal is None: |
|
pre_decimal = self.max_pre_decimal |
|
if post_decimal is None: |
|
post_decimal = self.max_post_decimal |
|
|
|
assert pre_decimal <= self.max_pre_decimal |
|
assert post_decimal <= self.max_post_decimal |
|
|
|
keys, multiplers = get_keys_and_multipliers(pre_decimal, post_decimal) |
|
|
|
symbol = int(np.sign(val)) |
|
if symbol == 0: |
|
symbol = 1 |
|
val = round(abs(val), post_decimal) |
|
|
|
tokens = [] |
|
if need_symbol: |
|
symbol_type = self.symbols[symbol] |
|
tokens.append(symbol_type) |
|
else: |
|
assert symbol >= 0 |
|
|
|
for key, multipler in zip(keys, multiplers): |
|
|
|
v = math.floor(val / multipler) |
|
if v > 9: |
|
raise ValueError(f'Invalid value {val} for {pre_decimal} pre_decimal and {post_decimal} post_decimal') |
|
val = val - v * multipler |
|
tokens.append(f'<{key}-{v}>') |
|
|
|
|
|
return ''.join(tokens) |
|
|
|
def parse_token(self, token): |
|
|
|
key, val = token[1:-1].split('-') |
|
return key, int(val) |
|
|
|
def decode(self, tokens_str): |
|
tokens = tokens_str.split('>') |
|
tokens = [x+'>' for x in tokens if x != ''] |
|
if tokens[0].startswith('<symbol'): |
|
symbol_type = tokens[0] |
|
tokens = tokens[1:] |
|
inv_map = {v: k for k, v in self.symbols.items()} |
|
symbol = inv_map[symbol_type] |
|
else: |
|
symbol = 1 |
|
|
|
accumulater = 0 |
|
for token in tokens: |
|
key, val = self.parse_token(token) |
|
multipler_index = self.keys.index(key) |
|
multipler = self.multiplers[multipler_index] |
|
actual_val = val * multipler |
|
|
|
accumulater += actual_val |
|
accumulater = accumulater * symbol |
|
|
|
|
|
return accumulater |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pre_post_decimals={ |
|
'size': { |
|
'pre_decimal': 1, |
|
'post_decimal': 2, |
|
'need_symbol': False |
|
}, |
|
'pos': { |
|
'pre_decimal': 1, |
|
'post_decimal': 2, |
|
'need_symbol': True |
|
}, |
|
'opacity': { |
|
'pre_decimal': 1, |
|
'post_decimal': 1, |
|
'need_symbol': False |
|
}, |
|
'color':{ |
|
'pre_decimal': 3, |
|
'post_decimal': 0, |
|
'need_symbol': False |
|
}, |
|
'angle':{ |
|
'pre_decimal': 1, |
|
'post_decimal': 2, |
|
'need_symbol': False |
|
}, |
|
'font_size':{ |
|
'pre_decimal': 3, |
|
'post_decimal': 0, |
|
'need_symbol': False |
|
}, |
|
} |
|
|
|
class QuantizerV4(BaseQuantizer): |
|
def __init__(self, quant=True, |
|
decimal_quantize_types = [], |
|
decimal_quantize_kwargs = {'max_pre_decimal':3, 'max_post_decimal':2}, |
|
mask_values=False, |
|
**kwargs): |
|
super().__init__(**kwargs) |
|
self.min = min |
|
self.max = max |
|
self.quant = quant |
|
self.mask_values = mask_values |
|
self.text_split_token = '<split-text>' |
|
self.decimal_quantize_types = decimal_quantize_types |
|
self.decimal_quantize = len(decimal_quantize_types) > 0 |
|
if len(decimal_quantize_types) > 0: |
|
print('decimal quantize types', decimal_quantize_types) |
|
self.decimal_quantizer = DecimalQuantizer(**decimal_quantize_kwargs) |
|
else: |
|
self.decimal_quantizer = None |
|
|
|
self.set_min_max_bins(min_max_bins) |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.width = kwargs.get('width', 1456) |
|
self.height = kwargs.get('height', 1457) |
|
self.width = int(self.width) |
|
self.height = int(self.height) |
|
|
|
def set_min_max_bins(self, min_max_bins): |
|
min_max_bins = copy.deepcopy(min_max_bins) |
|
|
|
for type_name, (min_val, max_val, n_bins) in min_max_bins.items(): |
|
assert n_bins % 2 == 0 |
|
min_max_bins[type_name] = (min_val, max_val, n_bins+1) |
|
self.min_max_bins = min_max_bins |
|
|
|
def setup_tokenizer(self, tokenizer): |
|
|
|
|
|
additional_special_tokens = [self.text_split_token] |
|
if self.decimal_quantize: |
|
special_tokens = self.decimal_quantizer.get_vocab() |
|
self.io_ignore_replace_tokens += special_tokens |
|
additional_special_tokens += special_tokens |
|
|
|
rest_types = [key for key in self.min_max_bins.keys() if key not in self.decimal_quantize_types] |
|
for type_name in rest_types: |
|
min_val, max_val, n_bins = self.min_max_bins[type_name] |
|
additional_special_tokens += [f'<{type_name}-{i}>' for i in range(n_bins)] |
|
|
|
if self.num_mask_tokens > 0: |
|
additional_special_tokens.extend(self.get_additional_mask_tokens()) |
|
|
|
print('additional_special_tokens', additional_special_tokens) |
|
|
|
tokenizer.add_special_tokens({'additional_special_tokens': additional_special_tokens}) |
|
self.additional_special_tokens = set(additional_special_tokens) |
|
return tokenizer |
|
|
|
@lru_cache(maxsize=128) |
|
def get_bins(self, real_type): |
|
|
|
min_val, max_val, n_bins = self.min_max_bins[real_type] |
|
return min_val, max_val, np.linspace(min_val, max_val, n_bins) |
|
|
|
def quantize(self, x, type): |
|
if not self.quant: |
|
return x |
|
"""Quantize a float array x into n_bins discrete values.""" |
|
real_type = specs[type] |
|
min_val, max_val, bins = self.get_bins(real_type) |
|
x = np.clip(float(x), min_val, max_val) |
|
if self.decimal_quantize and real_type in self.decimal_quantize_types: |
|
return self.decimal_quantizer(x, **pre_post_decimals[real_type]) |
|
val = np.digitize(x, bins) - 1 |
|
n_bins = len(bins) |
|
assert val >= 0 and val < n_bins |
|
return f'<{real_type}-{val}>' |
|
|
|
def dequantize(self, x): |
|
|
|
val = x.split('-')[1].strip('>') |
|
|
|
real_type = x.split('-')[0][1:] |
|
if self.decimal_quantize and self.decimal_quantizer.check_valid(x): |
|
return self.decimal_quantizer.decode(x) |
|
min_val, max_val, bins = self.get_bins(real_type) |
|
return bins[int(val)] |
|
|
|
def construct_map_dict(self): |
|
map_dict = {} |
|
for i in range(self.min_max_bins['size'][2]): |
|
name = "<size-%d>" % i |
|
value = self.dequantize(name) |
|
map_dict[name] = str(value) |
|
for i in range(self.min_max_bins['pos'][2]): |
|
name = "<pos-%d>" % i |
|
value = self.dequantize(name) |
|
map_dict[name] = str(value) |
|
return map_dict |
|
|
|
def postprocess_colorandfont(self, json_example): |
|
|
|
import re |
|
json_example = re.sub(r'(<font-\d+>)', r'"\1"', json_example) |
|
json_example = re.sub(r'(<color-\d+>)', r'"\1"', json_example) |
|
return json_example |
|
|
|
def to_str(self, x, type): |
|
feature = self.get_feature(type) |
|
return feature.int2str(x) |
|
|
|
def convert2layout(self, example): |
|
new_example = OrderedDict() |
|
new_example['wholecaption'] = example['wholecaption'] |
|
new_layout = [] |
|
for meta_layer in example['layout']: |
|
new_layout.append({ |
|
"layer": meta_layer["layer"], |
|
"x": self.quantize(meta_layer["x"]/self.width, 'x'), |
|
"y": self.quantize(meta_layer["y"]/self.height, 'y'), |
|
"width": self.quantize(meta_layer["width"]/self.width, 'width'), |
|
"height": self.quantize(meta_layer["height"]/self.height, 'height') |
|
}) |
|
new_example['layout'] = new_layout |
|
return new_example |
|
|
|
def apply_masking(self, |
|
json_example, |
|
mask_all=None, |
|
return_meta=False, |
|
|
|
|
|
mask_values = True |
|
): |
|
if mask_all is None: |
|
mask_all = self.mask_all |
|
|
|
json_example = copy.deepcopy(json_example) |
|
|
|
|
|
|
|
target_tokens = [] |
|
if self.mask_values and mask_values: |
|
target_tokens.append((-1,-1,'globalcaption', json_example['globalcaption'])) |
|
target_tokens.append((-1,-1,'canvas_width', json_example['canvas_width'])) |
|
target_tokens.append((-1,-1,'canvas_height', json_example['canvas_height'])) |
|
target_tokens.append((-1,-1,'category', json_example['category'])) |
|
target_tokens.append((-1,-1,'keywords', json_example['keywords'])) |
|
target_tokens.append((-1,-1,'bgcaption', json_example['layers']['bglayer']['bgcaption'])) |
|
target_tokens.append((-1,-1,'flag', json_example['layers']['objlayer']['flag'])) |
|
target_tokens.append((-1,-1,'objcaption', json_example['layers']['objlayer']['objcaption'])) |
|
for layer_i, textlayer in enumerate(json_example['layers']['textlayer']): |
|
target_tokens.append((layer_i, -1, 'text', json_example['layers']['textlayer'][textlayer])) |
|
if not mask_all: |
|
target_num_mask_tokens = random.randint(1, self.num_mask_tokens) |
|
if len(target_tokens) > target_num_mask_tokens: |
|
random.shuffle(target_tokens) |
|
target_tokens = target_tokens[:target_num_mask_tokens] |
|
|
|
target_tokens = sorted(target_tokens, key=lambda x: x[0]*100+x[1]) |
|
else: |
|
if len(target_tokens) > self.num_mask_tokens: |
|
|
|
target_tokens = target_tokens[-self.num_mask_tokens:] |
|
|
|
tuples = [] |
|
meta_infos = [] |
|
layer_list = ['heading', 'subheading', 'body'] |
|
for mask_i, (shape_i, key_i, key, value) in enumerate(target_tokens): |
|
if self.mask_type == 'cm3': |
|
mask_token = self.mask_tokens[mask_i] |
|
elif self.mask_type == 'mask_aug': |
|
mask_token = self.mask_aug_token |
|
else: |
|
raise ValueError(f'Invalid mask type {self.mask_type}') |
|
|
|
if '<' in value: |
|
num_token = value.count('<') |
|
else: |
|
num_token = value.count(' ') + 1 |
|
if shape_i == -1: |
|
if key in ['bgcaption']: |
|
json_example['layers']['bglayer']['bgcaption'] = mask_token |
|
elif key in ['objcaption']: |
|
json_example['layers']['objlayer']['objcaption'] = mask_token |
|
elif key in ['flag']: |
|
json_example['layers']['objlayer']['flag'] = mask_token |
|
else: |
|
json_example[key] = mask_token |
|
else: |
|
curlayer = layer_list[shape_i] |
|
json_example['layers']['textlayer'][curlayer] = mask_token |
|
tuples.append((mask_token, value, num_token)) |
|
meta_infos.append((shape_i,key)) |
|
if return_meta: |
|
return json_example, tuples, meta_infos |
|
else: |
|
return json_example, tuples |
|
|
|
|
|
|
|
def is_font_exists(font_name): |
|
font_list = font_manager.findSystemFonts() |
|
|
|
for font in font_list: |
|
if font_name.lower() in font.lower(): |
|
return True |
|
return False |
|
|
|
def print_info(msg): |
|
print(Fore.GREEN + "[INFO] " + msg) |
|
|
|
def print_warning(msg): |
|
print(Fore.YELLOW + "[WARNING] " + msg) |
|
|
|
def print_error(msg): |
|
print(Fore.RED + "[ERROR] " + msg) |
|
|
|
def load_feature(path): |
|
with open(path) as f: |
|
content = f.read() |
|
content = json.loads(content) |
|
names = [content[str(i)] for i in range(len(content))] |
|
return ClassLabel(num_classes= len(names), names=names) |
|
|
|
def get_quantizer(version='v1', update_vocab=False, **kwargs): |
|
""" if kwargs.pop('separate_alpha', False): # useless |
|
kwargs['n_visual_tokens'] *= 2 """ |
|
if version == 'v4': |
|
quantizer = QuantizerV4(**kwargs) |
|
else: |
|
raise NotImplementedError |
|
|
|
return quantizer |
|
|
|
|