|
import torch |
|
from torchmetrics import Metric |
|
|
|
class MyAccuracy(Metric): |
|
""" |
|
Accuracy metric costomized for handling sequences with padding. |
|
|
|
Methods: |
|
update(self, logits, labels, num_labels): Update the accuracy based on |
|
model predictions and ground truth labels. |
|
|
|
compute(self): Compute the accuracy. |
|
|
|
Attributes: |
|
total (torch.Tensor): Total number of non-padding elements. |
|
correct (torch.Tensor): Number of correctly predicted non-padding elements. |
|
""" |
|
def __init__(self): |
|
super().__init__() |
|
self.add_state('total', default=torch.tensor(0), dist_reduce_fx='sum') |
|
self.add_state('correct', default=torch.tensor(0), dist_reduce_fx='sum') |
|
|
|
def update(self, logits: torch.Tensor, labels: torch.Tensor, num_labels: int) -> None: |
|
""" |
|
Args: |
|
logits (torch.Tensor): Model predictions. |
|
labels (torch.Tensor): Ground truth labels. |
|
num_labels (int): Number of unique labels. |
|
""" |
|
flattened_targets = labels.view(-1) |
|
active_logits = logits.view(-1, num_labels) |
|
flattened_predictions = torch.argmax(active_logits, axis=1) |
|
|
|
|
|
active_accuracy = labels.view(-1) != -100 |
|
ac_labels = torch.masked_select(flattened_targets, active_accuracy) |
|
predictions = torch.masked_select(flattened_predictions, active_accuracy) |
|
|
|
self.correct += torch.sum(ac_labels == predictions) |
|
self.total += torch.numel(ac_labels) |
|
|
|
def compute(self) -> torch.Tensor: |
|
""" |
|
Calculate the accuracy. |
|
""" |
|
return self.correct.float() / self.total.float() |