LINC-BIT's picture
Upload 1912 files
b84549f verified
import torch
class CrossEntropyLossSoft(torch.nn.modules.loss._Loss):
""" inplace distillation for image classification (from US-Net) """
def forward(self, output, target):
target = torch.nn.functional.softmax(target, dim=1)
output_log_prob = torch.nn.functional.log_softmax(output, dim=1)
target = target.unsqueeze(1)
output_log_prob = output_log_prob.unsqueeze(2)
cross_entropy_loss = -torch.bmm(target, output_log_prob).mean()
return cross_entropy_loss
class CrossEntropyLossSoft2(torch.nn.modules.loss._Loss):
""" inplace distillation for image classification (from US-Net) """
def forward(self, output, target, loss_mask):
target = torch.nn.functional.softmax(target, dim=-1)
output_log_prob = torch.nn.functional.log_softmax(output, dim=-1)
target = target.unsqueeze(-2)
output_log_prob = output_log_prob.unsqueeze(-1)
cross_entropy_loss = -torch.matmul(target, output_log_prob).squeeze(-1).squeeze(-1)
num_loss = loss_mask.sum()
loss = torch.sum(cross_entropy_loss.view(-1) * loss_mask.view(-1)) / num_loss
return loss