Tzktz's picture
Upload 7664 files
6fc683c verified
import pickle
import math
import argparse
import glob
import logging
from pathlib import Path
from tqdm import tqdm
import unicodedata
from transformers import BertTokenizer, RobertaTokenizer, XLMRobertaTokenizer
from s2s_ft.tokenization_unilm import UnilmTokenizer
from s2s_ft.tokenization_minilm import MinilmTokenizer
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
TOKENIZER_CLASSES = {
'bert': BertTokenizer,
'minilm': MinilmTokenizer,
'roberta': RobertaTokenizer,
'unilm': UnilmTokenizer,
'xlm-roberta': XLMRobertaTokenizer,
}
def read_traces_from_file(file_name):
with open(file_name, "rb") as fin:
meta = pickle.load(fin)
num_samples = meta["num_samples"]
samples = []
for _ in range(num_samples):
samples.append(pickle.load(fin))
return samples
def get_best_sequence(sample, eos_id, pad_id, length_penalty=None, alpha=None, expect=None, min_len=None):
# if not any((length_penalty, alpha, expect, min_len)):
# raise ValueError(
# "You can only specify length penalty or alpha, but not both.")
scores = sample["scores"]
wids_list = sample["wids"]
ptrs = sample["ptrs"]
last_frame_id = len(scores) - 1
for i, wids in enumerate(wids_list):
if all(wid in (eos_id, pad_id) for wid in wids):
last_frame_id = i
break
while all(wid == pad_id for wid in wids_list[last_frame_id]):
last_frame_id -= 1
max_score = -math.inf
frame_id = -1
pos_in_frame = -1
for fid in range(last_frame_id + 1):
for i, wid in enumerate(wids_list[fid]):
if fid <= last_frame_id and scores[fid][i] >= 0:
# skip paddings
continue
if (wid in (eos_id, pad_id)) or fid == last_frame_id:
s = scores[fid][i]
if length_penalty:
if expect:
s -= length_penalty * math.fabs(fid+1 - expect)
else:
s += length_penalty * (fid + 1)
elif alpha:
s = s / math.pow((5 + fid + 1) / 6.0, alpha)
if s > max_score:
# if (frame_id != -1) and min_len and (fid+1 < min_len):
# continue
max_score = s
frame_id = fid
pos_in_frame = i
if frame_id == -1:
seq = []
else:
seq = [wids_list[frame_id][pos_in_frame]]
for fid in range(frame_id, 0, -1):
pos_in_frame = ptrs[fid][pos_in_frame]
seq.append(wids_list[fid - 1][pos_in_frame])
seq.reverse()
return seq
def detokenize(tk_list):
r_list = []
for tk in tk_list:
if tk.startswith('##') and len(r_list) > 0:
r_list[-1] = r_list[-1] + tk[2:]
else:
r_list.append(tk)
return r_list
def simple_postprocess(tk_list):
# truncate duplicate punctuations
while tk_list and len(tk_list) > 4 and len(tk_list[-1]) == 1 and unicodedata.category(tk_list[-1]).startswith('P') and all(it == tk_list[-1] for it in tk_list[-4:]):
tk_list = tk_list[:-3]
return tk_list
# def include_unk(line):
# return " UNK ".join(line.split('<unk>')).strip()
def main(args):
tokenizer = TOKENIZER_CLASSES[args.model_type].from_pretrained(
args.tokenizer_name, do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None)
eos_token = tokenizer.sep_token
pad_token = tokenizer.pad_token
eos_id, pad_id = tokenizer.convert_tokens_to_ids([eos_token, pad_token])
logger.info("*********************************************")
logger.info(" EOS TOKEN = {}, ID = {}".format(eos_token, eos_id))
logger.info(" PAD TOKEN = {}, ID = {}".format(pad_token, pad_id))
logger.info("*********************************************")
for input_file in tqdm(glob.glob(args.input)):
if not Path(input_file+'.trace.pickle').exists():
continue
print(input_file)
samples = read_traces_from_file(input_file+'.trace.pickle')
results = []
for s in samples:
word_ids = get_best_sequence(s, eos_id, pad_id, alpha=args.alpha,
length_penalty=args.length_penalty, expect=args.expect, min_len=args.min_len)
tokens = tokenizer.convert_ids_to_tokens(word_ids)
buf = []
for t in tokens:
if t in (eos_token, pad_token):
break
else:
buf.append(t)
if args.model_type == "roberta" or args.model_type == "xlm-roberta":
output_text = " ".join(simple_postprocess(tokenizer.convert_tokens_to_string(buf).split(' ')))
if '\n' in output_text:
output_text = " [X_SEP] ".join(output_text.split('\n'))
else:
output_text = " ".join(simple_postprocess(detokenize(buf)))
results.append(output_text)
fn_out = input_file + '.'
if args.length_penalty:
fn_out += 'lenp'+str(args.length_penalty)
if args.expect:
fn_out += 'exp'+str(args.expect)
if args.alpha:
fn_out += 'alp'+str(args.alpha)
if args.min_len:
fn_out += 'minl'+str(args.min_len)
with open(fn_out, "w", encoding="utf-8") as fout:
for line in results:
fout.write(line)
fout.write("\n")
logger.info("Output file = [%s]" % fn_out)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, help="Input file.")
parser.add_argument("--model_type", default=None, type=str, required=True,
help="Model type selected in the list: " + ", ".join(TOKENIZER_CLASSES.keys()))
parser.add_argument("--alpha", default=None, type=float)
parser.add_argument("--length_penalty", default=None, type=float)
parser.add_argument("--expect", default=None, type=float,
help="Expectation of target length.")
parser.add_argument("--min_len", default=None, type=int)
# tokenizer_name
parser.add_argument("--tokenizer_name", default=None, type=str, required=True,
help="tokenizer name")
parser.add_argument("--do_lower_case", action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--cache_dir", default=None, type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
args = parser.parse_args()
main(args)