unet-swin-sky-replacement / replacements.py
Svane20's picture
Updated model to use PyTorch instead of ONNX
f28556a
import pymatting
import numpy as np
from PIL import Image
import random
from pathlib import Path
def get_foreground_estimation(image, alpha):
# Downscale the input image to match predicted_alpha
h, w = alpha.shape
downscaled_image = image.resize((w, h), Image.Resampling.LANCZOS)
# Normalize the image to [0, 1] range
normalized_image = np.array(downscaled_image) / 255.0
# Invert the alpha mask since the pymatting library expects the sky to be the background
inverted_alpha = 1 - alpha
return pymatting.estimate_foreground_ml(image=normalized_image, alpha=inverted_alpha)
def sky_replacement(foreground, alpha_mask):
new_sky_path = Path(__file__).parent / "assets/skies/francesco-ungaro-i75WTJn-RBY-unsplash.jpg"
new_sky_img = Image.open(new_sky_path).convert("RGB")
# Get the target size from the foreground image
h, w = foreground.shape[:2]
# Check the size of the sky image
sky_width, sky_height = new_sky_img.size
# If the sky image is smaller than the target size
if sky_width < w or sky_height < h:
scale = max(w / sky_width, h / sky_height)
new_size = (int(sky_width * scale), int(sky_height * scale))
new_sky_img = new_sky_img.resize(new_size, resample=Image.Resampling.LANCZOS)
sky_width, sky_height = new_sky_img.size
# Determine the maximum possible top-left coordinates for the crop
max_left = sky_width - w
max_top = sky_height - h
# Choose random offsets for left and top within the valid range
left = random.randint(a=0, b=max_left) if max_left > 0 else 0
top = random.randint(a=0, b=max_top) if max_top > 0 else 0
# Crop the sky image to the target size using the random offsets
new_sky_img = new_sky_img.crop((left, top, left + w, top + h))
new_sky = np.asarray(new_sky_img).astype(np.float32) / 255.0
if foreground.dtype != np.float32:
foreground = foreground.astype(np.float32) / 255.0
if foreground.shape[2] == 4:
foreground = foreground[:, :, :3]
# Ensure that the alpha mask values are within the range [0, 1]
alpha_mask = np.clip(alpha_mask, a_min=0, a_max=1)
# Blend the foreground with the new sky using the alpha mask
return (1 - alpha_mask[:, :, None]) * foreground + alpha_mask[:, :, None] * new_sky