Spaces:
Sleeping
Sleeping
import torch | |
import torchvision.transforms as T | |
from typing import Optional | |
from src.dataset import generate_image | |
from src.models import CrossAttentionClassifier, VGGLikeEncode | |
class CrossAttentionInference: | |
def __init__( | |
self, | |
model_path: str, | |
shape_params: Optional[dict] = None, | |
device: torch.device = torch.device("cpu"), | |
): | |
if not shape_params: | |
self.shape_params = {} | |
else: | |
self.shape_params = shape_params | |
self.device = device | |
self.encoder = VGGLikeEncode( | |
in_channels=1, | |
out_channels=128, | |
feature_dim=32, | |
apply_pooling=False | |
) | |
self.model = CrossAttentionClassifier(encoder=self.encoder) | |
state_dict = torch.load(model_path, map_location=device) | |
self.model.load_state_dict(state_dict) | |
self.model.eval() | |
self.model.to(device) | |
self.transform = T.Compose([ | |
T.ToTensor(), | |
T.Normalize(mean=(0.5,), std=(0.5,)) | |
]) | |
def pil_to_tensor(self, img): | |
return self.transform(img).unsqueeze(0).to(self.device) | |
def predict_random_pair(self): | |
img1, _ = generate_image(**self.shape_params) | |
img2, _ = generate_image(**self.shape_params) | |
img1_tensor = self.pil_to_tensor(img1) | |
img2_tensor = self.pil_to_tensor(img2) | |
with torch.no_grad(): | |
logits, _ = self.model(img1_tensor, img2_tensor) | |
preds = (torch.sigmoid(logits) > 0.5).float() | |
predicted_label = int(preds.item()) | |
return predicted_label, (img1, img2) | |