Spaces:
Running
on
L4
Running
on
L4
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# NVIDIA CORPORATION and its licensors retain all intellectual property | |
# and proprietary rights in and to this software, related documentation | |
# and any modifications thereto. Any use, reproduction, disclosure or | |
# distribution of this software and related documentation without an express | |
# license agreement from NVIDIA CORPORATION is strictly prohibited. | |
"""Loss functions.""" | |
from torch_utils import training_stats | |
from R3GAN.Trainer import AdversarialTraining | |
import torch | |
#---------------------------------------------------------------------------- | |
class R3GANLoss: | |
def __init__(self, G, D, augment_pipe=None): | |
self.trainer = AdversarialTraining(G, D) | |
if augment_pipe is not None: | |
self.preprocessor = lambda x: augment_pipe(x.to(torch.float32)).to(x.dtype) | |
else: | |
self.preprocessor = lambda x: x | |
def accumulate_gradients(self, phase, real_img, real_c, gen_z, gamma, gain): | |
# G | |
if phase == 'G': | |
AdversarialLoss, RelativisticLogits = self.trainer.AccumulateGeneratorGradients(gen_z, real_img, real_c, gain, self.preprocessor) | |
training_stats.report('Loss/scores/fake', RelativisticLogits) | |
training_stats.report('Loss/signs/fake', RelativisticLogits.sign()) | |
training_stats.report('Loss/G/loss', AdversarialLoss) | |
# D | |
if phase == 'D': | |
AdversarialLoss, RelativisticLogits, R1Penalty, R2Penalty = self.trainer.AccumulateDiscriminatorGradients(gen_z, real_img, real_c, gamma, gain, self.preprocessor) | |
training_stats.report('Loss/scores/real', RelativisticLogits) | |
training_stats.report('Loss/signs/real', RelativisticLogits.sign()) | |
training_stats.report('Loss/D/loss', AdversarialLoss) | |
training_stats.report('Loss/r1_penalty', R1Penalty) | |
training_stats.report('Loss/r2_penalty', R2Penalty) | |
#---------------------------------------------------------------------------- |