Spaces:
Running
on
L4
Running
on
L4
File size: 1,394 Bytes
d35ea9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import torch
import torch.nn as nn
import copy
import R3GAN.Networks
class Generator(nn.Module):
def __init__(self, *args, **kw):
super(Generator, self).__init__()
config = copy.deepcopy(kw)
del config['FP16Stages']
del config['c_dim']
del config['img_resolution']
if kw['c_dim'] != 0:
config['ConditionDimension'] = kw['c_dim']
self.Model = R3GAN.Networks.Generator(*args, **config)
self.z_dim = kw['NoiseDimension']
self.c_dim = kw['c_dim']
self.img_resolution = kw['img_resolution']
for x in kw['FP16Stages']:
self.Model.MainLayers[x].DataType = torch.bfloat16
def forward(self, x, c):
return self.Model(x, c)
class Discriminator(nn.Module):
def __init__(self, *args, **kw):
super(Discriminator, self).__init__()
config = copy.deepcopy(kw)
del config['FP16Stages']
del config['c_dim']
del config['img_resolution']
if kw['c_dim'] != 0:
config['ConditionDimension'] = kw['c_dim']
self.Model = R3GAN.Networks.Discriminator(*args, **config)
for x in kw['FP16Stages']:
self.Model.MainLayers[x].DataType = torch.bfloat16
def forward(self, x, c):
return self.Model(x, c) |