Commit
Β·
a2f941d
1
Parent(s):
acc872d
update
Browse files- data/answer.enc +0 -0
- main.py +17 -52
data/answer.enc
ADDED
The diff for this file is too large to render.
See raw diff
|
|
main.py
CHANGED
@@ -14,7 +14,7 @@ def load_and_decrypt_qrel(secret_key):
|
|
14 |
decrypted_data = cipher.decrypt(encrypted_data).decode("utf-8")
|
15 |
raw_data = json.loads(decrypted_data)
|
16 |
|
17 |
-
#
|
18 |
qrel_dict = defaultdict(lambda: defaultdict(dict))
|
19 |
for dataset, records in raw_data.items():
|
20 |
for item in records:
|
@@ -22,18 +22,15 @@ def load_and_decrypt_qrel(secret_key):
|
|
22 |
qrel_dict[dataset][qid][cid] = score
|
23 |
return qrel_dict
|
24 |
except Exception as e:
|
25 |
-
raise ValueError(f"Failed to decrypt answer file: {str(e)}")
|
26 |
|
27 |
def recall_at_k(corpus_top_100_list, relevant_ids, k=1):
|
28 |
return int(any(item in relevant_ids for item in corpus_top_100_list[:k]))
|
29 |
|
30 |
def ndcg_at_k(corpus_top_100_list, rel_dict, k):
|
31 |
all_items = list(dict.fromkeys(corpus_top_100_list + list(rel_dict.keys())))
|
32 |
-
|
33 |
y_true = [rel_dict.get(item, 0) for item in all_items]
|
34 |
-
|
35 |
y_score = [len(all_items) - i for i in range(len(all_items))]
|
36 |
-
|
37 |
return ndcg_score([y_true], [y_score], k=k)
|
38 |
|
39 |
def evaluate(pred_data, qrel_dict):
|
@@ -43,7 +40,6 @@ def evaluate(pred_data, qrel_dict):
|
|
43 |
continue
|
44 |
|
45 |
recall_1, ndcg_10, ndcg_100 = [], [], []
|
46 |
-
|
47 |
for item in queries:
|
48 |
qid = item["query_id"]
|
49 |
corpus_top_100_list = item["corpus_top_100_list"].split(",")
|
@@ -63,15 +59,17 @@ def evaluate(pred_data, qrel_dict):
|
|
63 |
|
64 |
return results
|
65 |
|
66 |
-
# ==== Gradio Wrapper ====
|
67 |
def process_json(file):
|
68 |
try:
|
69 |
pred_data = json.load(open(file))
|
70 |
except Exception as e:
|
71 |
-
return f"Invalid JSON format: {str(e)}"
|
|
|
|
|
|
|
|
|
72 |
|
73 |
try:
|
74 |
-
secret_key = os.getenv("SECRET_KEY")
|
75 |
qrel_dict = load_and_decrypt_qrel(secret_key)
|
76 |
except Exception as e:
|
77 |
return str(e)
|
@@ -80,68 +78,35 @@ def process_json(file):
|
|
80 |
metrics = evaluate(pred_data, qrel_dict)
|
81 |
return json.dumps(metrics, indent=2)
|
82 |
except Exception as e:
|
83 |
-
return f"Error during evaluation: {str(e)}"
|
84 |
-
|
85 |
-
# ==== Launch Gradio App ====
|
86 |
-
# def main_gradio():
|
87 |
-
# example_json = '''{
|
88 |
-
# "Google_WIT": [
|
89 |
-
# {"query_id": "1", "corpus_top_100_list": "5, 2, 8, ..."},
|
90 |
-
# {"query_id": "2", "corpus_top_100_list": "90, 13, 3, ..."}
|
91 |
-
# ],
|
92 |
-
# "MSCOCO": [
|
93 |
-
# {"query_id": "3", "corpus_top_100_list": "122, 35, 22, ..."},
|
94 |
-
# {"query_id": "2", "corpus_top_100_list": "90, 19, 3, ..."}
|
95 |
-
# ]
|
96 |
-
# "OVEN": [
|
97 |
-
# {"query_id": "3", "corpus_top_100_list": "11, 15, 22, ..."}
|
98 |
-
# ]
|
99 |
-
# "VisualNews": [
|
100 |
-
# {"query_id": "3", "corpus_top_100_list": "101, 35, 22, ..."}
|
101 |
-
# ]
|
102 |
-
# }'''
|
103 |
-
# gr.Interface(
|
104 |
-
# fn=process_json,
|
105 |
-
# inputs=gr.File(label="Upload Retrieval Result (JSON)"),
|
106 |
-
# outputs=gr.Textbox(label="Results"),
|
107 |
-
# title="Automated Evaluation of MixBench",
|
108 |
-
# description="Upload a prediction JSON to evaluate Recall@1, NDCG@10, and NDCG@100 against encrypted qrels.\n\nExample input:\n" + example_json
|
109 |
-
# ).launch(share=True)
|
110 |
def main_gradio():
|
111 |
example_json_html = (
|
112 |
-
'{<br>'
|
113 |
' "Google_WIT": [<br>'
|
114 |
' {"query_id": "1", "corpus_top_100_list": "5, 2, 8, ..."},<br>'
|
115 |
' {"query_id": "2", "corpus_top_100_list": "90, 13, 3, ..."}<br>'
|
116 |
' ],<br>'
|
117 |
' "MSCOCO": [<br>'
|
118 |
-
' {"query_id": "3", "corpus_top_100_list": "122, 35, 22, ..."}
|
119 |
-
' {"query_id": "2", "corpus_top_100_list": "90, 19, 3, ..."}<br>'
|
120 |
-
' ],<br>'
|
121 |
-
' "OVEN": [<br>'
|
122 |
-
' {"query_id": "3", "corpus_top_100_list": "11, 15, 22, ..."}<br>'
|
123 |
-
' ],<br>'
|
124 |
-
' "VisualNews": [<br>'
|
125 |
-
' {"query_id": "3", "corpus_top_100_list": "101, 35, 22, ..."}<br>'
|
126 |
' ]<br>'
|
127 |
-
'}'
|
128 |
)
|
129 |
|
130 |
gr.Interface(
|
131 |
fn=process_json,
|
132 |
inputs=gr.File(label="Upload Retrieval Result (JSON)"),
|
133 |
-
outputs=gr.Textbox(label="Results"),
|
134 |
-
title="Automated Evaluation of MixBench",
|
135 |
description=(
|
136 |
"Please upload your model's retrieval result on MixBench (in JSON format) to automatically evaluate its performance.<br><br>"
|
137 |
"For each subset (e.g., <code>MSCOCO</code>, <code>Google_WIT</code>, <code>VisualNews</code>, <code>OVEN</code>), "
|
138 |
-
"we
|
139 |
"- <strong>Recall@1</strong><br>"
|
140 |
"- <strong>NDCG@10</strong><br>"
|
141 |
"- <strong>NDCG@100</strong><br><br>"
|
142 |
-
"Expected input JSON format:<br><br>"
|
143 |
-
|
144 |
-
"<br>For reference query IDs, see the "
|
145 |
"<a href='https://huggingface.co/datasets/mixed-modality-search/MixBench2025/viewer/Google_WIT/mixed_corpus' target='_blank'>MixBench2025 dataset viewer</a>."
|
146 |
)
|
147 |
).launch(share=True)
|
|
|
14 |
decrypted_data = cipher.decrypt(encrypted_data).decode("utf-8")
|
15 |
raw_data = json.loads(decrypted_data)
|
16 |
|
17 |
+
# Convert to: dataset -> query_id -> {corpus_id: score}
|
18 |
qrel_dict = defaultdict(lambda: defaultdict(dict))
|
19 |
for dataset, records in raw_data.items():
|
20 |
for item in records:
|
|
|
22 |
qrel_dict[dataset][qid][cid] = score
|
23 |
return qrel_dict
|
24 |
except Exception as e:
|
25 |
+
raise ValueError(f"β Failed to decrypt answer file: {str(e)}")
|
26 |
|
27 |
def recall_at_k(corpus_top_100_list, relevant_ids, k=1):
|
28 |
return int(any(item in relevant_ids for item in corpus_top_100_list[:k]))
|
29 |
|
30 |
def ndcg_at_k(corpus_top_100_list, rel_dict, k):
|
31 |
all_items = list(dict.fromkeys(corpus_top_100_list + list(rel_dict.keys())))
|
|
|
32 |
y_true = [rel_dict.get(item, 0) for item in all_items]
|
|
|
33 |
y_score = [len(all_items) - i for i in range(len(all_items))]
|
|
|
34 |
return ndcg_score([y_true], [y_score], k=k)
|
35 |
|
36 |
def evaluate(pred_data, qrel_dict):
|
|
|
40 |
continue
|
41 |
|
42 |
recall_1, ndcg_10, ndcg_100 = [], [], []
|
|
|
43 |
for item in queries:
|
44 |
qid = item["query_id"]
|
45 |
corpus_top_100_list = item["corpus_top_100_list"].split(",")
|
|
|
59 |
|
60 |
return results
|
61 |
|
|
|
62 |
def process_json(file):
|
63 |
try:
|
64 |
pred_data = json.load(open(file))
|
65 |
except Exception as e:
|
66 |
+
return f"β Invalid JSON format: {str(e)}"
|
67 |
+
|
68 |
+
secret_key = os.getenv("SECRET_KEY")
|
69 |
+
if not secret_key:
|
70 |
+
return "β SECRET_KEY environment variable not set. Please configure it in your Hugging Face Space."
|
71 |
|
72 |
try:
|
|
|
73 |
qrel_dict = load_and_decrypt_qrel(secret_key)
|
74 |
except Exception as e:
|
75 |
return str(e)
|
|
|
78 |
metrics = evaluate(pred_data, qrel_dict)
|
79 |
return json.dumps(metrics, indent=2)
|
80 |
except Exception as e:
|
81 |
+
return f"β Error during evaluation: {str(e)}"
|
82 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
def main_gradio():
|
84 |
example_json_html = (
|
85 |
+
'<pre><code>{<br>'
|
86 |
' "Google_WIT": [<br>'
|
87 |
' {"query_id": "1", "corpus_top_100_list": "5, 2, 8, ..."},<br>'
|
88 |
' {"query_id": "2", "corpus_top_100_list": "90, 13, 3, ..."}<br>'
|
89 |
' ],<br>'
|
90 |
' "MSCOCO": [<br>'
|
91 |
+
' {"query_id": "3", "corpus_top_100_list": "122, 35, 22, ..."}<br>'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
' ]<br>'
|
93 |
+
'}</code></pre>'
|
94 |
)
|
95 |
|
96 |
gr.Interface(
|
97 |
fn=process_json,
|
98 |
inputs=gr.File(label="Upload Retrieval Result (JSON)"),
|
99 |
+
outputs=gr.Textbox(label="Evaluation Results"),
|
100 |
+
title="π Automated Evaluation of MixBench",
|
101 |
description=(
|
102 |
"Please upload your model's retrieval result on MixBench (in JSON format) to automatically evaluate its performance.<br><br>"
|
103 |
"For each subset (e.g., <code>MSCOCO</code>, <code>Google_WIT</code>, <code>VisualNews</code>, <code>OVEN</code>), "
|
104 |
+
"we compute:<br>"
|
105 |
"- <strong>Recall@1</strong><br>"
|
106 |
"- <strong>NDCG@10</strong><br>"
|
107 |
"- <strong>NDCG@100</strong><br><br>"
|
108 |
+
"Expected input JSON format:<br><br>" + example_json_html +
|
109 |
+
"<br>To find valid query IDs, see the "
|
|
|
110 |
"<a href='https://huggingface.co/datasets/mixed-modality-search/MixBench2025/viewer/Google_WIT/mixed_corpus' target='_blank'>MixBench2025 dataset viewer</a>."
|
111 |
)
|
112 |
).launch(share=True)
|