"""BERT finetuning runner.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import logging import glob import json import argparse import math import string from multiprocessing import Pool, cpu_count from tqdm import tqdm, trange from pathlib import Path import numpy as np # pip install py-rouge import rouge import time import tempfile import shutil # pip install pyrouge from evaluations.bs_pyrouge import Rouge155 logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger = logging.getLogger(__name__) parser = argparse.ArgumentParser() # Required parameters parser.add_argument("--gold", type=str, help="Gold output file.") parser.add_argument("--pred", type=str, help="Input prediction file.") parser.add_argument("--split", type=str, default="", help="Data split (train/dev/test).") parser.add_argument("--save_best", action='store_true', help="Save best epoch.") parser.add_argument("--only_eval_best", action='store_true', help="Only evaluate best epoch.") parser.add_argument("--trunc_len", type=int, default=0, help="Truncate line by the maximum length.") default_process_count = max(1, cpu_count() - 1) parser.add_argument("--processes", type=int, default=default_process_count, help="Number of processes to use (default %(default)s)") parser.add_argument("--perl", action='store_true', help="Using the perl script.") parser.add_argument('--lazy_eval', action='store_true', help="Skip evaluation if the .rouge file exists.") args = parser.parse_args() evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l'], max_n=2, limit_length=False, apply_avg=True, weight_factor=1.2) def test_rouge(cand, ref): temp_dir = tempfile.mkdtemp() candidates = cand references = ref assert len(candidates) == len(references) cnt = len(candidates) current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}".format(current_time)) if not os.path.isdir(tmp_dir): os.mkdir(tmp_dir) os.mkdir(tmp_dir + "/candidate") os.mkdir(tmp_dir + "/reference") try: for i in range(cnt): if len(references[i]) < 1: continue with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w", encoding="utf-8") as f: f.write(candidates[i]) with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w", encoding="utf-8") as f: f.write(references[i]) r = Rouge155(temp_dir=temp_dir) r.model_dir = tmp_dir + "/reference/" r.system_dir = tmp_dir + "/candidate/" r.model_filename_pattern = 'ref.#ID#.txt' r.system_filename_pattern = r'cand.(\d+).txt' rouge_results = r.convert_and_evaluate() print(rouge_results) results_dict = r.output_to_dict(rouge_results) finally: if os.path.isdir(tmp_dir): shutil.rmtree(tmp_dir) return results_dict def rouge_results_to_str(results_dict): return ">> ROUGE-F(1/2/l): {:.2f}/{:.2f}/{:.2f}\nROUGE-R(1/2/3/l): {:.2f}/{:.2f}/{:.2f}\n".format( results_dict["rouge_1_f_score"] * 100, results_dict["rouge_2_f_score"] * 100, results_dict["rouge_l_f_score"] * 100, results_dict["rouge_1_recall"] * 100, results_dict["rouge_2_recall"] * 100, results_dict["rouge_l_recall"] * 100 ) def count_tokens(tokens): counter = {} for t in tokens: if t in counter.keys(): counter[t] += 1 else: counter[t] = 1 return counter def get_f1(text_a, text_b): tokens_a = text_a.lower().split() tokens_b = text_b.lower().split() if len(tokens_a) == 0 or len(tokens_b) == 0: return 1 if len(tokens_a) == len(tokens_b) else 0 set_a = count_tokens(tokens_a) set_b = count_tokens(tokens_b) match = 0 for token in set_a.keys(): if token in set_b.keys(): match += min(set_a[token], set_b[token]) p = match / len(tokens_a) r = match / len(tokens_b) return 2.0 * p * r / (p + r + 1e-5) _tok_dict = {} def _is_digit(w): for ch in w: if not(ch.isdigit() or ch == ','): return False return True def fix_tokenization(text): input_tokens = text.split() output_tokens = [] i = 0 prev_dash = False while i < len(input_tokens): tok = input_tokens[i] flag_prev_dash = False if tok in _tok_dict.keys(): output_tokens.append(_tok_dict[tok]) i += 1 elif tok == "'" and len(output_tokens) > 0 and output_tokens[-1].endswith("n") and i < len(input_tokens) - 1 and input_tokens[i + 1] == "t": output_tokens[-1] = output_tokens[-1][:-1] output_tokens.append("n't") i += 2 elif tok == "'" and i < len(input_tokens) - 1 and input_tokens[i + 1] in ("s", "d", "ll"): output_tokens.append("'"+input_tokens[i + 1]) i += 2 elif tok == "." and i < len(input_tokens) - 2 and input_tokens[i + 1] == "." and input_tokens[i + 2] == ".": output_tokens.append("...") i += 3 elif tok == "," and len(output_tokens) > 0 and _is_digit(output_tokens[-1]) and i < len(input_tokens) - 1 and _is_digit(input_tokens[i + 1]): # $ 3 , 000 -> $ 3,000 output_tokens[-1] += ','+input_tokens[i + 1] i += 2 elif tok == "." and len(output_tokens) > 0 and output_tokens[-1].isdigit() and i < len(input_tokens) - 1 and input_tokens[i + 1].isdigit(): # 3 . 03 -> $ 3.03 output_tokens[-1] += '.'+input_tokens[i + 1] i += 2 elif tok == "." and len(output_tokens) > 0 and len(output_tokens[-1]) == 1 and output_tokens[-1].isupper() and i < len(input_tokens) - 2 and len(input_tokens[i + 1]) == 1 and input_tokens[i + 1].isupper() and input_tokens[i + 2] == '.': # U . N . -> U.N. k = i+3 while k+2 < len(input_tokens): if len(input_tokens[k + 1]) == 1 and input_tokens[k + 1].isupper() and input_tokens[k + 2] == '.': k += 2 else: break output_tokens[-1] += ''.join(input_tokens[i:k]) i += 2 elif prev_dash and len(output_tokens) > 0 and tok[0] not in string.punctuation: output_tokens[-1] += tok i += 1 else: output_tokens.append(tok) i += 1 prev_dash = flag_prev_dash return " ".join(output_tokens) def process_eval(eval_fn): gold_list = [] with open(args.gold, "r", encoding="utf-8") as f_in: for l in f_in: line = l.strip() gold_list.append(line) pred_list = [] with open(eval_fn, "r", encoding="utf-8") as f_in: for l in f_in: buf = [] sentence = fix_tokenization(l.strip()).replace("(", " -LRB- ").replace(")", " -RRB- ") while " " in sentence: sentence = sentence.replace(" ", " ") buf.append(sentence) if args.trunc_len: num_left = args.trunc_len trunc_list = [] for bit in buf: tk_list = bit.split() n = min(len(tk_list), num_left) trunc_list.append(' '.join(tk_list[:n])) num_left -= n if num_left <= 0: break else: trunc_list = buf line = "\n".join(trunc_list) pred_list.append(line) with open(eval_fn+'.post', 'w', encoding='utf-8') as f_out: for l in pred_list: f_out.write(l.strip()) f_out.write('\n') # rouge scores if len(pred_list) < len(gold_list): # evaluate subset gold_list = gold_list[:len(pred_list)] assert len(pred_list) == len(gold_list) if args.perl: scores = test_rouge(pred_list, gold_list) else: scores = evaluator.get_scores(pred_list, [[it] for it in gold_list]) return eval_fn, scores def main(): if args.perl: eval_fn_list = list(glob.glob(args.pred)) else: eval_fn_list = [eval_fn for eval_fn in glob.glob(args.pred) if not( args.lazy_eval and Path(eval_fn+".rouge").exists())] eval_fn_list = list(filter(lambda fn: not(fn.endswith( '.post') or fn.endswith('.rouge')), eval_fn_list)) if args.only_eval_best: best_epoch_dict = {} for dir_path in set(Path(fn).parent for fn in eval_fn_list): fn_save = os.path.join(dir_path, 'save_best.dev') if Path(fn_save).exists(): with open(fn_save, 'r') as f_in: __, o_name, __ = f_in.read().strip().split('\n') epoch = o_name.split('.')[1] best_epoch_dict[dir_path] = epoch new_eval_fn_list = [] for fn in eval_fn_list: dir_path = Path(fn).parent if dir_path in best_epoch_dict: if Path(fn).name.split('.')[1] == best_epoch_dict[dir_path]: new_eval_fn_list.append(fn) eval_fn_list = new_eval_fn_list logger.info("***** Evaluation: %s *****", ','.join(eval_fn_list)) num_pool = min(args.processes, len(eval_fn_list)) p = Pool(num_pool) r_list = p.imap_unordered(process_eval, eval_fn_list) r_list = sorted([(fn, scores) for fn, scores in r_list], key=lambda x: x[0]) rg2_dict = {} for fn, scores in r_list: print(fn) if args.perl: print(rouge_results_to_str(scores)) else: rg2_dict[fn] = scores['rouge-2']['f'] print( "ROUGE-1: {}\tROUGE-2: {}\n".format(scores['rouge-1']['f'], scores['rouge-2']['f'])) with open(fn+".rouge", 'w') as f_out: f_out.write(json.dumps( {'rg1': scores['rouge-1']['f'], 'rg2': scores['rouge-2']['f']})) p.close() p.join() if args.save_best: # find best results group_dict = {} for k, v in rg2_dict.items(): d_name, o_name = Path(k).parent, Path(k).name if (d_name not in group_dict) or (v > group_dict[d_name][1]): group_dict[d_name] = (o_name, v) # compare and save the best result for k, v in group_dict.items(): fn = os.path.join(k, 'save_best.'+args.split) o_name_s, rst_s = v should_save = True if Path(fn).exists(): with open(fn, 'r') as f_in: rst_f = float(f_in.read().strip().split('\n')[-1]) if rst_s <= rst_f: should_save = False if should_save: with open(fn, 'w') as f_out: f_out.write('{0}\n{1}\n{2}\n'.format(k, o_name_s, rst_s)) logger.info("Should save: {}".format(json.dumps(v, indent=2))) if __name__ == "__main__": main()