Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from omegaconf import OmegaConf | |
| import sys | |
| sys.path.append('../') | |
| from tracker.config import CONFIG | |
| from tracker.model.cutie import CUTIE | |
| from tracker.inference.inference_core import InferenceCore | |
| from tracker.utils.mask_mapper import MaskMapper | |
| from tools.painter import mask_painter | |
| class BaseTracker: | |
| def __init__(self, cutie_checkpoint, device) -> None: | |
| """ | |
| device: model device | |
| cutie_checkpoint: checkpoint of XMem model | |
| """ | |
| config = OmegaConf.create(CONFIG) | |
| # initialise XMem | |
| network = CUTIE(config).to(device).eval() | |
| model_weights = torch.load(cutie_checkpoint, map_location=device) | |
| network.load_weights(model_weights) | |
| # initialise IncerenceCore | |
| self.tracker = InferenceCore(network, config) | |
| self.device = device | |
| # changable properties | |
| self.mapper = MaskMapper() | |
| self.initialised = False | |
| def resize_mask(self, mask): | |
| # mask transform is applied AFTER mapper, so we need to post-process it in eval.py | |
| h, w = mask.shape[-2:] | |
| min_hw = min(h, w) | |
| return F.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)), | |
| mode='nearest') | |
| def image_to_torch(self, frame: np.ndarray, device: str = 'cuda'): | |
| # frame: H*W*3 numpy array | |
| frame = frame.transpose(2, 0, 1) | |
| frame = torch.from_numpy(frame).float().to(device, non_blocking=True) / 255 | |
| return frame | |
| def track(self, frame, first_frame_annotation=None): | |
| """ | |
| Input: | |
| frames: numpy arrays (H, W, 3) | |
| logit: numpy array (H, W), logit | |
| Output: | |
| mask: numpy arrays (H, W) | |
| logit: numpy arrays, probability map (H, W) | |
| painted_image: numpy array (H, W, 3) | |
| """ | |
| if first_frame_annotation is not None: # first frame mask | |
| # initialisation | |
| mask, labels = self.mapper.convert_mask(first_frame_annotation) | |
| mask = torch.Tensor(mask).to(self.device) | |
| else: | |
| mask = None | |
| labels = None | |
| # prepare inputs | |
| frame_tensor = self.image_to_torch(frame, self.device) | |
| # track one frame | |
| probs = self.tracker.step(frame_tensor, mask, labels) # logits 2 (bg fg) H W | |
| # convert to mask | |
| out_mask = torch.argmax(probs, dim=0) | |
| out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8) | |
| final_mask = np.zeros_like(out_mask) | |
| # map back | |
| for k, v in self.mapper.remappings.items(): | |
| final_mask[out_mask == v] = k | |
| num_objs = final_mask.max() | |
| painted_image = frame | |
| for obj in range(1, num_objs+1): | |
| if np.max(final_mask==obj) == 0: | |
| continue | |
| painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1) | |
| return final_mask, final_mask, painted_image | |
| def clear_memory(self): | |
| self.tracker.clear_memory() | |
| self.mapper.clear_labels() | |
| torch.cuda.empty_cache() |