sdpetrides commited on
Commit
d09b148
·
1 Parent(s): e430625

Update app.py to load the GAN model

Browse files
Files changed (1) hide show
  1. app.py +47 -3
app.py CHANGED
@@ -1,19 +1,63 @@
1
  import torch
 
2
 
3
  from PIL import Image
4
 
5
  import numpy as np
6
  import gradio as gr
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  MODEL_PATH = './gan_mnist_generator_20.pt'
9
- model = torch.load(MODEL_PATH)
 
 
 
 
 
 
10
 
11
- def run_generative_model(use_seed="False", seed=43):
12
  if use_seed == "True":
13
  torch.random.manual_seed(seed)
14
 
15
  # Run generator model
16
- noise = torch.randn(1, nz, 1, 1, device=device)
17
  with torch.no_grad():
18
  im = model(noise).detach().cpu()
19
 
 
1
  import torch
2
+ import torch.nn as nn
3
 
4
  from PIL import Image
5
 
6
  import numpy as np
7
  import gradio as gr
8
 
9
+
10
+ class DCGAN_Generator(nn.Module):
11
+ def __init__(self):
12
+ super(DCGAN_Generator, self).__init__()
13
+
14
+ self.conv1 = nn.ConvTranspose2d(100, 256, 5)
15
+ self.bn1 = nn.BatchNorm2d(256)
16
+ self.relu1 = nn.LeakyReLU(negative_slope=0.2)
17
+
18
+ self.conv2 = nn.ConvTranspose2d(256, 256, 5)
19
+ self.bn2 = nn.BatchNorm2d(256)
20
+ self.relu2 = nn.LeakyReLU(negative_slope=0.2)
21
+
22
+ self.conv3 = nn.ConvTranspose2d(256, 128, 4)
23
+ self.bn3 = nn.BatchNorm2d(128)
24
+ self.relu3 = nn.LeakyReLU(negative_slope=0.2)
25
+
26
+ self.conv4 = nn.ConvTranspose2d(128, 64, 2, 2)
27
+ self.bn4 = nn.BatchNorm2d(64)
28
+ self.relu4 = nn.LeakyReLU(negative_slope=0.2)
29
+
30
+ self.conv5 = nn.ConvTranspose2d(64, 32, 3)
31
+ self.bn5 = nn.BatchNorm2d(32)
32
+ self.relu5 = nn.LeakyReLU(negative_slope=0.2)
33
+
34
+ self.conv6 = nn.ConvTranspose2d(32, 1, 3)
35
+ self.tanh1 = nn.Tanh()
36
+
37
+ def forward(self, x):
38
+ x = self.relu1(self.bn1(self.conv1(x)))
39
+ x = self.relu2(self.bn2(self.conv2(x)))
40
+ x = self.relu3(self.bn3(self.conv3(x)))
41
+ x = self.relu4(self.bn4(self.conv4(x)))
42
+ x = self.relu5(self.bn5(self.conv5(x)))
43
+
44
+ return self.tanh1(self.conv6(x))
45
+
46
  MODEL_PATH = './gan_mnist_generator_20.pt'
47
+ model = DCGAN_Generator()
48
+ model.load_state_dict(
49
+ torch.load(
50
+ MODEL_PATH,
51
+ map_location=torch.device('cpu')
52
+ )
53
+ )
54
 
55
+ def run_generative_model(use_seed="False", seed=42):
56
  if use_seed == "True":
57
  torch.random.manual_seed(seed)
58
 
59
  # Run generator model
60
+ noise = torch.randn(1, 100, 1, 1)
61
  with torch.no_grad():
62
  im = model(noise).detach().cpu()
63