Spaces:
Runtime error
Runtime error
File size: 3,571 Bytes
4b8361a 17a2a7d 4b8361a 557fb53 4b8361a 0030bc6 557fb53 4b8361a 557fb53 4b8361a 557fb53 4b8361a 0030bc6 557fb53 4b8361a 1c22425 17a2a7d 1c22425 0030bc6 17a2a7d 557fb53 0030bc6 557fb53 0030bc6 557fb53 0030bc6 557fb53 e748bc2 17a2a7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import torch.nn as nn
import torch
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from functools import cache
import pandas as pd
import matplotlib.pyplot as plt
import io
from PIL import Image
class LabelWeightedBCELoss(nn.Module):
"""
Binary Cross Entropy loss that assumes each float in the final dimension is a binary probability distribution.
Allows for the weighing of each probability distribution wrt loss.
"""
def __init__(self, label_weights: torch.Tensor, reduction="mean"):
super().__init__()
self.label_weights = label_weights
match reduction:
case "mean":
self.reduction = torch.mean
case "sum":
self.reduction = torch.sum
def _log(self, x: torch.Tensor) -> torch.Tensor:
return torch.clamp_min(torch.log(x), -100)
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
losses = -self.label_weights * (
target * self._log(input) + (1 - target) * self._log(1 - input)
)
return self.reduction(losses)
# TODO: Code a onehot
def calculate_metrics(
pred, target, threshold=0.5, prefix="", multi_label=True
) -> dict[str, torch.Tensor]:
target = target.detach().cpu().numpy()
pred = pred.detach().cpu()
if not multi_label:
pred = nn.functional.softmax(pred, dim=1)
pred = pred.numpy()
params = {
"y_true": np.array(target > 0.0, dtype=float)
if multi_label
else target.argmax(1),
"y_pred": np.array(pred > threshold, dtype=float)
if multi_label
else pred.argmax(1),
"zero_division": 0,
"average": "macro",
}
metrics = {
"precision": precision_score(**params),
"recall": recall_score(**params),
"f1": f1_score(**params),
"accuracy": accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
}
return {
prefix + k: torch.tensor(v, dtype=torch.float32) for k, v in metrics.items()
}
class EarlyStopping:
def __init__(self, patience=0):
self.patience = patience
self.last_measure = np.inf
self.consecutive_increase = 0
def step(self, val) -> bool:
if self.last_measure <= val:
self.consecutive_increase += 1
else:
self.consecutive_increase = 0
self.last_measure = val
return self.patience < self.consecutive_increase
def get_id_label_mapping(labels: list[str]) -> tuple[dict, dict]:
id2label = {str(i): label for i, label in enumerate(labels)}
label2id = {label: str(i) for i, label in enumerate(labels)}
return id2label, label2id
def compute_hf_metrics(eval_pred):
predictions = np.argmax(eval_pred.predictions, axis=1)
return accuracy_score(y_true=eval_pred.label_ids, y_pred=predictions)
@cache
def get_dance_mapping(mapping_file: str) -> dict[str, str]:
mapping_df = pd.read_csv(mapping_file)
return {row["id"]: row["name"] for _, row in mapping_df.iterrows()}
def plot_to_image(figure) -> np.ndarray:
"""Converts the matplotlib plot specified by 'figure' to a PNG image and
returns it. The supplied figure is closed and inaccessible after this call."""
# Save the plot to a PNG in memory.
buf = io.BytesIO()
plt.savefig(buf, format="png")
# Closing the figure prevents it from being displayed directly inside
# the notebook.
plt.close(figure)
buf.seek(0)
return np.array(Image.open(buf))
|