Spaces:
Runtime error
Runtime error
Update
Browse files- anime-girl.jpg +0 -0
- app.py +4 -5
- rembg/_version.py +3 -3
- rembg/bg.py +14 -1
- rembg/commands/b_command.py +161 -0
- rembg/commands/s_command.py +58 -11
- rembg/session_factory.py +4 -2
- rembg/sessions/base.py +24 -2
- rembg/sessions/dis_anime.py +49 -0
- rembg/sessions/dis_general_use.py +49 -0
- rembg/sessions/sam.py +8 -4
- rembg/sessions/silueta.py +4 -2
- rembg/sessions/u2net.py +4 -2
- rembg/sessions/u2net_cloth_seg.py +4 -2
- rembg/sessions/u2net_human_seg.py +4 -2
- rembg/sessions/u2netp.py +4 -2
anime-girl.jpg
ADDED
|
app.py
CHANGED
|
@@ -5,7 +5,7 @@ import gradio as gr
|
|
| 5 |
import os
|
| 6 |
import cv2
|
| 7 |
|
| 8 |
-
def inference(file,
|
| 9 |
im = cv2.imread(file, cv2.IMREAD_COLOR)
|
| 10 |
cv2.imwrite(os.path.join("input.png"), im)
|
| 11 |
|
|
@@ -20,7 +20,6 @@ def inference(file, af, mask, model):
|
|
| 20 |
output = remove(
|
| 21 |
input,
|
| 22 |
session = new_session(model),
|
| 23 |
-
alpha_matting_erode_size = af,
|
| 24 |
only_mask = (True if mask == "Mask only" else False)
|
| 25 |
)
|
| 26 |
|
|
@@ -38,7 +37,6 @@ gr.Interface(
|
|
| 38 |
inference,
|
| 39 |
[
|
| 40 |
gr.inputs.Image(type="filepath", label="Input"),
|
| 41 |
-
gr.inputs.Slider(10, 25, default=10, label="Alpha matting erode size"),
|
| 42 |
gr.inputs.Radio(
|
| 43 |
[
|
| 44 |
"Default",
|
|
@@ -55,10 +53,11 @@ gr.Interface(
|
|
| 55 |
"u2net_cloth_seg",
|
| 56 |
"silueta",
|
| 57 |
"isnet-general-use",
|
|
|
|
| 58 |
"sam",
|
| 59 |
],
|
| 60 |
type="value",
|
| 61 |
-
default="
|
| 62 |
label="Models"
|
| 63 |
),
|
| 64 |
],
|
|
@@ -66,6 +65,6 @@ gr.Interface(
|
|
| 66 |
title=title,
|
| 67 |
description=description,
|
| 68 |
article=article,
|
| 69 |
-
examples=[["lion.png",
|
| 70 |
enable_queue=True
|
| 71 |
).launch()
|
|
|
|
| 5 |
import os
|
| 6 |
import cv2
|
| 7 |
|
| 8 |
+
def inference(file, mask, model):
|
| 9 |
im = cv2.imread(file, cv2.IMREAD_COLOR)
|
| 10 |
cv2.imwrite(os.path.join("input.png"), im)
|
| 11 |
|
|
|
|
| 20 |
output = remove(
|
| 21 |
input,
|
| 22 |
session = new_session(model),
|
|
|
|
| 23 |
only_mask = (True if mask == "Mask only" else False)
|
| 24 |
)
|
| 25 |
|
|
|
|
| 37 |
inference,
|
| 38 |
[
|
| 39 |
gr.inputs.Image(type="filepath", label="Input"),
|
|
|
|
| 40 |
gr.inputs.Radio(
|
| 41 |
[
|
| 42 |
"Default",
|
|
|
|
| 53 |
"u2net_cloth_seg",
|
| 54 |
"silueta",
|
| 55 |
"isnet-general-use",
|
| 56 |
+
"isnet-anime",
|
| 57 |
"sam",
|
| 58 |
],
|
| 59 |
type="value",
|
| 60 |
+
default="isnet-general-use",
|
| 61 |
label="Models"
|
| 62 |
),
|
| 63 |
],
|
|
|
|
| 65 |
title=title,
|
| 66 |
description=description,
|
| 67 |
article=article,
|
| 68 |
+
examples=[["lion.png", "Default", "u2net"], ["girl.jpg", "Default", "u2net"], ["anime-girl.jpg", "Default", "isnet-anime"]],
|
| 69 |
enable_queue=True
|
| 70 |
).launch()
|
rembg/_version.py
CHANGED
|
@@ -23,9 +23,9 @@ def get_keywords():
|
|
| 23 |
# setup.py/versioneer.py will grep for the variable names, so they must
|
| 24 |
# each be defined on a line of their own. _version.py will just call
|
| 25 |
# get_keywords().
|
| 26 |
-
git_refnames = " (HEAD -> main)"
|
| 27 |
-
git_full = "
|
| 28 |
-
git_date = "2023-
|
| 29 |
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
|
| 30 |
return keywords
|
| 31 |
|
|
|
|
| 23 |
# setup.py/versioneer.py will grep for the variable names, so they must
|
| 24 |
# each be defined on a line of their own. _version.py will just call
|
| 25 |
# get_keywords().
|
| 26 |
+
git_refnames = " (HEAD -> main, tag: v2.0.43)"
|
| 27 |
+
git_full = "848a38e4cc5cf41522974dea00848596105b1dfa"
|
| 28 |
+
git_date = "2023-06-02 09:20:57 -0300"
|
| 29 |
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
|
| 30 |
return keywords
|
| 31 |
|
rembg/bg.py
CHANGED
|
@@ -11,7 +11,7 @@ from cv2 import (
|
|
| 11 |
getStructuringElement,
|
| 12 |
morphologyEx,
|
| 13 |
)
|
| 14 |
-
from PIL import Image
|
| 15 |
from PIL.Image import Image as PILImage
|
| 16 |
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
| 17 |
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
|
@@ -19,6 +19,7 @@ from pymatting.util.util import stack_images
|
|
| 19 |
from scipy.ndimage import binary_erosion
|
| 20 |
|
| 21 |
from .session_factory import new_session
|
|
|
|
| 22 |
from .sessions.base import BaseSession
|
| 23 |
|
| 24 |
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
|
|
@@ -113,6 +114,15 @@ def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> P
|
|
| 113 |
return colored_image
|
| 114 |
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
def remove(
|
| 117 |
data: Union[bytes, PILImage, np.ndarray],
|
| 118 |
alpha_matting: bool = False,
|
|
@@ -138,6 +148,9 @@ def remove(
|
|
| 138 |
else:
|
| 139 |
raise ValueError("Input type {} is not supported.".format(type(data)))
|
| 140 |
|
|
|
|
|
|
|
|
|
|
| 141 |
if session is None:
|
| 142 |
session = new_session("u2net", *args, **kwargs)
|
| 143 |
|
|
|
|
| 11 |
getStructuringElement,
|
| 12 |
morphologyEx,
|
| 13 |
)
|
| 14 |
+
from PIL import Image, ImageOps
|
| 15 |
from PIL.Image import Image as PILImage
|
| 16 |
from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf
|
| 17 |
from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
|
|
|
|
| 19 |
from scipy.ndimage import binary_erosion
|
| 20 |
|
| 21 |
from .session_factory import new_session
|
| 22 |
+
from .sessions import sessions_class
|
| 23 |
from .sessions.base import BaseSession
|
| 24 |
|
| 25 |
kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3))
|
|
|
|
| 114 |
return colored_image
|
| 115 |
|
| 116 |
|
| 117 |
+
def fix_image_orientation(img: PILImage) -> PILImage:
|
| 118 |
+
return ImageOps.exif_transpose(img)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def download_models() -> None:
|
| 122 |
+
for session in sessions_class:
|
| 123 |
+
session.download_models()
|
| 124 |
+
|
| 125 |
+
|
| 126 |
def remove(
|
| 127 |
data: Union[bytes, PILImage, np.ndarray],
|
| 128 |
alpha_matting: bool = False,
|
|
|
|
| 148 |
else:
|
| 149 |
raise ValueError("Input type {} is not supported.".format(type(data)))
|
| 150 |
|
| 151 |
+
# Fix image orientation
|
| 152 |
+
img = fix_image_orientation(img)
|
| 153 |
+
|
| 154 |
if session is None:
|
| 155 |
session = new_session("u2net", *args, **kwargs)
|
| 156 |
|
rembg/commands/b_command.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import io
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
from typing import IO
|
| 7 |
+
|
| 8 |
+
import click
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
from ..bg import remove
|
| 12 |
+
from ..session_factory import new_session
|
| 13 |
+
from ..sessions import sessions_names
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@click.command(
|
| 17 |
+
name="b",
|
| 18 |
+
help="for a byte stream as input",
|
| 19 |
+
)
|
| 20 |
+
@click.option(
|
| 21 |
+
"-m",
|
| 22 |
+
"--model",
|
| 23 |
+
default="u2net",
|
| 24 |
+
type=click.Choice(sessions_names),
|
| 25 |
+
show_default=True,
|
| 26 |
+
show_choices=True,
|
| 27 |
+
help="model name",
|
| 28 |
+
)
|
| 29 |
+
@click.option(
|
| 30 |
+
"-a",
|
| 31 |
+
"--alpha-matting",
|
| 32 |
+
is_flag=True,
|
| 33 |
+
show_default=True,
|
| 34 |
+
help="use alpha matting",
|
| 35 |
+
)
|
| 36 |
+
@click.option(
|
| 37 |
+
"-af",
|
| 38 |
+
"--alpha-matting-foreground-threshold",
|
| 39 |
+
default=240,
|
| 40 |
+
type=int,
|
| 41 |
+
show_default=True,
|
| 42 |
+
help="trimap fg threshold",
|
| 43 |
+
)
|
| 44 |
+
@click.option(
|
| 45 |
+
"-ab",
|
| 46 |
+
"--alpha-matting-background-threshold",
|
| 47 |
+
default=10,
|
| 48 |
+
type=int,
|
| 49 |
+
show_default=True,
|
| 50 |
+
help="trimap bg threshold",
|
| 51 |
+
)
|
| 52 |
+
@click.option(
|
| 53 |
+
"-ae",
|
| 54 |
+
"--alpha-matting-erode-size",
|
| 55 |
+
default=10,
|
| 56 |
+
type=int,
|
| 57 |
+
show_default=True,
|
| 58 |
+
help="erode size",
|
| 59 |
+
)
|
| 60 |
+
@click.option(
|
| 61 |
+
"-om",
|
| 62 |
+
"--only-mask",
|
| 63 |
+
is_flag=True,
|
| 64 |
+
show_default=True,
|
| 65 |
+
help="output only the mask",
|
| 66 |
+
)
|
| 67 |
+
@click.option(
|
| 68 |
+
"-ppm",
|
| 69 |
+
"--post-process-mask",
|
| 70 |
+
is_flag=True,
|
| 71 |
+
show_default=True,
|
| 72 |
+
help="post process the mask",
|
| 73 |
+
)
|
| 74 |
+
@click.option(
|
| 75 |
+
"-bgc",
|
| 76 |
+
"--bgcolor",
|
| 77 |
+
default=None,
|
| 78 |
+
type=(int, int, int, int),
|
| 79 |
+
nargs=4,
|
| 80 |
+
help="Background color (R G B A) to replace the removed background with",
|
| 81 |
+
)
|
| 82 |
+
@click.option("-x", "--extras", type=str)
|
| 83 |
+
@click.option(
|
| 84 |
+
"-o",
|
| 85 |
+
"--output_specifier",
|
| 86 |
+
type=str,
|
| 87 |
+
help="printf-style specifier for output filenames (e.g. 'output-%d.png'))",
|
| 88 |
+
)
|
| 89 |
+
@click.argument(
|
| 90 |
+
"image_width",
|
| 91 |
+
type=int,
|
| 92 |
+
)
|
| 93 |
+
@click.argument(
|
| 94 |
+
"image_height",
|
| 95 |
+
type=int,
|
| 96 |
+
)
|
| 97 |
+
def rs_command(
|
| 98 |
+
model: str,
|
| 99 |
+
extras: str,
|
| 100 |
+
image_width: int,
|
| 101 |
+
image_height: int,
|
| 102 |
+
output_specifier: str,
|
| 103 |
+
**kwargs
|
| 104 |
+
) -> None:
|
| 105 |
+
try:
|
| 106 |
+
kwargs.update(json.loads(extras))
|
| 107 |
+
except Exception:
|
| 108 |
+
pass
|
| 109 |
+
|
| 110 |
+
session = new_session(model)
|
| 111 |
+
bytes_per_img = image_width * image_height * 3
|
| 112 |
+
|
| 113 |
+
if output_specifier:
|
| 114 |
+
output_dir = os.path.dirname(
|
| 115 |
+
os.path.abspath(os.path.expanduser(output_specifier))
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
if not os.path.isdir(output_dir):
|
| 119 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 120 |
+
|
| 121 |
+
def img_to_byte_array(img: Image) -> bytes:
|
| 122 |
+
buff = io.BytesIO()
|
| 123 |
+
img.save(buff, format="PNG")
|
| 124 |
+
return buff.getvalue()
|
| 125 |
+
|
| 126 |
+
async def connect_stdin_stdout():
|
| 127 |
+
loop = asyncio.get_event_loop()
|
| 128 |
+
reader = asyncio.StreamReader()
|
| 129 |
+
protocol = asyncio.StreamReaderProtocol(reader)
|
| 130 |
+
|
| 131 |
+
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
|
| 132 |
+
w_transport, w_protocol = await loop.connect_write_pipe(
|
| 133 |
+
asyncio.streams.FlowControlMixin, sys.stdout
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop)
|
| 137 |
+
return reader, writer
|
| 138 |
+
|
| 139 |
+
async def main():
|
| 140 |
+
reader, writer = await connect_stdin_stdout()
|
| 141 |
+
|
| 142 |
+
idx = 0
|
| 143 |
+
while True:
|
| 144 |
+
try:
|
| 145 |
+
img_bytes = await reader.readexactly(bytes_per_img)
|
| 146 |
+
if not img_bytes:
|
| 147 |
+
break
|
| 148 |
+
|
| 149 |
+
img = Image.frombytes("RGB", (image_width, image_height), img_bytes)
|
| 150 |
+
output = remove(img, session=session, **kwargs)
|
| 151 |
+
|
| 152 |
+
if output_specifier:
|
| 153 |
+
output.save((output_specifier % idx), format="PNG")
|
| 154 |
+
else:
|
| 155 |
+
writer.write(img_to_byte_array(output))
|
| 156 |
+
|
| 157 |
+
idx += 1
|
| 158 |
+
except asyncio.IncompleteReadError:
|
| 159 |
+
break
|
| 160 |
+
|
| 161 |
+
asyncio.run(main())
|
rembg/commands/s_command.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
| 1 |
import json
|
| 2 |
-
|
|
|
|
|
|
|
| 3 |
|
| 4 |
import aiohttp
|
| 5 |
import click
|
|
|
|
| 6 |
import uvicorn
|
| 7 |
from asyncer import asyncify
|
| 8 |
from fastapi import Depends, FastAPI, File, Form, Query
|
|
@@ -70,6 +73,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
| 70 |
"url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
|
| 71 |
},
|
| 72 |
openapi_tags=tags_metadata,
|
|
|
|
| 73 |
)
|
| 74 |
|
| 75 |
app.add_middleware(
|
|
@@ -83,10 +87,10 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
| 83 |
class CommonQueryParams:
|
| 84 |
def __init__(
|
| 85 |
self,
|
| 86 |
-
model:
|
| 87 |
-
str, Query(regex=r"(" + "|".join(sessions_names) + ")")
|
| 88 |
-
] = Query(
|
| 89 |
description="Model to use when processing image",
|
|
|
|
|
|
|
| 90 |
),
|
| 91 |
a: bool = Query(default=False, description="Enable Alpha Matting"),
|
| 92 |
af: int = Query(
|
|
@@ -128,10 +132,10 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
| 128 |
class CommonQueryPostParams:
|
| 129 |
def __init__(
|
| 130 |
self,
|
| 131 |
-
model:
|
| 132 |
-
str, Form(regex=r"(" + "|".join(sessions_names) + ")")
|
| 133 |
-
] = Form(
|
| 134 |
description="Model to use when processing image",
|
|
|
|
|
|
|
| 135 |
),
|
| 136 |
a: bool = Form(default=False, description="Enable Alpha Matting"),
|
| 137 |
af: int = Form(
|
|
@@ -190,13 +194,18 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
| 190 |
only_mask=commons.om,
|
| 191 |
post_process_mask=commons.ppm,
|
| 192 |
bgcolor=commons.bgc,
|
| 193 |
-
**kwargs
|
| 194 |
),
|
| 195 |
media_type="image/png",
|
| 196 |
)
|
| 197 |
|
| 198 |
@app.on_event("startup")
|
| 199 |
def startup():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
if threads is not None:
|
| 201 |
from anyio import CapacityLimiter
|
| 202 |
from anyio.lowlevel import RunVar
|
|
@@ -204,7 +213,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
| 204 |
RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
|
| 205 |
|
| 206 |
@app.get(
|
| 207 |
-
path="/",
|
| 208 |
tags=["Background Removal"],
|
| 209 |
summary="Remove from URL",
|
| 210 |
description="Removes the background from an image obtained by retrieving an URL.",
|
|
@@ -221,7 +230,7 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
| 221 |
return await asyncify(im_without_bg)(file, commons)
|
| 222 |
|
| 223 |
@app.post(
|
| 224 |
-
path="/",
|
| 225 |
tags=["Background Removal"],
|
| 226 |
summary="Remove from Stream",
|
| 227 |
description="Removes the background from an image sent within the request itself.",
|
|
@@ -235,4 +244,42 @@ def s_command(port: int, log_level: str, threads: int) -> None:
|
|
| 235 |
):
|
| 236 |
return await asyncify(im_without_bg)(file, commons) # type: ignore
|
| 237 |
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
+
import os
|
| 3 |
+
import webbrowser
|
| 4 |
+
from typing import Optional, Tuple, cast
|
| 5 |
|
| 6 |
import aiohttp
|
| 7 |
import click
|
| 8 |
+
import gradio as gr
|
| 9 |
import uvicorn
|
| 10 |
from asyncer import asyncify
|
| 11 |
from fastapi import Depends, FastAPI, File, Form, Query
|
|
|
|
| 73 |
"url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt",
|
| 74 |
},
|
| 75 |
openapi_tags=tags_metadata,
|
| 76 |
+
docs_url="/api",
|
| 77 |
)
|
| 78 |
|
| 79 |
app.add_middleware(
|
|
|
|
| 87 |
class CommonQueryParams:
|
| 88 |
def __init__(
|
| 89 |
self,
|
| 90 |
+
model: str = Query(
|
|
|
|
|
|
|
| 91 |
description="Model to use when processing image",
|
| 92 |
+
regex=r"(" + "|".join(sessions_names) + ")",
|
| 93 |
+
default="u2net",
|
| 94 |
),
|
| 95 |
a: bool = Query(default=False, description="Enable Alpha Matting"),
|
| 96 |
af: int = Query(
|
|
|
|
| 132 |
class CommonQueryPostParams:
|
| 133 |
def __init__(
|
| 134 |
self,
|
| 135 |
+
model: str = Form(
|
|
|
|
|
|
|
| 136 |
description="Model to use when processing image",
|
| 137 |
+
regex=r"(" + "|".join(sessions_names) + ")",
|
| 138 |
+
default="u2net",
|
| 139 |
),
|
| 140 |
a: bool = Form(default=False, description="Enable Alpha Matting"),
|
| 141 |
af: int = Form(
|
|
|
|
| 194 |
only_mask=commons.om,
|
| 195 |
post_process_mask=commons.ppm,
|
| 196 |
bgcolor=commons.bgc,
|
| 197 |
+
**kwargs,
|
| 198 |
),
|
| 199 |
media_type="image/png",
|
| 200 |
)
|
| 201 |
|
| 202 |
@app.on_event("startup")
|
| 203 |
def startup():
|
| 204 |
+
try:
|
| 205 |
+
webbrowser.open(f"http://localhost:{port}")
|
| 206 |
+
except Exception:
|
| 207 |
+
pass
|
| 208 |
+
|
| 209 |
if threads is not None:
|
| 210 |
from anyio import CapacityLimiter
|
| 211 |
from anyio.lowlevel import RunVar
|
|
|
|
| 213 |
RunVar("_default_thread_limiter").set(CapacityLimiter(threads))
|
| 214 |
|
| 215 |
@app.get(
|
| 216 |
+
path="/api/remove",
|
| 217 |
tags=["Background Removal"],
|
| 218 |
summary="Remove from URL",
|
| 219 |
description="Removes the background from an image obtained by retrieving an URL.",
|
|
|
|
| 230 |
return await asyncify(im_without_bg)(file, commons)
|
| 231 |
|
| 232 |
@app.post(
|
| 233 |
+
path="/api/remove",
|
| 234 |
tags=["Background Removal"],
|
| 235 |
summary="Remove from Stream",
|
| 236 |
description="Removes the background from an image sent within the request itself.",
|
|
|
|
| 244 |
):
|
| 245 |
return await asyncify(im_without_bg)(file, commons) # type: ignore
|
| 246 |
|
| 247 |
+
def gr_app(app):
|
| 248 |
+
def inference(input_path, model):
|
| 249 |
+
output_path = "output.png"
|
| 250 |
+
with open(input_path, "rb") as i:
|
| 251 |
+
with open(output_path, "wb") as o:
|
| 252 |
+
input = i.read()
|
| 253 |
+
output = remove(input, session=new_session(model))
|
| 254 |
+
o.write(output)
|
| 255 |
+
return os.path.join(output_path)
|
| 256 |
+
|
| 257 |
+
interface = gr.Interface(
|
| 258 |
+
inference,
|
| 259 |
+
[
|
| 260 |
+
gr.components.Image(type="filepath", label="Input"),
|
| 261 |
+
gr.components.Dropdown(
|
| 262 |
+
[
|
| 263 |
+
"u2net",
|
| 264 |
+
"u2netp",
|
| 265 |
+
"u2net_human_seg",
|
| 266 |
+
"u2net_cloth_seg",
|
| 267 |
+
"silueta",
|
| 268 |
+
"isnet-general-use",
|
| 269 |
+
"isnet-anime",
|
| 270 |
+
],
|
| 271 |
+
value="u2net",
|
| 272 |
+
label="Models",
|
| 273 |
+
),
|
| 274 |
+
],
|
| 275 |
+
gr.components.Image(type="filepath", label="Output"),
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
interface.queue(concurrency_count=3)
|
| 279 |
+
app = gr.mount_gradio_app(app, interface, path="/")
|
| 280 |
+
return app
|
| 281 |
+
|
| 282 |
+
print(f"To access the API documentation, go to http://localhost:{port}/api")
|
| 283 |
+
print(f"To access the UI, go to http://localhost:{port}")
|
| 284 |
+
|
| 285 |
+
uvicorn.run(gr_app(app), host="0.0.0.0", port=port, log_level=log_level)
|
rembg/session_factory.py
CHANGED
|
@@ -8,7 +8,9 @@ from .sessions.base import BaseSession
|
|
| 8 |
from .sessions.u2net import U2netSession
|
| 9 |
|
| 10 |
|
| 11 |
-
def new_session(
|
|
|
|
|
|
|
| 12 |
session_class: Type[BaseSession] = U2netSession
|
| 13 |
|
| 14 |
for sc in sessions_class:
|
|
@@ -21,4 +23,4 @@ def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession:
|
|
| 21 |
if "OMP_NUM_THREADS" in os.environ:
|
| 22 |
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
|
| 23 |
|
| 24 |
-
return session_class(model_name, sess_opts, *args, **kwargs)
|
|
|
|
| 8 |
from .sessions.u2net import U2netSession
|
| 9 |
|
| 10 |
|
| 11 |
+
def new_session(
|
| 12 |
+
model_name: str = "u2net", providers=None, *args, **kwargs
|
| 13 |
+
) -> BaseSession:
|
| 14 |
session_class: Type[BaseSession] = U2netSession
|
| 15 |
|
| 16 |
for sc in sessions_class:
|
|
|
|
| 23 |
if "OMP_NUM_THREADS" in os.environ:
|
| 24 |
sess_opts.inter_op_num_threads = int(os.environ["OMP_NUM_THREADS"])
|
| 25 |
|
| 26 |
+
return session_class(model_name, sess_opts, providers, *args, **kwargs)
|
rembg/sessions/base.py
CHANGED
|
@@ -8,11 +8,29 @@ from PIL.Image import Image as PILImage
|
|
| 8 |
|
| 9 |
|
| 10 |
class BaseSession:
|
| 11 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
self.model_name = model_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
self.inner_session = ort.InferenceSession(
|
| 14 |
str(self.__class__.download_models()),
|
| 15 |
-
providers=
|
| 16 |
sess_options=sess_opts,
|
| 17 |
)
|
| 18 |
|
|
@@ -46,6 +64,10 @@ class BaseSession:
|
|
| 46 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
| 47 |
raise NotImplementedError
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
@classmethod
|
| 50 |
def u2net_home(cls, *args, **kwargs):
|
| 51 |
return os.path.expanduser(
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class BaseSession:
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
model_name: str,
|
| 14 |
+
sess_opts: ort.SessionOptions,
|
| 15 |
+
providers=None,
|
| 16 |
+
*args,
|
| 17 |
+
**kwargs
|
| 18 |
+
):
|
| 19 |
self.model_name = model_name
|
| 20 |
+
|
| 21 |
+
self.providers = []
|
| 22 |
+
|
| 23 |
+
_providers = ort.get_available_providers()
|
| 24 |
+
if providers:
|
| 25 |
+
for provider in providers:
|
| 26 |
+
if provider in _providers:
|
| 27 |
+
self.providers.append(provider)
|
| 28 |
+
else:
|
| 29 |
+
self.providers.extend(_providers)
|
| 30 |
+
|
| 31 |
self.inner_session = ort.InferenceSession(
|
| 32 |
str(self.__class__.download_models()),
|
| 33 |
+
providers=self.providers,
|
| 34 |
sess_options=sess_opts,
|
| 35 |
)
|
| 36 |
|
|
|
|
| 64 |
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
| 65 |
raise NotImplementedError
|
| 66 |
|
| 67 |
+
@classmethod
|
| 68 |
+
def checksum_disabled(cls, *args, **kwargs):
|
| 69 |
+
return os.getenv("MODEL_CHECKSUM_DISABLED", None) is not None
|
| 70 |
+
|
| 71 |
@classmethod
|
| 72 |
def u2net_home(cls, *args, **kwargs):
|
| 73 |
return os.path.expanduser(
|
rembg/sessions/dis_anime.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pooch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from PIL.Image import Image as PILImage
|
| 8 |
+
|
| 9 |
+
from .base import BaseSession
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DisSession(BaseSession):
|
| 13 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
| 14 |
+
ort_outs = self.inner_session.run(
|
| 15 |
+
None,
|
| 16 |
+
self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
pred = ort_outs[0][:, 0, :, :]
|
| 20 |
+
|
| 21 |
+
ma = np.max(pred)
|
| 22 |
+
mi = np.min(pred)
|
| 23 |
+
|
| 24 |
+
pred = (pred - mi) / (ma - mi)
|
| 25 |
+
pred = np.squeeze(pred)
|
| 26 |
+
|
| 27 |
+
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
| 28 |
+
mask = mask.resize(img.size, Image.LANCZOS)
|
| 29 |
+
|
| 30 |
+
return [mask]
|
| 31 |
+
|
| 32 |
+
@classmethod
|
| 33 |
+
def download_models(cls, *args, **kwargs):
|
| 34 |
+
fname = f"{cls.name()}.onnx"
|
| 35 |
+
pooch.retrieve(
|
| 36 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx",
|
| 37 |
+
None
|
| 38 |
+
if cls.checksum_disabled(*args, **kwargs)
|
| 39 |
+
else "md5:6f184e756bb3bd901c8849220a83e38e",
|
| 40 |
+
fname=fname,
|
| 41 |
+
path=cls.u2net_home(*args, **kwargs),
|
| 42 |
+
progressbar=True,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
return os.path.join(cls.u2net_home(), fname)
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def name(cls, *args, **kwargs):
|
| 49 |
+
return "isnet-anime"
|
rembg/sessions/dis_general_use.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pooch
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from PIL.Image import Image as PILImage
|
| 8 |
+
|
| 9 |
+
from .base import BaseSession
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DisSession(BaseSession):
|
| 13 |
+
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
|
| 14 |
+
ort_outs = self.inner_session.run(
|
| 15 |
+
None,
|
| 16 |
+
self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)),
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
pred = ort_outs[0][:, 0, :, :]
|
| 20 |
+
|
| 21 |
+
ma = np.max(pred)
|
| 22 |
+
mi = np.min(pred)
|
| 23 |
+
|
| 24 |
+
pred = (pred - mi) / (ma - mi)
|
| 25 |
+
pred = np.squeeze(pred)
|
| 26 |
+
|
| 27 |
+
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
|
| 28 |
+
mask = mask.resize(img.size, Image.LANCZOS)
|
| 29 |
+
|
| 30 |
+
return [mask]
|
| 31 |
+
|
| 32 |
+
@classmethod
|
| 33 |
+
def download_models(cls, *args, **kwargs):
|
| 34 |
+
fname = f"{cls.name()}.onnx"
|
| 35 |
+
pooch.retrieve(
|
| 36 |
+
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx",
|
| 37 |
+
None
|
| 38 |
+
if cls.checksum_disabled(*args, **kwargs)
|
| 39 |
+
else "md5:fc16ebd8b0c10d971d3513d564d01e29",
|
| 40 |
+
fname=fname,
|
| 41 |
+
path=cls.u2net_home(*args, **kwargs),
|
| 42 |
+
progressbar=True,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
return os.path.join(cls.u2net_home(), fname)
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def name(cls, *args, **kwargs):
|
| 49 |
+
return "isnet-general-use"
|
rembg/sessions/sam.py
CHANGED
|
@@ -141,17 +141,21 @@ class SamSession(BaseSession):
|
|
| 141 |
|
| 142 |
pooch.retrieve(
|
| 143 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
|
| 144 |
-
|
|
|
|
|
|
|
| 145 |
fname=fname_encoder,
|
| 146 |
-
path=cls.u2net_home(),
|
| 147 |
progressbar=True,
|
| 148 |
)
|
| 149 |
|
| 150 |
pooch.retrieve(
|
| 151 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
|
| 152 |
-
|
|
|
|
|
|
|
| 153 |
fname=fname_decoder,
|
| 154 |
-
path=cls.u2net_home(),
|
| 155 |
progressbar=True,
|
| 156 |
)
|
| 157 |
|
|
|
|
| 141 |
|
| 142 |
pooch.retrieve(
|
| 143 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
|
| 144 |
+
None
|
| 145 |
+
if cls.checksum_disabled(*args, **kwargs)
|
| 146 |
+
else "md5:13d97c5c79ab13ef86d67cbde5f1b250",
|
| 147 |
fname=fname_encoder,
|
| 148 |
+
path=cls.u2net_home(*args, **kwargs),
|
| 149 |
progressbar=True,
|
| 150 |
)
|
| 151 |
|
| 152 |
pooch.retrieve(
|
| 153 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
|
| 154 |
+
None
|
| 155 |
+
if cls.checksum_disabled(*args, **kwargs)
|
| 156 |
+
else "md5:fa3d1c36a3187d3de1c8deebf33dd127",
|
| 157 |
fname=fname_decoder,
|
| 158 |
+
path=cls.u2net_home(*args, **kwargs),
|
| 159 |
progressbar=True,
|
| 160 |
)
|
| 161 |
|
rembg/sessions/silueta.py
CHANGED
|
@@ -36,9 +36,11 @@ class SiluetaSession(BaseSession):
|
|
| 36 |
fname = f"{cls.name()}.onnx"
|
| 37 |
pooch.retrieve(
|
| 38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx",
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
fname=fname,
|
| 41 |
-
path=cls.u2net_home(),
|
| 42 |
progressbar=True,
|
| 43 |
)
|
| 44 |
|
|
|
|
| 36 |
fname = f"{cls.name()}.onnx"
|
| 37 |
pooch.retrieve(
|
| 38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx",
|
| 39 |
+
None
|
| 40 |
+
if cls.checksum_disabled(*args, **kwargs)
|
| 41 |
+
else "md5:55e59e0d8062d2f5d013f4725ee84782",
|
| 42 |
fname=fname,
|
| 43 |
+
path=cls.u2net_home(*args, **kwargs),
|
| 44 |
progressbar=True,
|
| 45 |
)
|
| 46 |
|
rembg/sessions/u2net.py
CHANGED
|
@@ -36,9 +36,11 @@ class U2netSession(BaseSession):
|
|
| 36 |
fname = f"{cls.name()}.onnx"
|
| 37 |
pooch.retrieve(
|
| 38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
fname=fname,
|
| 41 |
-
path=cls.u2net_home(),
|
| 42 |
progressbar=True,
|
| 43 |
)
|
| 44 |
|
|
|
|
| 36 |
fname = f"{cls.name()}.onnx"
|
| 37 |
pooch.retrieve(
|
| 38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx",
|
| 39 |
+
None
|
| 40 |
+
if cls.checksum_disabled(*args, **kwargs)
|
| 41 |
+
else "md5:60024c5c889badc19c04ad937298a77b",
|
| 42 |
fname=fname,
|
| 43 |
+
path=cls.u2net_home(*args, **kwargs),
|
| 44 |
progressbar=True,
|
| 45 |
)
|
| 46 |
|
rembg/sessions/u2net_cloth_seg.py
CHANGED
|
@@ -97,9 +97,11 @@ class Unet2ClothSession(BaseSession):
|
|
| 97 |
fname = f"{cls.name()}.onnx"
|
| 98 |
pooch.retrieve(
|
| 99 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
|
| 100 |
-
|
|
|
|
|
|
|
| 101 |
fname=fname,
|
| 102 |
-
path=cls.u2net_home(),
|
| 103 |
progressbar=True,
|
| 104 |
)
|
| 105 |
|
|
|
|
| 97 |
fname = f"{cls.name()}.onnx"
|
| 98 |
pooch.retrieve(
|
| 99 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx",
|
| 100 |
+
None
|
| 101 |
+
if cls.checksum_disabled(*args, **kwargs)
|
| 102 |
+
else "md5:2434d1f3cb744e0e49386c906e5a08bb",
|
| 103 |
fname=fname,
|
| 104 |
+
path=cls.u2net_home(*args, **kwargs),
|
| 105 |
progressbar=True,
|
| 106 |
)
|
| 107 |
|
rembg/sessions/u2net_human_seg.py
CHANGED
|
@@ -36,9 +36,11 @@ class U2netHumanSegSession(BaseSession):
|
|
| 36 |
fname = f"{cls.name()}.onnx"
|
| 37 |
pooch.retrieve(
|
| 38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
fname=fname,
|
| 41 |
-
path=cls.u2net_home(),
|
| 42 |
progressbar=True,
|
| 43 |
)
|
| 44 |
|
|
|
|
| 36 |
fname = f"{cls.name()}.onnx"
|
| 37 |
pooch.retrieve(
|
| 38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx",
|
| 39 |
+
None
|
| 40 |
+
if cls.checksum_disabled(*args, **kwargs)
|
| 41 |
+
else "md5:c09ddc2e0104f800e3e1bb4652583d1f",
|
| 42 |
fname=fname,
|
| 43 |
+
path=cls.u2net_home(*args, **kwargs),
|
| 44 |
progressbar=True,
|
| 45 |
)
|
| 46 |
|
rembg/sessions/u2netp.py
CHANGED
|
@@ -36,9 +36,11 @@ class U2netpSession(BaseSession):
|
|
| 36 |
fname = f"{cls.name()}.onnx"
|
| 37 |
pooch.retrieve(
|
| 38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
fname=fname,
|
| 41 |
-
path=cls.u2net_home(),
|
| 42 |
progressbar=True,
|
| 43 |
)
|
| 44 |
|
|
|
|
| 36 |
fname = f"{cls.name()}.onnx"
|
| 37 |
pooch.retrieve(
|
| 38 |
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx",
|
| 39 |
+
None
|
| 40 |
+
if cls.checksum_disabled(*args, **kwargs)
|
| 41 |
+
else "md5:8e83ca70e441ab06c318d82300c84806",
|
| 42 |
fname=fname,
|
| 43 |
+
path=cls.u2net_home(*args, **kwargs),
|
| 44 |
progressbar=True,
|
| 45 |
)
|
| 46 |
|