Towsif7's picture
firrst commit
59e40e1
"""
Source url: https://github.com/OPHoperHPO/image-background-remove-tool
Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO].
License: Apache License 2.0
"""
import cv2
import numpy as np
from PIL import Image
def prob_filter(mask: Image.Image, prob_threshold=231) -> Image.Image:
"""
Applies a filter to the mask by the probability of locating an object in the object area.
Args:
prob_threshold: Threshold of probability for mark area as background.
mask: Predicted object mask
Raises:
ValueError if mask or trimap has wrong color mode
Returns:
Generated trimap for image.
"""
if mask.mode != "L":
raise ValueError("Input mask has wrong color mode.")
# noinspection PyTypeChecker
mask_array = np.array(mask)
mask_array[mask_array > prob_threshold] = 255 # Probability filter for mask
mask_array[mask_array <= prob_threshold] = 0
return Image.fromarray(mask_array).convert("L")
def prob_as_unknown_area(
trimap: Image.Image, mask: Image.Image, prob_threshold=255
) -> Image.Image:
"""
Marks any uncertainty in the seg mask as an unknown region.
Args:
prob_threshold: Threshold of probability for mark area as unknown.
trimap: Generated trimap.
mask: Predicted object mask
Raises:
ValueError if mask or trimap has wrong color mode
Returns:
Generated trimap for image.
"""
if mask.mode != "L" or trimap.mode != "L":
raise ValueError("Input mask has wrong color mode.")
# noinspection PyTypeChecker
mask_array = np.array(mask)
# noinspection PyTypeChecker
trimap_array = np.array(trimap)
trimap_array[np.logical_and(mask_array <= prob_threshold, mask_array > 0)] = 127
return Image.fromarray(trimap_array).convert("L")
def post_erosion(trimap: Image.Image, erosion_iters=1) -> Image.Image:
"""
Performs erosion on the mask and marks the resulting area as an unknown region.
Args:
erosion_iters: The number of iterations of erosion that
the object's mask will be subjected to before forming an unknown area
trimap: Generated trimap.
mask: Predicted object mask
Returns:
Generated trimap for image.
"""
if trimap.mode != "L":
raise ValueError("Input mask has wrong color mode.")
# noinspection PyTypeChecker
trimap_array = np.array(trimap)
if erosion_iters > 0:
without_unknown_area = trimap_array.copy()
without_unknown_area[without_unknown_area == 127] = 0
erosion_kernel = np.ones((3, 3), np.uint8)
erode = cv2.erode(
without_unknown_area, erosion_kernel, iterations=erosion_iters
)
erode = np.where(erode == 0, 0, without_unknown_area)
trimap_array[np.logical_and(erode == 0, without_unknown_area > 0)] = 127
erode = trimap_array.copy()
else:
erode = trimap_array.copy()
return Image.fromarray(erode).convert("L")