Spaces:
Running
Running
TheEeeeLin
commited on
Commit
·
4be6b70
1
Parent(s):
59bff44
fix huggingface
Browse files- app.py +1 -1
- hivision/creator/human_matting.py +30 -6
app.py
CHANGED
@@ -618,4 +618,4 @@ if __name__ == "__main__":
|
|
618 |
],
|
619 |
)
|
620 |
|
621 |
-
demo.launch(
|
|
|
618 |
],
|
619 |
)
|
620 |
|
621 |
+
demo.launch()
|
hivision/creator/human_matting.py
CHANGED
@@ -15,7 +15,17 @@ from .context import Context
|
|
15 |
import cv2
|
16 |
import os
|
17 |
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
def extract_human(ctx: Context):
|
@@ -24,7 +34,21 @@ def extract_human(ctx: Context):
|
|
24 |
:param ctx: 上下文
|
25 |
"""
|
26 |
# 抠图
|
27 |
-
matting_image = get_modnet_matting(ctx.processing_image,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
# 修复抠图
|
29 |
ctx.processing_image = hollow_out_fix(matting_image)
|
30 |
ctx.matting_image = ctx.processing_image.copy()
|
@@ -92,13 +116,13 @@ def read_modnet_image(input_image, ref_size=512):
|
|
92 |
return im, width, length
|
93 |
|
94 |
|
95 |
-
sess = None
|
96 |
|
97 |
|
98 |
def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
|
99 |
-
global sess
|
100 |
-
if sess is None:
|
101 |
-
|
102 |
|
103 |
input_name = sess.get_inputs()[0].name
|
104 |
output_name = sess.get_outputs()[0].name
|
|
|
15 |
import cv2
|
16 |
import os
|
17 |
|
18 |
+
|
19 |
+
WEIGHTS = {
|
20 |
+
"hivision_modnet": os.path.join(
|
21 |
+
os.path.dirname(__file__), "weights", "hivision_modnet.onnx"
|
22 |
+
),
|
23 |
+
"modnet_photographic_portrait_matting": os.path.join(
|
24 |
+
os.path.dirname(__file__),
|
25 |
+
"weights",
|
26 |
+
"modnet_photographic_portrait_matting.onnx",
|
27 |
+
),
|
28 |
+
}
|
29 |
|
30 |
|
31 |
def extract_human(ctx: Context):
|
|
|
34 |
:param ctx: 上下文
|
35 |
"""
|
36 |
# 抠图
|
37 |
+
matting_image = get_modnet_matting(ctx.processing_image, WEIGHTS["hivision_modnet"])
|
38 |
+
# 修复抠图
|
39 |
+
ctx.processing_image = hollow_out_fix(matting_image)
|
40 |
+
ctx.matting_image = ctx.processing_image.copy()
|
41 |
+
|
42 |
+
|
43 |
+
def extract_human_modnet_photographic_portrait_matting(ctx: Context):
|
44 |
+
"""
|
45 |
+
人像抠图
|
46 |
+
:param ctx: 上下文
|
47 |
+
"""
|
48 |
+
# 抠图
|
49 |
+
matting_image = get_modnet_matting(
|
50 |
+
ctx.processing_image, WEIGHTS["modnet_photographic_portrait_matting"]
|
51 |
+
)
|
52 |
# 修复抠图
|
53 |
ctx.processing_image = hollow_out_fix(matting_image)
|
54 |
ctx.matting_image = ctx.processing_image.copy()
|
|
|
116 |
return im, width, length
|
117 |
|
118 |
|
119 |
+
# sess = None
|
120 |
|
121 |
|
122 |
def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
|
123 |
+
# global sess
|
124 |
+
# if sess is None:
|
125 |
+
sess = onnxruntime.InferenceSession(checkpoint_path)
|
126 |
|
127 |
input_name = sess.get_inputs()[0].name
|
128 |
output_name = sess.get_outputs()[0].name
|