import pytest from pathlib import Path from src.models import EvalResult, FullEvalResult cur_fp = Path(__file__) # Ref: https://github.com/AIR-Bench/AIR-Bench/blob/4b27b8a8f2047a963805fcf6fb9d74be51ec440c/docs/available_tasks.md # 24.05 # | Task | dev | test | # | ---- | --- | ---- | # | Long-Doc | 4 | 11 | # | QA | 54 | 53 | # # 24.04 # | Task | test | # | ---- | ---- | # | Long-Doc | 15 | # | QA | 13 | NUM_QA_BENCHMARKS_24_05 = 53 NUM_DOC_BENCHMARKS_24_05 = 11 NUM_QA_BENCHMARKS_24_04 = 13 NUM_DOC_BENCHMARKS_24_04 = 15 def test_eval_result(): eval_result = EvalResult( eval_name="eval_name", retrieval_model="bge-m3", reranking_model="NoReranking", results=[ { "domain": "law", "lang": "en", "dataset": "lex_files_500K-600K", "value": 0.45723 } ], task="qa", metric="ndcg_at_3", timestamp="2024-05-14T03:09:08Z", revision="1e243f14bd295ccdea7a118fe847399d", is_anonymous=True, ) @pytest.mark.parametrize( 'file_path', [ "AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json", "AIR-Bench_24.05/bge-m3/NoReranker/results.json" ]) def test_full_eval_result_init_from_json_file(file_path): json_fp = cur_fp.parents[1] / "toydata/eval_results/" / file_path full_eval_result = FullEvalResult.init_from_json_file(json_fp) assert json_fp.parents[0].stem == full_eval_result.reranking_model assert json_fp.parents[1].stem == full_eval_result.retrieval_model assert len(full_eval_result.results) == 70 @pytest.mark.parametrize( 'file_path, task, expected_num_results', [ ("AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json", "qa", NUM_QA_BENCHMARKS_24_04), ("AIR-Bench_24.04/bge-m3/jina-reranker-v2-base-multilingual/results.json", "long-doc", NUM_DOC_BENCHMARKS_24_04), ("AIR-Bench_24.05/bge-m3/NoReranker/results.json", "qa", NUM_QA_BENCHMARKS_24_05), ("AIR-Bench_24.05/bge-m3/NoReranker/results.json", "long-doc", NUM_DOC_BENCHMARKS_24_05), ]) def test_full_eval_result_to_dict(file_path, task, expected_num_results): json_fp = cur_fp.parents[1] / "toydata/eval_results/" / file_path full_eval_result = FullEvalResult.init_from_json_file(json_fp) result_dict_list = full_eval_result.to_dict(task) assert len(result_dict_list) == 1 result = result_dict_list[0] attr_list = frozenset([ 'eval_name', 'Retrieval Method', 'Reranking Model', 'Retrieval Model LINK', 'Reranking Model LINK', 'Revision', 'Submission Date', 'Anonymous Submission']) result_cols = list(result.keys()) assert len(result_cols) == (expected_num_results + len(attr_list))