RAG_Eval / tests /test_metrics.py
Rom89823974978's picture
Resolved tests issues
79bdbbe
import pytest
import numpy as np
from evaluation.metrics import (
precision_at_k,
recall_at_k,
mean_reciprocal_rank,
average_precision,
rag_score,
bleu,
rouge_l,
bert_score,
qags,
fact_score,
ragas_f,
)
def test_retrieval_metrics_simple():
retrieved = ["d1", "d2", "d3", "d4"]
relevant = {"d2", "d4", "d5"}
assert precision_at_k(retrieved, relevant, 2) == pytest.approx(0.5, rel=1e-6)
assert precision_at_k(retrieved, relevant, 3) == pytest.approx(1 / 3, rel=1e-6)
assert recall_at_k(retrieved, relevant, 2) == pytest.approx(1 / 3, rel=1e-6)
assert recall_at_k(retrieved, relevant, 4) == pytest.approx(2 / 3, rel=1e-6)
assert mean_reciprocal_rank(retrieved, relevant) == pytest.approx(0.5, rel=1e-6)
assert average_precision(retrieved, relevant) == pytest.approx(1 / 3, rel=1e-6)
def test_rag_score_harmonic_mean():
scores = {"retrieval_f1": 0.8, "generation_bleu": 0.6}
val = rag_score(scores)
target = 2.0 / (1 / 0.8 + 1 / 0.6)
assert val == pytest.approx(target, rel=1e-6)
scores_zero = {"retrieval_f1": 0.0, "generation_bleu": 0.6}
assert rag_score(scores_zero) == pytest.approx(0.0, rel=1e-6)
@pytest.mark.parametrize(
"preds, refs, expected_min",
[
(["Hello world"], ["Hello world"], 0.0),
(["Some text"], ["Different text"], 0.0),
],
)
def test_generation_metrics_fallback(preds, refs, expected_min):
b = bleu(preds, refs)
r = rouge_l(preds, refs)
bs = bert_score(preds, refs)
assert isinstance(b, float) and b == pytest.approx(expected_min, rel=1e-6)
assert isinstance(r, float) and r == pytest.approx(expected_min, rel=1e-6)
assert isinstance(bs, float) and bs == pytest.approx(expected_min, rel=1e-6)
@pytest.mark.parametrize(
"preds, refs, ctxs, expected",
[
(["A"], ["A"], ["ctx"], 0.0),
(["B"], ["C"], [""], 0.0),
],
)
def test_qags_factscore_ragas_f_fallback(preds, refs, ctxs, expected):
assert qags(preds, refs) == pytest.approx(expected, rel=1e-6)
assert fact_score(preds, refs) == pytest.approx(expected, rel=1e-6)
assert ragas_f(preds, refs, ctxs) == pytest.approx(expected, rel=1e-6)