File size: 2,809 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# ------------------------------------------
# TextDiffuser: Diffusion Models as Text Painters
# Paper Link: https://arxiv.org/abs/2305.10855
# Code Link: https://github.com/microsoft/unilm/tree/master/textdiffuser
# Copyright (c) Microsoft Corporation.
# This file provides the inference script.
# ------------------------------------------

import os
import re
import copy

gts = {
    'ChineseDrawText': [],
    'DrawBenchText': [],
    'DrawTextCreative': [],
    'LAIONEval4000': [],
    'OpenLibraryEval500': [],
    'TMDBEval500': [],
}

results = {
    'stablediffusion': {'cnt':0, 'p':0, 'r':0, 'f':0, 'acc':0},
    'textdiffuser': {'cnt':0, 'p':0, 'r':0, 'f':0, 'acc':0},
    'controlnet': {'cnt':0, 'p':0, 'r':0, 'f':0, 'acc':0},
    'deepfloyd': {'cnt':0, 'p':0, 'r':0, 'f':0, 'acc':0},
}

def get_key_words(text: str):
    words = []
    text = text
    matches = re.findall(r"'(.*?)'", text) # find the keywords enclosed by ''
    if matches:
        for match in matches:
            words.extend(match.split())
   
    return words


# load gt
files = os.listdir('/path/to/MARIOEval')
for file in files:
    lines = open(os.path.join('/path/to/MARIOEval', file, f'{file}.txt')).readlines()
    for line in lines:
        line = line.strip().lower()
        gts[file].append(get_key_words(line))
print(gts['ChineseDrawText'][:10])


def get_p_r_acc(method, pred, gt):

    pred = [p.strip().lower() for p in pred] 
    gt = [g.strip().lower() for g in gt]

    pred_orig = copy.deepcopy(pred)
    gt_orig = copy.deepcopy(gt)

    pred_length = len(pred)
    gt_length = len(gt)

    for p in pred:
        if p in gt_orig:
            pred_orig.remove(p) 
            gt_orig.remove(p)

    p = (pred_length - len(pred_orig)) / (pred_length + 1e-8)
    r = (gt_length - len(gt_orig)) / (gt_length + 1e-8)
   
    pred_sorted = sorted(pred)
    gt_sorted = sorted(gt)
    if ''.join(pred_sorted) == ''.join(gt_sorted):
        acc = 1
    else:
        acc = 0

    return p, r, acc


files = os.listdir('/path/to/MaskTextSpotterV3/tools/ocr_result')
print(len(files))

for file in files:
    method, dataset, prompt_index, image_index = file.strip().split('_')
    ocrs = open(os.path.join('/path/to/MaskTextSpotterV3/tools/ocr_result', file)).readlines()
    p, r, acc = get_p_r_acc(method, ocrs, gts[dataset][int(prompt_index)])
    results[method]['cnt'] += 1
    results[method]['p'] += p
    results[method]['r'] += r
    results[method]['acc'] += acc

for method in results.keys():
    results[method]['p'] /= results[method]['cnt']
    results[method]['r'] /= results[method]['cnt']
    results[method]['f'] = 2 * results[method]['p'] * results[method]['r'] / (results[method]['p'] + results[method]['r'] + 1e-8)
    results[method]['acc'] /= results[method]['cnt']
    
print(results)