|
import torch |
|
from torchvision.transforms import Compose, Resize, ToTensor, Normalize |
|
import numpy as np |
|
|
|
from models import SwinMattingModel |
|
|
|
|
|
class Pipeline: |
|
def __init__(self, model_name: str): |
|
self.transforms = Compose( |
|
[ |
|
Resize(size=(512, 512)), |
|
ToTensor(), |
|
Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
|
], |
|
) |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.is_torch_script = self.device.type == 'cpu' |
|
self.checkpoint = f"models/{model_name}.pt" |
|
self.model = self._load_model() |
|
|
|
self._log_device_info() |
|
|
|
def inference(self, image): |
|
if self.model is None: |
|
raise RuntimeError("Model is not loaded. Call load_model() first.") |
|
|
|
tensor = self.transforms(image).unsqueeze(0).to(self.device) |
|
with torch.inference_mode(): |
|
output = self.model(tensor) |
|
|
|
output = output.detach().cpu().numpy() |
|
output = np.clip(output, a_min=0, a_max=1) |
|
|
|
return np.squeeze(output, axis=0).squeeze() |
|
|
|
def _load_pytorch_model(self): |
|
model = SwinMattingModel({ |
|
"encoder": { |
|
"model_name": "microsoft/swin-small-patch4-window7-224" |
|
}, |
|
"decoder": { |
|
"use_attn": True, |
|
"refine_channels": 16 |
|
} |
|
}) |
|
self._load_checkpoint(model) |
|
|
|
model.to(self.device) |
|
model.eval() |
|
|
|
return model |
|
|
|
def _load_model(self): |
|
model = self._load_pytorch_model() |
|
|
|
model.to(self.device) |
|
model.eval() |
|
|
|
return model |
|
|
|
def _load_checkpoint(self, model): |
|
checkpoint = torch.load(self.checkpoint, map_location="cpu", weights_only=True) |
|
|
|
missing_keys, unexpected_keys = model.load_state_dict(checkpoint) |
|
if missing_keys: |
|
print(missing_keys) |
|
raise RuntimeError("Missing keys in checkpoint.") |
|
if unexpected_keys: |
|
print(unexpected_keys) |
|
raise RuntimeError("Unexpected keys in checkpoint.") |
|
|
|
def _log_device_info(self): |
|
if self.device.type == 'cuda': |
|
print(f"Hardware: {torch.cuda.get_device_name(torch.cuda.current_device())}") |
|
|