Spaces:
Running
on
L4
Running
on
L4
| # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| """Loss functions.""" | |
| from torch_utils import training_stats | |
| from R3GAN.Trainer import AdversarialTraining | |
| import torch | |
| #---------------------------------------------------------------------------- | |
| class R3GANLoss: | |
| def __init__(self, G, D, augment_pipe=None): | |
| self.trainer = AdversarialTraining(G, D) | |
| if augment_pipe is not None: | |
| self.preprocessor = lambda x: augment_pipe(x.to(torch.float32)).to(x.dtype) | |
| else: | |
| self.preprocessor = lambda x: x | |
| def accumulate_gradients(self, phase, real_img, real_c, gen_z, gamma, gain): | |
| # G | |
| if phase == 'G': | |
| AdversarialLoss, RelativisticLogits = self.trainer.AccumulateGeneratorGradients(gen_z, real_img, real_c, gain, self.preprocessor) | |
| training_stats.report('Loss/scores/fake', RelativisticLogits) | |
| training_stats.report('Loss/signs/fake', RelativisticLogits.sign()) | |
| training_stats.report('Loss/G/loss', AdversarialLoss) | |
| # D | |
| if phase == 'D': | |
| AdversarialLoss, RelativisticLogits, R1Penalty, R2Penalty = self.trainer.AccumulateDiscriminatorGradients(gen_z, real_img, real_c, gamma, gain, self.preprocessor) | |
| training_stats.report('Loss/scores/real', RelativisticLogits) | |
| training_stats.report('Loss/signs/real', RelativisticLogits.sign()) | |
| training_stats.report('Loss/D/loss', AdversarialLoss) | |
| training_stats.report('Loss/r1_penalty', R1Penalty) | |
| training_stats.report('Loss/r2_penalty', R2Penalty) | |
| #---------------------------------------------------------------------------- |