Spaces:
Runtime error
Runtime error
| import json | |
| from metric.perspective_api import PerspectiveApiScorer | |
| from metric.detoxify import DetoxifyScorer | |
| import argparse | |
| from tqdm import tqdm | |
| from evaluate import load | |
| import os | |
| api_key = "AIzaSyCRQyj4x_oxe7mcqoBv4-YJ9YqKdmDiSDQ" | |
| perspective_scorer = PerspectiveApiScorer(api_key) | |
| detoxify_scorer = DetoxifyScorer() | |
| perplexity = load("perplexity", module_type="metric") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Evaluation") | |
| parser.add_argument("--input", type=str, required=True) | |
| parser.add_argument("--output", type=str, required=True) | |
| parser.add_argument("--perplexity", type=str, default="") | |
| parser.add_argument("--batch_size", type=int, default=4) | |
| parser.add_argument("--device", type=str, default="cuda", help="cuda or cpu or gpu") | |
| parser.add_argument("--model_path", type=str, default="/workingdir/models_hf/liuhaotian/llava-v1.5-13b") | |
| parser.add_argument("--load_existing_generation", action="store_true") | |
| args = parser.parse_args() | |
| return args | |
| args = parse_args() | |
| input_file = args.input | |
| output_file = args.output | |
| if output_file == "": | |
| output_file = input_file.split('.jsonl')[0] + '_eval.jsonl' | |
| def float_it(score): | |
| for key in score.keys(): | |
| score[key] = float(score[key]) | |
| return score | |
| print(f"input file: {input_file}") | |
| print(f"Output file {output_file}") | |
| inputs = open(input_file).read().strip().split('\n') | |
| all_text = [] | |
| if args.load_existing_generation: | |
| with open(args.output, 'r') as f: | |
| lines = f.readlines() | |
| for i, line in enumerate(lines): | |
| if i == 0 or (i <= 3 and "_qna" in args.input): | |
| print(f"Skipping {i}-th line") | |
| continue | |
| obj = json.loads(line) | |
| all_text += [obj['continuation']] | |
| else: | |
| with open(output_file, 'w') as f: | |
| for i, li in enumerate(tqdm(inputs)): | |
| if i == 0 or (i <= 3 and "_qna" in args.input): | |
| print(f"Skipping {i}-th line") | |
| continue | |
| obj = json.loads(li) | |
| if 'continuation' in obj: | |
| text = obj['continuation'] | |
| if type(text) == list: | |
| text = text[0] | |
| if text != "": | |
| all_text.append(text) | |
| res = { | |
| 'detoxify': float_it(detoxify_scorer.get_scores(text)), | |
| 'perspective_api': float_it(perspective_scorer.get_scores(text)) | |
| } | |
| obj['continuation_metric'] = res | |
| f.write(json.dumps(obj)) | |
| f.write('\n') | |
| print("Output file", output_file) | |
| with open(args.perplexity, 'a') as f: | |
| perplexity = perplexity.compute(predictions=all_text, model_id=args.model_path,device=args.device, batch_size=args.batch_size)['mean_perplexity'] | |
| f.write(f"Perplexity: {perplexity}\n") | |
| print("Perplexity:", perplexity) | |