Check edge cases
Browse files- L3Score.py +5 -4
L3Score.py
CHANGED
|
@@ -109,10 +109,9 @@ class L3Score(evaluate.Metric):
|
|
| 109 |
"""Optional: download external resources useful to compute the scores"""
|
| 110 |
pass
|
| 111 |
|
| 112 |
-
def _verify_input(self,
|
| 113 |
"""Verify the input parameters"""
|
| 114 |
|
| 115 |
-
print(provider)
|
| 116 |
if provider not in PROVIDER_WITH_TOP_LOGPROBS:
|
| 117 |
raise ValueError(
|
| 118 |
"Provider must offer top_logprobs to use this metric, pick from {}".format(
|
|
@@ -120,6 +119,8 @@ class L3Score(evaluate.Metric):
|
|
| 120 |
)
|
| 121 |
)
|
| 122 |
|
|
|
|
|
|
|
| 123 |
|
| 124 |
def _get_llm(self, model, api_key):
|
| 125 |
"""Get the LLM"""
|
|
@@ -137,10 +138,10 @@ class L3Score(evaluate.Metric):
|
|
| 137 |
model="gpt-4o-mini",
|
| 138 |
):
|
| 139 |
"""Returns the scores"""
|
|
|
|
| 140 |
|
| 141 |
-
print("Inside compute")
|
| 142 |
# Check whether llm can be initialized
|
| 143 |
-
self._verify_input(
|
| 144 |
|
| 145 |
# Initialize the LLM
|
| 146 |
llm = self._get_llm(model, api_key)
|
|
|
|
| 109 |
"""Optional: download external resources useful to compute the scores"""
|
| 110 |
pass
|
| 111 |
|
| 112 |
+
def _verify_input(self, questions, predictions, references, provider):
|
| 113 |
"""Verify the input parameters"""
|
| 114 |
|
|
|
|
| 115 |
if provider not in PROVIDER_WITH_TOP_LOGPROBS:
|
| 116 |
raise ValueError(
|
| 117 |
"Provider must offer top_logprobs to use this metric, pick from {}".format(
|
|
|
|
| 119 |
)
|
| 120 |
)
|
| 121 |
|
| 122 |
+
assert len(questions) == len(predictions) == len(references), "Questions, predictions and references must have the same length"
|
| 123 |
+
|
| 124 |
|
| 125 |
def _get_llm(self, model, api_key):
|
| 126 |
"""Get the LLM"""
|
|
|
|
| 138 |
model="gpt-4o-mini",
|
| 139 |
):
|
| 140 |
"""Returns the scores"""
|
| 141 |
+
print(questions,predictions,references)
|
| 142 |
|
|
|
|
| 143 |
# Check whether llm can be initialized
|
| 144 |
+
self._verify_input(questions, predictions, references, provider)
|
| 145 |
|
| 146 |
# Initialize the LLM
|
| 147 |
llm = self._get_llm(model, api_key)
|