cels / src /inference.py
alexandraroze's picture
solution
50bd1fc
import torch
import torchvision.transforms as T
from typing import Optional
from src.dataset import generate_image
from src.models import CrossAttentionClassifier, VGGLikeEncode
class CrossAttentionInference:
def __init__(
self,
model_path: str,
shape_params: Optional[dict] = None,
device: torch.device = torch.device("cpu"),
):
if not shape_params:
self.shape_params = {}
else:
self.shape_params = shape_params
self.device = device
self.encoder = VGGLikeEncode(
in_channels=1,
out_channels=128,
feature_dim=32,
apply_pooling=False
)
self.model = CrossAttentionClassifier(encoder=self.encoder)
state_dict = torch.load(model_path, map_location=device)
self.model.load_state_dict(state_dict)
self.model.eval()
self.model.to(device)
self.transform = T.Compose([
T.ToTensor(),
T.Normalize(mean=(0.5,), std=(0.5,))
])
def pil_to_tensor(self, img):
return self.transform(img).unsqueeze(0).to(self.device)
def predict_random_pair(self):
img1, _ = generate_image(**self.shape_params)
img2, _ = generate_image(**self.shape_params)
img1_tensor = self.pil_to_tensor(img1)
img2_tensor = self.pil_to_tensor(img2)
with torch.no_grad():
logits, _ = self.model(img1_tensor, img2_tensor)
preds = (torch.sigmoid(logits) > 0.5).float()
predicted_label = int(preds.item())
return predicted_label, (img1, img2)