cmaeti commited on
Commit
31ed87f
·
verified ·
1 Parent(s): c555a05

Upload 3 files

Browse files
Files changed (3) hide show
  1. demo/demo.py +58 -0
  2. demo/test_digit.jpg +0 -0
  3. mnist-i-jepa.pth +3 -0
demo/demo.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import matplotlib.pyplot as plt
7
+
8
+ # Define the IJEPAModel (same as before)
9
+ class IJEPAModel(nn.Module):
10
+ def __init__(self, feature_dim=128):
11
+ super(IJEPAModel, self).__init__()
12
+ self.encoder = nn.Sequential(
13
+ nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
14
+ nn.ReLU(),
15
+ nn.MaxPool2d(kernel_size=2, stride=2),
16
+ nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
17
+ nn.ReLU(),
18
+ nn.MaxPool2d(kernel_size=2, stride=2),
19
+ nn.Flatten(),
20
+ nn.Linear(64 * 7 * 7, feature_dim)
21
+ )
22
+ self.classifier = nn.Linear(feature_dim, 10)
23
+
24
+ def forward(self, x):
25
+ x = self.encoder(x)
26
+ x = self.classifier(x)
27
+ return x
28
+
29
+ # Load the model
30
+ model = IJEPAModel()
31
+ model.load_state_dict(torch.load("mnist-i-jepa.pth"))
32
+ model.eval() # Set the model to evaluation mode
33
+
34
+ # Preprocess the input image (resize, convert to grayscale, normalize)
35
+ transform = transforms.Compose([
36
+ transforms.Grayscale(num_output_channels=1), # Ensure the image is in grayscale
37
+ transforms.Resize((28, 28)), # Resize to MNIST dimensions
38
+ transforms.ToTensor(), # Convert to tensor
39
+ transforms.Normalize((0.5,), (0.5,)) # Normalize the image
40
+ ])
41
+
42
+ # Load the test image
43
+ img = Image.open("test_digit.jpg")
44
+ img = transform(img).unsqueeze(0) # Add batch dimension
45
+
46
+ # Predict the digit
47
+ with torch.no_grad(): # Disable gradient computation for inference
48
+ output = model(img)
49
+ _, predicted = torch.max(output, 1)
50
+
51
+ # Display the result
52
+ predicted_digit = predicted.item()
53
+ print(f"Predicted digit: {predicted_digit}")
54
+
55
+ # Optionally, display the image
56
+ plt.imshow(img.squeeze(), cmap='gray')
57
+ plt.title(f"Predicted digit: {predicted_digit}")
58
+ plt.show()
demo/test_digit.jpg ADDED
mnist-i-jepa.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78383f1bc50559f90179f5bb7488998e5cf2643e8f4de5159e30fd9e26488188
3
+ size 1689980