import torch import torchvision.transforms as T from typing import Optional from src.dataset import generate_image from src.models import CrossAttentionClassifier, VGGLikeEncode class CrossAttentionInference: def __init__( self, model_path: str, shape_params: Optional[dict] = None, device: torch.device = torch.device("cpu"), ): if not shape_params: self.shape_params = {} else: self.shape_params = shape_params self.device = device self.encoder = VGGLikeEncode( in_channels=1, out_channels=128, feature_dim=32, apply_pooling=False ) self.model = CrossAttentionClassifier(encoder=self.encoder) state_dict = torch.load(model_path, map_location=device) self.model.load_state_dict(state_dict) self.model.eval() self.model.to(device) self.transform = T.Compose([ T.ToTensor(), T.Normalize(mean=(0.5,), std=(0.5,)) ]) def pil_to_tensor(self, img): return self.transform(img).unsqueeze(0).to(self.device) def predict_random_pair(self): img1, _ = generate_image(**self.shape_params) img2, _ = generate_image(**self.shape_params) img1_tensor = self.pil_to_tensor(img1) img2_tensor = self.pil_to_tensor(img2) with torch.no_grad(): logits, _ = self.model(img1_tensor, img2_tensor) preds = (torch.sigmoid(logits) > 0.5).float() predicted_label = int(preds.item()) return predicted_label, (img1, img2)