Spaces:
Runtime error
Runtime error
Commit
·
17a2a7d
1
Parent(s):
ad4c4e2
added evaulations
Browse files- models/training_environment.py +55 -9
- models/utils.py +29 -2
models/training_environment.py
CHANGED
|
@@ -1,10 +1,16 @@
|
|
| 1 |
import importlib
|
| 2 |
-
from models.utils import calculate_metrics
|
| 3 |
-
|
| 4 |
from abc import ABC, abstractmethod
|
| 5 |
import pytorch_lightning as pl
|
|
|
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class TrainingEnvironment(pl.LightningModule):
|
|
@@ -27,8 +33,8 @@ class TrainingEnvironment(pl.LightningModule):
|
|
| 27 |
config["training_environment"].get("loggers", {})
|
| 28 |
)
|
| 29 |
self.config = config
|
| 30 |
-
self.has_multi_label_predictions = (
|
| 31 |
-
|
| 32 |
)
|
| 33 |
self.save_hyperparameters(
|
| 34 |
{
|
|
@@ -44,6 +50,8 @@ class TrainingEnvironment(pl.LightningModule):
|
|
| 44 |
) -> torch.Tensor:
|
| 45 |
features, labels = batch
|
| 46 |
outputs = self.model(features)
|
|
|
|
|
|
|
| 47 |
loss = self.criterion(outputs, labels)
|
| 48 |
metrics = calculate_metrics(
|
| 49 |
outputs,
|
|
@@ -62,6 +70,8 @@ class TrainingEnvironment(pl.LightningModule):
|
|
| 62 |
):
|
| 63 |
x, y = batch
|
| 64 |
preds = self.model(x)
|
|
|
|
|
|
|
| 65 |
metrics = calculate_metrics(
|
| 66 |
preds, y, prefix="val/", multi_label=self.has_multi_label_predictions
|
| 67 |
)
|
|
@@ -71,12 +81,48 @@ class TrainingEnvironment(pl.LightningModule):
|
|
| 71 |
def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
|
| 72 |
x, y = batch
|
| 73 |
preds = self.model(x)
|
| 74 |
-
self.
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
def configure_optimizers(self):
|
| 82 |
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
|
|
|
| 1 |
import importlib
|
| 2 |
+
from models.utils import calculate_metrics, plot_to_image, get_dance_mapping
|
| 3 |
+
import numpy as np
|
| 4 |
from abc import ABC, abstractmethod
|
| 5 |
import pytorch_lightning as pl
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
import torch
|
| 8 |
import torch.nn as nn
|
| 9 |
+
from sklearn.metrics import (
|
| 10 |
+
roc_auc_score,
|
| 11 |
+
confusion_matrix,
|
| 12 |
+
ConfusionMatrixDisplay,
|
| 13 |
+
)
|
| 14 |
|
| 15 |
|
| 16 |
class TrainingEnvironment(pl.LightningModule):
|
|
|
|
| 33 |
config["training_environment"].get("loggers", {})
|
| 34 |
)
|
| 35 |
self.config = config
|
| 36 |
+
self.has_multi_label_predictions = not (
|
| 37 |
+
type(criterion).__name__ == "CrossEntropyLoss"
|
| 38 |
)
|
| 39 |
self.save_hyperparameters(
|
| 40 |
{
|
|
|
|
| 50 |
) -> torch.Tensor:
|
| 51 |
features, labels = batch
|
| 52 |
outputs = self.model(features)
|
| 53 |
+
if self.has_multi_label_predictions:
|
| 54 |
+
outputs = nn.functional.sigmoid(outputs)
|
| 55 |
loss = self.criterion(outputs, labels)
|
| 56 |
metrics = calculate_metrics(
|
| 57 |
outputs,
|
|
|
|
| 70 |
):
|
| 71 |
x, y = batch
|
| 72 |
preds = self.model(x)
|
| 73 |
+
if self.has_multi_label_predictions:
|
| 74 |
+
preds = nn.functional.sigmoid(preds)
|
| 75 |
metrics = calculate_metrics(
|
| 76 |
preds, y, prefix="val/", multi_label=self.has_multi_label_predictions
|
| 77 |
)
|
|
|
|
| 81 |
def test_step(self, batch: tuple[torch.Tensor, torch.TensorType], batch_index: int):
|
| 82 |
x, y = batch
|
| 83 |
preds = self.model(x)
|
| 84 |
+
if self.has_multi_label_predictions:
|
| 85 |
+
preds = nn.functional.sigmoid(preds)
|
| 86 |
+
metrics = calculate_metrics(
|
| 87 |
+
preds, y, prefix="test/", multi_label=self.has_multi_label_predictions
|
| 88 |
+
)
|
| 89 |
+
if not self.has_multi_label_predictions:
|
| 90 |
+
preds = nn.functional.softmax(preds, dim=1)
|
| 91 |
+
y = y.detach().cpu().numpy()
|
| 92 |
+
preds = preds.detach().cpu().numpy()
|
| 93 |
+
# ROC-auc score
|
| 94 |
+
try:
|
| 95 |
+
metrics["test/roc_auc_score"] = torch.tensor(
|
| 96 |
+
roc_auc_score(y, preds), dtype=torch.float32
|
| 97 |
+
)
|
| 98 |
+
except ValueError:
|
| 99 |
+
# If there is only one class, roc_auc_score will throw an error
|
| 100 |
+
pass
|
| 101 |
+
|
| 102 |
+
pass
|
| 103 |
+
self.log_dict(metrics, prog_bar=True)
|
| 104 |
+
# Create confusion matrix
|
| 105 |
+
|
| 106 |
+
preds = preds.argmax(axis=1)
|
| 107 |
+
y = y.argmax(axis=1)
|
| 108 |
+
cm = confusion_matrix(
|
| 109 |
+
preds, y, normalize="all", labels=np.arange(len(self.config["dance_ids"]))
|
| 110 |
)
|
| 111 |
+
if hasattr(self, "test_cm"):
|
| 112 |
+
self.test_cm += cm
|
| 113 |
+
else:
|
| 114 |
+
self.test_cm = cm
|
| 115 |
+
|
| 116 |
+
def on_test_end(self):
|
| 117 |
+
dance_ids = sorted(self.config["dance_ids"])
|
| 118 |
+
np.fill_diagonal(self.test_cm, 0)
|
| 119 |
+
cm = self.test_cm / self.test_cm.max()
|
| 120 |
+
ConfusionMatrixDisplay(cm, display_labels=dance_ids).plot()
|
| 121 |
+
image = plot_to_image(plt.gcf())
|
| 122 |
+
image = torch.tensor(image, dtype=torch.uint8)
|
| 123 |
+
image = image.permute(2, 0, 1)
|
| 124 |
+
self.logger.experiment.add_image("test/confusion_matrix", image, 0)
|
| 125 |
+
delattr(self, "test_cm")
|
| 126 |
|
| 127 |
def configure_optimizers(self):
|
| 128 |
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
|
models/utils.py
CHANGED
|
@@ -2,6 +2,11 @@ import torch.nn as nn
|
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
| 4 |
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class LabelWeightedBCELoss(nn.Module):
|
|
@@ -38,10 +43,13 @@ def calculate_metrics(
|
|
| 38 |
) -> dict[str, torch.Tensor]:
|
| 39 |
target = target.detach().cpu().numpy()
|
| 40 |
pred = pred.detach().cpu()
|
| 41 |
-
|
|
|
|
| 42 |
pred = pred.numpy()
|
| 43 |
params = {
|
| 44 |
-
"y_true": target
|
|
|
|
|
|
|
| 45 |
"y_pred": np.array(pred > threshold, dtype=float)
|
| 46 |
if multi_label
|
| 47 |
else pred.argmax(1),
|
|
@@ -85,3 +93,22 @@ def get_id_label_mapping(labels: list[str]) -> tuple[dict, dict]:
|
|
| 85 |
def compute_hf_metrics(eval_pred):
|
| 86 |
predictions = np.argmax(eval_pred.predictions, axis=1)
|
| 87 |
return accuracy_score(y_true=eval_pred.label_ids, y_pred=predictions)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
| 4 |
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
| 5 |
+
from functools import cache
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import io
|
| 9 |
+
from PIL import Image
|
| 10 |
|
| 11 |
|
| 12 |
class LabelWeightedBCELoss(nn.Module):
|
|
|
|
| 43 |
) -> dict[str, torch.Tensor]:
|
| 44 |
target = target.detach().cpu().numpy()
|
| 45 |
pred = pred.detach().cpu()
|
| 46 |
+
if not multi_label:
|
| 47 |
+
pred = nn.functional.softmax(pred, dim=1)
|
| 48 |
pred = pred.numpy()
|
| 49 |
params = {
|
| 50 |
+
"y_true": np.array(target > 0.0, dtype=float)
|
| 51 |
+
if multi_label
|
| 52 |
+
else target.argmax(1),
|
| 53 |
"y_pred": np.array(pred > threshold, dtype=float)
|
| 54 |
if multi_label
|
| 55 |
else pred.argmax(1),
|
|
|
|
| 93 |
def compute_hf_metrics(eval_pred):
|
| 94 |
predictions = np.argmax(eval_pred.predictions, axis=1)
|
| 95 |
return accuracy_score(y_true=eval_pred.label_ids, y_pred=predictions)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@cache
|
| 99 |
+
def get_dance_mapping(mapping_file: str) -> dict[str, str]:
|
| 100 |
+
mapping_df = pd.read_csv(mapping_file)
|
| 101 |
+
return {row["id"]: row["name"] for _, row in mapping_df.iterrows()}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def plot_to_image(figure) -> np.ndarray:
|
| 105 |
+
"""Converts the matplotlib plot specified by 'figure' to a PNG image and
|
| 106 |
+
returns it. The supplied figure is closed and inaccessible after this call."""
|
| 107 |
+
# Save the plot to a PNG in memory.
|
| 108 |
+
buf = io.BytesIO()
|
| 109 |
+
plt.savefig(buf, format="png")
|
| 110 |
+
# Closing the figure prevents it from being displayed directly inside
|
| 111 |
+
# the notebook.
|
| 112 |
+
plt.close(figure)
|
| 113 |
+
buf.seek(0)
|
| 114 |
+
return np.array(Image.open(buf))
|