GAN for MNIST Digit Generation

This repository contains a Generative Adversarial Network (GAN) trained on the MNIST dataset to generate realistic handwritten digits. The model was trained as part of the Generative AI course.

Model Details

  • Model Type: GAN
  • Dataset: MNIST (handwritten digits)
  • Generator Input: Latent vector of size 100
  • Output: 28x28 grayscale images
  • Framework: PyTorch

Training Details

  • Optimizer: Adam
  • Learning Rate: 0.0002
  • Beta1: 0.5
  • Epochs: 50
  • Batch Size: 64
  • Weight Decay: 0.0001
  • Logging: Weights & Biases

Usage

Loading the Model

To load the trained model, use the following code snippet:

from gan import Generator
import torch

latent_dim = 100
generator = Generator(latent_dim)
generator.load_state_dict(torch.load("./gan_mnist.pth"))
generator.eval()

# Generate samples
z = torch.randn(16, latent_dim)
samples = generator(z)

Example Results

generated images

References

License

This project is licensed under the MIT License.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train hussamalafandi/GAN_MNIST

Collection including hussamalafandi/GAN_MNIST