|
import torch as t |
|
|
|
|
|
def corn_loss(logits, y_train, num_classes): |
|
"""Computes the CORN loss described in our forthcoming |
|
'Deep Neural Networks for Rank Consistent Ordinal |
|
Regression based on Conditional Probabilities' |
|
manuscript. |
|
Parameters |
|
---------- |
|
logits : torch.tensor, shape=(num_examples, num_classes-1) |
|
Outputs of the CORN layer. |
|
y_train : torch.tensor, shape=(num_examples) |
|
Torch tensor containing the class labels. |
|
num_classes : int |
|
Number of unique class labels (class labels should start at 0). |
|
Returns |
|
---------- |
|
loss : torch.tensor |
|
A torch.tensor containing a single loss value. |
|
Examples |
|
---------- |
|
>>> import torch |
|
>>> from coral_pytorch.losses import corn_loss |
|
>>> # Consider 8 training examples |
|
>>> _ = torch.manual_seed(123) |
|
>>> X_train = torch.rand(8, 99) |
|
>>> y_train = torch.tensor([0, 1, 2, 2, 2, 3, 4, 4]) |
|
>>> NUM_CLASSES = 5 |
|
>>> # |
|
>>> # |
|
>>> # def __init__(self): |
|
>>> corn_net = torch.nn.Linear(99, NUM_CLASSES-1) |
|
>>> # |
|
>>> # |
|
>>> # def forward(self, X_train): |
|
>>> logits = corn_net(X_train) |
|
>>> logits.shape |
|
torch.Size([8, 4]) |
|
>>> corn_loss(logits, y_train, NUM_CLASSES) |
|
tensor(0.7127, grad_fn=<DivBackward0>) |
|
https://github.com/Raschka-research-group/coral-pytorch/blob/c6ab93afd555a6eac708c95ae1feafa15f91c5aa/coral_pytorch/losses.py |
|
""" |
|
sets = [] |
|
for i in range(num_classes - 1): |
|
label_mask = y_train > i - 1 |
|
label_tensor = (y_train[label_mask] > i).to(t.int64) |
|
sets.append((label_mask, label_tensor)) |
|
|
|
num_examples = 0 |
|
losses = 0.0 |
|
for task_index, s in enumerate(sets): |
|
train_examples = s[0] |
|
train_labels = s[1] |
|
|
|
if len(train_labels) < 1: |
|
continue |
|
|
|
num_examples += len(train_labels) |
|
pred = logits[train_examples, task_index] |
|
|
|
loss = -t.sum( |
|
t.nn.functional.logsigmoid(pred) * train_labels |
|
+ (t.nn.functional.logsigmoid(pred) - pred) * (1 - train_labels) |
|
) |
|
losses += loss |
|
|
|
return losses / num_examples |
|
|
|
|
|
def corn_label_from_logits(logits): |
|
""" |
|
Returns the predicted rank label from logits for a |
|
network trained via the CORN loss. |
|
Parameters |
|
---------- |
|
logits : torch.tensor, shape=(n_examples, n_classes) |
|
Torch tensor consisting of logits returned by the |
|
neural net. |
|
Returns |
|
---------- |
|
labels : torch.tensor, shape=(n_examples) |
|
Integer tensor containing the predicted rank (class) labels |
|
Examples |
|
---------- |
|
>>> # 2 training examples, 5 classes |
|
>>> logits = torch.tensor([[14.152, -6.1942, 0.47710, 0.96850], |
|
... [65.667, 0.303, 11.500, -4.524]]) |
|
>>> corn_label_from_logits(logits) |
|
tensor([1, 3]) |
|
https://github.com/Raschka-research-group/coral-pytorch/blob/c6ab93afd555a6eac708c95ae1feafa15f91c5aa/coral_pytorch/dataset.py |
|
""" |
|
probas = t.sigmoid(logits) |
|
probas = t.cumprod(probas, dim=1) |
|
predict_levels = probas > 0.5 |
|
predicted_labels = t.sum(predict_levels, dim=1) |
|
return predicted_labels |
|
|