File size: 1,174 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch
class CrossEntropyLossSoft(torch.nn.modules.loss._Loss):
""" inplace distillation for image classification (from US-Net) """
def forward(self, output, target):
target = torch.nn.functional.softmax(target, dim=1)
output_log_prob = torch.nn.functional.log_softmax(output, dim=1)
target = target.unsqueeze(1)
output_log_prob = output_log_prob.unsqueeze(2)
cross_entropy_loss = -torch.bmm(target, output_log_prob).mean()
return cross_entropy_loss
class CrossEntropyLossSoft2(torch.nn.modules.loss._Loss):
""" inplace distillation for image classification (from US-Net) """
def forward(self, output, target, loss_mask):
target = torch.nn.functional.softmax(target, dim=-1)
output_log_prob = torch.nn.functional.log_softmax(output, dim=-1)
target = target.unsqueeze(-2)
output_log_prob = output_log_prob.unsqueeze(-1)
cross_entropy_loss = -torch.matmul(target, output_log_prob).squeeze(-1).squeeze(-1)
num_loss = loss_mask.sum()
loss = torch.sum(cross_entropy_loss.view(-1) * loss_mask.view(-1)) / num_loss
return loss |