Spaces:
Running
Running
nopperl
commited on
Commit
·
39c0f4e
1
Parent(s):
3455245
Add application
Browse files- README.md +12 -0
- app.py +30 -0
- examples/01.png +0 -0
- examples/02.png +0 -0
- onnx_inference.py +245 -0
- requirements.txt +7 -0
README.md
CHANGED
@@ -8,6 +8,18 @@ sdk_version: 4.8.0
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
+
preload_from_hub:
|
12 |
+
- nopperl/marked-lineart-vectorizer model.onnx
|
13 |
+
datasets:
|
14 |
+
- kmewhort/tu-berlin-svgs
|
15 |
+
tags:
|
16 |
+
- image-vectorization
|
17 |
+
- sketch
|
18 |
+
- sketch-synthesis
|
19 |
+
- svg
|
20 |
+
- vector-image
|
21 |
+
- line-drawing
|
22 |
+
- line-art
|
23 |
---
|
24 |
|
25 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os.path import basename, splitext
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
from huggingface_hub import hf_hub_download
|
5 |
+
|
6 |
+
from onnx_inference import vectorize_image
|
7 |
+
|
8 |
+
|
9 |
+
MODEL_PATH = hf_hub_download("nopperl/marked-lineart-vectorizer", "model.onnx")
|
10 |
+
|
11 |
+
|
12 |
+
def predict(input_image_path, threshold, stroke_width):
|
13 |
+
output_filepath = splitext(basename(input_image_path))[0] + ".svg"
|
14 |
+
for recons_img in vectorize_image(input_image_path, model=MODEL_PATH, output=output_filepath, threshold_ratio=threshold, stroke_width=stroke_width):
|
15 |
+
yield recons_img
|
16 |
+
yield output_filepath
|
17 |
+
|
18 |
+
|
19 |
+
interface = gr.Interface(
|
20 |
+
predict,
|
21 |
+
inputs=[gr.Image(sources="upload", type="filepath"), gr.Slider(minimum=0.1, maximum=0.9, value=0.1, label="threshold"), gr.Slider(minimum=0.1, maximum=4.0, value=0.512, label="stroke_width")],
|
22 |
+
outputs=gr.Image(),
|
23 |
+
description="Demo for a model that converts raster line-art images into vector images iteratively. The model is trained on black-and-white line-art images, hence it won't work with other images. Inference time will be quite slow due to a lack of GPU resources. More information at https://github.com/nopperl/marked-lineart-vectorization.",
|
24 |
+
examples = [
|
25 |
+
["examples/01.png", 0.1, 0.512],
|
26 |
+
["examples/02.png", 0.1, 0.512]
|
27 |
+
],
|
28 |
+
analytics_enabled=False
|
29 |
+
)
|
30 |
+
interface.launch()
|
examples/01.png
ADDED
examples/02.png
ADDED
onnx_inference.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
from io import BytesIO
|
4 |
+
from os import listdir, makedirs
|
5 |
+
from os.path import basename, isdir, join, splitext
|
6 |
+
from random import randint
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
from cairosvg import svg2png
|
10 |
+
import numpy as np
|
11 |
+
from imageio.v3 import imread, imwrite
|
12 |
+
from skimage.transform import rescale
|
13 |
+
from svgpathtools import CubicBezier, Line, QuadraticBezier, disvg, wsvg
|
14 |
+
|
15 |
+
import onnx
|
16 |
+
import onnxruntime as ort
|
17 |
+
|
18 |
+
|
19 |
+
def raster_bezier_hard(all_points, image_width=128, image_height=128, stroke_width=2., colors=None, white_background=True, mark=None):
|
20 |
+
if colors is None:
|
21 |
+
colors = [[0., 0., 0., 1.]] * len(all_points)
|
22 |
+
elif colors is list and colors[0] is not list:
|
23 |
+
colors = [colors] * len(all_points)
|
24 |
+
else:
|
25 |
+
colors = np.array(colors)
|
26 |
+
colors[:, :3] *= 255
|
27 |
+
colors = ["rgb(" + ",".join(map(str, color[:3])) + ")" for color in colors]
|
28 |
+
background_color = "white" if white_background else None
|
29 |
+
all_points = all_points + 0
|
30 |
+
all_points[:, :, 0] *= image_width
|
31 |
+
all_points[:, :, 1] *= image_height
|
32 |
+
bezier_curves = [numpy_to_bezier(points) for points in all_points]
|
33 |
+
attributes = [{"stroke": colors[i], "stroke-width": str(stroke_width), "fill": "none"} for i in range(len(bezier_curves))]
|
34 |
+
if mark is not None:
|
35 |
+
mark = mark + 0
|
36 |
+
mark[0] *= image_width
|
37 |
+
mark[1] *= image_height
|
38 |
+
mark_points = np.vstack([mark - stroke_width, mark + stroke_width])
|
39 |
+
mark_path = numpy_to_bezier(mark_points)
|
40 |
+
mark_attr = {"stroke": "blue", "stroke-width": str(stroke_width * 2), "fill": "blue"}
|
41 |
+
bezier_curves.append(mark_path)
|
42 |
+
attributes.append(mark_attr)
|
43 |
+
svg_attributes = {"width": f"{image_width}px", "height": f"{image_height}px"}
|
44 |
+
svg_string = disvg(bezier_curves, attributes=attributes, svg_attributes=svg_attributes, paths2Drawing=True).tostring()
|
45 |
+
png_string = svg2png(bytestring=svg_string, background_color=background_color)
|
46 |
+
image = imread(BytesIO(png_string), extension=".png")
|
47 |
+
output = image.astype("float32")
|
48 |
+
output /= 255
|
49 |
+
output = np.moveaxis(output, 2, 0)
|
50 |
+
return output, all_points
|
51 |
+
|
52 |
+
def diff_remaining_img(raster_img: np.ndarray, recons_img: np.ndarray):
|
53 |
+
remaining_img = raster_img.copy()
|
54 |
+
tmp_remaining_img = remaining_img.copy()
|
55 |
+
tmp_remaining_img[tmp_remaining_img < 1] = 0.
|
56 |
+
recons_img[recons_img < 1] = 0.
|
57 |
+
same_mask = (tmp_remaining_img == recons_img).copy()
|
58 |
+
remaining_img[same_mask] = 1
|
59 |
+
return remaining_img
|
60 |
+
|
61 |
+
|
62 |
+
def place_point_on_img(image, point):
|
63 |
+
if np.any(point == point.astype(int)):
|
64 |
+
point_idx_start = point.astype(int)
|
65 |
+
point_idx_end = point.astype(int) + 1
|
66 |
+
else:
|
67 |
+
point_idx_start = np.floor(point).astype(int)
|
68 |
+
point_idx_end = np.ceil(point).astype(int)
|
69 |
+
if image.shape[0] == 3:
|
70 |
+
image[0, point_idx_start[1]:point_idx_end[1], point_idx_start[0]:point_idx_end[0]] = 0
|
71 |
+
image[1, point_idx_start[1]:point_idx_end[1], point_idx_start[0]:point_idx_end[0]] = 0
|
72 |
+
image[2, point_idx_start[1]:point_idx_end[1], point_idx_start[0]:point_idx_end[0]] = 1
|
73 |
+
else:
|
74 |
+
image[0, point_idx_start[1]:point_idx_end[1], point_idx_start[0]:point_idx_end[0]] = 0.5
|
75 |
+
return image
|
76 |
+
|
77 |
+
|
78 |
+
def rgb_to_grayscale(image: np.ndarray):
|
79 |
+
image = image[0] * .2989 + image[1] *.587 + image[2] *.114
|
80 |
+
return image
|
81 |
+
|
82 |
+
|
83 |
+
def sample_black_pixel(image: np.ndarray):
|
84 |
+
image = rgb_to_grayscale(image.copy())
|
85 |
+
black_indices = np.argwhere(~np.isclose(image, np.ones_like(image, dtype="float32"), atol=0.5) != 0)
|
86 |
+
black_idx = black_indices[randint(0, len(black_indices) - 1)].astype("float32")
|
87 |
+
black_idx[0] /= image.shape[0]
|
88 |
+
black_idx[1] /= image.shape[1]
|
89 |
+
black_idx = black_idx[[1, 0]]
|
90 |
+
return black_idx
|
91 |
+
|
92 |
+
|
93 |
+
def numpy_to_bezier(points: np.ndarray):
|
94 |
+
if len(points) == 2:
|
95 |
+
return Line(*(complex(point[0], point[1]) for point in points))
|
96 |
+
elif len(points) == 3:
|
97 |
+
return QuadraticBezier(*(complex(point[0], point[1]) for point in points))
|
98 |
+
elif len(points) == 4:
|
99 |
+
return CubicBezier(*(complex(point[0], point[1]) for point in points))
|
100 |
+
|
101 |
+
|
102 |
+
def center_on_point(image, point, new_width=None, new_height=None):
|
103 |
+
_, height, width = image.shape
|
104 |
+
if new_width is None:
|
105 |
+
new_width = width
|
106 |
+
if new_height is None:
|
107 |
+
new_height = height
|
108 |
+
half_width = round(width / 2)
|
109 |
+
half_height = round(height / 2)
|
110 |
+
point = point.copy()
|
111 |
+
point[0] *= width
|
112 |
+
point[1] *= height
|
113 |
+
point = point.round().astype(int)
|
114 |
+
top=half_height - (half_height - point[1])
|
115 |
+
left=half_width - (half_width - point[0])
|
116 |
+
padded = np.pad(image, ((0, 0), (half_height, half_height), (half_width, half_width)), constant_values=1)
|
117 |
+
cropped = padded[:, top:top+new_height, left:left+new_width]
|
118 |
+
return cropped
|
119 |
+
|
120 |
+
|
121 |
+
def reverse_center_on_point(paths, point):
|
122 |
+
for i in range(len(paths)):
|
123 |
+
paths[i, :, 0] -= 0.5 - point[i, 0]
|
124 |
+
paths[i, :, 1] -= 0.5 - point[i, 1]
|
125 |
+
|
126 |
+
|
127 |
+
def save_as_svg(curves: np.ndarray, filename, img_width, img_height, stroke_width=2.0):
|
128 |
+
svg_paths = [numpy_to_bezier(curve) for curve in curves]
|
129 |
+
output_attributes = [{"stroke": "black", "stroke-width": stroke_width, "stroke-linecap": "round", "fill": "none"}] * len(svg_paths)
|
130 |
+
svg_attributes = {"width": f"{img_width}px", "height": f"{img_height}px"}
|
131 |
+
wsvg(svg_paths, attributes=output_attributes, svg_attributes=svg_attributes, filename=filename)
|
132 |
+
|
133 |
+
|
134 |
+
def save_as_png(filename: str, image: np.ndarray):
|
135 |
+
image = np.moveaxis(image.copy(), 0, 2)
|
136 |
+
image *= 255
|
137 |
+
imwrite(filename, image.round().astype("uint8"))
|
138 |
+
|
139 |
+
|
140 |
+
def setup_model(model_path):
|
141 |
+
model = onnx.load(model_path)
|
142 |
+
onnx.checker.check_model(model)
|
143 |
+
ort_sess = ort.InferenceSession(model_path, providers=["CUDAExecutionProvider"])
|
144 |
+
return ort_sess
|
145 |
+
|
146 |
+
|
147 |
+
def vectorize_image(input_image_path, model: Union[str, ort.InferenceSession], output=None, threshold_ratio=0.1, stroke_width=0.512, width=512, height=512, binarization_threshold=0, force_grayscale=False):
|
148 |
+
if type(model) is str:
|
149 |
+
ort_sess = setup_model(model)
|
150 |
+
elif type(model) is ort.InferenceSession:
|
151 |
+
ort_sess = model
|
152 |
+
else:
|
153 |
+
raise ValueError("Invalid value for the model argument")
|
154 |
+
|
155 |
+
# Get dimensions expected by the model
|
156 |
+
_, channels, height, width = ort_sess.get_inputs()[0].shape
|
157 |
+
input_image = imread(input_image_path, pilmode="RGB") / 255
|
158 |
+
original_height, original_width, _ = input_image.shape
|
159 |
+
# scale and white pad image to dimensions expected by the model
|
160 |
+
if original_height >= original_width:
|
161 |
+
scale = height / original_height
|
162 |
+
else:
|
163 |
+
scale = width / original_width
|
164 |
+
print(f"Rescale factor: {scale}")
|
165 |
+
input_image = rescale(input_image, scale, channel_axis=2, order=5)
|
166 |
+
scaled_height, scaled_width = input_image.shape[:2]
|
167 |
+
raster_img = np.ones((height, width, channels), dtype="float32")
|
168 |
+
raster_img[:input_image.shape[0], :input_image.shape[1]] = input_image
|
169 |
+
# convert CHW
|
170 |
+
raster_img = np.moveaxis(raster_img, 2, 0)
|
171 |
+
if binarization_threshold > 0:
|
172 |
+
raster_img[raster_img < binarization_threshold] = 0.
|
173 |
+
width = raster_img.shape[2]
|
174 |
+
height = raster_img.shape[1]
|
175 |
+
curve_pixels = (raster_img < .5).sum()
|
176 |
+
threshold = curve_pixels * threshold_ratio
|
177 |
+
print(f"Reconstruction candidate pixels: {curve_pixels}")
|
178 |
+
print(f"Reconstruction threshold: {threshold.astype(int)}")
|
179 |
+
recons_points = None
|
180 |
+
recons_img = np.ones_like(raster_img, dtype="float32")
|
181 |
+
remaining_img = raster_img.copy()
|
182 |
+
while (remaining_img < .5).sum() > threshold:
|
183 |
+
remaining_img = diff_remaining_img(raster_img, recons_img)
|
184 |
+
try:
|
185 |
+
mark = sample_black_pixel(remaining_img)
|
186 |
+
except ValueError:
|
187 |
+
break
|
188 |
+
centered_img = remaining_img.copy()
|
189 |
+
mark_real = mark.copy()
|
190 |
+
mark_real[0] *= width
|
191 |
+
mark_real[1] *= height
|
192 |
+
centered_img = place_point_on_img(centered_img, mark_real)
|
193 |
+
centered_img = center_on_point(centered_img, mark)
|
194 |
+
result = ort_sess.run(None, {"marked_raster_image": np.expand_dims(centered_img, 0)})
|
195 |
+
points = result[0]
|
196 |
+
reverse_center_on_point(points, np.expand_dims(mark, 0))
|
197 |
+
points = np.expand_dims(points, 1)
|
198 |
+
if recons_points is None:
|
199 |
+
recons_points = points
|
200 |
+
else:
|
201 |
+
recons_points = np.concatenate((recons_points, points), axis=1)
|
202 |
+
recons_img, _ = raster_bezier_hard(recons_points.squeeze(0), image_width=width, image_height=height, stroke_width=stroke_width)
|
203 |
+
yield np.moveaxis(recons_img, 0, 2)
|
204 |
+
|
205 |
+
output_filepath = splitext(basename(input_image_path))[0] + ".svg"
|
206 |
+
if output is not None:
|
207 |
+
if isdir(output):
|
208 |
+
makedirs(output, exist_ok=True)
|
209 |
+
output_filepath = join(output, output_filepath)
|
210 |
+
elif type(output) is str and output.endswith(".svg"):
|
211 |
+
output_filepath = output
|
212 |
+
recons_points = recons_points.squeeze(0)
|
213 |
+
recons_points[:, :, 0] *= width * (1 / scale)
|
214 |
+
recons_points[:, :, 1] *= height * (1 / scale)
|
215 |
+
save_as_svg(recons_points, output_filepath, original_width, original_height, stroke_width=stroke_width)
|
216 |
+
|
217 |
+
|
218 |
+
def main():
|
219 |
+
parser = ArgumentParser(description="Inference script for the marked curve reconstruction model in ONNX format.")
|
220 |
+
parser.add_argument("model", metavar="FIlE", help="path to the *.onnx file")
|
221 |
+
parser.add_argument("-i", "--input_images", nargs="*", metavar="FILE", help="one or multiple paths to raster images that should be vectorized.")
|
222 |
+
parser.add_argument("-d", "--input_dir", metavar="DIR", help="path to a directory of raster images that should be vectorized.")
|
223 |
+
parser.add_argument("-o", "--output", help="optional output directory or file")
|
224 |
+
parser.add_argument("--threshold_ratio", "-t", default=0.1, type=float, help="The ratio of black pixels which need to be reconstructed before the algorithm terminates")
|
225 |
+
parser.add_argument("--stroke_width", "-r", default=0.512, type=float, help="stroke width if it should be different from the one specified in the model")
|
226 |
+
parser.add_argument("--seed", "-s", default=1234, help="Fixed random number generation seed. Set to negative number to deactivate")
|
227 |
+
parser.add_argument("-b", "--binarization_threshold", default=0., type=float, help="Set to a value in (0,1) to binarize the image.")
|
228 |
+
|
229 |
+
args = parser.parse_args()
|
230 |
+
|
231 |
+
if args.seed >= 0:
|
232 |
+
np.random.seed(args.seed)
|
233 |
+
if args.input_images is not None:
|
234 |
+
input_images = args.input_images
|
235 |
+
elif args.input_dir is not None and isdir(args.input_dir):
|
236 |
+
input_images = [join(args.input_dir, f) for f in listdir(args.input_dir)]
|
237 |
+
else:
|
238 |
+
print("-i or -d need to be passed")
|
239 |
+
exit(1)
|
240 |
+
for input_image in input_images:
|
241 |
+
vectorize_image(input_image, args.model, output=args.output, threshold_ratio=args.threshold_ratio, stroke_width=args.stroke_width, binarization_threshold=args.binarization_threshold, force_grayscale=False)
|
242 |
+
|
243 |
+
|
244 |
+
if __name__ == "__main__":
|
245 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
onnx==1.14.0
|
2 |
+
onnxruntime==1.15.1
|
3 |
+
imageio==2.31.1
|
4 |
+
svgpathtools==1.6.1
|
5 |
+
cairosvg==2.7.0
|
6 |
+
scikit-image==0.21.0
|
7 |
+
huggingface_hub==0.20.1
|