yu-val-weiss
commited on
Commit
·
0a5e4ab
1
Parent(s):
995725e
use numpy code for simplicity
Browse files- blimp.py +11 -17
- requirements.txt +2 -2
blimp.py
CHANGED
|
@@ -18,6 +18,7 @@ from typing import Optional
|
|
| 18 |
|
| 19 |
import datasets
|
| 20 |
import evaluate
|
|
|
|
| 21 |
import torch
|
| 22 |
from evaluate import logging
|
| 23 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
@@ -250,24 +251,18 @@ class Blimp(evaluate.Metric):
|
|
| 250 |
sent_type="bad",
|
| 251 |
)
|
| 252 |
|
| 253 |
-
#
|
| 254 |
-
|
| 255 |
-
accuracy = correct / len(good_probs)
|
| 256 |
-
results[category] = accuracy
|
| 257 |
|
|
|
|
| 258 |
phenom_results[phenom].append(accuracy)
|
| 259 |
|
| 260 |
-
phenom_term_averages = {
|
| 261 |
-
term: sum(accuracies) / len(accuracies)
|
| 262 |
-
for term, accuracies in phenom_results.items()
|
| 263 |
-
}
|
| 264 |
-
# Calculate overall accuracy
|
| 265 |
-
overall_accuracy = sum(results.values()) / len(results)
|
| 266 |
-
|
| 267 |
return {
|
| 268 |
"by_uid": results,
|
| 269 |
-
"accuracy":
|
| 270 |
-
"by_phenomenon":
|
|
|
|
|
|
|
| 271 |
}
|
| 272 |
|
| 273 |
|
|
@@ -307,12 +302,11 @@ def get_batch_probabilities(
|
|
| 307 |
|
| 308 |
if batch_size > 1:
|
| 309 |
# mask padding tokens
|
| 310 |
-
|
| 311 |
-
token_log_probs *= mask
|
| 312 |
|
| 313 |
# sum log probabilities
|
| 314 |
sequence_log_probs = token_log_probs.sum(dim=1)
|
| 315 |
|
| 316 |
-
probs.
|
| 317 |
|
| 318 |
-
return probs
|
|
|
|
| 18 |
|
| 19 |
import datasets
|
| 20 |
import evaluate
|
| 21 |
+
import numpy as np
|
| 22 |
import torch
|
| 23 |
from evaluate import logging
|
| 24 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
| 251 |
sent_type="bad",
|
| 252 |
)
|
| 253 |
|
| 254 |
+
# compute accuracy (mean of instances where good prob > bad prob)
|
| 255 |
+
accuracy = np.mean(good_probs > bad_probs)
|
|
|
|
|
|
|
| 256 |
|
| 257 |
+
results[category] = accuracy
|
| 258 |
phenom_results[phenom].append(accuracy)
|
| 259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
return {
|
| 261 |
"by_uid": results,
|
| 262 |
+
"accuracy": np.mean(list(results.values())),
|
| 263 |
+
"by_phenomenon": {
|
| 264 |
+
term: np.mean(acc) for term, acc in phenom_results.items()
|
| 265 |
+
},
|
| 266 |
}
|
| 267 |
|
| 268 |
|
|
|
|
| 302 |
|
| 303 |
if batch_size > 1:
|
| 304 |
# mask padding tokens
|
| 305 |
+
token_log_probs.masked_fill_(labels == tokenizer.pad_token_id, 0.0)
|
|
|
|
| 306 |
|
| 307 |
# sum log probabilities
|
| 308 |
sequence_log_probs = token_log_probs.sum(dim=1)
|
| 309 |
|
| 310 |
+
probs.append(sequence_log_probs.cpu().numpy())
|
| 311 |
|
| 312 |
+
return np.concatenate(probs)
|
requirements.txt
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
git+https://github.com/huggingface/evaluate@5aa3982a9a8c86e506860e381d428a64b0cce73b
|
| 2 |
torch
|
| 3 |
-
|
| 4 |
-
|
|
|
|
| 1 |
git+https://github.com/huggingface/evaluate@5aa3982a9a8c86e506860e381d428a64b0cce73b
|
| 2 |
torch
|
| 3 |
+
transformers
|
| 4 |
+
numpy
|