File size: 1,174 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch


class CrossEntropyLossSoft(torch.nn.modules.loss._Loss):
    """ inplace distillation for image classification (from US-Net) """
    def forward(self, output, target):
        target = torch.nn.functional.softmax(target, dim=1)
        
        output_log_prob = torch.nn.functional.log_softmax(output, dim=1)
        target = target.unsqueeze(1)
        output_log_prob = output_log_prob.unsqueeze(2)
        cross_entropy_loss = -torch.bmm(target, output_log_prob).mean()
        return cross_entropy_loss

class CrossEntropyLossSoft2(torch.nn.modules.loss._Loss):
    """ inplace distillation for image classification (from US-Net) """
    def forward(self, output, target, loss_mask):
        target = torch.nn.functional.softmax(target, dim=-1)
        
        output_log_prob = torch.nn.functional.log_softmax(output, dim=-1)
        target = target.unsqueeze(-2)
        output_log_prob = output_log_prob.unsqueeze(-1)
        cross_entropy_loss = -torch.matmul(target, output_log_prob).squeeze(-1).squeeze(-1)
        num_loss = loss_mask.sum()
        loss = torch.sum(cross_entropy_loss.view(-1) * loss_mask.view(-1)) / num_loss
        return loss