TheEeeeLin commited on
Commit
4be6b70
·
1 Parent(s): 59bff44

fix huggingface

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. hivision/creator/human_matting.py +30 -6
app.py CHANGED
@@ -618,4 +618,4 @@ if __name__ == "__main__":
618
  ],
619
  )
620
 
621
- demo.launch(server_name=args.host, server_port=args.port)
 
618
  ],
619
  )
620
 
621
+ demo.launch()
hivision/creator/human_matting.py CHANGED
@@ -15,7 +15,17 @@ from .context import Context
15
  import cv2
16
  import os
17
 
18
- weight_path = os.path.join(os.path.dirname(__file__), "weights", "hivision_modnet.onnx")
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  def extract_human(ctx: Context):
@@ -24,7 +34,21 @@ def extract_human(ctx: Context):
24
  :param ctx: 上下文
25
  """
26
  # 抠图
27
- matting_image = get_modnet_matting(ctx.processing_image, weight_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # 修复抠图
29
  ctx.processing_image = hollow_out_fix(matting_image)
30
  ctx.matting_image = ctx.processing_image.copy()
@@ -92,13 +116,13 @@ def read_modnet_image(input_image, ref_size=512):
92
  return im, width, length
93
 
94
 
95
- sess = None
96
 
97
 
98
  def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
99
- global sess
100
- if sess is None:
101
- sess = onnxruntime.InferenceSession(checkpoint_path)
102
 
103
  input_name = sess.get_inputs()[0].name
104
  output_name = sess.get_outputs()[0].name
 
15
  import cv2
16
  import os
17
 
18
+
19
+ WEIGHTS = {
20
+ "hivision_modnet": os.path.join(
21
+ os.path.dirname(__file__), "weights", "hivision_modnet.onnx"
22
+ ),
23
+ "modnet_photographic_portrait_matting": os.path.join(
24
+ os.path.dirname(__file__),
25
+ "weights",
26
+ "modnet_photographic_portrait_matting.onnx",
27
+ ),
28
+ }
29
 
30
 
31
  def extract_human(ctx: Context):
 
34
  :param ctx: 上下文
35
  """
36
  # 抠图
37
+ matting_image = get_modnet_matting(ctx.processing_image, WEIGHTS["hivision_modnet"])
38
+ # 修复抠图
39
+ ctx.processing_image = hollow_out_fix(matting_image)
40
+ ctx.matting_image = ctx.processing_image.copy()
41
+
42
+
43
+ def extract_human_modnet_photographic_portrait_matting(ctx: Context):
44
+ """
45
+ 人像抠图
46
+ :param ctx: 上下文
47
+ """
48
+ # 抠图
49
+ matting_image = get_modnet_matting(
50
+ ctx.processing_image, WEIGHTS["modnet_photographic_portrait_matting"]
51
+ )
52
  # 修复抠图
53
  ctx.processing_image = hollow_out_fix(matting_image)
54
  ctx.matting_image = ctx.processing_image.copy()
 
116
  return im, width, length
117
 
118
 
119
+ # sess = None
120
 
121
 
122
  def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
123
+ # global sess
124
+ # if sess is None:
125
+ sess = onnxruntime.InferenceSession(checkpoint_path)
126
 
127
  input_name = sess.get_inputs()[0].name
128
  output_name = sess.get_outputs()[0].name