import gradio as gr import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image # Check for GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define the Generator architecture class Generator(nn.Module): def __init__(self, latent_dim=100, img_channels=3, feature_dim=64): super(Generator, self).__init__() self.latent_dim = latent_dim self.model = nn.Sequential( nn.ConvTranspose2d(latent_dim, feature_dim * 8, 4, 1, 0, bias=False), nn.BatchNorm2d(feature_dim * 8), nn.ReLU(True), nn.ConvTranspose2d(feature_dim * 8, feature_dim * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(feature_dim * 4), nn.ReLU(True), nn.ConvTranspose2d(feature_dim * 4, feature_dim * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(feature_dim * 2), nn.ReLU(True), nn.ConvTranspose2d(feature_dim * 2, feature_dim, 4, 2, 1, bias=False), nn.BatchNorm2d(feature_dim), nn.ReLU(True), nn.ConvTranspose2d(feature_dim, img_channels, 4, 2, 1, bias=False), nn.Tanh() ) def forward(self, z): return self.model(z) def generate_latent_space(self, batch_size): return torch.randn(batch_size, self.latent_dim, 1, 1, device=device) # Instantiate the generator and load pre-trained weights latent_dim = 100 generator = Generator(latent_dim=latent_dim) # Make sure you have uploaded your pre-trained model file "generator.pth" to your Space generator.load_state_dict(torch.load("generator.pth", map_location=device)) generator.to(device) generator.eval() # Function to generate a face image def generate_face(): with torch.no_grad(): # Generate a random latent vector and produce an image z = generator.generate_latent_space(1) generated_image = generator(z) generated_image = generated_image.cpu().squeeze(0) # Denormalize the image (from [-1, 1] to [0, 1]) generated_image = generated_image * 0.5 + 0.5 # Convert the tensor to a PIL Image to_pil = transforms.ToPILImage() image = to_pil(generated_image) return image # Set up the Gradio interface demo = gr.Interface( fn=generate_face, inputs=[], # No inputs – each button press generates a new image outputs="image", title="CelebA GAN Face Generator", description="Generates a face image using a pre-trained GAN on the CelebA dataset.", ) if __name__ == "__main__": demo.launch()