|
import gradio as gr |
|
import json |
|
import os |
|
import numpy as np |
|
from cryptography.fernet import Fernet |
|
from collections import defaultdict |
|
from sklearn.metrics import ndcg_score |
|
|
|
def load_and_decrypt_qrel(secret_key): |
|
try: |
|
with open("data/answer.enc", "rb") as enc_file: |
|
encrypted_data = enc_file.read() |
|
cipher = Fernet(secret_key.encode()) |
|
decrypted_data = cipher.decrypt(encrypted_data).decode("utf-8") |
|
raw_data = json.loads(decrypted_data) |
|
|
|
|
|
qrel_dict = defaultdict(lambda: defaultdict(dict)) |
|
for dataset, records in raw_data.items(): |
|
for item in records: |
|
qid, cid, score = item["query_id"], item["corpus_id"], item["score"] |
|
qrel_dict[dataset][qid][cid] = score |
|
return qrel_dict |
|
except Exception as e: |
|
raise ValueError(f"β Failed to decrypt answer file: {str(e)}") |
|
|
|
def recall_at_k(corpus_top_100_list, relevant_ids, k=1): |
|
return int(any(item in relevant_ids for item in corpus_top_100_list[:k])) |
|
|
|
def ndcg_at_k(corpus_top_100_list, rel_dict, k): |
|
all_items = list(dict.fromkeys(corpus_top_100_list + list(rel_dict.keys()))) |
|
y_true = [rel_dict.get(item, 0) for item in all_items] |
|
y_score = [len(all_items) - i for i in range(len(all_items))] |
|
return ndcg_score([y_true], [y_score], k=k) |
|
|
|
def evaluate(pred_data, qrel_dict): |
|
results = {} |
|
for dataset, queries in pred_data.items(): |
|
if dataset not in qrel_dict: |
|
continue |
|
|
|
recall_1, ndcg_10, ndcg_100 = [], [], [] |
|
for item in queries: |
|
qid = item["query_id"] |
|
corpus_top_100_list = item["corpus_top_100_list"].split(",") |
|
corpus_top_100_list = [x.strip() for x in corpus_top_100_list if x.strip()] |
|
rel_dict = qrel_dict[dataset].get(qid, {}) |
|
relevant_ids = [cid for cid, score in rel_dict.items() if score > 0] |
|
|
|
recall_1.append(recall_at_k(corpus_top_100_list, relevant_ids, 1)) |
|
ndcg_10.append(ndcg_at_k(corpus_top_100_list, rel_dict, 10)) |
|
ndcg_100.append(ndcg_at_k(corpus_top_100_list, rel_dict, 100)) |
|
|
|
results[dataset] = { |
|
"Recall@1": round(np.mean(recall_1) * 100, 2), |
|
"NDCG@10": round(np.mean(ndcg_10) * 100, 2), |
|
"NDCG@100": round(np.mean(ndcg_100) * 100, 2), |
|
} |
|
|
|
return results |
|
|
|
def process_json(file): |
|
try: |
|
pred_data = json.load(open(file)) |
|
except Exception as e: |
|
return f"β Invalid JSON format: {str(e)}" |
|
|
|
secret_key = os.getenv("SECRET_KEY") |
|
if not secret_key: |
|
return "β SECRET_KEY environment variable not set. Please configure it in your Hugging Face Space." |
|
|
|
try: |
|
qrel_dict = load_and_decrypt_qrel(secret_key) |
|
except Exception as e: |
|
return str(e) |
|
|
|
try: |
|
metrics = evaluate(pred_data, qrel_dict) |
|
return json.dumps(metrics, indent=2) |
|
except Exception as e: |
|
return f"β Error during evaluation: {str(e)}" |
|
|
|
def main_gradio(): |
|
example_json_html = ( |
|
'<pre><code>{<br>' |
|
' "Google_WIT": [<br>' |
|
' {"query_id": "1", "corpus_top_100_list": "5, 2, 8, ..."},<br>' |
|
' {"query_id": "2", "corpus_top_100_list": "90, 13, 3, ..."}<br>' |
|
' ],<br>' |
|
' "MSCOCO": [<br>' |
|
' {"query_id": "3", "corpus_top_100_list": "122, 35, 22, ..."}<br>' |
|
' ]<br>' |
|
'}</code></pre>' |
|
) |
|
|
|
gr.Interface( |
|
fn=process_json, |
|
inputs=gr.File(label="Upload Retrieval Result (JSON)"), |
|
outputs=gr.Textbox(label="Evaluation Results"), |
|
title="π Automated Evaluation of MixBench", |
|
description=( |
|
"Please upload your model's retrieval result on MixBench (in JSON format) to automatically evaluate its performance.<br><br>" |
|
"For each subset (e.g., <code>MSCOCO</code>, <code>Google_WIT</code>, <code>VisualNews</code>, <code>OVEN</code>), " |
|
"we compute:<br>" |
|
"- <strong>Recall@1</strong><br>" |
|
"- <strong>NDCG@10</strong><br>" |
|
"- <strong>NDCG@100</strong><br><br>" |
|
"Expected input JSON format:<br><br>" + example_json_html + |
|
"<br>To find valid query IDs, see the " |
|
"<a href='https://huggingface.co/datasets/mixed-modality-search/MixBench2025/viewer/Google_WIT/mixed_corpus' target='_blank'>MixBench2025 dataset viewer</a>." |
|
) |
|
).launch(share=True) |
|
|
|
if __name__ == "__main__": |
|
main_gradio() |
|
|