cmaeti commited on
Commit
3659ca5
·
verified ·
1 Parent(s): 20f90d2

Upload 2 files

Browse files
Files changed (2) hide show
  1. demo/app.py +61 -0
  2. demo/test_digit.jpg +0 -0
demo/app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+
6
+ # Define the model architecture
7
+ class Net(nn.Module):
8
+ def __init__(self):
9
+ super(Net, self).__init__()
10
+ self.conv1 = nn.Conv2d(1, 32, 3, 1)
11
+ self.conv2 = nn.Conv2d(32, 64, 3, 1)
12
+ self.dropout1 = nn.Dropout(0.25)
13
+ self.dropout2 = nn.Dropout(0.5)
14
+ self.fc1 = nn.Linear(9216, 128)
15
+ self.fc2 = nn.Linear(128, 10)
16
+
17
+ def forward(self, x):
18
+ x = self.conv1(x)
19
+ x = torch.relu(x)
20
+ x = self.conv2(x)
21
+ x = torch.relu(x)
22
+ x = torch.max_pool2d(x, 2)
23
+ x = self.dropout1(x)
24
+ x = torch.flatten(x, 1)
25
+ x = self.fc1(x)
26
+ x = torch.relu(x)
27
+ x = self.dropout2(x)
28
+ x = self.fc2(x)
29
+ output = torch.log_softmax(x, dim=1)
30
+ return output
31
+
32
+ # Load the trained model
33
+ model = Net()
34
+ #model.load_state_dict(torch.load('mnist-cnn.pth')) # Load weights
35
+ model.load_state_dict(torch.load('mnist-cnn.pth', weights_only=True)) # Load weights
36
+
37
+ # Set the model to evaluation mode
38
+ model.eval()
39
+
40
+ # Function to load and preprocess the image
41
+ def preprocess_image(image_path):
42
+ img = Image.open(image_path).convert('L') # Convert to grayscale
43
+ transform = transforms.Compose([
44
+ transforms.Resize((28, 28)),
45
+ transforms.ToTensor(),
46
+ transforms.Normalize((0.1307,), (0.3081,))
47
+ ])
48
+ img_tensor = transform(img).unsqueeze(0) # Add batch dimension
49
+ return img_tensor
50
+
51
+ # Load and preprocess a new image
52
+ image_path = "test_digit.jpg" # Replace with your image file path
53
+ input_image = preprocess_image(image_path)
54
+
55
+ # Make the prediction
56
+ with torch.no_grad():
57
+ outputs = model(input_image)
58
+ predicted_class = torch.argmax(outputs, dim=1).item()
59
+
60
+ # Output the predicted digit
61
+ print(f"Predicted digit: {predicted_class}")
demo/test_digit.jpg ADDED