Spaces:
Build error
Build error
Commit
·
d9697ef
1
Parent(s):
199c85f
added postprocessing
Browse files- app/sam/postprocess.py +21 -0
app/sam/postprocess.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import numpy.typing as npt
|
| 5 |
+
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
from kornia.morphology import erosion, dilation
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def clean_mask_torch(mask: Tensor) -> Tensor:
|
| 11 |
+
kernel = torch.ones(2, 2).to(mask.device)
|
| 12 |
+
if len(mask.shape) == 2:
|
| 13 |
+
mask = mask[None, None, :, :]
|
| 14 |
+
if mask.dtype == torch.bool:
|
| 15 |
+
mask = mask.int()
|
| 16 |
+
return dilation(erosion(mask, kernel), kernel)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def clean_mask_np(mask: npt.NDArray) -> npt.NDArray:
|
| 20 |
+
kernel = np.ones((2, 2), np.uint8)
|
| 21 |
+
return cv2.dilate(cv2.erode(mask, kernel, iterations=1), kernel, iterations=1)
|