Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| class WeightedCrossEntropy(nn.CrossEntropyLoss): | |
| def __init__(self, weights, **pytorch_ce_loss_args) -> None: | |
| super().__init__(reduction='none', **pytorch_ce_loss_args) | |
| self.weights = weights | |
| def __call__(self, outputs, targets, to_weight=True): | |
| loss = super().__call__(outputs, targets) | |
| if to_weight: | |
| return (loss * self.weights[targets]).sum() / self.weights[targets].sum() | |
| else: | |
| return loss.mean() | |
| if __name__ == '__main__': | |
| x = torch.randn(10, 5) | |
| target = torch.randint(0, 5, (10,)) | |
| weights = torch.tensor([1., 2., 3., 4., 5.]) | |
| # criterion_weighted = nn.CrossEntropyLoss(weight=weights) | |
| # loss_weighted = criterion_weighted(x, target) | |
| # criterion_weighted_manual = nn.CrossEntropyLoss(reduction='none') | |
| # loss_weighted_manual = criterion_weighted_manual(x, target) | |
| # print(loss_weighted, loss_weighted_manual.mean()) | |
| # loss_weighted_manual = (loss_weighted_manual * weights[target]).sum() / weights[target].sum() | |
| # print(loss_weighted, loss_weighted_manual) | |
| # print(torch.allclose(loss_weighted, loss_weighted_manual)) | |
| pytorch_weighted = nn.CrossEntropyLoss(weight=weights) | |
| pytorch_unweighted = nn.CrossEntropyLoss() | |
| custom = WeightedCrossEntropy(weights) | |
| assert torch.allclose(pytorch_weighted(x, target), custom(x, target, to_weight=True)) | |
| assert torch.allclose(pytorch_unweighted(x, target), custom(x, target, to_weight=False)) | |
| print(custom(x, target, to_weight=True), custom(x, target, to_weight=False)) | |