etweedy commited on
Commit
688bf34
·
1 Parent(s): a73fef6

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +3 -29
  2. mnist_model.pth +3 -0
  3. mnist_model_weights.pth +3 -0
app.py CHANGED
@@ -2,35 +2,9 @@ import torch
2
  from torch import nn
3
  import gradio as gr
4
 
5
- # Define the custom CNN model class that was trained on the MNIST data
6
- class CNN(nn.Module):
7
- """
8
- A custom CNN class. The network has: (1) a convolution layer with 1 input channel and 16 output channels with ReLU activation and 2x2 max-pooling, (2) a second convolution layer with 16 input channels and 32 output channels with ReLU activation and 2x2 max-pooling, and (3) a linear output layer with 10 outputs.
9
- """
10
- def __init__(self):
11
- super(CNN,self).__init__()
12
- self.conv1 = nn.Sequential(
13
- nn.Conv2d(1,16,5,stride=1,padding=2),
14
- nn.ReLU(),
15
- nn.MaxPool2d(kernel_size=2),
16
- )
17
- self.conv2 = nn.Sequential(
18
- nn.Conv2d(16,32,5,1,2),
19
- nn.ReLU(),
20
- nn.MaxPool2d(2),
21
- )
22
- self.out = nn.Linear(32*7*7,10)
23
-
24
- # Forward propogation method
25
- def forward(self,x):
26
- x=self.conv1(x)
27
- x=self.conv2(x)
28
- x = x.view(-1,32*7*7)
29
- return self.out(x)
30
-
31
- # Initialize an instance and load in the saved state_dict for the trained model
32
- model = CNN()
33
- model.load_state_dict(torch.load('mnist2.pkl',map_location=torch.device('cpu')))
34
  model.eval()
35
 
36
  # Prediction function
 
2
  from torch import nn
3
  import gradio as gr
4
 
5
+ # Load the model and then the post-training state_dict
6
+ model = torch.load('mnist_model.pth',map_location=torch.device('cpu'))
7
+ model.load_state_dict(torch.load('mnist_model_weights.pth',map_location=torch.device('cpu')))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  model.eval()
9
 
10
  # Prediction function
mnist_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:170a7a2185e474af9dd12e1232e28a9ea6003451a49b6dbc225d98f89c072a0e
3
+ size 13103247
mnist_model_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4cd377c11e267b84a351cebf48d4ea0a768b05e5502e34a4bf2854f55bd7d19
3
+ size 13100943