import gradio as gr import json import os import numpy as np from cryptography.fernet import Fernet from collections import defaultdict from sklearn.metrics import ndcg_score def load_and_decrypt_qrel(secret_key): try: with open("data/answer.enc", "rb") as enc_file: encrypted_data = enc_file.read() cipher = Fernet(secret_key.encode()) decrypted_data = cipher.decrypt(encrypted_data).decode("utf-8") raw_data = json.loads(decrypted_data) # Convert to: dataset -> query_id -> {corpus_id: score} qrel_dict = defaultdict(lambda: defaultdict(dict)) for dataset, records in raw_data.items(): for item in records: qid, cid, score = item["query_id"], item["corpus_id"], item["score"] qrel_dict[dataset][qid][cid] = score return qrel_dict except Exception as e: raise ValueError(f"❌ Failed to decrypt answer file: {str(e)}") def recall_at_k(corpus_top_100_list, relevant_ids, k=1): return int(any(item in relevant_ids for item in corpus_top_100_list[:k])) def ndcg_at_k(corpus_top_100_list, rel_dict, k): all_items = list(dict.fromkeys(corpus_top_100_list + list(rel_dict.keys()))) y_true = [rel_dict.get(item, 0) for item in all_items] y_score = [len(all_items) - i for i in range(len(all_items))] return ndcg_score([y_true], [y_score], k=k) def evaluate(pred_data, qrel_dict): results = {} for dataset, queries in pred_data.items(): if dataset not in qrel_dict: continue recall_1, ndcg_10, ndcg_100 = [], [], [] for item in queries: qid = item["query_id"] corpus_top_100_list = item["corpus_top_100_list"].split(",") corpus_top_100_list = [x.strip() for x in corpus_top_100_list if x.strip()] rel_dict = qrel_dict[dataset].get(qid, {}) relevant_ids = [cid for cid, score in rel_dict.items() if score > 0] recall_1.append(recall_at_k(corpus_top_100_list, relevant_ids, 1)) ndcg_10.append(ndcg_at_k(corpus_top_100_list, rel_dict, 10)) ndcg_100.append(ndcg_at_k(corpus_top_100_list, rel_dict, 100)) results[dataset] = { "Recall@1": round(np.mean(recall_1) * 100, 2), "NDCG@10": round(np.mean(ndcg_10) * 100, 2), "NDCG@100": round(np.mean(ndcg_100) * 100, 2), } return results def process_json(file): try: pred_data = json.load(open(file)) except Exception as e: return f"❌ Invalid JSON format: {str(e)}" secret_key = os.getenv("SECRET_KEY") if not secret_key: return "❌ SECRET_KEY environment variable not set. Please configure it in your Hugging Face Space." try: qrel_dict = load_and_decrypt_qrel(secret_key) except Exception as e: return str(e) try: metrics = evaluate(pred_data, qrel_dict) return json.dumps(metrics, indent=2) except Exception as e: return f"❌ Error during evaluation: {str(e)}" def main_gradio(): example_json_html = ( '
{
' '  "Google_WIT": [
' '    {"query_id": "1", "corpus_top_100_list": "5, 2, 8, ..."},
' '    {"query_id": "2", "corpus_top_100_list": "90, 13, 3, ..."}
' '  ],
' '  "MSCOCO": [
' '    {"query_id": "1", "corpus_top_100_list": "122, 35, 22, ..."}
' '  ],
' '  "OVEN": [
' '    {"query_id": "1", "corpus_top_100_list": "11, 15, 22, ..."}
' '  ],
' '  "VisualNews": [
' '    {"query_id": "1", "corpus_top_100_list": "101, 35, 77, ..."}
' '  ]
' '}
' ) gr.Interface( fn=process_json, inputs=gr.File(label="Upload Retrieval Result (JSON)"), outputs=gr.Textbox(label="Evaluation Results"), title="🔍 Automated Evaluation of MixBench", description=( "Please upload your model's retrieval result on MixBench (in JSON format) to automatically evaluate its performance.

" "For each subset (e.g., MSCOCO, Google_WIT, VisualNews, OVEN), " "we compute:
" "- Recall@1
" "- NDCG@10
" "- NDCG@100

" "Expected input JSON format:

" + example_json_html + "
To find valid query IDs, see the " "MixBench2025 dataset viewer." ) ).launch(share=True) if __name__ == "__main__": main_gradio()