MNIST-Generator / app.py
sdpetrides's picture
Update app.py to load the GAN model
d09b148
raw
history blame
2.02 kB
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
import gradio as gr
class DCGAN_Generator(nn.Module):
def __init__(self):
super(DCGAN_Generator, self).__init__()
self.conv1 = nn.ConvTranspose2d(100, 256, 5)
self.bn1 = nn.BatchNorm2d(256)
self.relu1 = nn.LeakyReLU(negative_slope=0.2)
self.conv2 = nn.ConvTranspose2d(256, 256, 5)
self.bn2 = nn.BatchNorm2d(256)
self.relu2 = nn.LeakyReLU(negative_slope=0.2)
self.conv3 = nn.ConvTranspose2d(256, 128, 4)
self.bn3 = nn.BatchNorm2d(128)
self.relu3 = nn.LeakyReLU(negative_slope=0.2)
self.conv4 = nn.ConvTranspose2d(128, 64, 2, 2)
self.bn4 = nn.BatchNorm2d(64)
self.relu4 = nn.LeakyReLU(negative_slope=0.2)
self.conv5 = nn.ConvTranspose2d(64, 32, 3)
self.bn5 = nn.BatchNorm2d(32)
self.relu5 = nn.LeakyReLU(negative_slope=0.2)
self.conv6 = nn.ConvTranspose2d(32, 1, 3)
self.tanh1 = nn.Tanh()
def forward(self, x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.relu4(self.bn4(self.conv4(x)))
x = self.relu5(self.bn5(self.conv5(x)))
return self.tanh1(self.conv6(x))
MODEL_PATH = './gan_mnist_generator_20.pt'
model = DCGAN_Generator()
model.load_state_dict(
torch.load(
MODEL_PATH,
map_location=torch.device('cpu')
)
)
def run_generative_model(use_seed="False", seed=42):
if use_seed == "True":
torch.random.manual_seed(seed)
# Run generator model
noise = torch.randn(1, 100, 1, 1)
with torch.no_grad():
im = model(noise).detach().cpu()
# Process image
im = torch.squeeze(im) # reduce dimension to get single image
im = im * 128 + 128 # linear scaler from [-1, 1] to [0, 255]
im = np.uint8(im)
return Image.fromarray(im)
demo = gr.Interface(
fn=run_generative_model,
inputs=[
gr.Radio(["True", "False"], value="False"),
gr.Slider(0, 100, value=42),
],
outputs="image",
)
demo.launch(share=True)