| import torch | |
| from torch import nn | |
| def get_loss(name): | |
| if name == "cosface": | |
| return CosFace() | |
| elif name == "arcface": | |
| return ArcFace() | |
| else: | |
| raise ValueError() | |
| class CosFace(nn.Module): | |
| def __init__(self, s=64.0, m=0.40): | |
| super(CosFace, self).__init__() | |
| self.s = s | |
| self.m = m | |
| def forward(self, cosine, label): | |
| index = torch.where(label != -1)[0] | |
| m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) | |
| m_hot.scatter_(1, label[index, None], self.m) | |
| cosine[index] -= m_hot | |
| ret = cosine * self.s | |
| return ret | |
| class ArcFace(nn.Module): | |
| def __init__(self, s=64.0, m=0.5): | |
| super(ArcFace, self).__init__() | |
| self.s = s | |
| self.m = m | |
| def forward(self, cosine: torch.Tensor, label): | |
| index = torch.where(label != -1)[0] | |
| m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) | |
| m_hot.scatter_(1, label[index, None], self.m) | |
| cosine.acos_() | |
| cosine[index] += m_hot | |
| cosine.cos_().mul_(self.s) | |
| return cosine | |