Spaces:
Sleeping
Sleeping
| from typing import Dict | |
| import torch | |
| from torchmetrics import functional as FM | |
| def classification_metrics( | |
| preds: torch.Tensor, | |
| target: torch.Tensor, | |
| num_classes: int, | |
| average: str = 'macro', | |
| task: str = 'multiclass') -> Dict[str, torch.Tensor]: | |
| """ | |
| get_classification_metrics | |
| Return some metrics evaluation the classification task | |
| Parameters | |
| ---------- | |
| preds : torch.Tensor | |
| logits, probs | |
| target : torch.Tensor | |
| targets label | |
| Returns | |
| ------- | |
| Dict[str, torch.Tensor] | |
| _description_ | |
| """ | |
| f1 = FM.f1_score(preds=preds, | |
| target=target, | |
| num_classes=num_classes, | |
| task=task, | |
| average=average) | |
| recall = FM.recall(preds=preds, | |
| target=target, | |
| num_classes=num_classes, | |
| task=task, | |
| average=average) | |
| precision = FM.precision(preds=preds, | |
| target=target, | |
| num_classes=num_classes, | |
| task=task, | |
| average=average) | |
| return dict(f1=f1, precision=precision, recall=recall) | |