Spaces:
Runtime error
Runtime error
| import argparse | |
| import numpy as np | |
| from pathlib import Path | |
| import tqdm | |
| from pprint import pprint | |
| import torch | |
| from torch.nn.utils.rnn import pad_sequence | |
| from scrl.config import load_config | |
| from scrl.training import setup_and_train | |
| from scrl.model import labels_to_summary | |
| from scrl.eval_metrics import compute_token_f1 | |
| import scrl.utils as utils | |
| from nltk import word_tokenize | |
| def evaluate_validation_reward(args, manager, model, tokenizer, reward_generator, dataset): | |
| device = args.device | |
| idx_range = list(range(len(dataset))) | |
| dataset_indices = list(utils.batchify(idx_range, args.batch_size)) | |
| rewards = [] | |
| for i, indices in enumerate(dataset_indices): | |
| if args.max_val_steps != None and i >= args.max_val_steps: | |
| break | |
| batch = dataset[indices] | |
| input_ids = batch["input_ids"] | |
| input_ids = pad_sequence( | |
| [torch.tensor(ids) for ids in input_ids], batch_first=True | |
| ) | |
| logits = model(input_ids.to(device)) | |
| probs = torch.softmax(logits, dim=2) | |
| argmax_labels = torch.argmax(logits, dim=2).to(device) | |
| argmax_summaries = labels_to_summary(input_ids, argmax_labels, tokenizer) | |
| argmax_rewards, _ = reward_generator(batch["document"], argmax_summaries) | |
| rewards += argmax_rewards | |
| avg_reward = np.mean(rewards) | |
| return avg_reward | |
| def evaluate_validation_dataset(args, manager, model, tokenizer, reward_generator, dataset_path): | |
| f1_scores = [] | |
| dataset = list(utils.read_jsonl(dataset_path)) | |
| dump_data = [] | |
| for item in tqdm.tqdm(dataset): | |
| src = item["text"] | |
| tgts = item["summaries"] | |
| input_ids = torch.tensor(tokenizer([src])["input_ids"]).to(args.device) | |
| logits = model.forward(input_ids) | |
| argmax_labels = torch.argmax(logits, dim=2) | |
| pred = labels_to_summary(input_ids, argmax_labels, tokenizer)[0] | |
| pred_tokens = word_tokenize(pred) | |
| src_tokens = word_tokenize(src) | |
| item_scores = [] | |
| for tgt in tgts: | |
| tgt_tokens = word_tokenize(tgt) | |
| pred_tokens = [t.lower() for t in pred_tokens] | |
| tgt_tokens = [t.lower() for t in tgt_tokens] | |
| token_f1 = compute_token_f1( | |
| tgt_tokens, pred_tokens, use_counts=True | |
| ) | |
| item_scores.append(token_f1) | |
| if args.dump: | |
| probs = torch.softmax(logits, dim=2)[0].detach().tolist() | |
| dump_item = { | |
| "probs": probs, | |
| "source": src, | |
| "target": tgts[0], | |
| "f1-score": item_scores[0], | |
| "pred_summary": pred, | |
| "pred_labels": argmax_labels[0].tolist(), | |
| } | |
| dump_data.append(dump_item) | |
| item_score = np.mean(item_scores) | |
| f1_scores.append(item_score) | |
| score = np.mean(f1_scores) | |
| if args.dump: | |
| dataset_name = dataset_path.name.split(".jsonl")[0] | |
| dump_dir = manager.dir / f"dump-{dataset_name}" | |
| dump_dir.mkdir(exist_ok=True) | |
| utils.write_jsonl( | |
| dump_data, | |
| dump_dir / f"step-{manager.step}.jsonl", | |
| "w" | |
| ) | |
| return score | |
| def evaluate(args, manager, model, tokenizer, reward_generator, holdout_data): | |
| step = manager.step | |
| val_reward = evaluate_validation_reward(args, manager, model, tokenizer, reward_generator, holdout_data) | |
| reward_path = manager.dir / "val_rewards.jsonl" | |
| if reward_path.exists(): | |
| reward_results = list(utils.read_jsonl(reward_path)) | |
| prev_max = max([x["score"] for x in reward_results]) | |
| else: | |
| reward_results = [] | |
| prev_max = 0 | |
| if val_reward > prev_max: | |
| manager.save_model(model, step, "best_val_reward") | |
| reward_results.append({"step": step, "score": val_reward}) | |
| utils.write_jsonl(reward_results, reward_path, "w") | |
| if args.verbose: | |
| print("Validation Rewards:") | |
| pprint(reward_results) | |
| print() | |
| # only used if a validation dataset is specified in config | |
| for val_data_path in args.validation_datasets: | |
| val_data_path = Path(val_data_path) | |
| dataset_name = val_data_path.name.split(".jsonl")[0] | |
| dataset_score = evaluate_validation_dataset( | |
| args, manager, model, tokenizer, reward_generator, val_data_path | |
| ) | |
| result_path = Path(manager.dir / f"val_data_results.{dataset_name}.jsonl") | |
| if result_path.exists(): | |
| dataset_results = list(utils.read_jsonl(result_path)) | |
| prev_max = max([x["score"] for x in dataset_results]) | |
| else: | |
| dataset_results = [] | |
| prev_max = 0 | |
| if dataset_score > prev_max: | |
| manager.save_model(model, step, f"best_on_{dataset_name}") | |
| dataset_results.append({"step": step, "score": dataset_score}) | |
| utils.write_jsonl(dataset_results, result_path, "w") | |
| if args.verbose: | |
| print(f"Validation Dataset Results for {dataset_name}:") | |
| pprint(dataset_results) | |
| print() | |
| def main(args): | |
| utils.set_random_seed(0) | |
| setup_and_train(args, eval_func=evaluate) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", help="path to JSON config file") | |
| parser.add_argument("--device", default="cuda") | |
| parser.add_argument("--dump", action="store_true") | |
| parser.add_argument("--verbose", action="store_true") | |
| parser.add_argument( | |
| "--fresh", | |
| action="store_true", | |
| help="delete model directory and start from scratch" | |
| ) | |
| main(load_config(parser.parse_args())) | |