Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| from pathlib import Path | |
| import pytest | |
| from src.models import EvalResult, FullEvalResult | |
| cur_fp = Path(__file__) | |
| # Ref: https://github.com/AIR-Bench/AIR-Bench/blob/4b27b8a8f2047a963805fcf6fb9d74be51ec440c/docs/available_tasks.md | |
| # 24.05 | |
| # | Task | dev | test | | |
| # | ---- | --- | ---- | | |
| # | Long-Doc | 4 | 11 | | |
| # | QA | 54 | 53 | | |
| # | |
| # 24.04 | |
| # | Task | test | | |
| # | ---- | ---- | | |
| # | Long-Doc | 15 | | |
| # | QA | 13 | | |
| NUM_QA_BENCHMARKS_24_05 = 53 | |
| NUM_DOC_BENCHMARKS_24_05 = 11 | |
| NUM_QA_BENCHMARKS_24_04 = 13 | |
| NUM_DOC_BENCHMARKS_24_04 = 15 | |
| def test_eval_result(): | |
| EvalResult( | |
| eval_name="eval_name", | |
| retrieval_model="bge-m3", | |
| reranking_model="NoReranking", | |
| results=[{"domain": "law", "lang": "en", "dataset": "lex_files_500K-600K", "value": 0.45723}], | |
| task="qa", | |
| metric="ndcg_at_3", | |
| timestamp="2024-05-14T03:09:08Z", | |
| revision="1e243f14bd295ccdea7a118fe847399d", | |
| is_anonymous=True, | |
| ) | |
| def test_full_eval_result_init_from_json_file(file_path): | |
| json_fp = cur_fp.parents[1] / "toydata/eval_results/" / file_path | |
| full_eval_result = FullEvalResult.init_from_json_file(json_fp) | |
| assert json_fp.parents[0].stem == full_eval_result.reranking_model | |
| assert json_fp.parents[1].stem == full_eval_result.retrieval_model | |
| assert len(full_eval_result.results) == 70 | |
| def test_full_eval_result_to_dict(file_path, task, expected_num_results): | |
| json_fp = cur_fp.parents[1] / "toydata/eval_results/" / file_path | |
| full_eval_result = FullEvalResult.init_from_json_file(json_fp) | |
| result_dict_list = full_eval_result.to_dict(task) | |
| assert len(result_dict_list) == 1 | |
| result = result_dict_list[0] | |
| attr_list = frozenset( | |
| [ | |
| "eval_name", | |
| "Retrieval Method", | |
| "Reranking Model", | |
| "Retrieval Model LINK", | |
| "Reranking Model LINK", | |
| "Revision", | |
| "Submission Date", | |
| "Anonymous Submission", | |
| ] | |
| ) | |
| result_cols = list(result.keys()) | |
| assert len(result_cols) == (expected_num_results + len(attr_list)) | |