mixed-modality-search's picture
update
a2f941d
raw
history blame
4.65 kB
import gradio as gr
import json
import os
import numpy as np
from cryptography.fernet import Fernet
from collections import defaultdict
from sklearn.metrics import ndcg_score
def load_and_decrypt_qrel(secret_key):
try:
with open("data/answer.enc", "rb") as enc_file:
encrypted_data = enc_file.read()
cipher = Fernet(secret_key.encode())
decrypted_data = cipher.decrypt(encrypted_data).decode("utf-8")
raw_data = json.loads(decrypted_data)
# Convert to: dataset -> query_id -> {corpus_id: score}
qrel_dict = defaultdict(lambda: defaultdict(dict))
for dataset, records in raw_data.items():
for item in records:
qid, cid, score = item["query_id"], item["corpus_id"], item["score"]
qrel_dict[dataset][qid][cid] = score
return qrel_dict
except Exception as e:
raise ValueError(f"❌ Failed to decrypt answer file: {str(e)}")
def recall_at_k(corpus_top_100_list, relevant_ids, k=1):
return int(any(item in relevant_ids for item in corpus_top_100_list[:k]))
def ndcg_at_k(corpus_top_100_list, rel_dict, k):
all_items = list(dict.fromkeys(corpus_top_100_list + list(rel_dict.keys())))
y_true = [rel_dict.get(item, 0) for item in all_items]
y_score = [len(all_items) - i for i in range(len(all_items))]
return ndcg_score([y_true], [y_score], k=k)
def evaluate(pred_data, qrel_dict):
results = {}
for dataset, queries in pred_data.items():
if dataset not in qrel_dict:
continue
recall_1, ndcg_10, ndcg_100 = [], [], []
for item in queries:
qid = item["query_id"]
corpus_top_100_list = item["corpus_top_100_list"].split(",")
corpus_top_100_list = [x.strip() for x in corpus_top_100_list if x.strip()]
rel_dict = qrel_dict[dataset].get(qid, {})
relevant_ids = [cid for cid, score in rel_dict.items() if score > 0]
recall_1.append(recall_at_k(corpus_top_100_list, relevant_ids, 1))
ndcg_10.append(ndcg_at_k(corpus_top_100_list, rel_dict, 10))
ndcg_100.append(ndcg_at_k(corpus_top_100_list, rel_dict, 100))
results[dataset] = {
"Recall@1": round(np.mean(recall_1) * 100, 2),
"NDCG@10": round(np.mean(ndcg_10) * 100, 2),
"NDCG@100": round(np.mean(ndcg_100) * 100, 2),
}
return results
def process_json(file):
try:
pred_data = json.load(open(file))
except Exception as e:
return f"❌ Invalid JSON format: {str(e)}"
secret_key = os.getenv("SECRET_KEY")
if not secret_key:
return "❌ SECRET_KEY environment variable not set. Please configure it in your Hugging Face Space."
try:
qrel_dict = load_and_decrypt_qrel(secret_key)
except Exception as e:
return str(e)
try:
metrics = evaluate(pred_data, qrel_dict)
return json.dumps(metrics, indent=2)
except Exception as e:
return f"❌ Error during evaluation: {str(e)}"
def main_gradio():
example_json_html = (
'<pre><code>{<br>'
'&nbsp;&nbsp;"Google_WIT": [<br>'
'&nbsp;&nbsp;&nbsp;&nbsp;{"query_id": "1", "corpus_top_100_list": "5, 2, 8, ..."},<br>'
'&nbsp;&nbsp;&nbsp;&nbsp;{"query_id": "2", "corpus_top_100_list": "90, 13, 3, ..."}<br>'
'&nbsp;&nbsp;],<br>'
'&nbsp;&nbsp;"MSCOCO": [<br>'
'&nbsp;&nbsp;&nbsp;&nbsp;{"query_id": "3", "corpus_top_100_list": "122, 35, 22, ..."}<br>'
'&nbsp;&nbsp;]<br>'
'}</code></pre>'
)
gr.Interface(
fn=process_json,
inputs=gr.File(label="Upload Retrieval Result (JSON)"),
outputs=gr.Textbox(label="Evaluation Results"),
title="πŸ” Automated Evaluation of MixBench",
description=(
"Please upload your model's retrieval result on MixBench (in JSON format) to automatically evaluate its performance.<br><br>"
"For each subset (e.g., <code>MSCOCO</code>, <code>Google_WIT</code>, <code>VisualNews</code>, <code>OVEN</code>), "
"we compute:<br>"
"- <strong>Recall@1</strong><br>"
"- <strong>NDCG@10</strong><br>"
"- <strong>NDCG@100</strong><br><br>"
"Expected input JSON format:<br><br>" + example_json_html +
"<br>To find valid query IDs, see the "
"<a href='https://huggingface.co/datasets/mixed-modality-search/MixBench2025/viewer/Google_WIT/mixed_corpus' target='_blank'>MixBench2025 dataset viewer</a>."
)
).launch(share=True)
if __name__ == "__main__":
main_gradio()