File size: 4,321 Bytes
59e40e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from pathlib import Path

import click
import tqdm

from carvekit.utils.image_utils import ALLOWED_SUFFIXES
from carvekit.utils.pool_utils import batch_generator, thread_pool_processing
from carvekit.web.schemas.config import MLConfig
from carvekit.web.utils.init_utils import init_interface
from carvekit.utils.fs_utils import save_file


@click.command(
    "removebg",
    help="Performs background removal on specified photos using console interface.",
)
@click.option("-i", required=True, type=str, help="Path to input file or dir")
@click.option("-o", default="none", type=str, help="Path to output file or dir")
@click.option("--pre", default="none", type=str, help="Preprocessing method")
@click.option("--post", default="fba", type=str, help="Postprocessing method.")
@click.option("--net", default="tracer_b7", type=str, help="Segmentation Network")
@click.option(
    "--recursive",
    default=False,
    type=bool,
    help="Enables recursive search for images in a folder",
)
@click.option(
    "--batch_size",
    default=10,
    type=int,
    help="Batch Size for list of images to be loaded to RAM",
)
@click.option(
    "--batch_size_seg",
    default=5,
    type=int,
    help="Batch size for list of images to be processed by segmentation " "network",
)
@click.option(
    "--batch_size_mat",
    default=1,
    type=int,
    help="Batch size for list of images to be processed by matting " "network",
)
@click.option(
    "--seg_mask_size",
    default=640,
    type=int,
    help="The size of the input image for the segmentation neural network.",
)
@click.option(
    "--matting_mask_size",
    default=2048,
    type=int,
    help="The size of the input image for the matting neural network.",
)
@click.option(
    "--trimap_dilation",
    default=30,
    type=int,
    help="The size of the offset radius from the object mask in "
    "pixels when forming an unknown area",
)
@click.option(
    "--trimap_erosion",
    default=5,
    type=int,
    help="The number of iterations of erosion that the object's "
    "mask will be subjected to before forming an unknown area",
)
@click.option(
    "--trimap_prob_threshold",
    default=231,
    type=int,
    help="Probability threshold at which the prob_filter "
    "and prob_as_unknown_area operations will be "
    "applied",
)
@click.option("--device", default="cpu", type=str, help="Processing Device.")
@click.option(
    "--fp16", default=False, type=bool, help="Enables mixed precision processing."
)
def removebg(
    i: str,
    o: str,
    pre: str,
    post: str,
    net: str,
    recursive: bool,
    batch_size: int,
    batch_size_seg: int,
    batch_size_mat: int,
    seg_mask_size: int,
    matting_mask_size: int,
    device: str,
    fp16: bool,
    trimap_dilation: int,
    trimap_erosion: int,
    trimap_prob_threshold: int,
):
    out_path = Path(o)
    input_path = Path(i)
    if input_path.is_dir():
        if recursive:
            all_images = input_path.rglob("*.*")
        else:
            all_images = input_path.glob("*.*")
        all_images = [
            i
            for i in all_images
            if i.suffix.lower() in ALLOWED_SUFFIXES and "_bg_removed" not in i.name
        ]
    else:
        all_images = [input_path]

    interface_config = MLConfig(
        segmentation_network=net,
        preprocessing_method=pre,
        postprocessing_method=post,
        device=device,
        batch_size_seg=batch_size_seg,
        batch_size_matting=batch_size_mat,
        seg_mask_size=seg_mask_size,
        matting_mask_size=matting_mask_size,
        fp16=fp16,
        trimap_dilation=trimap_dilation,
        trimap_erosion=trimap_erosion,
        trimap_prob_threshold=trimap_prob_threshold,
    )

    interface = init_interface(interface_config)

    for image_batch in tqdm.tqdm(
        batch_generator(all_images, n=batch_size),
        total=int(len(all_images) / batch_size),
        desc="Removing background",
        unit=" image batch",
        colour="blue",
    ):
        images_without_background = interface(image_batch)  # Remove background
        thread_pool_processing(
            lambda x: save_file(out_path, image_batch[x], images_without_background[x]),
            range((len(image_batch))),
        )  # Drop images to fs


if __name__ == "__main__":
    removebg()