ART_v1.0 / quantizer.py
WYBar's picture
finish with token
8fe62ee
import torch
import numpy as np
import copy
from collections import OrderedDict
import json
from datasets import ClassLabel
import random
import math
from functools import lru_cache
from matplotlib import font_manager
from colorama import Fore, Style, init
class BaseQuantizer:
@property
def ignore_tokens(self):
if self.num_mask_tokens > 0:
if self.mask_type == 'cm3':
return [self.predict_start_token] + self.mask_tokens
elif self.mask_type == 'mask_aug':
return [self.mask_aug_token]
else:
raise ValueError(f'Invalid mask type {self.mask_type}')
else:
return []
def __init__(self, simplify_json=False, mask_all=False,
num_mask_tokens=0, mask_type='cm3', **kwargs):
self.simplify_json=simplify_json
self.io_ignore_replace_tokens = ['<split-text>']
self.mask_all = mask_all
self.num_mask_tokens = num_mask_tokens
self.mask_type = mask_type
if self.mask_type == 'mask_aug':
self.mask_aug_token = '<mask-aug>'
elif self.mask_type == 'cm3':
self.predict_start_token = '<pred-start>'
else:
raise ValueError(f'Invalid mask type {self.mask_type}')
def get_additional_mask_tokens(self):
if self.mask_type == 'cm3': # 两种配置:1. ['<pred-start>'] + '<mask-%d>',数量和self.num_mask_tokens相关 2. ['<mask-aug>']
self.mask_tokens = ['<mask-%d>' % i for i in range(self.num_mask_tokens)]
return [self.predict_start_token] + self.mask_tokens
elif self.mask_type == 'mask_aug':
return [self.mask_aug_token]
else:
raise ValueError(f'Invalid mask type {self.mask_type}')
def dump2json(self, json_example):
if self.simplify_json: # 将 dict 转化为 str, 如果simplify_json is True,那么缩减空格和换行,删除token的双引号
content = json.dumps(json_example, separators=(',',':'))
for token in self.additional_special_tokens:
content = content.replace(f'"{token}"', token)
else:
content = json.dumps(json_example)
return content
def load_json(self, content): # 将str转化为json
replace_tokens = set(self.additional_special_tokens) - set(self.io_ignore_replace_tokens) # sirui change
if self.simplify_json:
for token in replace_tokens: # 如果simplify_json is True,那么为 token 添加双引号
content = content.replace(token, f'"{token}"')
return json.loads(content)
def apply_masking(self,
json_example,
mask_all=None,
return_meta=False,
target_keys=['width', 'height', 'left', 'top'],
target_element_types=None
):
if mask_all is None:
mask_all = self.mask_all
json_example = copy.deepcopy(json_example)
target_keys = set(target_keys)
target_tokens = []
for shape_i, shape in enumerate(json_example['layers']['textlayer']):
# element_type = self.general_dequantize(shape['type'],'type',to_float=False)
# if target_element_types is not None:
# if element_type not in target_element_types:
# continue
for key_i, key in enumerate(shape.keys()):
if key in target_keys:
target_tokens.append((shape_i, key_i, key, shape[key]))
if not mask_all:
target_num_mask_tokens = random.randint(1, self.num_mask_tokens)
if len(target_tokens) > target_num_mask_tokens:
random.shuffle(target_tokens)
target_tokens = target_tokens[:target_num_mask_tokens]
# sort by shape_i and key_i
target_tokens = sorted(target_tokens, key=lambda x: x[0]*100+x[1])
else:
if len(target_tokens) > self.num_mask_tokens:
# 取最后面几个
target_tokens = target_tokens[-self.num_mask_tokens:]
tuples = []
meta_infos = []
for mask_i, (shape_i, key_i, key, value) in enumerate(target_tokens):
if self.mask_type == 'cm3':
mask_token = self.mask_tokens[mask_i]
elif self.mask_type == 'mask_aug':
mask_token = self.mask_aug_token
else:
raise ValueError(f'Invalid mask type {self.mask_type}')
# <one-1><decimal0-1><decimal1-2>
if '<' in value:
num_token = value.count('<')
else:
num_token = value.count(' ')
json_example['layers']['textlayer'][shape_i][key] = mask_token
tuples.append((mask_token, value, num_token))
meta_infos.append((shape_i,key))
if return_meta:
return json_example, tuples, meta_infos
else:
return json_example, tuples
def make_prediction_postfix(self, tuples):
postfix = self.predict_start_token
for mask_token, value, num_token in tuples:
postfix = postfix+ f'{mask_token}{value}'
return postfix
# specs={
# "width":"size",
# "height":"size",
# "left":"pos",
# "top":"pos",
# "x":"pos", # center x
# "y":"pos", # center y
# "opacity":"opacity",
# "color":"color",
# "angle":"angle",
# "font_size":"font_size",
# 'ratio':'ratio',
# 'letter_spacing': 'spacing',
# 'textlen': 'textlen'
# }
specs={
"width":"size",
"height":"size",
"x":"pos", # center x
"y":"pos", # center y
"color":"color",
"font":"font"
}
# TODO change min_max_bins
# min_max_bins = {
# 'size':(0,2,256),
# 'pos':(-1,1,256),
# # 'opacity':(0,1,8),
# 'opacity':(0,255,8),
# 'color':(0,255,32),
# 'angle':(0,2*np.pi,64),
# 'font_size':(2,200,100),
# 'spacing': (0,1,40),
# 'textlen': (1,20,20)
# }
min_max_bins = {
'size': (0,1,256),
'pos': (0,1,256),
'color': (0,137,138),
'font': (0,511,512)
}
import numpy as np
# pre 和 post 分别代表 10 的幂,分别对应大数和小数部分,参数代表位数
def get_keys_and_multipliers(pre_decimal=3, post_decimal=2):
pre_keys = ['one', 'ten', 'hundred', 'thousand']
pre_multiplers = [1, 10, 100, 1000]
assert pre_decimal <= len(pre_keys)
pre_keys = pre_keys[:pre_decimal][::-1]
pre_multiplers = pre_multiplers[:pre_decimal][::-1]
post_keys = [f'decimal{x}' for x in range(post_decimal)]
post_multiplers = [10 ** -(x+1) for x in range(post_decimal)]
keys = pre_keys + post_keys
multiplers = pre_multiplers + post_multiplers
return keys, multiplers
class DecimalQuantizer:
def __init__(self, max_pre_decimal=3, max_post_decimal=2):
self.max_pre_decimal = max_pre_decimal
self.max_post_decimal = max_post_decimal
self.keys, self.multiplers = get_keys_and_multipliers(max_pre_decimal, max_post_decimal)
self.symbols = {
-1: '<symbol-1>',
1: '<symbol-0>',
}
def get_vocab(self):
special_tokens = [*self.symbols.values()] # ['<symbol-1>', '<symbol-0>']
for key in self.keys: # ['one', 'ten', 'hundred', 'thousand'] + ['decimal0', 'decimal1]
special_tokens.extend([f'<{key}-{i}>' for i in range(10)])
return special_tokens
def check_valid(self, token):
prefix = token.lstrip('<').split('-')[0] # '<symbol-1>' -> 'symbol-1>' -> ['symbol', '1>']
if prefix =='symbol' or prefix in self.keys:
return True
else:
return False
# 小数点后保留两位
def __call__(self, val, pre_decimal=None, post_decimal=None, need_symbol=False): # 100.00
if pre_decimal is None:
pre_decimal = self.max_pre_decimal
if post_decimal is None:
post_decimal = self.max_post_decimal
assert pre_decimal <= self.max_pre_decimal
assert post_decimal <= self.max_post_decimal
keys, multiplers = get_keys_and_multipliers(pre_decimal, post_decimal)
symbol = int(np.sign(val)) # 返回一个浮点数(1.0, -1.0 或 0.0),代表正负和0
if symbol == 0: # 两类:>= 0 & < 0
symbol = 1
val = round(abs(val), post_decimal) # 将 val 的绝对值四舍五入到 post_decimal 位小数
tokens = []
if need_symbol: # self.symbols = {-1: '<symbol-1>', 1: '<symbol-0>',}
symbol_type = self.symbols[symbol]
tokens.append(symbol_type)
else:
assert symbol >= 0
for key, multipler in zip(keys, multiplers):
# 用于获取对于给定数值 val,每一位的数字,并且生成为'<one-7>'这样的token
v = math.floor(val / multipler)
if v > 9:
raise ValueError(f'Invalid value {val} for {pre_decimal} pre_decimal and {post_decimal} post_decimal')
val = val - v * multipler
tokens.append(f'<{key}-{v}>')
# 对于val,生成每一位数字对应的token,如果need_symbol = True,还会在前面加上 标识 >= 0 和 < 0 的 symbol-1 和 symbol-0
return ''.join(tokens)
def parse_token(self, token):
# <hundred-1> -> hundred, 1
key, val = token[1:-1].split('-')
return key, int(val)
def decode(self, tokens_str): # 将token_str用 > 先拆开,再添上 > ,然后转化为 list
tokens = tokens_str.split('>')
tokens = [x+'>' for x in tokens if x != '']
if tokens[0].startswith('<symbol'):
symbol_type = tokens[0]
tokens = tokens[1:]
inv_map = {v: k for k, v in self.symbols.items()} # 和 原字典 键、值 对调
symbol = inv_map[symbol_type]
else:
symbol = 1
accumulater = 0
for token in tokens:
key, val = self.parse_token(token)
multipler_index = self.keys.index(key)
multipler = self.multiplers[multipler_index]
actual_val = val * multipler
# print(key, val, multipler, actual_val)
accumulater += actual_val
accumulater = accumulater * symbol
# 还原出原来的整数,带有符号,并且精度 由 pre/post_decimal位数控制
return accumulater
# min_max_bins = {
# 'size': (0,1,256),
# 'pos': (0,1,256),
# 'color': (0,137,138),
# 'font': (0,511,512)
# }
pre_post_decimals={
'size': {
'pre_decimal': 1,
'post_decimal': 2,
'need_symbol': False
},
'pos': {
'pre_decimal': 1,
'post_decimal': 2,
'need_symbol': True
},
'opacity': {
'pre_decimal': 1,
'post_decimal': 1,
'need_symbol': False
},
'color':{
'pre_decimal': 3,
'post_decimal': 0,
'need_symbol': False
},
'angle':{
'pre_decimal': 1,
'post_decimal': 2,
'need_symbol': False
},
'font_size':{
'pre_decimal': 3,
'post_decimal': 0,
'need_symbol': False
},
}
class QuantizerV4(BaseQuantizer):
def __init__(self, quant=True,
decimal_quantize_types = [],
decimal_quantize_kwargs = {'max_pre_decimal':3, 'max_post_decimal':2},
mask_values=False,
**kwargs):
super().__init__(**kwargs)
self.min = min
self.max = max
self.quant = quant
self.mask_values = mask_values
self.text_split_token = '<split-text>'
self.decimal_quantize_types = decimal_quantize_types
self.decimal_quantize = len(decimal_quantize_types) > 0
if len(decimal_quantize_types) > 0:
print('decimal quantize types', decimal_quantize_types)
self.decimal_quantizer = DecimalQuantizer(**decimal_quantize_kwargs)
else:
self.decimal_quantizer = None
self.set_min_max_bins(min_max_bins)
# min_max_bins = {
# 'size': (0,1,256),
# 'pos': (0,1,256),
# 'color': (0,137,138),
# 'font': (0,511,512)
# }
self.width = kwargs.get('width', 1456)
self.height = kwargs.get('height', 1457)
self.width = int(self.width)
self.height = int(self.height)
def set_min_max_bins(self, min_max_bins): # 检查 n_bins是否是偶数,然后将其 +1
min_max_bins = copy.deepcopy(min_max_bins)
# adjust the bins to plus one
for type_name, (min_val, max_val, n_bins) in min_max_bins.items():
assert n_bins % 2 == 0 # must be even
min_max_bins[type_name] = (min_val, max_val, n_bins+1)
self.min_max_bins = min_max_bins
def setup_tokenizer(self, tokenizer):
# 整个函数生成additional_special_tokens:1. '<split-text>' 2.<one-1> <symbol-1> : decimal quantizer 3. <size-255> quantizerV4 4.self.get_additional_mask_tokens()
# 然后tokenizer.add_special_tokens({'additional_special_tokens': additional_special_tokens})
additional_special_tokens = [self.text_split_token] # self.text_split_token = '<split-text>'
if self.decimal_quantize:
special_tokens = self.decimal_quantizer.get_vocab() # <one-1> <symbol-1>
self.io_ignore_replace_tokens += special_tokens # self.io_ignore_replace_tokens = ['<split-text>'] 在BaseQuantizer中声明
additional_special_tokens += special_tokens
# the order must be preserved, other wise the tokenizer will be wrong
rest_types = [key for key in self.min_max_bins.keys() if key not in self.decimal_quantize_types]
for type_name in rest_types:
min_val, max_val, n_bins = self.min_max_bins[type_name]
additional_special_tokens += [f'<{type_name}-{i}>' for i in range(n_bins)] # <size-256>
if self.num_mask_tokens > 0:
additional_special_tokens.extend(self.get_additional_mask_tokens())
print('additional_special_tokens', additional_special_tokens)
tokenizer.add_special_tokens({'additional_special_tokens': additional_special_tokens})
self.additional_special_tokens = set(additional_special_tokens)
return tokenizer
@lru_cache(maxsize=128) # 缓存函数的返回值,以提高性能。maxsize=128 表示缓存最多存储 128 个不同的输入结果
def get_bins(self, real_type): # real_type: size, pos, font, color
# 返回 最小值,最大值,等距数组
min_val, max_val, n_bins = self.min_max_bins[real_type]
return min_val, max_val, np.linspace(min_val, max_val, n_bins)
def quantize(self, x, type): # (0.25, 'y') -> (<size-50>)
if not self.quant:
return x
"""Quantize a float array x into n_bins discrete values."""
real_type = specs[type] # x, y, width, height, color, font -> size, pos, font, color
min_val, max_val, bins = self.get_bins(real_type)
x = np.clip(float(x), min_val, max_val) # 确保 x 的值在 [min_val, max_val] 范围内,否则截断
if self.decimal_quantize and real_type in self.decimal_quantize_types:
return self.decimal_quantizer(x, **pre_post_decimals[real_type])
val = np.digitize(x, bins) - 1 # val是一个整数,取值范围在[0, len(bins)],换句话说就是bins数组的索引
n_bins = len(bins)
assert val >= 0 and val < n_bins
return f'<{real_type}-{val}>' # <size-255>
def dequantize(self, x): # (<size-255> -> 0.99?)
# <pos-1>->1
val = x.split('-')[1].strip('>')
# <pos-1>->pos
real_type = x.split('-')[0][1:]
if self.decimal_quantize and self.decimal_quantizer.check_valid(x):
return self.decimal_quantizer.decode(x)
min_val, max_val, bins = self.get_bins(real_type)
return bins[int(val)]
def construct_map_dict(self):
map_dict = {}
for i in range(self.min_max_bins['size'][2]): # 'size': (0, 1, 256),
name = "<size-%d>" % i
value = self.dequantize(name)
map_dict[name] = str(value) # 255 -> 0.99?
for i in range(self.min_max_bins['pos'][2]):
name = "<pos-%d>" % i
value = self.dequantize(name)
map_dict[name] = str(value)
return map_dict
def postprocess_colorandfont(self, json_example):
# 将其中的 正则 匹配部分 用双引号包裹
import re
json_example = re.sub(r'(<font-\d+>)', r'"\1"', json_example)
json_example = re.sub(r'(<color-\d+>)', r'"\1"', json_example)
return json_example
def to_str(self, x, type):
feature = self.get_feature(type)
return feature.int2str(x)
def convert2layout(self, example): # 将原始的数据转化为 <size-255> 的 token形式
new_example = OrderedDict()
new_example['wholecaption'] = example['wholecaption']
new_layout = []
for meta_layer in example['layout']:
new_layout.append({
"layer": meta_layer["layer"],
"x": self.quantize(meta_layer["x"]/self.width, 'x'),
"y": self.quantize(meta_layer["y"]/self.height, 'y'),
"width": self.quantize(meta_layer["width"]/self.width, 'width'),
"height": self.quantize(meta_layer["height"]/self.height, 'height')
})
new_example['layout'] = new_layout
return new_example
def apply_masking(self,
json_example,
mask_all=None,
return_meta=False,
# target_keys=['width', 'height', 'left', 'top'], # useless
# target_element_types=None, # useless
mask_values = True
):
if mask_all is None:
mask_all = self.mask_all
json_example = copy.deepcopy(json_example)
# 这段内容对json中的一些 value 替换为 <mask-i>,并用self.num_mask_tokens限制mask的数量,根据参数还可能进行随机mask
# 并记录 <mask-i> & value & num_token = value.count('<') 的 三元tuple
target_tokens = []
if self.mask_values and mask_values:
target_tokens.append((-1,-1,'globalcaption', json_example['globalcaption']))
target_tokens.append((-1,-1,'canvas_width', json_example['canvas_width']))
target_tokens.append((-1,-1,'canvas_height', json_example['canvas_height']))
target_tokens.append((-1,-1,'category', json_example['category']))
target_tokens.append((-1,-1,'keywords', json_example['keywords']))
target_tokens.append((-1,-1,'bgcaption', json_example['layers']['bglayer']['bgcaption']))
target_tokens.append((-1,-1,'flag', json_example['layers']['objlayer']['flag']))
target_tokens.append((-1,-1,'objcaption', json_example['layers']['objlayer']['objcaption']))
for layer_i, textlayer in enumerate(json_example['layers']['textlayer']):
target_tokens.append((layer_i, -1, 'text', json_example['layers']['textlayer'][textlayer]))
if not mask_all: # 随机取值 target_num_mask_tokens, 上界是self.num_mask_tokens
target_num_mask_tokens = random.randint(1, self.num_mask_tokens)
if len(target_tokens) > target_num_mask_tokens:
random.shuffle(target_tokens)
target_tokens = target_tokens[:target_num_mask_tokens]
# sort by shape_i and key_i
target_tokens = sorted(target_tokens, key=lambda x: x[0]*100+x[1])
else: # 取定值 num_mask_tokens
if len(target_tokens) > self.num_mask_tokens:
# 取最后面几个
target_tokens = target_tokens[-self.num_mask_tokens:]
tuples = []
meta_infos = []
layer_list = ['heading', 'subheading', 'body']
for mask_i, (shape_i, key_i, key, value) in enumerate(target_tokens):
if self.mask_type == 'cm3':
mask_token = self.mask_tokens[mask_i]
elif self.mask_type == 'mask_aug':
mask_token = self.mask_aug_token
else:
raise ValueError(f'Invalid mask type {self.mask_type}')
# <one-1><decimal0-1><decimal1-2>
if '<' in value:
num_token = value.count('<')
else:
num_token = value.count(' ') + 1
if shape_i == -1:
if key in ['bgcaption']:
json_example['layers']['bglayer']['bgcaption'] = mask_token
elif key in ['objcaption']:
json_example['layers']['objlayer']['objcaption'] = mask_token
elif key in ['flag']:
json_example['layers']['objlayer']['flag'] = mask_token
else:
json_example[key] = mask_token
else:
curlayer = layer_list[shape_i]
json_example['layers']['textlayer'][curlayer] = mask_token
tuples.append((mask_token, value, num_token))
meta_infos.append((shape_i,key))
if return_meta:
return json_example, tuples, meta_infos
else:
return json_example, tuples
# useless orginally used for render
def is_font_exists(font_name):
font_list = font_manager.findSystemFonts()
# print("\nfont_list: ",font_list)
for font in font_list:
if font_name.lower() in font.lower():
return True
return False
def print_info(msg):
print(Fore.GREEN + "[INFO] " + msg)
def print_warning(msg):
print(Fore.YELLOW + "[WARNING] " + msg)
def print_error(msg):
print(Fore.RED + "[ERROR] " + msg)
def load_feature(path):
with open(path) as f:
content = f.read()
content = json.loads(content)
names = [content[str(i)] for i in range(len(content))]
return ClassLabel(num_classes= len(names), names=names)
def get_quantizer(version='v1', update_vocab=False, **kwargs):
""" if kwargs.pop('separate_alpha', False): # useless
kwargs['n_visual_tokens'] *= 2 """
if version == 'v4':
quantizer = QuantizerV4(**kwargs)
else:
raise NotImplementedError
return quantizer