|
""" |
|
train_digit_classifier.py |
|
|
|
A fully documented training script for a convolutional neural network (CNN) |
|
classifier trained on MNIST + EMNIST digits + blank images. |
|
|
|
Author: Deep Shah |
|
License: GPL-3.0 |
|
""" |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import torchvision |
|
import torchvision.transforms as transforms |
|
from torch.utils.data import DataLoader, Dataset, TensorDataset |
|
from sklearn.model_selection import train_test_split |
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.manual_seed(42) |
|
np.random.seed(42) |
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"[INFO] Using device: {device}") |
|
|
|
|
|
|
|
|
|
|
|
class EMNISTDigitsDataset(Dataset): |
|
""" |
|
A PyTorch-compatible wrapper for the EMNIST digits dataset loaded via TensorFlow Datasets. |
|
Ensures data is shaped correctly and optionally transformed. |
|
""" |
|
|
|
def __init__(self, split="train", transform=None): |
|
import tensorflow_datasets as tfds |
|
ds = tfds.load("emnist/digits", split=split, as_supervised=True) |
|
self.images = [] |
|
self.labels = [] |
|
for image, label in tfds.as_numpy(ds): |
|
if image.ndim == 2: |
|
image = image[..., np.newaxis] |
|
elif image.ndim == 4 and image.shape[0] == 1: |
|
image = image[0] |
|
self.images.append(image) |
|
self.labels.append(label) |
|
self.images = np.array(self.images, dtype=np.float32) / 255.0 |
|
self.labels = np.array(self.labels, dtype=np.int64) |
|
self.transform = transform |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, idx): |
|
image = self.images[idx] |
|
label = self.labels[idx] |
|
if self.transform: |
|
image = self.transform(torch.tensor(image.transpose(2, 0, 1))).transpose(1, 2).numpy() |
|
return torch.tensor(image.transpose(2, 0, 1), dtype=torch.float32), torch.tensor(label, dtype=torch.long) |
|
|
|
|
|
|
|
|
|
|
|
|
|
train_transform = transforms.Compose([ |
|
transforms.ToPILImage(), |
|
transforms.RandomRotation(10), |
|
transforms.RandomAffine(degrees=0, scale=(0.9, 1.1), translate=(0.1, 0.1)), |
|
transforms.ToTensor() |
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
mnist_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True) |
|
mnist_images = mnist_dataset.data.numpy().astype(np.float32) / 255.0 |
|
mnist_images = mnist_images[..., np.newaxis] |
|
mnist_labels = mnist_dataset.targets.numpy() |
|
|
|
|
|
emnist_dataset = EMNISTDigitsDataset(split="train", transform=None) |
|
emnist_images = emnist_dataset.images |
|
emnist_labels = emnist_dataset.labels |
|
|
|
|
|
x_blank = np.zeros((5000, 28, 28, 1), dtype=np.float32) |
|
y_blank = np.full((5000,), 10, dtype=np.int64) |
|
|
|
|
|
x_combined = np.concatenate([mnist_images, emnist_images, x_blank], axis=0) |
|
y_combined = np.concatenate([mnist_labels, emnist_labels, y_blank], axis=0) |
|
|
|
|
|
indices = np.random.permutation(len(x_combined)) |
|
x_combined = x_combined[indices] |
|
y_combined = y_combined[indices] |
|
|
|
|
|
|
|
|
|
|
|
x_train, x_val, y_train, y_val = train_test_split( |
|
x_combined, y_combined, test_size=0.1, random_state=42 |
|
) |
|
|
|
|
|
train_dataset = TensorDataset( |
|
torch.tensor(x_train.transpose(0, 3, 1, 2), dtype=torch.float32), |
|
torch.tensor(y_train, dtype=torch.long) |
|
) |
|
val_dataset = TensorDataset( |
|
torch.tensor(x_val.transpose(0, 3, 1, 2), dtype=torch.float32), |
|
torch.tensor(y_val, dtype=torch.long) |
|
) |
|
|
|
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) |
|
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False) |
|
|
|
|
|
|
|
|
|
|
|
class CNN(nn.Module): |
|
""" |
|
This CNN is designed to: |
|
- Use 3 convolutional blocks with increasing depth (32 -> 64 -> 128) |
|
- Use BatchNorm to stabilize training |
|
- Use Dropout to prevent overfitting |
|
- Flatten and use 2 dense layers to classify |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.features = nn.Sequential( |
|
nn.Conv2d(1, 32, 3, padding=1), |
|
nn.BatchNorm2d(32), |
|
nn.ReLU(), |
|
nn.Conv2d(32, 64, 3, padding=1), |
|
nn.BatchNorm2d(64), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2, 2), |
|
nn.Dropout(0.1), |
|
nn.Conv2d(64, 128, 3, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(), |
|
nn.MaxPool2d(2, 2), |
|
nn.Dropout(0.1) |
|
) |
|
self.classifier = nn.Sequential( |
|
nn.Flatten(), |
|
nn.Linear(128 * 7 * 7, 128), |
|
nn.BatchNorm1d(128), |
|
nn.ReLU(), |
|
nn.Dropout(0.2), |
|
nn.Linear(128, 11) |
|
) |
|
|
|
def forward(self, x): |
|
return self.classifier(self.features(x)) |
|
|
|
model = CNN().to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss() |
|
|
|
|
|
optimizer = optim.Adam(model.parameters(), lr=0.001) |
|
|
|
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=2, min_lr=1e-6) |
|
|
|
|
|
patience = 5 |
|
patience_counter = 0 |
|
best_val_loss = float("inf") |
|
best_model_state = None |
|
|
|
|
|
|
|
|
|
|
|
for epoch in range(1, 51): |
|
model.train() |
|
running_loss = 0 |
|
correct = 0 |
|
total = 0 |
|
|
|
for images, labels in train_loader: |
|
images, labels = images.to(device), labels.to(device) |
|
|
|
|
|
for i in range(len(images)): |
|
images[i] = train_transform(images[i].cpu()).to(device) |
|
|
|
optimizer.zero_grad() |
|
outputs = model(images) |
|
loss = criterion(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
running_loss += loss.item() |
|
_, predicted = torch.max(outputs, 1) |
|
total += labels.size(0) |
|
correct += (predicted == labels).sum().item() |
|
|
|
train_acc = 100 * correct / total |
|
train_loss = running_loss / len(train_loader) |
|
|
|
|
|
|
|
|
|
model.eval() |
|
val_loss = 0 |
|
val_correct = 0 |
|
val_total = 0 |
|
with torch.no_grad(): |
|
for images, labels in val_loader: |
|
images, labels = images.to(device), labels.to(device) |
|
outputs = model(images) |
|
loss = criterion(outputs, labels) |
|
val_loss += loss.item() |
|
_, predicted = torch.max(outputs, 1) |
|
val_total += labels.size(0) |
|
val_correct += (predicted == labels).sum().item() |
|
|
|
val_acc = 100 * val_correct / val_total |
|
val_loss /= len(val_loader) |
|
|
|
print(f"Epoch {epoch:02d}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%, " |
|
f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%") |
|
|
|
|
|
scheduler.step(val_loss) |
|
|
|
|
|
if val_loss < best_val_loss: |
|
best_val_loss = val_loss |
|
best_model_state = model.state_dict() |
|
patience_counter = 0 |
|
else: |
|
patience_counter += 1 |
|
if patience_counter >= patience: |
|
print("[INFO] Early stopping triggered.") |
|
break |
|
|
|
|
|
model.load_state_dict(best_model_state) |
|
|
|
|
|
torch.save(model.state_dict(), "mnist_emnist_blank_cnn_v1.pth") |
|
print("[INFO] Model weights saved as mnist_emnist_blank_cnn_v1.pth") |
|
|
|
|
|
model.eval() |
|
example_input = torch.randn(1, 1, 28, 28).to(device) |
|
scripted_model = torch.jit.trace(model, example_input) |
|
scripted_model.save("mnist_emnist_blank_cnn_v1.pt") |
|
print("[INFO] TorchScript model saved as mnist_emnist_blank_cnn_v1.pt") |
|
|
|
|
|
|
|
prev_device = next(model.parameters()).device |
|
try: |
|
model_cpu = model.to("cpu").eval() |
|
dummy = torch.randn(1, 1, 28, 28) |
|
|
|
onnx_path = "mnist_emnist_blank_cnn_v1.onnx" |
|
torch.onnx.export( |
|
model_cpu, |
|
dummy, |
|
onnx_path, |
|
export_params=True, |
|
opset_version=13, |
|
do_constant_folding=True, |
|
input_names=["input"], |
|
output_names=["logits"], |
|
dynamic_axes={"input": {0: "batch_size"}, "logits": {0: "batch_size"}}, |
|
) |
|
print(f"[INFO] ONNX model saved as {onnx_path}") |
|
finally: |
|
model.to(prev_device).eval() |
|
|