Spaces:
Running
Running
File size: 11,040 Bytes
39c0f4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 |
#!/usr/bin/env python3
from argparse import ArgumentParser
from io import BytesIO
from os import listdir, makedirs
from os.path import basename, isdir, join, splitext
from random import randint
from typing import Union
from cairosvg import svg2png
import numpy as np
from imageio.v3 import imread, imwrite
from skimage.transform import rescale
from svgpathtools import CubicBezier, Line, QuadraticBezier, disvg, wsvg
import onnx
import onnxruntime as ort
def raster_bezier_hard(all_points, image_width=128, image_height=128, stroke_width=2., colors=None, white_background=True, mark=None):
if colors is None:
colors = [[0., 0., 0., 1.]] * len(all_points)
elif colors is list and colors[0] is not list:
colors = [colors] * len(all_points)
else:
colors = np.array(colors)
colors[:, :3] *= 255
colors = ["rgb(" + ",".join(map(str, color[:3])) + ")" for color in colors]
background_color = "white" if white_background else None
all_points = all_points + 0
all_points[:, :, 0] *= image_width
all_points[:, :, 1] *= image_height
bezier_curves = [numpy_to_bezier(points) for points in all_points]
attributes = [{"stroke": colors[i], "stroke-width": str(stroke_width), "fill": "none"} for i in range(len(bezier_curves))]
if mark is not None:
mark = mark + 0
mark[0] *= image_width
mark[1] *= image_height
mark_points = np.vstack([mark - stroke_width, mark + stroke_width])
mark_path = numpy_to_bezier(mark_points)
mark_attr = {"stroke": "blue", "stroke-width": str(stroke_width * 2), "fill": "blue"}
bezier_curves.append(mark_path)
attributes.append(mark_attr)
svg_attributes = {"width": f"{image_width}px", "height": f"{image_height}px"}
svg_string = disvg(bezier_curves, attributes=attributes, svg_attributes=svg_attributes, paths2Drawing=True).tostring()
png_string = svg2png(bytestring=svg_string, background_color=background_color)
image = imread(BytesIO(png_string), extension=".png")
output = image.astype("float32")
output /= 255
output = np.moveaxis(output, 2, 0)
return output, all_points
def diff_remaining_img(raster_img: np.ndarray, recons_img: np.ndarray):
remaining_img = raster_img.copy()
tmp_remaining_img = remaining_img.copy()
tmp_remaining_img[tmp_remaining_img < 1] = 0.
recons_img[recons_img < 1] = 0.
same_mask = (tmp_remaining_img == recons_img).copy()
remaining_img[same_mask] = 1
return remaining_img
def place_point_on_img(image, point):
if np.any(point == point.astype(int)):
point_idx_start = point.astype(int)
point_idx_end = point.astype(int) + 1
else:
point_idx_start = np.floor(point).astype(int)
point_idx_end = np.ceil(point).astype(int)
if image.shape[0] == 3:
image[0, point_idx_start[1]:point_idx_end[1], point_idx_start[0]:point_idx_end[0]] = 0
image[1, point_idx_start[1]:point_idx_end[1], point_idx_start[0]:point_idx_end[0]] = 0
image[2, point_idx_start[1]:point_idx_end[1], point_idx_start[0]:point_idx_end[0]] = 1
else:
image[0, point_idx_start[1]:point_idx_end[1], point_idx_start[0]:point_idx_end[0]] = 0.5
return image
def rgb_to_grayscale(image: np.ndarray):
image = image[0] * .2989 + image[1] *.587 + image[2] *.114
return image
def sample_black_pixel(image: np.ndarray):
image = rgb_to_grayscale(image.copy())
black_indices = np.argwhere(~np.isclose(image, np.ones_like(image, dtype="float32"), atol=0.5) != 0)
black_idx = black_indices[randint(0, len(black_indices) - 1)].astype("float32")
black_idx[0] /= image.shape[0]
black_idx[1] /= image.shape[1]
black_idx = black_idx[[1, 0]]
return black_idx
def numpy_to_bezier(points: np.ndarray):
if len(points) == 2:
return Line(*(complex(point[0], point[1]) for point in points))
elif len(points) == 3:
return QuadraticBezier(*(complex(point[0], point[1]) for point in points))
elif len(points) == 4:
return CubicBezier(*(complex(point[0], point[1]) for point in points))
def center_on_point(image, point, new_width=None, new_height=None):
_, height, width = image.shape
if new_width is None:
new_width = width
if new_height is None:
new_height = height
half_width = round(width / 2)
half_height = round(height / 2)
point = point.copy()
point[0] *= width
point[1] *= height
point = point.round().astype(int)
top=half_height - (half_height - point[1])
left=half_width - (half_width - point[0])
padded = np.pad(image, ((0, 0), (half_height, half_height), (half_width, half_width)), constant_values=1)
cropped = padded[:, top:top+new_height, left:left+new_width]
return cropped
def reverse_center_on_point(paths, point):
for i in range(len(paths)):
paths[i, :, 0] -= 0.5 - point[i, 0]
paths[i, :, 1] -= 0.5 - point[i, 1]
def save_as_svg(curves: np.ndarray, filename, img_width, img_height, stroke_width=2.0):
svg_paths = [numpy_to_bezier(curve) for curve in curves]
output_attributes = [{"stroke": "black", "stroke-width": stroke_width, "stroke-linecap": "round", "fill": "none"}] * len(svg_paths)
svg_attributes = {"width": f"{img_width}px", "height": f"{img_height}px"}
wsvg(svg_paths, attributes=output_attributes, svg_attributes=svg_attributes, filename=filename)
def save_as_png(filename: str, image: np.ndarray):
image = np.moveaxis(image.copy(), 0, 2)
image *= 255
imwrite(filename, image.round().astype("uint8"))
def setup_model(model_path):
model = onnx.load(model_path)
onnx.checker.check_model(model)
ort_sess = ort.InferenceSession(model_path, providers=["CUDAExecutionProvider"])
return ort_sess
def vectorize_image(input_image_path, model: Union[str, ort.InferenceSession], output=None, threshold_ratio=0.1, stroke_width=0.512, width=512, height=512, binarization_threshold=0, force_grayscale=False):
if type(model) is str:
ort_sess = setup_model(model)
elif type(model) is ort.InferenceSession:
ort_sess = model
else:
raise ValueError("Invalid value for the model argument")
# Get dimensions expected by the model
_, channels, height, width = ort_sess.get_inputs()[0].shape
input_image = imread(input_image_path, pilmode="RGB") / 255
original_height, original_width, _ = input_image.shape
# scale and white pad image to dimensions expected by the model
if original_height >= original_width:
scale = height / original_height
else:
scale = width / original_width
print(f"Rescale factor: {scale}")
input_image = rescale(input_image, scale, channel_axis=2, order=5)
scaled_height, scaled_width = input_image.shape[:2]
raster_img = np.ones((height, width, channels), dtype="float32")
raster_img[:input_image.shape[0], :input_image.shape[1]] = input_image
# convert CHW
raster_img = np.moveaxis(raster_img, 2, 0)
if binarization_threshold > 0:
raster_img[raster_img < binarization_threshold] = 0.
width = raster_img.shape[2]
height = raster_img.shape[1]
curve_pixels = (raster_img < .5).sum()
threshold = curve_pixels * threshold_ratio
print(f"Reconstruction candidate pixels: {curve_pixels}")
print(f"Reconstruction threshold: {threshold.astype(int)}")
recons_points = None
recons_img = np.ones_like(raster_img, dtype="float32")
remaining_img = raster_img.copy()
while (remaining_img < .5).sum() > threshold:
remaining_img = diff_remaining_img(raster_img, recons_img)
try:
mark = sample_black_pixel(remaining_img)
except ValueError:
break
centered_img = remaining_img.copy()
mark_real = mark.copy()
mark_real[0] *= width
mark_real[1] *= height
centered_img = place_point_on_img(centered_img, mark_real)
centered_img = center_on_point(centered_img, mark)
result = ort_sess.run(None, {"marked_raster_image": np.expand_dims(centered_img, 0)})
points = result[0]
reverse_center_on_point(points, np.expand_dims(mark, 0))
points = np.expand_dims(points, 1)
if recons_points is None:
recons_points = points
else:
recons_points = np.concatenate((recons_points, points), axis=1)
recons_img, _ = raster_bezier_hard(recons_points.squeeze(0), image_width=width, image_height=height, stroke_width=stroke_width)
yield np.moveaxis(recons_img, 0, 2)
output_filepath = splitext(basename(input_image_path))[0] + ".svg"
if output is not None:
if isdir(output):
makedirs(output, exist_ok=True)
output_filepath = join(output, output_filepath)
elif type(output) is str and output.endswith(".svg"):
output_filepath = output
recons_points = recons_points.squeeze(0)
recons_points[:, :, 0] *= width * (1 / scale)
recons_points[:, :, 1] *= height * (1 / scale)
save_as_svg(recons_points, output_filepath, original_width, original_height, stroke_width=stroke_width)
def main():
parser = ArgumentParser(description="Inference script for the marked curve reconstruction model in ONNX format.")
parser.add_argument("model", metavar="FIlE", help="path to the *.onnx file")
parser.add_argument("-i", "--input_images", nargs="*", metavar="FILE", help="one or multiple paths to raster images that should be vectorized.")
parser.add_argument("-d", "--input_dir", metavar="DIR", help="path to a directory of raster images that should be vectorized.")
parser.add_argument("-o", "--output", help="optional output directory or file")
parser.add_argument("--threshold_ratio", "-t", default=0.1, type=float, help="The ratio of black pixels which need to be reconstructed before the algorithm terminates")
parser.add_argument("--stroke_width", "-r", default=0.512, type=float, help="stroke width if it should be different from the one specified in the model")
parser.add_argument("--seed", "-s", default=1234, help="Fixed random number generation seed. Set to negative number to deactivate")
parser.add_argument("-b", "--binarization_threshold", default=0., type=float, help="Set to a value in (0,1) to binarize the image.")
args = parser.parse_args()
if args.seed >= 0:
np.random.seed(args.seed)
if args.input_images is not None:
input_images = args.input_images
elif args.input_dir is not None and isdir(args.input_dir):
input_images = [join(args.input_dir, f) for f in listdir(args.input_dir)]
else:
print("-i or -d need to be passed")
exit(1)
for input_image in input_images:
vectorize_image(input_image, args.model, output=args.output, threshold_ratio=args.threshold_ratio, stroke_width=args.stroke_width, binarization_threshold=args.binarization_threshold, force_grayscale=False)
if __name__ == "__main__":
main()
|