|
|
|
|
|
|
|
|
|
|
|
|
|
import typing as tp |
|
|
|
import flashy |
|
import torch |
|
from torch import autograd |
|
|
|
|
|
class Balancer: |
|
"""Loss balancer. |
|
|
|
The loss balancer combines losses together to compute gradients for the backward. |
|
Given `y = f(...)`, and a number of losses `l1(y, ...)`, `l2(y, ...)`, with `...` |
|
not having any dependence on `f`, the balancer can efficiently normalize the partial gradients |
|
`d l1 / d y`, `d l2 / dy` before summing them in order to achieve a desired ratio between |
|
the losses. For instance if `weights = {'l1': 2, 'l2': 1}`, 66% of the gradient |
|
going into `f(...)` will come from `l1` on average, and 33% from `l2`. This allows for an easy |
|
interpration of the weights even if the intrisic scale of `l1`, `l2` ... is unknown. |
|
|
|
Noting `g1 = d l1 / dy`, etc., the balanced gradient `G` will be |
|
(with `avg` an exponential moving average over the updates), |
|
|
|
G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i) |
|
|
|
If `balance_grads` is False, this is deactivated, and instead the gradient will just be the |
|
standard sum of the partial gradients with the given weights. |
|
|
|
A call to the backward method of the balancer will compute the the partial gradients, |
|
combining all the losses and potentially rescaling the gradients, |
|
which can help stabilize the training and reason about multiple losses with varying scales. |
|
The obtained gradient with respect to `y` is then back-propagated to `f(...)`. |
|
|
|
Expected usage: |
|
|
|
weights = {'loss_a': 1, 'loss_b': 4} |
|
balancer = Balancer(weights, ...) |
|
losses: dict = {} |
|
losses['loss_a'] = compute_loss_a(x, y) |
|
losses['loss_b'] = compute_loss_b(x, y) |
|
if model.training(): |
|
effective_loss = balancer.backward(losses, x) |
|
|
|
Args: |
|
weights (dict[str, float]): Weight coefficient for each loss. The balancer expect the losses keys |
|
from the backward method to match the weights keys to assign weight to each of the provided loss. |
|
balance_grads (bool): Whether to rescale gradients so that weights reflect the fraction of the |
|
overall gradient, rather than a constant multiplier. |
|
total_norm (float): Reference norm when rescaling gradients, ignored otherwise. |
|
emay_decay (float): EMA decay for averaging the norms. |
|
per_batch_item (bool): Whether to compute the averaged norm per batch item or not. This only holds |
|
when rescaling the gradients. |
|
epsilon (float): Epsilon value for numerical stability. |
|
monitor (bool): If True, stores in `self.metrics` the relative ratio between the norm of the gradients |
|
coming from each loss, when calling `backward()`. |
|
""" |
|
def __init__(self, weights: tp.Dict[str, float], balance_grads: bool = True, total_norm: float = 1., |
|
ema_decay: float = 0.999, per_batch_item: bool = True, epsilon: float = 1e-12, |
|
monitor: bool = False): |
|
self.weights = weights |
|
self.per_batch_item = per_batch_item |
|
self.total_norm = total_norm or 1. |
|
self.averager = flashy.averager(ema_decay or 1.) |
|
self.epsilon = epsilon |
|
self.monitor = monitor |
|
self.balance_grads = balance_grads |
|
self._metrics: tp.Dict[str, tp.Any] = {} |
|
|
|
@property |
|
def metrics(self): |
|
return self._metrics |
|
|
|
def backward(self, losses: tp.Dict[str, torch.Tensor], input: torch.Tensor) -> torch.Tensor: |
|
"""Compute the backward and return the effective train loss, e.g. the loss obtained from |
|
computing the effective weights. If `balance_grads` is True, the effective weights |
|
are the one that needs to be applied to each gradient to respect the desired relative |
|
scale of gradients coming from each loss. |
|
|
|
Args: |
|
losses (Dict[str, torch.Tensor]): dictionary with the same keys as `self.weights`. |
|
input (torch.Tensor): the input of the losses, typically the output of the model. |
|
This should be the single point of dependence between the losses |
|
and the model being trained. |
|
""" |
|
norms = {} |
|
grads = {} |
|
for name, loss in losses.items(): |
|
|
|
grad, = autograd.grad(loss, [input], retain_graph=True) |
|
if self.per_batch_item: |
|
|
|
dims = tuple(range(1, grad.dim())) |
|
norm = grad.norm(dim=dims, p=2).mean() |
|
else: |
|
norm = grad.norm(p=2) |
|
norms[name] = norm |
|
grads[name] = grad |
|
|
|
count = 1 |
|
if self.per_batch_item: |
|
count = len(grad) |
|
|
|
|
|
avg_norms = flashy.distrib.average_metrics(self.averager(norms), count) |
|
|
|
|
|
total = sum(avg_norms.values()) |
|
|
|
self._metrics = {} |
|
if self.monitor: |
|
|
|
for k, v in avg_norms.items(): |
|
self._metrics[f'ratio_{k}'] = v / total |
|
|
|
total_weights = sum([self.weights[k] for k in avg_norms]) |
|
assert total_weights > 0. |
|
desired_ratios = {k: w / total_weights for k, w in self.weights.items()} |
|
|
|
out_grad = torch.zeros_like(input) |
|
effective_loss = torch.tensor(0., device=input.device, dtype=input.dtype) |
|
for name, avg_norm in avg_norms.items(): |
|
if self.balance_grads: |
|
|
|
scale = desired_ratios[name] * self.total_norm / (self.epsilon + avg_norm) |
|
else: |
|
|
|
scale = self.weights[name] |
|
out_grad.add_(grads[name], alpha=scale) |
|
effective_loss += scale * losses[name].detach() |
|
|
|
input.backward(out_grad) |
|
return effective_loss |
|
|