File size: 3,571 Bytes
4b8361a
 
 
 
17a2a7d
 
 
 
 
4b8361a
557fb53
4b8361a
0030bc6
 
 
 
557fb53
 
4b8361a
 
 
 
 
 
 
 
557fb53
 
4b8361a
 
 
557fb53
 
 
4b8361a
 
 
0030bc6
 
 
557fb53
 
 
4b8361a
1c22425
17a2a7d
 
1c22425
0030bc6
17a2a7d
 
 
557fb53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0030bc6
 
 
 
 
 
557fb53
0030bc6
 
557fb53
0030bc6
 
 
 
557fb53
 
 
 
 
 
 
 
 
 
 
 
e748bc2
17a2a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch.nn as nn
import torch
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from functools import cache
import pandas as pd
import matplotlib.pyplot as plt
import io
from PIL import Image


class LabelWeightedBCELoss(nn.Module):
    """
    Binary Cross Entropy loss that assumes each float in the final dimension is a binary probability distribution.
    Allows for the weighing of each probability distribution wrt loss.
    """

    def __init__(self, label_weights: torch.Tensor, reduction="mean"):
        super().__init__()
        self.label_weights = label_weights

        match reduction:
            case "mean":
                self.reduction = torch.mean
            case "sum":
                self.reduction = torch.sum

    def _log(self, x: torch.Tensor) -> torch.Tensor:
        return torch.clamp_min(torch.log(x), -100)

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        losses = -self.label_weights * (
            target * self._log(input) + (1 - target) * self._log(1 - input)
        )
        return self.reduction(losses)


# TODO: Code a onehot


def calculate_metrics(
    pred, target, threshold=0.5, prefix="", multi_label=True
) -> dict[str, torch.Tensor]:
    target = target.detach().cpu().numpy()
    pred = pred.detach().cpu()
    if not multi_label:
        pred = nn.functional.softmax(pred, dim=1)
    pred = pred.numpy()
    params = {
        "y_true": np.array(target > 0.0, dtype=float)
        if multi_label
        else target.argmax(1),
        "y_pred": np.array(pred > threshold, dtype=float)
        if multi_label
        else pred.argmax(1),
        "zero_division": 0,
        "average": "macro",
    }
    metrics = {
        "precision": precision_score(**params),
        "recall": recall_score(**params),
        "f1": f1_score(**params),
        "accuracy": accuracy_score(y_true=params["y_true"], y_pred=params["y_pred"]),
    }
    return {
        prefix + k: torch.tensor(v, dtype=torch.float32) for k, v in metrics.items()
    }


class EarlyStopping:
    def __init__(self, patience=0):
        self.patience = patience
        self.last_measure = np.inf
        self.consecutive_increase = 0

    def step(self, val) -> bool:
        if self.last_measure <= val:
            self.consecutive_increase += 1
        else:
            self.consecutive_increase = 0
        self.last_measure = val

        return self.patience < self.consecutive_increase


def get_id_label_mapping(labels: list[str]) -> tuple[dict, dict]:
    id2label = {str(i): label for i, label in enumerate(labels)}
    label2id = {label: str(i) for i, label in enumerate(labels)}

    return id2label, label2id


def compute_hf_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return accuracy_score(y_true=eval_pred.label_ids, y_pred=predictions)


@cache
def get_dance_mapping(mapping_file: str) -> dict[str, str]:
    mapping_df = pd.read_csv(mapping_file)
    return {row["id"]: row["name"] for _, row in mapping_df.iterrows()}


def plot_to_image(figure) -> np.ndarray:
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
    returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    # Closing the figure prevents it from being displayed directly inside
    # the notebook.
    plt.close(figure)
    buf.seek(0)
    return np.array(Image.open(buf))