Spaces:
Sleeping
Sleeping
| # ------------------------------------------ | |
| # TextDiffuser: Diffusion Models as Text Painters | |
| # Paper Link: https://arxiv.org/abs/2305.10855 | |
| # Code Link: https://github.com/microsoft/unilm/tree/master/textdiffuser | |
| # Copyright (c) Microsoft Corporation. | |
| # This file provides the inference script. | |
| # ------------------------------------------ | |
| import os | |
| import re | |
| import copy | |
| gts = { | |
| 'ChineseDrawText': [], | |
| 'DrawBenchText': [], | |
| 'DrawTextCreative': [], | |
| 'LAIONEval4000': [], | |
| 'OpenLibraryEval500': [], | |
| 'TMDBEval500': [], | |
| } | |
| results = { | |
| 'stablediffusion': {'cnt':0, 'p':0, 'r':0, 'f':0, 'acc':0}, | |
| 'textdiffuser': {'cnt':0, 'p':0, 'r':0, 'f':0, 'acc':0}, | |
| 'controlnet': {'cnt':0, 'p':0, 'r':0, 'f':0, 'acc':0}, | |
| 'deepfloyd': {'cnt':0, 'p':0, 'r':0, 'f':0, 'acc':0}, | |
| } | |
| def get_key_words(text: str): | |
| words = [] | |
| text = text | |
| matches = re.findall(r"'(.*?)'", text) # find the keywords enclosed by '' | |
| if matches: | |
| for match in matches: | |
| words.extend(match.split()) | |
| return words | |
| # load gt | |
| files = os.listdir('/path/to/MARIOEval') | |
| for file in files: | |
| lines = open(os.path.join('/path/to/MARIOEval', file, f'{file}.txt')).readlines() | |
| for line in lines: | |
| line = line.strip().lower() | |
| gts[file].append(get_key_words(line)) | |
| print(gts['ChineseDrawText'][:10]) | |
| def get_p_r_acc(method, pred, gt): | |
| pred = [p.strip().lower() for p in pred] | |
| gt = [g.strip().lower() for g in gt] | |
| pred_orig = copy.deepcopy(pred) | |
| gt_orig = copy.deepcopy(gt) | |
| pred_length = len(pred) | |
| gt_length = len(gt) | |
| for p in pred: | |
| if p in gt_orig: | |
| pred_orig.remove(p) | |
| gt_orig.remove(p) | |
| p = (pred_length - len(pred_orig)) / (pred_length + 1e-8) | |
| r = (gt_length - len(gt_orig)) / (gt_length + 1e-8) | |
| pred_sorted = sorted(pred) | |
| gt_sorted = sorted(gt) | |
| if ''.join(pred_sorted) == ''.join(gt_sorted): | |
| acc = 1 | |
| else: | |
| acc = 0 | |
| return p, r, acc | |
| files = os.listdir('/path/to/MaskTextSpotterV3/tools/ocr_result') | |
| print(len(files)) | |
| for file in files: | |
| method, dataset, prompt_index, image_index = file.strip().split('_') | |
| ocrs = open(os.path.join('/path/to/MaskTextSpotterV3/tools/ocr_result', file)).readlines() | |
| p, r, acc = get_p_r_acc(method, ocrs, gts[dataset][int(prompt_index)]) | |
| results[method]['cnt'] += 1 | |
| results[method]['p'] += p | |
| results[method]['r'] += r | |
| results[method]['acc'] += acc | |
| for method in results.keys(): | |
| results[method]['p'] /= results[method]['cnt'] | |
| results[method]['r'] /= results[method]['cnt'] | |
| results[method]['f'] = 2 * results[method]['p'] * results[method]['r'] / (results[method]['p'] + results[method]['r'] + 1e-8) | |
| results[method]['acc'] /= results[method]['cnt'] | |
| print(results) | |