Spaces:
Sleeping
Sleeping
import pickle | |
import math | |
import argparse | |
import glob | |
import logging | |
from pathlib import Path | |
from tqdm import tqdm | |
import unicodedata | |
from transformers import BertTokenizer, RobertaTokenizer, XLMRobertaTokenizer | |
from s2s_ft.tokenization_unilm import UnilmTokenizer | |
from s2s_ft.tokenization_minilm import MinilmTokenizer | |
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | |
datefmt='%m/%d/%Y %H:%M:%S', | |
level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
TOKENIZER_CLASSES = { | |
'bert': BertTokenizer, | |
'minilm': MinilmTokenizer, | |
'roberta': RobertaTokenizer, | |
'unilm': UnilmTokenizer, | |
'xlm-roberta': XLMRobertaTokenizer, | |
} | |
def read_traces_from_file(file_name): | |
with open(file_name, "rb") as fin: | |
meta = pickle.load(fin) | |
num_samples = meta["num_samples"] | |
samples = [] | |
for _ in range(num_samples): | |
samples.append(pickle.load(fin)) | |
return samples | |
def get_best_sequence(sample, eos_id, pad_id, length_penalty=None, alpha=None, expect=None, min_len=None): | |
# if not any((length_penalty, alpha, expect, min_len)): | |
# raise ValueError( | |
# "You can only specify length penalty or alpha, but not both.") | |
scores = sample["scores"] | |
wids_list = sample["wids"] | |
ptrs = sample["ptrs"] | |
last_frame_id = len(scores) - 1 | |
for i, wids in enumerate(wids_list): | |
if all(wid in (eos_id, pad_id) for wid in wids): | |
last_frame_id = i | |
break | |
while all(wid == pad_id for wid in wids_list[last_frame_id]): | |
last_frame_id -= 1 | |
max_score = -math.inf | |
frame_id = -1 | |
pos_in_frame = -1 | |
for fid in range(last_frame_id + 1): | |
for i, wid in enumerate(wids_list[fid]): | |
if fid <= last_frame_id and scores[fid][i] >= 0: | |
# skip paddings | |
continue | |
if (wid in (eos_id, pad_id)) or fid == last_frame_id: | |
s = scores[fid][i] | |
if length_penalty: | |
if expect: | |
s -= length_penalty * math.fabs(fid+1 - expect) | |
else: | |
s += length_penalty * (fid + 1) | |
elif alpha: | |
s = s / math.pow((5 + fid + 1) / 6.0, alpha) | |
if s > max_score: | |
# if (frame_id != -1) and min_len and (fid+1 < min_len): | |
# continue | |
max_score = s | |
frame_id = fid | |
pos_in_frame = i | |
if frame_id == -1: | |
seq = [] | |
else: | |
seq = [wids_list[frame_id][pos_in_frame]] | |
for fid in range(frame_id, 0, -1): | |
pos_in_frame = ptrs[fid][pos_in_frame] | |
seq.append(wids_list[fid - 1][pos_in_frame]) | |
seq.reverse() | |
return seq | |
def detokenize(tk_list): | |
r_list = [] | |
for tk in tk_list: | |
if tk.startswith('##') and len(r_list) > 0: | |
r_list[-1] = r_list[-1] + tk[2:] | |
else: | |
r_list.append(tk) | |
return r_list | |
def simple_postprocess(tk_list): | |
# truncate duplicate punctuations | |
while tk_list and len(tk_list) > 4 and len(tk_list[-1]) == 1 and unicodedata.category(tk_list[-1]).startswith('P') and all(it == tk_list[-1] for it in tk_list[-4:]): | |
tk_list = tk_list[:-3] | |
return tk_list | |
# def include_unk(line): | |
# return " UNK ".join(line.split('<unk>')).strip() | |
def main(args): | |
tokenizer = TOKENIZER_CLASSES[args.model_type].from_pretrained( | |
args.tokenizer_name, do_lower_case=args.do_lower_case, | |
cache_dir=args.cache_dir if args.cache_dir else None) | |
eos_token = tokenizer.sep_token | |
pad_token = tokenizer.pad_token | |
eos_id, pad_id = tokenizer.convert_tokens_to_ids([eos_token, pad_token]) | |
logger.info("*********************************************") | |
logger.info(" EOS TOKEN = {}, ID = {}".format(eos_token, eos_id)) | |
logger.info(" PAD TOKEN = {}, ID = {}".format(pad_token, pad_id)) | |
logger.info("*********************************************") | |
for input_file in tqdm(glob.glob(args.input)): | |
if not Path(input_file+'.trace.pickle').exists(): | |
continue | |
print(input_file) | |
samples = read_traces_from_file(input_file+'.trace.pickle') | |
results = [] | |
for s in samples: | |
word_ids = get_best_sequence(s, eos_id, pad_id, alpha=args.alpha, | |
length_penalty=args.length_penalty, expect=args.expect, min_len=args.min_len) | |
tokens = tokenizer.convert_ids_to_tokens(word_ids) | |
buf = [] | |
for t in tokens: | |
if t in (eos_token, pad_token): | |
break | |
else: | |
buf.append(t) | |
if args.model_type == "roberta" or args.model_type == "xlm-roberta": | |
output_text = " ".join(simple_postprocess(tokenizer.convert_tokens_to_string(buf).split(' '))) | |
if '\n' in output_text: | |
output_text = " [X_SEP] ".join(output_text.split('\n')) | |
else: | |
output_text = " ".join(simple_postprocess(detokenize(buf))) | |
results.append(output_text) | |
fn_out = input_file + '.' | |
if args.length_penalty: | |
fn_out += 'lenp'+str(args.length_penalty) | |
if args.expect: | |
fn_out += 'exp'+str(args.expect) | |
if args.alpha: | |
fn_out += 'alp'+str(args.alpha) | |
if args.min_len: | |
fn_out += 'minl'+str(args.min_len) | |
with open(fn_out, "w", encoding="utf-8") as fout: | |
for line in results: | |
fout.write(line) | |
fout.write("\n") | |
logger.info("Output file = [%s]" % fn_out) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--input", type=str, help="Input file.") | |
parser.add_argument("--model_type", default=None, type=str, required=True, | |
help="Model type selected in the list: " + ", ".join(TOKENIZER_CLASSES.keys())) | |
parser.add_argument("--alpha", default=None, type=float) | |
parser.add_argument("--length_penalty", default=None, type=float) | |
parser.add_argument("--expect", default=None, type=float, | |
help="Expectation of target length.") | |
parser.add_argument("--min_len", default=None, type=int) | |
# tokenizer_name | |
parser.add_argument("--tokenizer_name", default=None, type=str, required=True, | |
help="tokenizer name") | |
parser.add_argument("--do_lower_case", action='store_true', | |
help="Set this flag if you are using an uncased model.") | |
parser.add_argument("--cache_dir", default=None, type=str, | |
help="Where do you want to store the pre-trained models downloaded from s3") | |
args = parser.parse_args() | |
main(args) | |