Spaces:
Runtime error
Runtime error
| from typing import List, Optional | |
| import torch | |
| from torch import nn | |
| from torch.nn.functional import ( | |
| smooth_l1_loss, | |
| ) | |
| def flatten_CHW(im: torch.Tensor) -> torch.Tensor: | |
| """ | |
| (B, C, H, W) -> (B, -1) | |
| """ | |
| B = im.shape[0] | |
| return im.reshape(B, -1) | |
| def stddev(x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| x: (B, -1), assume with mean normalized | |
| Retuens: | |
| stddev: (B) | |
| """ | |
| return torch.sqrt(torch.mean(x * x, dim=-1)) | |
| def gram_matrix(input_): | |
| B, C = input_.shape[:2] | |
| features = input_.view(B, C, -1) | |
| N = features.shape[-1] | |
| G = torch.bmm(features, features.transpose(1, 2)) # C x C | |
| return G.div(C * N) | |
| class ColorTransferLoss(nn.Module): | |
| """Penalize the gram matrix difference between StyleGAN2's ToRGB outputs""" | |
| def __init__( | |
| self, | |
| init_rgbs, | |
| scale_rgb: bool = False | |
| ): | |
| super().__init__() | |
| with torch.no_grad(): | |
| init_feats = [x.detach() for x in init_rgbs] | |
| self.stds = [stddev(flatten_CHW(rgb)) if scale_rgb else 1 for rgb in init_feats] # (B, 1, 1, 1) or scalar | |
| self.grams = [gram_matrix(rgb / std) for rgb, std in zip(init_feats, self.stds)] | |
| def forward(self, rgbs: List[torch.Tensor], level: int = None): | |
| if level is None: | |
| level = len(self.grams) | |
| feats = rgbs | |
| loss = 0 | |
| for i, (rgb, std) in enumerate(zip(feats[:level], self.stds[:level])): | |
| G = gram_matrix(rgb / std) | |
| loss = loss + smooth_l1_loss(G, self.grams[i]) | |
| return loss | |