Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # Copyright (c) Meta Platforms, Inc. All Rights Reserved | |
| from typing import Tuple | |
| import numpy as np | |
| import torch | |
| from .clip import load as clip_load | |
| from detectron2.utils.comm import get_local_rank, synchronize | |
| def expand_box( | |
| x1: float, | |
| y1: float, | |
| x2: float, | |
| y2: float, | |
| expand_ratio: float = 1.0, | |
| max_h: int = None, | |
| max_w: int = None, | |
| ): | |
| cx = 0.5 * (x1 + x2) | |
| cy = 0.5 * (y1 + y2) | |
| w = x2 - x1 | |
| h = y2 - y1 | |
| w = w * expand_ratio | |
| h = h * expand_ratio | |
| box = [cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h] | |
| if max_h is not None: | |
| box[1] = max(0, box[1]) | |
| box[3] = min(max_h - 1, box[3]) | |
| if max_w is not None: | |
| box[0] = max(0, box[0]) | |
| box[2] = min(max_w - 1, box[2]) | |
| return [int(b) for b in box] | |
| def mask2box(mask: torch.Tensor): | |
| # use naive way | |
| row = torch.nonzero(mask.sum(dim=0))[:, 0] | |
| if len(row) == 0: | |
| return None | |
| x1 = row.min() | |
| x2 = row.max() | |
| col = np.nonzero(mask.sum(dim=1))[:, 0] | |
| y1 = col.min() | |
| y2 = col.max() | |
| return x1, y1, x2 + 1, y2 + 1 | |
| def crop_with_mask( | |
| image: torch.Tensor, | |
| mask: torch.Tensor, | |
| bbox: torch.Tensor, | |
| fill: Tuple[float, float, float] = (0, 0, 0), | |
| expand_ratio: float = 1.0, | |
| ): | |
| l, t, r, b = expand_box(*bbox, expand_ratio) | |
| _, h, w = image.shape | |
| l = max(l, 0) | |
| t = max(t, 0) | |
| r = min(r, w) | |
| b = min(b, h) | |
| new_image = torch.cat( | |
| [image.new_full((1, b - t, r - l), fill_value=val) for val in fill] | |
| ) | |
| mask_bool = mask.bool() | |
| return image[:, t:b, l:r] * mask[None, t:b, l:r] + (~ mask_bool[None, t:b, l:r]) * new_image, mask[None, t:b, l:r] | |
| def build_clip_model(model: str, mask_prompt_depth: int = 0, frozen: bool = True): | |
| rank = get_local_rank() | |
| if rank == 0: | |
| # download on rank 0 only | |
| model, _ = clip_load(model, mask_prompt_depth=mask_prompt_depth, device="cpu") | |
| synchronize() | |
| if rank != 0: | |
| model, _ = clip_load(model, mask_prompt_depth=mask_prompt_depth, device="cpu") | |
| synchronize() | |
| if frozen: | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| return model | |