Spaces:
Build error
Build error
| import torch | |
| import evaluate | |
| import random | |
| from unimernet.common.registry import registry | |
| from unimernet.tasks.base_task import BaseTask | |
| from unimernet.common.dist_utils import main_process | |
| import os.path as osp | |
| import json | |
| import numpy as np | |
| from torchtext.data import metrics | |
| from rapidfuzz.distance import Levenshtein | |
| class UniMERNet_Train(BaseTask): | |
| def __init__(self, temperature, do_sample, top_p, evaluate, report_metric=True, agg_metric="edit_distance"): | |
| super(UniMERNet_Train, self).__init__() | |
| self.temperature = temperature | |
| self.do_sample = do_sample | |
| self.top_p = top_p | |
| self.evaluate = evaluate | |
| self.agg_metric = agg_metric | |
| self.report_metric = report_metric | |
| def setup_task(cls, cfg): | |
| run_cfg = cfg.run_cfg | |
| generate_cfg = run_cfg.generate_cfg | |
| temperature = generate_cfg.get('temperature', .2) | |
| do_sample = generate_cfg.get("do_sample", False) | |
| top_p = generate_cfg.get("top_p", 0.95) | |
| evaluate = run_cfg.evaluate | |
| report_metric = run_cfg.get("report_metric", True) | |
| agg_metric = run_cfg.get("agg_metric", "edit_distance") | |
| return cls( | |
| temperature=temperature, | |
| do_sample=do_sample, | |
| top_p=top_p, | |
| evaluate=evaluate, | |
| report_metric=report_metric, | |
| agg_metric=agg_metric, | |
| ) | |
| def valid_step(self, model, samples): | |
| results = [] | |
| image, text = samples["image"], samples["text_input"] | |
| preds = model.generate( | |
| samples, | |
| temperature=self.temperature, | |
| do_sample=self.do_sample, | |
| top_p=self.top_p | |
| ) | |
| pred_tokens = preds["pred_tokens"] | |
| pred_strs = preds["pred_str"] | |
| pred_ids = preds["pred_ids"] # [b, n-1] | |
| truth_inputs = model.tokenizer.tokenize(text) | |
| truth_ids = truth_inputs["input_ids"][:, 1:] | |
| truth_tokens = model.tokenizer.detokenize(truth_inputs["input_ids"]) | |
| truth_strs = model.tokenizer.token2str(truth_inputs["input_ids"]) | |
| ids = samples["id"] | |
| for pred_token, pred_str, pred_id, truth_token, truth_str, truth_id, id_ in zip(pred_tokens, pred_strs, | |
| pred_ids, truth_tokens, | |
| truth_strs, truth_ids, ids): | |
| pred_id = pred_id.tolist() | |
| truth_id = truth_id.tolist() | |
| shape_diff = len(pred_id) - len(truth_id) | |
| if shape_diff < 0: | |
| pred_id = pred_id + [model.tokenizer.pad_token_id] * (-shape_diff) | |
| else: | |
| truth_id = truth_id + [model.tokenizer.pad_token_id] * shape_diff | |
| pred_id, truth_id = torch.LongTensor(pred_id), torch.LongTensor(truth_id) | |
| mask = torch.logical_or(pred_id != model.tokenizer.pad_token_id, truth_id != model.tokenizer.pad_token_id) | |
| tok_acc = (pred_id == truth_id)[mask].float().mean().item() | |
| this_item = { | |
| "pred_token": pred_token, | |
| "pred_str": pred_str, | |
| "truth_str": truth_str, | |
| "truth_token": truth_token, | |
| "token_acc": tok_acc, | |
| "id": id_ | |
| } | |
| results.append(this_item) | |
| return results | |
| def after_evaluation(self, val_result, split_name, epoch, **kwargs): | |
| eval_result_file = self.save_result( | |
| result=val_result, | |
| result_dir=registry.get_path("result_dir"), | |
| filename="{}_epoch{}".format(split_name, epoch), | |
| remove_duplicate="id", | |
| ) | |
| if self.report_metric: | |
| metrics = self._report_metrics( | |
| eval_result_file=eval_result_file, split_name=split_name | |
| ) | |
| else: | |
| metrics = {"agg_metrics": 0.0} | |
| return metrics | |
| def _report_metrics(self, eval_result_file, split_name): | |
| with open(eval_result_file) as f: | |
| results = json.load(f) | |
| edit_dists = [] | |
| all_pred_tokens = [] | |
| all_truth_tokens = [] | |
| all_pred_strs = [] | |
| all_truth_strs = [] | |
| token_accs = [] | |
| for result in results: | |
| pred_token, pred_str, truth_token, truth_str, tok_acc = result["pred_token"], result["pred_str"], result[ | |
| "truth_token"], result["truth_str"], result["token_acc"] | |
| if len(truth_str) > 0: | |
| norm_edit_dist = Levenshtein.normalized_distance(pred_str, truth_str) | |
| edit_dists.append(norm_edit_dist) | |
| all_pred_tokens.append(pred_token) | |
| all_truth_tokens.append([truth_token]) | |
| all_pred_strs.append(pred_str) | |
| all_truth_strs.append(truth_str) | |
| token_accs.append(tok_acc) | |
| # bleu_score = metrics.bleu_score(all_pred_tokens, all_truth_tokens) | |
| bleu = evaluate.load("bleu", keep_in_memory=True, experiment_id=random.randint(1, 1e8)) | |
| bleu_results = bleu.compute(predictions=all_pred_strs, references=all_truth_strs) | |
| bleu_score = bleu_results['bleu'] | |
| edit_distance = np.mean(edit_dists) | |
| token_accuracy = np.mean(token_accs) | |
| eval_ret = {"bleu": bleu_score, "edit_distance": edit_distance, "token_accuracy": token_accuracy} | |
| log_stats = {split_name: {k: v for k, v in eval_ret.items()}} | |
| with open( | |
| osp.join(registry.get_path("output_dir"), "evaluate.txt"), "a" | |
| ) as f: | |
| f.write(json.dumps(log_stats) + "\n") | |
| coco_res = {k: v for k, v in eval_ret.items()} | |
| # agg_metrics = sum([v for v in eval_ret.values()]) | |
| if "edit" in self.agg_metric.lower(): # edit_distance | |
| agg_metrics = (1 - edit_distance) * 100 | |
| elif "bleu" in self.agg_metric.lower(): # bleu_score | |
| agg_metrics = bleu_score * 100 | |
| elif "token" in self.agg_metric.lower(): # token_accuracy | |
| agg_metrics = token_accuracy * 100 | |
| else: | |
| raise ValueError(f"Invalid metrics: '{self.agg_metric}'") | |
| coco_res["agg_metrics"] = agg_metrics | |
| return coco_res | |