Towsif7's picture
firrst commit
59e40e1
from pathlib import Path
import click
import tqdm
from carvekit.utils.image_utils import ALLOWED_SUFFIXES
from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
from carvekit.web.schemas.config import MLConfig
from carvekit.web.utils.init_utils import init_interface
from carvekit.utils.fs_utils import save_file
@click.command(
"removebg",
help="Performs background removal on specified photos using console interface.",
)
@click.option("-i", required=True, type=str, help="Path to input file or dir")
@click.option("-o", default="none", type=str, help="Path to output file or dir")
@click.option("--pre", default="none", type=str, help="Preprocessing method")
@click.option("--post", default="fba", type=str, help="Postprocessing method.")
@click.option("--net", default="tracer_b7", type=str, help="Segmentation Network")
@click.option(
"--recursive",
default=False,
type=bool,
help="Enables recursive search for images in a folder",
)
@click.option(
"--batch_size",
default=10,
type=int,
help="Batch Size for list of images to be loaded to RAM",
)
@click.option(
"--batch_size_seg",
default=5,
type=int,
help="Batch size for list of images to be processed by segmentation " "network",
)
@click.option(
"--batch_size_mat",
default=1,
type=int,
help="Batch size for list of images to be processed by matting " "network",
)
@click.option(
"--seg_mask_size",
default=640,
type=int,
help="The size of the input image for the segmentation neural network.",
)
@click.option(
"--matting_mask_size",
default=2048,
type=int,
help="The size of the input image for the matting neural network.",
)
@click.option(
"--trimap_dilation",
default=30,
type=int,
help="The size of the offset radius from the object mask in "
"pixels when forming an unknown area",
)
@click.option(
"--trimap_erosion",
default=5,
type=int,
help="The number of iterations of erosion that the object's "
"mask will be subjected to before forming an unknown area",
)
@click.option(
"--trimap_prob_threshold",
default=231,
type=int,
help="Probability threshold at which the prob_filter "
"and prob_as_unknown_area operations will be "
"applied",
)
@click.option("--device", default="cpu", type=str, help="Processing Device.")
@click.option(
"--fp16", default=False, type=bool, help="Enables mixed precision processing."
)
def removebg(
i: str,
o: str,
pre: str,
post: str,
net: str,
recursive: bool,
batch_size: int,
batch_size_seg: int,
batch_size_mat: int,
seg_mask_size: int,
matting_mask_size: int,
device: str,
fp16: bool,
trimap_dilation: int,
trimap_erosion: int,
trimap_prob_threshold: int,
):
out_path = Path(o)
input_path = Path(i)
if input_path.is_dir():
if recursive:
all_images = input_path.rglob("*.*")
else:
all_images = input_path.glob("*.*")
all_images = [
i
for i in all_images
if i.suffix.lower() in ALLOWED_SUFFIXES and "_bg_removed" not in i.name
]
else:
all_images = [input_path]
interface_config = MLConfig(
segmentation_network=net,
preprocessing_method=pre,
postprocessing_method=post,
device=device,
batch_size_seg=batch_size_seg,
batch_size_matting=batch_size_mat,
seg_mask_size=seg_mask_size,
matting_mask_size=matting_mask_size,
fp16=fp16,
trimap_dilation=trimap_dilation,
trimap_erosion=trimap_erosion,
trimap_prob_threshold=trimap_prob_threshold,
)
interface = init_interface(interface_config)
for image_batch in tqdm.tqdm(
batch_generator(all_images, n=batch_size),
total=int(len(all_images) / batch_size),
desc="Removing background",
unit=" image batch",
colour="blue",
):
images_without_background = interface(image_batch) # Remove background
thread_pool_processing(
lambda x: save_file(out_path, image_batch[x], images_without_background[x]),
range((len(image_batch))),
) # Drop images to fs
if __name__ == "__main__":
removebg()