Svane20's picture
Updated model to use PyTorch instead of ONNX
d3ba20a
import torch
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
import numpy as np
from models import SwinMattingModel
class Pipeline:
def __init__(self, model_name: str):
self.transforms = Compose(
[
Resize(size=(512, 512)),
ToTensor(),
Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
],
)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.is_torch_script = self.device.type == 'cpu'
self.checkpoint = f"models/{model_name}.pt"
self.model = self._load_model()
self._log_device_info()
def inference(self, image):
if self.model is None:
raise RuntimeError("Model is not loaded. Call load_model() first.")
tensor = self.transforms(image).unsqueeze(0).to(self.device)
with torch.inference_mode():
output = self.model(tensor)
output = output.detach().cpu().numpy()
output = np.clip(output, a_min=0, a_max=1)
return np.squeeze(output, axis=0).squeeze()
def _load_pytorch_model(self):
model = SwinMattingModel({
"encoder": {
"model_name": "microsoft/swin-small-patch4-window7-224"
},
"decoder": {
"use_attn": True,
"refine_channels": 16
}
})
self._load_checkpoint(model)
model.to(self.device)
model.eval()
return model
def _load_model(self):
model = self._load_pytorch_model()
model.to(self.device)
model.eval()
return model
def _load_checkpoint(self, model):
checkpoint = torch.load(self.checkpoint, map_location="cpu", weights_only=True)
missing_keys, unexpected_keys = model.load_state_dict(checkpoint)
if missing_keys:
print(missing_keys)
raise RuntimeError("Missing keys in checkpoint.")
if unexpected_keys:
print(unexpected_keys)
raise RuntimeError("Unexpected keys in checkpoint.")
def _log_device_info(self):
if self.device.type == 'cuda':
print(f"Hardware: {torch.cuda.get_device_name(torch.cuda.current_device())}")