mnist-i-jepa / demo /demo.py
cmaeti's picture
Upload 3 files
31ed87f verified
raw
history blame
1.94 kB
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
# Define the IJEPAModel (same as before)
class IJEPAModel(nn.Module):
def __init__(self, feature_dim=128):
super(IJEPAModel, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(64 * 7 * 7, feature_dim)
)
self.classifier = nn.Linear(feature_dim, 10)
def forward(self, x):
x = self.encoder(x)
x = self.classifier(x)
return x
# Load the model
model = IJEPAModel()
model.load_state_dict(torch.load("mnist-i-jepa.pth"))
model.eval() # Set the model to evaluation mode
# Preprocess the input image (resize, convert to grayscale, normalize)
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1), # Ensure the image is in grayscale
transforms.Resize((28, 28)), # Resize to MNIST dimensions
transforms.ToTensor(), # Convert to tensor
transforms.Normalize((0.5,), (0.5,)) # Normalize the image
])
# Load the test image
img = Image.open("test_digit.jpg")
img = transform(img).unsqueeze(0) # Add batch dimension
# Predict the digit
with torch.no_grad(): # Disable gradient computation for inference
output = model(img)
_, predicted = torch.max(output, 1)
# Display the result
predicted_digit = predicted.item()
print(f"Predicted digit: {predicted_digit}")
# Optionally, display the image
plt.imshow(img.squeeze(), cmap='gray')
plt.title(f"Predicted digit: {predicted_digit}")
plt.show()