deepshah23's picture
Update README.md
adf39c3 verified
metadata
license: gpl-3.0
language:
  - en
metrics:
  - accuracy
pipeline_tag: image-classification
tags:
  - digits
  - cnn
  - mnist
  - emnist
  - pytorch
  - handwriting-recognition
  - onnx

Digit & Blank Image Classifier (PyTorch CNN)

A high-accuracy convolutional neural network trained to classify handwritten digits from the MNIST and EMNIST Digits datasets, and additionally detect blank images (unfilled boxes) as a distinct class. This model is trained using PyTorch and exported in TorchScript format (.pt) for reliable and portable inference.


License & Attribution

This model is licensed under the AGPL-3.0 license to comply with the Plom Project licensing requirements.

Developed as part of the Plom Project

Authors & Credits:

  • Model: Deep Shah, Undergraduate Research Assistant, UBC
  • Supervision: Prof. Andrew Rechnitzer and Prof. Colin B. MacDonald
  • Project: The Plom Project GitLab

Overview

  • Input: 1Γ—28Γ—28 grayscale image
  • Output: Integer class prediction:
    • 0–9: Digits
    • 10: Blank image
  • Architecture: 3-layer CNN with BatchNorm, ReLU, MaxPooling, Dropout, Fully Connected Layers
  • Model Format: TorchScript (.pt), ONNX (.onnx)
  • Training Dataset: Combined MNIST, EMNIST Digits, and 5000 synthetic blank images

Dataset Details

Datasets Used:

  • MNIST – 28Γ—28 handwritten digits (0–9), 60,000 training images
  • EMNIST Digits – 28Γ—28 digits extracted from handwritten characters, 240,000+ training samples
  • Blank Images – 5,000 synthetic all-black 28Γ—28 images, labeled as class 10 to simulate unfilled regions

Preprocessing:

  • Normalized pixel values to [0, 1]
  • Converted images to channel-first format (N, C, H, W)
  • Combined and shuffled datasets

Data Augmentation

To improve generalization and robustness to handwriting variation:

  • RandomRotation(Β±10Β°)
  • RandomAffine: scale (0.9–1.1), translate (Β±10%)

These transformations simulate handwritten noise and variation in real student submissions.


Model Architecture

Input: (1, 28, 28)
↓ Conv2D(1 β†’ 32) + BatchNorm + ReLU
↓ Conv2D(32 β†’ 64) + BatchNorm + ReLU
↓ MaxPool2d(2x2) + Dropout(0.1)
↓ Conv2D(64 β†’ 128) + BatchNorm + ReLU
↓ MaxPool2d(2x2) + Dropout(0.1)
↓ Flatten
↓ Linear(128*7*7 β†’ 128) + BatchNorm + ReLU + Dropout(0.2)
↓ Linear(128 β†’ 11)
β†’ Output: class logits (digits 0–9, blank = 10)

Training Configuration

Hyperparameter Value
Optimizer Adam (lr=0.001)
Loss Function CrossEntropyLoss
Scheduler ReduceLROnPlateau
Early Stopping Patience = 5
Epochs Max 50
Batch Size 64
Device CPU or CUDA
Random Seed 42

Evaluation Results

Metric Value
Test Accuracy 99.73%
Blank Image Accuracy 100.00%

All 5,000 blank images were correctly classified.


Inference Examples

1. TorchScript (PyTorch)

import torch

# Load TorchScript model
model = torch.jit.load("mnist_emnist_blank_cnn_v1.pt")
model.eval()

# Dummy input (1 image, 1 channel, 28x28)
img = torch.randn(1, 1, 28, 28)

# Predict
with torch.no_grad():
    out = model(img)
    predicted = out.argmax(dim=1).item()

print("Predicted class:", predicted)

2. ONNX (ONNX Runtime)

import onnxruntime as ort
import numpy as np

# Load ONNX model
session = ort.InferenceSession("mnist_emnist_blank_cnn_v1.onnx", providers=["CPUExecutionProvider"])

# Dummy input
img = np.random.randn(1, 1, 28, 28).astype(np.float32)

# Predict
outputs = session.run(None, {"input": img})
predicted = int(outputs[0].argmax(axis=1)[0])

print("Predicted class:", predicted)

If the prediction is 10, the model considers the image to be blank (no digits present).


Included Files

  • train_digit_classifier.py: Training script with full documentation
  • mnist_emnist_blank_cnn_v1.pth: Final trained model weights
  • mnist_emnist_blank_cnn_v1.pt: TorchScript export for deployment
  • mnist_emnist_blank_cnn_v1.onnx: ONNX export for deployment
  • requirements.txt: Required dependencies for training or inference

Intended Use

This model was designed to support the Plom Project’s student ID digit detection system, helping automatically identify handwritten digits (and detect blank/unfilled boxes) from scanned exam sheets.

It may also be adapted for other handwritten digit classification tasks or real-time blank field detection applications.