efficientnetv2-m-e6 / inference.py
Thouph's picture
Upload 3 files
02370bc
raw
history blame
1.32 kB
import time
from PIL import Image
from timm.data import resolve_data_config
import torch
from torchvision.transforms import transforms
model = torch.load('path/to/model.pth')
model.eval()
config = resolve_data_config({}, model=model)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize image
])
with open("tags.txt", "r") as f:
categories = [s.strip() for s in f.readlines()]
categories=sorted(categories)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
images=["your_image_here.jpg", "your_second_image_here.jpg"]
for item in images:
start = time.time()
img = Image.open(item).convert('RGB')
tensor = transform(img).unsqueeze(0).to(device) # transform and add batch dimension
with torch.no_grad():
out = model(tensor)
probabilities = torch.nn.functional.sigmoid(out[0])
print(probabilities.shape)
top10_prob, top10_catid = torch.topk(probabilities, 10)
for i in range(top10_prob.size(0)):
print(categories[top10_catid[i]], top10_prob[i].item())
end = time.time()
print(f'Executed in {end - start} seconds')
print("\n\n", end="")