Spaces:
Sleeping
Sleeping
File size: 2,019 Bytes
2f0c9bf d09b148 2f0c9bf d09b148 2f0c9bf d09b148 2f0c9bf d09b148 2f0c9bf d09b148 2f0c9bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
import gradio as gr
class DCGAN_Generator(nn.Module):
def __init__(self):
super(DCGAN_Generator, self).__init__()
self.conv1 = nn.ConvTranspose2d(100, 256, 5)
self.bn1 = nn.BatchNorm2d(256)
self.relu1 = nn.LeakyReLU(negative_slope=0.2)
self.conv2 = nn.ConvTranspose2d(256, 256, 5)
self.bn2 = nn.BatchNorm2d(256)
self.relu2 = nn.LeakyReLU(negative_slope=0.2)
self.conv3 = nn.ConvTranspose2d(256, 128, 4)
self.bn3 = nn.BatchNorm2d(128)
self.relu3 = nn.LeakyReLU(negative_slope=0.2)
self.conv4 = nn.ConvTranspose2d(128, 64, 2, 2)
self.bn4 = nn.BatchNorm2d(64)
self.relu4 = nn.LeakyReLU(negative_slope=0.2)
self.conv5 = nn.ConvTranspose2d(64, 32, 3)
self.bn5 = nn.BatchNorm2d(32)
self.relu5 = nn.LeakyReLU(negative_slope=0.2)
self.conv6 = nn.ConvTranspose2d(32, 1, 3)
self.tanh1 = nn.Tanh()
def forward(self, x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.relu4(self.bn4(self.conv4(x)))
x = self.relu5(self.bn5(self.conv5(x)))
return self.tanh1(self.conv6(x))
MODEL_PATH = './gan_mnist_generator_20.pt'
model = DCGAN_Generator()
model.load_state_dict(
torch.load(
MODEL_PATH,
map_location=torch.device('cpu')
)
)
def run_generative_model(use_seed="False", seed=42):
if use_seed == "True":
torch.random.manual_seed(seed)
# Run generator model
noise = torch.randn(1, 100, 1, 1)
with torch.no_grad():
im = model(noise).detach().cpu()
# Process image
im = torch.squeeze(im) # reduce dimension to get single image
im = im * 128 + 128 # linear scaler from [-1, 1] to [0, 255]
im = np.uint8(im)
return Image.fromarray(im)
demo = gr.Interface(
fn=run_generative_model,
inputs=[
gr.Radio(["True", "False"], value="False"),
gr.Slider(0, 100, value=42),
],
outputs="image",
)
demo.launch(share=True) |