Spaces:
Sleeping
Sleeping
"""BERT finetuning runner.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
import json | |
import glob | |
import logging | |
import argparse | |
import math | |
from tqdm import tqdm | |
import numpy as np | |
import torch | |
import random | |
import pickle | |
from s2s_ft.modeling_decoding import BertForSeq2SeqDecoder, BertConfig | |
from transformers.tokenization_bert import whitespace_tokenize | |
import s2s_ft.s2s_loader as seq2seq_loader | |
from s2s_ft.utils import load_and_cache_examples | |
from transformers import \ | |
BertTokenizer, RobertaTokenizer, XLMRobertaTokenizer, ElectraTokenizer | |
from s2s_ft.tokenization_unilm import UnilmTokenizer | |
from s2s_ft.tokenization_minilm import MinilmTokenizer | |
TOKENIZER_CLASSES = { | |
'bert': BertTokenizer, | |
'minilm': MinilmTokenizer, | |
'roberta': RobertaTokenizer, | |
'unilm': UnilmTokenizer, | |
'xlm-roberta': XLMRobertaTokenizer, | |
'electra': ElectraTokenizer, | |
} | |
class WhitespaceTokenizer(object): | |
def tokenize(self, text): | |
return whitespace_tokenize(text) | |
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | |
datefmt='%m/%d/%Y %H:%M:%S', | |
level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def detokenize(tk_list): | |
r_list = [] | |
for tk in tk_list: | |
if tk.startswith('##') and len(r_list) > 0: | |
r_list[-1] = r_list[-1] + tk[2:] | |
else: | |
r_list.append(tk) | |
return r_list | |
def ascii_print(text): | |
text = text.encode("ascii", "ignore") | |
print(text) | |
def main(): | |
parser = argparse.ArgumentParser() | |
# Required parameters | |
parser.add_argument("--model_type", default=None, type=str, required=True, | |
help="Model type selected in the list: " + ", ".join(TOKENIZER_CLASSES.keys())) | |
parser.add_argument("--model_path", default=None, type=str, required=True, | |
help="Path to the model checkpoint.") | |
parser.add_argument("--config_path", default=None, type=str, | |
help="Path to config.json for the model.") | |
# tokenizer_name | |
parser.add_argument("--tokenizer_name", default=None, type=str, required=True, | |
help="tokenizer name") | |
parser.add_argument("--max_seq_length", default=512, type=int, | |
help="The maximum total input sequence length after WordPiece tokenization. \n" | |
"Sequences longer than this will be truncated, and sequences shorter \n" | |
"than this will be padded.") | |
# decoding parameters | |
parser.add_argument('--fp16', action='store_true', | |
help="Whether to use 16-bit float precision instead of 32-bit") | |
parser.add_argument('--no_cuda', action='store_true', | |
help="Whether to use CUDA for decoding") | |
parser.add_argument("--input_file", type=str, help="Input file") | |
parser.add_argument('--subset', type=int, default=0, | |
help="Decode a subset of the input dataset.") | |
parser.add_argument("--output_file", type=str, help="output file") | |
parser.add_argument("--split", type=str, default="", | |
help="Data split (train/val/test).") | |
parser.add_argument('--tokenized_input', action='store_true', | |
help="Whether the input is tokenized.") | |
parser.add_argument('--seed', type=int, default=123, | |
help="random seed for initialization") | |
parser.add_argument("--do_lower_case", action='store_true', | |
help="Set this flag if you are using an uncased model.") | |
parser.add_argument('--batch_size', type=int, default=4, | |
help="Batch size for decoding.") | |
parser.add_argument('--beam_size', type=int, default=1, | |
help="Beam size for searching") | |
parser.add_argument('--length_penalty', type=float, default=0, | |
help="Length penalty for beam search") | |
parser.add_argument('--forbid_duplicate_ngrams', action='store_true') | |
parser.add_argument('--forbid_ignore_word', type=str, default=None, | |
help="Forbid the word during forbid_duplicate_ngrams") | |
parser.add_argument("--min_len", default=1, type=int) | |
parser.add_argument('--need_score_traces', action='store_true') | |
parser.add_argument('--ngram_size', type=int, default=3) | |
parser.add_argument('--mode', default="s2s", | |
choices=["s2s", "l2r", "both"]) | |
parser.add_argument('--max_tgt_length', type=int, default=128, | |
help="maximum length of target sequence") | |
parser.add_argument('--s2s_special_token', action='store_true', | |
help="New special tokens ([S2S_SEP]/[S2S_CLS]) of S2S.") | |
parser.add_argument('--s2s_add_segment', action='store_true', | |
help="Additional segmental for the encoder of S2S.") | |
parser.add_argument('--s2s_share_segment', action='store_true', | |
help="Sharing segment embeddings for the encoder of S2S (used with --s2s_add_segment).") | |
parser.add_argument('--pos_shift', action='store_true', | |
help="Using position shift for fine-tuning.") | |
parser.add_argument("--cache_dir", default=None, type=str, | |
help="Where do you want to store the pre-trained models downloaded from s3") | |
args = parser.parse_args() | |
if args.need_score_traces and args.beam_size <= 1: | |
raise ValueError( | |
"Score trace is only available for beam search with beam size > 1.") | |
if args.max_tgt_length >= args.max_seq_length - 2: | |
raise ValueError("Maximum tgt length exceeds max seq length - 2.") | |
device = torch.device( | |
"cuda" if torch.cuda.is_available() else "cpu") | |
n_gpu = torch.cuda.device_count() | |
if args.seed > 0: | |
random.seed(args.seed) | |
np.random.seed(args.seed) | |
torch.manual_seed(args.seed) | |
if n_gpu > 0: | |
torch.cuda.manual_seed_all(args.seed) | |
else: | |
random_seed = random.randint(0, 10000) | |
logger.info("Set random seed as: {}".format(random_seed)) | |
random.seed(random_seed) | |
np.random.seed(random_seed) | |
torch.manual_seed(random_seed) | |
if n_gpu > 0: | |
torch.cuda.manual_seed_all(args.seed) | |
tokenizer = TOKENIZER_CLASSES[args.model_type].from_pretrained( | |
args.tokenizer_name, do_lower_case=args.do_lower_case, | |
cache_dir=args.cache_dir if args.cache_dir else None) | |
if args.model_type == "roberta": | |
vocab = tokenizer.encoder | |
elif args.model_type == "xlm-roberta": | |
vocab = {} | |
for tk_id in range(len(tokenizer)): | |
tk = tokenizer._convert_id_to_token(tk_id) | |
vocab[tk] = tk_id | |
else: | |
vocab = tokenizer.vocab | |
if hasattr(tokenizer, 'model_max_length'): | |
tokenizer.model_max_length = args.max_seq_length | |
elif hasattr(tokenizer, 'max_len'): | |
tokenizer.max_len = args.max_seq_length | |
mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids( | |
[tokenizer.mask_token, tokenizer.sep_token, tokenizer.sep_token]) | |
forbid_ignore_set = None | |
if args.forbid_ignore_word: | |
w_list = [] | |
for w in args.forbid_ignore_word.split('|'): | |
if w.startswith('[') and w.endswith(']'): | |
w_list.append(w.upper()) | |
else: | |
w_list.append(w) | |
forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list)) | |
print(args.model_path) | |
found_checkpoint_flag = False | |
for model_recover_path in glob.glob(args.model_path): | |
if not os.path.isdir(model_recover_path): | |
continue | |
logger.info("***** Recover model: %s *****", model_recover_path) | |
config_file = args.config_path if args.config_path else os.path.join(model_recover_path, "config.json") | |
logger.info("Read decoding config from: %s" % config_file) | |
config = BertConfig.from_json_file(config_file) | |
bi_uni_pipeline = [] | |
bi_uni_pipeline.append(seq2seq_loader.Preprocess4Seq2seqDecoder( | |
list(vocab.keys()), tokenizer.convert_tokens_to_ids, args.max_seq_length, | |
max_tgt_length=args.max_tgt_length, pos_shift=args.pos_shift, | |
source_type_id=config.source_type_id, target_type_id=config.target_type_id, | |
cls_token=tokenizer.cls_token, sep_token=tokenizer.sep_token, pad_token=tokenizer.pad_token)) | |
found_checkpoint_flag = True | |
model = BertForSeq2SeqDecoder.from_pretrained( | |
model_recover_path, config=config, mask_word_id=mask_word_id, search_beam_size=args.beam_size, | |
length_penalty=args.length_penalty, eos_id=eos_word_ids, sos_id=sos_word_id, | |
forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, | |
ngram_size=args.ngram_size, min_len=args.min_len, mode=args.mode, | |
max_position_embeddings=args.max_seq_length, pos_shift=args.pos_shift, | |
) | |
if args.fp16: | |
model.half() | |
model.to(device) | |
if n_gpu > 1: | |
model = torch.nn.DataParallel(model) | |
torch.cuda.empty_cache() | |
model.eval() | |
next_i = 0 | |
max_src_length = args.max_seq_length - 2 - args.max_tgt_length | |
if args.pos_shift: | |
max_src_length += 1 | |
to_pred = load_and_cache_examples( | |
args.input_file, tokenizer, local_rank=-1, | |
cached_features_file=None, shuffle=False, eval_mode=True) | |
input_lines = [] | |
for line in to_pred: | |
input_lines.append(tokenizer.convert_ids_to_tokens(line.source_ids)[:max_src_length]) | |
if args.subset > 0: | |
logger.info("Decoding subset: %d", args.subset) | |
input_lines = input_lines[:args.subset] | |
input_lines = sorted(list(enumerate(input_lines)), | |
key=lambda x: -len(x[1])) | |
output_lines = [""] * len(input_lines) | |
score_trace_list = [None] * len(input_lines) | |
total_batch = math.ceil(len(input_lines) / args.batch_size) | |
with tqdm(total=total_batch) as pbar: | |
batch_count = 0 | |
first_batch = True | |
while next_i < len(input_lines): | |
_chunk = input_lines[next_i:next_i + args.batch_size] | |
buf_id = [x[0] for x in _chunk] | |
buf = [x[1] for x in _chunk] | |
next_i += args.batch_size | |
batch_count += 1 | |
max_a_len = max([len(x) for x in buf]) | |
instances = [] | |
for instance in [(x, max_a_len) for x in buf]: | |
for proc in bi_uni_pipeline: | |
instances.append(proc(instance)) | |
with torch.no_grad(): | |
batch = seq2seq_loader.batch_list_to_batch_tensors( | |
instances) | |
batch = [ | |
t.to(device) if t is not None else None for t in batch] | |
input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch | |
traces = model(input_ids, token_type_ids, | |
position_ids, input_mask, task_idx=task_idx, mask_qkv=mask_qkv) | |
if args.beam_size > 1: | |
traces = {k: v.tolist() for k, v in traces.items()} | |
output_ids = traces['pred_seq'] | |
else: | |
output_ids = traces.tolist() | |
for i in range(len(buf)): | |
w_ids = output_ids[i] | |
output_buf = tokenizer.convert_ids_to_tokens(w_ids) | |
output_tokens = [] | |
for t in output_buf: | |
if t in (tokenizer.sep_token, tokenizer.pad_token): | |
break | |
output_tokens.append(t) | |
if args.model_type == "roberta" or args.model_type == "xlm-roberta": | |
output_sequence = tokenizer.convert_tokens_to_string(output_tokens) | |
else: | |
output_sequence = ' '.join(detokenize(output_tokens)) | |
if '\n' in output_sequence: | |
output_sequence = " [X_SEP] ".join(output_sequence.split('\n')) | |
output_lines[buf_id[i]] = output_sequence | |
if first_batch or batch_count % 50 == 0: | |
logger.info("{} = {}".format(buf_id[i], output_sequence)) | |
if args.need_score_traces: | |
score_trace_list[buf_id[i]] = { | |
'scores': traces['scores'][i], 'wids': traces['wids'][i], 'ptrs': traces['ptrs'][i]} | |
pbar.update(1) | |
first_batch = False | |
if args.output_file: | |
fn_out = args.output_file | |
else: | |
fn_out = model_recover_path+'.'+args.split | |
with open(fn_out, "w", encoding="utf-8") as fout: | |
for l in output_lines: | |
fout.write(l) | |
fout.write("\n") | |
if args.need_score_traces: | |
with open(fn_out + ".trace.pickle", "wb") as fout_trace: | |
pickle.dump( | |
{"version": 0.0, "num_samples": len(input_lines)}, fout_trace) | |
for x in score_trace_list: | |
pickle.dump(x, fout_trace) | |
if not found_checkpoint_flag: | |
logger.info("Not found the model checkpoint file!") | |
if __name__ == "__main__": | |
main() | |