""" train_digit_classifier.py A fully documented training script for a convolutional neural network (CNN) classifier trained on MNIST + EMNIST digits + blank images. Author: Deep Shah License: GPL-3.0 """ import numpy as np import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader, Dataset, TensorDataset from sklearn.model_selection import train_test_split import os # ---------------------------------------------------------------------- # 1. Reproducibility Setup # ---------------------------------------------------------------------- # Set fixed seeds to make results deterministic (important for debugging and reproducibility) torch.manual_seed(42) np.random.seed(42) # ---------------------------------------------------------------------- # 2. Device Selection # ---------------------------------------------------------------------- # Automatically use GPU if available; fallback to CPU otherwise device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[INFO] Using device: {device}") # ---------------------------------------------------------------------- # 3. EMNIST Loader (Custom Dataset class) # ---------------------------------------------------------------------- class EMNISTDigitsDataset(Dataset): """ A PyTorch-compatible wrapper for the EMNIST digits dataset loaded via TensorFlow Datasets. Ensures data is shaped correctly and optionally transformed. """ def __init__(self, split="train", transform=None): import tensorflow_datasets as tfds ds = tfds.load("emnist/digits", split=split, as_supervised=True) self.images = [] self.labels = [] for image, label in tfds.as_numpy(ds): if image.ndim == 2: image = image[..., np.newaxis] elif image.ndim == 4 and image.shape[0] == 1: image = image[0] self.images.append(image) self.labels.append(label) self.images = np.array(self.images, dtype=np.float32) / 255.0 # Normalize to [0,1] self.labels = np.array(self.labels, dtype=np.int64) self.transform = transform def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx] label = self.labels[idx] if self.transform: image = self.transform(torch.tensor(image.transpose(2, 0, 1))).transpose(1, 2).numpy() return torch.tensor(image.transpose(2, 0, 1), dtype=torch.float32), torch.tensor(label, dtype=torch.long) # ---------------------------------------------------------------------- # 4. Data Augmentation Strategy # ---------------------------------------------------------------------- # We use a modest augmentation strategy to improve generalization train_transform = transforms.Compose([ transforms.ToPILImage(), transforms.RandomRotation(10), # Handle slanted handwriting transforms.RandomAffine(degrees=0, scale=(0.9, 1.1), translate=(0.1, 0.1)), # Simulate slight distortions transforms.ToTensor() ]) # ---------------------------------------------------------------------- # 5. Load Datasets (MNIST + EMNIST + Blank) # ---------------------------------------------------------------------- # Load MNIST mnist_dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True) mnist_images = mnist_dataset.data.numpy().astype(np.float32) / 255.0 mnist_images = mnist_images[..., np.newaxis] mnist_labels = mnist_dataset.targets.numpy() # Load EMNIST emnist_dataset = EMNISTDigitsDataset(split="train", transform=None) emnist_images = emnist_dataset.images emnist_labels = emnist_dataset.labels # Create blank (all-black) 28x28 images, labeled with class 10 x_blank = np.zeros((5000, 28, 28, 1), dtype=np.float32) y_blank = np.full((5000,), 10, dtype=np.int64) # Combine all datasets x_combined = np.concatenate([mnist_images, emnist_images, x_blank], axis=0) y_combined = np.concatenate([mnist_labels, emnist_labels, y_blank], axis=0) # Shuffle for randomness indices = np.random.permutation(len(x_combined)) x_combined = x_combined[indices] y_combined = y_combined[indices] # ---------------------------------------------------------------------- # 6. Train/Validation Split # ---------------------------------------------------------------------- x_train, x_val, y_train, y_val = train_test_split( x_combined, y_combined, test_size=0.1, random_state=42 ) # Convert to PyTorch format train_dataset = TensorDataset( torch.tensor(x_train.transpose(0, 3, 1, 2), dtype=torch.float32), torch.tensor(y_train, dtype=torch.long) ) val_dataset = TensorDataset( torch.tensor(x_val.transpose(0, 3, 1, 2), dtype=torch.float32), torch.tensor(y_val, dtype=torch.long) ) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False) # ---------------------------------------------------------------------- # 7. CNN Architecture # ---------------------------------------------------------------------- class CNN(nn.Module): """ This CNN is designed to: - Use 3 convolutional blocks with increasing depth (32 -> 64 -> 128) - Use BatchNorm to stabilize training - Use Dropout to prevent overfitting - Flatten and use 2 dense layers to classify """ def __init__(self): super().__init__() self.features = nn.Sequential( nn.Conv2d(1, 32, 3, padding=1), # Small receptive field nn.BatchNorm2d(32), nn.ReLU(), nn.Conv2d(32, 64, 3, padding=1), # Slightly deeper nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout(0.1), # Helps regularize nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2, 2), nn.Dropout(0.1) ) self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(128 * 7 * 7, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.2), nn.Linear(128, 11) # 0-9 digits + blank (class 10) ) def forward(self, x): return self.classifier(self.features(x)) model = CNN().to(device) # ---------------------------------------------------------------------- # 8. Training Configuration # ---------------------------------------------------------------------- # CrossEntropyLoss is standard for multi-class classification criterion = nn.CrossEntropyLoss() # Adam is used because it's efficient for noisy gradients & fast convergence optimizer = optim.Adam(model.parameters(), lr=0.001) # ReduceLROnPlateau reduces LR when validation loss plateaus (adaptive control) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=2, min_lr=1e-6) # Early stopping is used to prevent overfitting and wasted training patience = 5 patience_counter = 0 best_val_loss = float("inf") best_model_state = None # ---------------------------------------------------------------------- # 9. Training Loop # ---------------------------------------------------------------------- for epoch in range(1, 51): model.train() running_loss = 0 correct = 0 total = 0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) # Apply data augmentation on CPU for i in range(len(images)): images[i] = train_transform(images[i].cpu()).to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() train_acc = 100 * correct / total train_loss = running_loss / len(train_loader) # ---------------- # Validation phase # ---------------- model.eval() val_loss = 0 val_correct = 0 val_total = 0 with torch.no_grad(): for images, labels in val_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) val_loss += loss.item() _, predicted = torch.max(outputs, 1) val_total += labels.size(0) val_correct += (predicted == labels).sum().item() val_acc = 100 * val_correct / val_total val_loss /= len(val_loader) print(f"Epoch {epoch:02d}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%, " f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.2f}%") # Adjust learning rate if plateau scheduler.step(val_loss) # Save best model if val_loss < best_val_loss: best_val_loss = val_loss best_model_state = model.state_dict() patience_counter = 0 else: patience_counter += 1 if patience_counter >= patience: print("[INFO] Early stopping triggered.") break # Load best model model.load_state_dict(best_model_state) # Save PyTorch weights torch.save(model.state_dict(), "mnist_emnist_blank_cnn_v1.pth") print("[INFO] Model weights saved as mnist_emnist_blank_cnn_v1.pth") # Convert to TorchScript for deployment (required by Hugging Face Inference API) model.eval() example_input = torch.randn(1, 1, 28, 28).to(device) scripted_model = torch.jit.trace(model, example_input) scripted_model.save("mnist_emnist_blank_cnn_v1.pt") print("[INFO] TorchScript model saved as mnist_emnist_blank_cnn_v1.pt") # ONNX export # We move to CPU just for export (then restore the device). prev_device = next(model.parameters()).device try: model_cpu = model.to("cpu").eval() dummy = torch.randn(1, 1, 28, 28) # match input shape onnx_path = "mnist_emnist_blank_cnn_v1.onnx" torch.onnx.export( model_cpu, dummy, onnx_path, export_params=True, opset_version=13, do_constant_folding=True, input_names=["input"], output_names=["logits"], dynamic_axes={"input": {0: "batch_size"}, "logits": {0: "batch_size"}}, ) print(f"[INFO] ONNX model saved as {onnx_path}") finally: model.to(prev_device).eval() # restore original device