Upload 6 files
Browse files- utils/common.py +159 -0
- utils/cond_fn.py +98 -0
- utils/face_restoration_helper.py +517 -0
- utils/helpers.py +216 -0
- utils/inference.py +320 -0
- utils/sampler.py +341 -0
utils/common.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Mapping, Any, Tuple, Callable
|
| 2 |
+
import importlib
|
| 3 |
+
import os
|
| 4 |
+
from urllib.parse import urlparse
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from torch.hub import download_url_to_file, get_dir
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_obj_from_str(string: str, reload: bool=False) -> Any:
|
| 15 |
+
module, cls = string.rsplit(".", 1)
|
| 16 |
+
if reload:
|
| 17 |
+
module_imp = importlib.import_module(module)
|
| 18 |
+
importlib.reload(module_imp)
|
| 19 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def instantiate_from_config(config: Mapping[str, Any]) -> Any:
|
| 23 |
+
if not "target" in config:
|
| 24 |
+
raise KeyError("Expected key `target` to instantiate.")
|
| 25 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def wavelet_blur(image: Tensor, radius: int):
|
| 29 |
+
"""
|
| 30 |
+
Apply wavelet blur to the input tensor.
|
| 31 |
+
"""
|
| 32 |
+
# input shape: (1, 3, H, W)
|
| 33 |
+
# convolution kernel
|
| 34 |
+
kernel_vals = [
|
| 35 |
+
[0.0625, 0.125, 0.0625],
|
| 36 |
+
[0.125, 0.25, 0.125],
|
| 37 |
+
[0.0625, 0.125, 0.0625],
|
| 38 |
+
]
|
| 39 |
+
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
| 40 |
+
# add channel dimensions to the kernel to make it a 4D tensor
|
| 41 |
+
kernel = kernel[None, None]
|
| 42 |
+
# repeat the kernel across all input channels
|
| 43 |
+
kernel = kernel.repeat(3, 1, 1, 1)
|
| 44 |
+
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
| 45 |
+
# apply convolution
|
| 46 |
+
output = F.conv2d(image, kernel, groups=3, dilation=radius)
|
| 47 |
+
return output
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def wavelet_decomposition(image: Tensor, levels=5):
|
| 51 |
+
"""
|
| 52 |
+
Apply wavelet decomposition to the input tensor.
|
| 53 |
+
This function only returns the low frequency & the high frequency.
|
| 54 |
+
"""
|
| 55 |
+
high_freq = torch.zeros_like(image)
|
| 56 |
+
for i in range(levels):
|
| 57 |
+
radius = 2 ** i
|
| 58 |
+
low_freq = wavelet_blur(image, radius)
|
| 59 |
+
high_freq += (image - low_freq)
|
| 60 |
+
image = low_freq
|
| 61 |
+
|
| 62 |
+
return high_freq, low_freq
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
|
| 66 |
+
"""
|
| 67 |
+
Apply wavelet decomposition, so that the content will have the same color as the style.
|
| 68 |
+
"""
|
| 69 |
+
# calculate the wavelet decomposition of the content feature
|
| 70 |
+
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
| 71 |
+
del content_low_freq
|
| 72 |
+
# calculate the wavelet decomposition of the style feature
|
| 73 |
+
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
| 74 |
+
del style_high_freq
|
| 75 |
+
# reconstruct the content feature with the style's high frequency
|
| 76 |
+
return content_high_freq + style_low_freq
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/
|
| 80 |
+
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
| 81 |
+
"""Load file form http url, will download models if necessary.
|
| 82 |
+
|
| 83 |
+
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
url (str): URL to be downloaded.
|
| 87 |
+
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
|
| 88 |
+
Default: None.
|
| 89 |
+
progress (bool): Whether to show the download progress. Default: True.
|
| 90 |
+
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
str: The path to the downloaded file.
|
| 94 |
+
"""
|
| 95 |
+
if model_dir is None: # use the pytorch hub_dir
|
| 96 |
+
hub_dir = get_dir()
|
| 97 |
+
model_dir = os.path.join(hub_dir, 'checkpoints')
|
| 98 |
+
|
| 99 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 100 |
+
|
| 101 |
+
parts = urlparse(url)
|
| 102 |
+
filename = os.path.basename(parts.path)
|
| 103 |
+
if file_name is not None:
|
| 104 |
+
filename = file_name
|
| 105 |
+
cached_file = os.path.abspath(os.path.join(model_dir, filename))
|
| 106 |
+
if not os.path.exists(cached_file):
|
| 107 |
+
print(f'Downloading: "{url}" to {cached_file}\n')
|
| 108 |
+
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
| 109 |
+
return cached_file
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def sliding_windows(h: int, w: int, tile_size: int, tile_stride: int) -> Tuple[int, int, int, int]:
|
| 113 |
+
hi_list = list(range(0, h - tile_size + 1, tile_stride))
|
| 114 |
+
if (h - tile_size) % tile_stride != 0:
|
| 115 |
+
hi_list.append(h - tile_size)
|
| 116 |
+
|
| 117 |
+
wi_list = list(range(0, w - tile_size + 1, tile_stride))
|
| 118 |
+
if (w - tile_size) % tile_stride != 0:
|
| 119 |
+
wi_list.append(w - tile_size)
|
| 120 |
+
|
| 121 |
+
coords = []
|
| 122 |
+
for hi in hi_list:
|
| 123 |
+
for wi in wi_list:
|
| 124 |
+
coords.append((hi, hi + tile_size, wi, wi + tile_size))
|
| 125 |
+
return coords
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# https://github.com/csslc/CCSR/blob/main/model/q_sampler.py#L503
|
| 129 |
+
def gaussian_weights(tile_width: int, tile_height: int) -> np.ndarray:
|
| 130 |
+
"""Generates a gaussian mask of weights for tile contributions"""
|
| 131 |
+
latent_width = tile_width
|
| 132 |
+
latent_height = tile_height
|
| 133 |
+
var = 0.01
|
| 134 |
+
midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
|
| 135 |
+
x_probs = [
|
| 136 |
+
np.exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / np.sqrt(2 * np.pi * var)
|
| 137 |
+
for x in range(latent_width)]
|
| 138 |
+
midpoint = latent_height / 2
|
| 139 |
+
y_probs = [
|
| 140 |
+
np.exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / np.sqrt(2 * np.pi * var)
|
| 141 |
+
for y in range(latent_height)]
|
| 142 |
+
weights = np.outer(y_probs, x_probs)
|
| 143 |
+
return weights
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
COUNT_VRAM = bool(os.environ.get("COUNT_VRAM", False))
|
| 147 |
+
|
| 148 |
+
def count_vram_usage(func: Callable) -> Callable:
|
| 149 |
+
if not COUNT_VRAM:
|
| 150 |
+
return func
|
| 151 |
+
|
| 152 |
+
def wrapper(*args, **kwargs):
|
| 153 |
+
peak_before = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
| 154 |
+
ret = func(*args, **kwargs)
|
| 155 |
+
torch.cuda.synchronize()
|
| 156 |
+
peak_after = torch.cuda.max_memory_allocated() / (1024 ** 3)
|
| 157 |
+
print(f"VRAM peak before {func.__name__}: {peak_before:.5f} GB, after: {peak_after:.5f} GB")
|
| 158 |
+
return ret
|
| 159 |
+
return wrapper
|
utils/cond_fn.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import overload, Tuple
|
| 2 |
+
import torch
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Guidance:
|
| 7 |
+
|
| 8 |
+
def __init__(self, scale: float, t_start: int, t_stop: int, space: str, repeat: int) -> "Guidance":
|
| 9 |
+
"""
|
| 10 |
+
Initialize restoration guidance.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
scale (float): Gradient scale (denoted as `s` in our paper). The larger the gradient scale,
|
| 14 |
+
the closer the final result will be to the output of the first stage model.
|
| 15 |
+
t_start (int), t_stop (int): The timestep to start or stop guidance. Note that the sampling
|
| 16 |
+
process starts from t=1000 to t=0, the `t_start` should be larger than `t_stop`.
|
| 17 |
+
space (str): The data space for computing loss function (rgb or latent).
|
| 18 |
+
|
| 19 |
+
Our restoration guidance is based on [GDP](https://github.com/Fayeben/GenerativeDiffusionPrior).
|
| 20 |
+
Thanks for their work!
|
| 21 |
+
"""
|
| 22 |
+
self.scale = scale * 3000
|
| 23 |
+
self.t_start = t_start
|
| 24 |
+
self.t_stop = t_stop
|
| 25 |
+
self.target = None
|
| 26 |
+
self.space = space
|
| 27 |
+
self.repeat = repeat
|
| 28 |
+
|
| 29 |
+
def load_target(self, target: torch.Tensor) -> None:
|
| 30 |
+
self.target = target
|
| 31 |
+
|
| 32 |
+
def __call__(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
|
| 33 |
+
# avoid propagating gradient out of this scope
|
| 34 |
+
pred_x0 = pred_x0.detach().clone()
|
| 35 |
+
target_x0 = target_x0.detach().clone()
|
| 36 |
+
return self._forward(target_x0, pred_x0, t)
|
| 37 |
+
|
| 38 |
+
@overload
|
| 39 |
+
def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
|
| 40 |
+
...
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class MSEGuidance(Guidance):
|
| 44 |
+
|
| 45 |
+
def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
|
| 46 |
+
# inputs: [-1, 1], nchw, rgb
|
| 47 |
+
with torch.enable_grad():
|
| 48 |
+
pred_x0.requires_grad_(True)
|
| 49 |
+
loss = (pred_x0 - target_x0).pow(2).mean((1, 2, 3)).sum()
|
| 50 |
+
scale = self.scale
|
| 51 |
+
g = -torch.autograd.grad(loss, pred_x0)[0] * scale
|
| 52 |
+
return g, loss.item()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class WeightedMSEGuidance(Guidance):
|
| 56 |
+
|
| 57 |
+
def _get_weight(self, target: torch.Tensor) -> torch.Tensor:
|
| 58 |
+
# convert RGB to G
|
| 59 |
+
rgb_to_gray_kernel = torch.tensor([0.2989, 0.5870, 0.1140]).view(1, 3, 1, 1)
|
| 60 |
+
target = torch.sum(target * rgb_to_gray_kernel.to(target.device), dim=1, keepdim=True)
|
| 61 |
+
# initialize sobel kernel in x and y axis
|
| 62 |
+
G_x = [
|
| 63 |
+
[1, 0, -1],
|
| 64 |
+
[2, 0, -2],
|
| 65 |
+
[1, 0, -1]
|
| 66 |
+
]
|
| 67 |
+
G_y = [
|
| 68 |
+
[1, 2, 1],
|
| 69 |
+
[0, 0, 0],
|
| 70 |
+
[-1, -2, -1]
|
| 71 |
+
]
|
| 72 |
+
G_x = torch.tensor(G_x, dtype=target.dtype, device=target.device)[None]
|
| 73 |
+
G_y = torch.tensor(G_y, dtype=target.dtype, device=target.device)[None]
|
| 74 |
+
G = torch.stack((G_x, G_y))
|
| 75 |
+
|
| 76 |
+
target = F.pad(target, (1, 1, 1, 1), mode='replicate') # padding = 1
|
| 77 |
+
grad = F.conv2d(target, G, stride=1)
|
| 78 |
+
mag = grad.pow(2).sum(dim=1, keepdim=True).sqrt()
|
| 79 |
+
|
| 80 |
+
n, c, h, w = mag.size()
|
| 81 |
+
block_size = 2
|
| 82 |
+
blocks = mag.view(n, c, h // block_size, block_size, w // block_size, block_size).permute(0, 1, 2, 4, 3, 5).contiguous()
|
| 83 |
+
block_mean = blocks.sum(dim=(-2, -1), keepdim=True).tanh().repeat(1, 1, 1, 1, block_size, block_size).permute(0, 1, 2, 4, 3, 5).contiguous()
|
| 84 |
+
block_mean = block_mean.view(n, c, h, w)
|
| 85 |
+
weight_map = 1 - block_mean
|
| 86 |
+
|
| 87 |
+
return weight_map
|
| 88 |
+
|
| 89 |
+
def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]:
|
| 90 |
+
# inputs: [-1, 1], nchw, rgb
|
| 91 |
+
with torch.no_grad():
|
| 92 |
+
w = self._get_weight((target_x0 + 1) / 2)
|
| 93 |
+
with torch.enable_grad():
|
| 94 |
+
pred_x0.requires_grad_(True)
|
| 95 |
+
loss = ((pred_x0 - target_x0).pow(2) * w).mean((1, 2, 3)).sum()
|
| 96 |
+
scale = self.scale
|
| 97 |
+
g = -torch.autograd.grad(loss, pred_x0)[0] * scale
|
| 98 |
+
return g, loss.item()
|
utils/face_restoration_helper.py
ADDED
|
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
from torchvision.transforms.functional import normalize
|
| 6 |
+
|
| 7 |
+
from facexlib.detection import init_detection_model
|
| 8 |
+
from facexlib.parsing import init_parsing_model
|
| 9 |
+
from facexlib.utils.misc import img2tensor, imwrite
|
| 10 |
+
|
| 11 |
+
from utils.common import load_file_from_url
|
| 12 |
+
|
| 13 |
+
def get_largest_face(det_faces, h, w):
|
| 14 |
+
|
| 15 |
+
def get_location(val, length):
|
| 16 |
+
if val < 0:
|
| 17 |
+
return 0
|
| 18 |
+
elif val > length:
|
| 19 |
+
return length
|
| 20 |
+
else:
|
| 21 |
+
return val
|
| 22 |
+
|
| 23 |
+
face_areas = []
|
| 24 |
+
for det_face in det_faces:
|
| 25 |
+
left = get_location(det_face[0], w)
|
| 26 |
+
right = get_location(det_face[2], w)
|
| 27 |
+
top = get_location(det_face[1], h)
|
| 28 |
+
bottom = get_location(det_face[3], h)
|
| 29 |
+
face_area = (right - left) * (bottom - top)
|
| 30 |
+
face_areas.append(face_area)
|
| 31 |
+
largest_idx = face_areas.index(max(face_areas))
|
| 32 |
+
return det_faces[largest_idx], largest_idx
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_center_face(det_faces, h=0, w=0, center=None):
|
| 36 |
+
if center is not None:
|
| 37 |
+
center = np.array(center)
|
| 38 |
+
else:
|
| 39 |
+
center = np.array([w / 2, h / 2])
|
| 40 |
+
center_dist = []
|
| 41 |
+
for det_face in det_faces:
|
| 42 |
+
face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
|
| 43 |
+
dist = np.linalg.norm(face_center - center)
|
| 44 |
+
center_dist.append(dist)
|
| 45 |
+
center_idx = center_dist.index(min(center_dist))
|
| 46 |
+
return det_faces[center_idx], center_idx
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class FaceRestoreHelper(object):
|
| 50 |
+
"""Helper for the face restoration pipeline (base class)."""
|
| 51 |
+
|
| 52 |
+
def __init__(self,
|
| 53 |
+
upscale_factor,
|
| 54 |
+
face_size=512,
|
| 55 |
+
crop_ratio=(1, 1),
|
| 56 |
+
det_model='retinaface_resnet50',
|
| 57 |
+
save_ext='png',
|
| 58 |
+
template_3points=False,
|
| 59 |
+
pad_blur=False,
|
| 60 |
+
use_parse=False,
|
| 61 |
+
device=None):
|
| 62 |
+
self.template_3points = template_3points # improve robustness
|
| 63 |
+
self.upscale_factor = int(upscale_factor)
|
| 64 |
+
# the cropped face ratio based on the square face
|
| 65 |
+
self.crop_ratio = crop_ratio # (h, w)
|
| 66 |
+
assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
|
| 67 |
+
self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
|
| 68 |
+
self.det_model = det_model
|
| 69 |
+
|
| 70 |
+
if self.det_model == 'dlib':
|
| 71 |
+
# standard 5 landmarks for FFHQ faces with 1024 x 1024
|
| 72 |
+
self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
|
| 73 |
+
[337.91089109, 488.38613861], [437.95049505, 493.51485149],
|
| 74 |
+
[513.58415842, 678.5049505]])
|
| 75 |
+
self.face_template = self.face_template / (1024 // face_size)
|
| 76 |
+
elif self.template_3points:
|
| 77 |
+
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
|
| 78 |
+
else:
|
| 79 |
+
# standard 5 landmarks for FFHQ faces with 512 x 512
|
| 80 |
+
# facexlib
|
| 81 |
+
self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
|
| 82 |
+
[201.26117, 371.41043], [313.08905, 371.15118]])
|
| 83 |
+
|
| 84 |
+
# dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
|
| 85 |
+
# self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
|
| 86 |
+
# [198.22603, 372.82502], [313.91018, 372.75659]])
|
| 87 |
+
|
| 88 |
+
self.face_template = self.face_template * (face_size / 512.0)
|
| 89 |
+
if self.crop_ratio[0] > 1:
|
| 90 |
+
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
|
| 91 |
+
if self.crop_ratio[1] > 1:
|
| 92 |
+
self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
|
| 93 |
+
self.save_ext = save_ext
|
| 94 |
+
self.pad_blur = pad_blur
|
| 95 |
+
if self.pad_blur is True:
|
| 96 |
+
self.template_3points = False
|
| 97 |
+
|
| 98 |
+
self.all_landmarks_5 = []
|
| 99 |
+
self.det_faces = []
|
| 100 |
+
self.affine_matrices = []
|
| 101 |
+
self.inverse_affine_matrices = []
|
| 102 |
+
self.cropped_faces = []
|
| 103 |
+
self.restored_faces = []
|
| 104 |
+
self.pad_input_imgs = []
|
| 105 |
+
|
| 106 |
+
if device is None:
|
| 107 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 108 |
+
# self.device = get_device()
|
| 109 |
+
else:
|
| 110 |
+
self.device = device
|
| 111 |
+
|
| 112 |
+
# init face detection model
|
| 113 |
+
self.face_detector = init_detection_model(det_model, half=False, device=self.device)
|
| 114 |
+
|
| 115 |
+
# init face parsing model
|
| 116 |
+
self.use_parse = use_parse
|
| 117 |
+
self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)
|
| 118 |
+
|
| 119 |
+
def set_upscale_factor(self, upscale_factor):
|
| 120 |
+
self.upscale_factor = upscale_factor
|
| 121 |
+
|
| 122 |
+
def read_image(self, img):
|
| 123 |
+
"""img can be image path or cv2 loaded image."""
|
| 124 |
+
# self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
|
| 125 |
+
if isinstance(img, str):
|
| 126 |
+
img = cv2.imread(img)
|
| 127 |
+
|
| 128 |
+
if np.max(img) > 256: # 16-bit image
|
| 129 |
+
img = img / 65535 * 255
|
| 130 |
+
if len(img.shape) == 2: # gray image
|
| 131 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
| 132 |
+
elif img.shape[2] == 4: # BGRA image with alpha channel
|
| 133 |
+
img = img[:, :, 0:3]
|
| 134 |
+
|
| 135 |
+
self.input_img = img
|
| 136 |
+
# self.is_gray = is_gray(img, threshold=10)
|
| 137 |
+
# if self.is_gray:
|
| 138 |
+
# print('Grayscale input: True')
|
| 139 |
+
|
| 140 |
+
if min(self.input_img.shape[:2])<512:
|
| 141 |
+
f = 512.0/min(self.input_img.shape[:2])
|
| 142 |
+
self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
|
| 143 |
+
|
| 144 |
+
def init_dlib(self, detection_path, landmark5_path):
|
| 145 |
+
"""Initialize the dlib detectors and predictors."""
|
| 146 |
+
try:
|
| 147 |
+
import dlib
|
| 148 |
+
except ImportError:
|
| 149 |
+
print('Please install dlib by running:' 'conda install -c conda-forge dlib')
|
| 150 |
+
detection_path = load_file_from_url(url=detection_path, model_dir='weights/dlib', progress=True, file_name=None)
|
| 151 |
+
landmark5_path = load_file_from_url(url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None)
|
| 152 |
+
face_detector = dlib.cnn_face_detection_model_v1(detection_path)
|
| 153 |
+
shape_predictor_5 = dlib.shape_predictor(landmark5_path)
|
| 154 |
+
return face_detector, shape_predictor_5
|
| 155 |
+
|
| 156 |
+
def get_face_landmarks_5_dlib(self,
|
| 157 |
+
only_keep_largest=False,
|
| 158 |
+
scale=1):
|
| 159 |
+
det_faces = self.face_detector(self.input_img, scale)
|
| 160 |
+
|
| 161 |
+
if len(det_faces) == 0:
|
| 162 |
+
print('No face detected. Try to increase upsample_num_times.')
|
| 163 |
+
return 0
|
| 164 |
+
else:
|
| 165 |
+
if only_keep_largest:
|
| 166 |
+
print('Detect several faces and only keep the largest.')
|
| 167 |
+
face_areas = []
|
| 168 |
+
for i in range(len(det_faces)):
|
| 169 |
+
face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
|
| 170 |
+
det_faces[i].rect.bottom() - det_faces[i].rect.top())
|
| 171 |
+
face_areas.append(face_area)
|
| 172 |
+
largest_idx = face_areas.index(max(face_areas))
|
| 173 |
+
self.det_faces = [det_faces[largest_idx]]
|
| 174 |
+
else:
|
| 175 |
+
self.det_faces = det_faces
|
| 176 |
+
|
| 177 |
+
if len(self.det_faces) == 0:
|
| 178 |
+
return 0
|
| 179 |
+
|
| 180 |
+
for face in self.det_faces:
|
| 181 |
+
shape = self.shape_predictor_5(self.input_img, face.rect)
|
| 182 |
+
landmark = np.array([[part.x, part.y] for part in shape.parts()])
|
| 183 |
+
self.all_landmarks_5.append(landmark)
|
| 184 |
+
|
| 185 |
+
return len(self.all_landmarks_5)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_face_landmarks_5(self,
|
| 189 |
+
only_keep_largest=False,
|
| 190 |
+
only_center_face=False,
|
| 191 |
+
resize=None,
|
| 192 |
+
blur_ratio=0.01,
|
| 193 |
+
eye_dist_threshold=None):
|
| 194 |
+
if self.det_model == 'dlib':
|
| 195 |
+
return self.get_face_landmarks_5_dlib(only_keep_largest)
|
| 196 |
+
|
| 197 |
+
if resize is None:
|
| 198 |
+
scale = 1
|
| 199 |
+
input_img = self.input_img
|
| 200 |
+
else:
|
| 201 |
+
h, w = self.input_img.shape[0:2]
|
| 202 |
+
scale = resize / min(h, w)
|
| 203 |
+
scale = max(1, scale) # always scale up
|
| 204 |
+
h, w = int(h * scale), int(w * scale)
|
| 205 |
+
interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
|
| 206 |
+
input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
|
| 207 |
+
|
| 208 |
+
with torch.no_grad():
|
| 209 |
+
bboxes = self.face_detector.detect_faces(input_img)
|
| 210 |
+
|
| 211 |
+
if bboxes is None or bboxes.shape[0] == 0:
|
| 212 |
+
return 0
|
| 213 |
+
else:
|
| 214 |
+
bboxes = bboxes / scale
|
| 215 |
+
|
| 216 |
+
for bbox in bboxes:
|
| 217 |
+
# remove faces with too small eye distance: side faces or too small faces
|
| 218 |
+
eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
|
| 219 |
+
if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
|
| 220 |
+
continue
|
| 221 |
+
|
| 222 |
+
if self.template_3points:
|
| 223 |
+
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
|
| 224 |
+
else:
|
| 225 |
+
landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
|
| 226 |
+
self.all_landmarks_5.append(landmark)
|
| 227 |
+
self.det_faces.append(bbox[0:5])
|
| 228 |
+
|
| 229 |
+
if len(self.det_faces) == 0:
|
| 230 |
+
return 0
|
| 231 |
+
if only_keep_largest:
|
| 232 |
+
h, w, _ = self.input_img.shape
|
| 233 |
+
self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
|
| 234 |
+
self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
|
| 235 |
+
elif only_center_face:
|
| 236 |
+
h, w, _ = self.input_img.shape
|
| 237 |
+
self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
|
| 238 |
+
self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
|
| 239 |
+
|
| 240 |
+
# pad blurry images
|
| 241 |
+
if self.pad_blur:
|
| 242 |
+
self.pad_input_imgs = []
|
| 243 |
+
for landmarks in self.all_landmarks_5:
|
| 244 |
+
# get landmarks
|
| 245 |
+
eye_left = landmarks[0, :]
|
| 246 |
+
eye_right = landmarks[1, :]
|
| 247 |
+
eye_avg = (eye_left + eye_right) * 0.5
|
| 248 |
+
mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
|
| 249 |
+
eye_to_eye = eye_right - eye_left
|
| 250 |
+
eye_to_mouth = mouth_avg - eye_avg
|
| 251 |
+
|
| 252 |
+
# Get the oriented crop rectangle
|
| 253 |
+
# x: half width of the oriented crop rectangle
|
| 254 |
+
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
|
| 255 |
+
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
|
| 256 |
+
# norm with the hypotenuse: get the direction
|
| 257 |
+
x /= np.hypot(*x) # get the hypotenuse of a right triangle
|
| 258 |
+
rect_scale = 1.5
|
| 259 |
+
x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
|
| 260 |
+
# y: half height of the oriented crop rectangle
|
| 261 |
+
y = np.flipud(x) * [-1, 1]
|
| 262 |
+
|
| 263 |
+
# c: center
|
| 264 |
+
c = eye_avg + eye_to_mouth * 0.1
|
| 265 |
+
# quad: (left_top, left_bottom, right_bottom, right_top)
|
| 266 |
+
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
|
| 267 |
+
# qsize: side length of the square
|
| 268 |
+
qsize = np.hypot(*x) * 2
|
| 269 |
+
border = max(int(np.rint(qsize * 0.1)), 3)
|
| 270 |
+
|
| 271 |
+
# get pad
|
| 272 |
+
# pad: (width_left, height_top, width_right, height_bottom)
|
| 273 |
+
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
|
| 274 |
+
int(np.ceil(max(quad[:, 1]))))
|
| 275 |
+
pad = [
|
| 276 |
+
max(-pad[0] + border, 1),
|
| 277 |
+
max(-pad[1] + border, 1),
|
| 278 |
+
max(pad[2] - self.input_img.shape[0] + border, 1),
|
| 279 |
+
max(pad[3] - self.input_img.shape[1] + border, 1)
|
| 280 |
+
]
|
| 281 |
+
|
| 282 |
+
if max(pad) > 1:
|
| 283 |
+
# pad image
|
| 284 |
+
pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
|
| 285 |
+
# modify landmark coords
|
| 286 |
+
landmarks[:, 0] += pad[0]
|
| 287 |
+
landmarks[:, 1] += pad[1]
|
| 288 |
+
# blur pad images
|
| 289 |
+
h, w, _ = pad_img.shape
|
| 290 |
+
y, x, _ = np.ogrid[:h, :w, :1]
|
| 291 |
+
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
|
| 292 |
+
np.float32(w - 1 - x) / pad[2]),
|
| 293 |
+
1.0 - np.minimum(np.float32(y) / pad[1],
|
| 294 |
+
np.float32(h - 1 - y) / pad[3]))
|
| 295 |
+
blur = int(qsize * blur_ratio)
|
| 296 |
+
if blur % 2 == 0:
|
| 297 |
+
blur += 1
|
| 298 |
+
blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
|
| 299 |
+
# blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
|
| 300 |
+
|
| 301 |
+
pad_img = pad_img.astype('float32')
|
| 302 |
+
pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
|
| 303 |
+
pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
|
| 304 |
+
pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
|
| 305 |
+
self.pad_input_imgs.append(pad_img)
|
| 306 |
+
else:
|
| 307 |
+
self.pad_input_imgs.append(np.copy(self.input_img))
|
| 308 |
+
|
| 309 |
+
return len(self.all_landmarks_5)
|
| 310 |
+
|
| 311 |
+
def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
|
| 312 |
+
"""Align and warp faces with face template.
|
| 313 |
+
"""
|
| 314 |
+
if self.pad_blur:
|
| 315 |
+
assert len(self.pad_input_imgs) == len(
|
| 316 |
+
self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
|
| 317 |
+
for idx, landmark in enumerate(self.all_landmarks_5):
|
| 318 |
+
# use 5 landmarks to get affine matrix
|
| 319 |
+
# use cv2.LMEDS method for the equivalence to skimage transform
|
| 320 |
+
# ref: https://blog.csdn.net/yichxi/article/details/115827338
|
| 321 |
+
affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
|
| 322 |
+
self.affine_matrices.append(affine_matrix)
|
| 323 |
+
# warp and crop faces
|
| 324 |
+
if border_mode == 'constant':
|
| 325 |
+
border_mode = cv2.BORDER_CONSTANT
|
| 326 |
+
elif border_mode == 'reflect101':
|
| 327 |
+
border_mode = cv2.BORDER_REFLECT101
|
| 328 |
+
elif border_mode == 'reflect':
|
| 329 |
+
border_mode = cv2.BORDER_REFLECT
|
| 330 |
+
if self.pad_blur:
|
| 331 |
+
input_img = self.pad_input_imgs[idx]
|
| 332 |
+
else:
|
| 333 |
+
input_img = self.input_img
|
| 334 |
+
cropped_face = cv2.warpAffine(
|
| 335 |
+
input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
|
| 336 |
+
self.cropped_faces.append(cropped_face)
|
| 337 |
+
# save the cropped face
|
| 338 |
+
if save_cropped_path is not None:
|
| 339 |
+
path = os.path.splitext(save_cropped_path)[0]
|
| 340 |
+
save_path = f'{path}_{idx:02d}.{self.save_ext}'
|
| 341 |
+
imwrite(cropped_face, save_path)
|
| 342 |
+
|
| 343 |
+
def get_inverse_affine(self, save_inverse_affine_path=None):
|
| 344 |
+
"""Get inverse affine matrix."""
|
| 345 |
+
for idx, affine_matrix in enumerate(self.affine_matrices):
|
| 346 |
+
inverse_affine = cv2.invertAffineTransform(affine_matrix)
|
| 347 |
+
inverse_affine *= self.upscale_factor
|
| 348 |
+
self.inverse_affine_matrices.append(inverse_affine)
|
| 349 |
+
# save inverse affine matrices
|
| 350 |
+
if save_inverse_affine_path is not None:
|
| 351 |
+
path, _ = os.path.splitext(save_inverse_affine_path)
|
| 352 |
+
save_path = f'{path}_{idx:02d}.pth'
|
| 353 |
+
torch.save(inverse_affine, save_path)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def add_restored_face(self, restored_face, input_face=None):
|
| 357 |
+
# if self.is_gray:
|
| 358 |
+
# restored_face = bgr2gray(restored_face) # convert img into grayscale
|
| 359 |
+
# if input_face is not None:
|
| 360 |
+
# restored_face = adain_npy(restored_face, input_face) # transfer the color
|
| 361 |
+
self.restored_faces.append(restored_face)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
|
| 365 |
+
h, w, _ = self.input_img.shape
|
| 366 |
+
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
|
| 367 |
+
|
| 368 |
+
if upsample_img is None:
|
| 369 |
+
# simply resize the background
|
| 370 |
+
# upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
| 371 |
+
upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
|
| 372 |
+
else:
|
| 373 |
+
upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
|
| 374 |
+
|
| 375 |
+
assert len(self.restored_faces) == len(
|
| 376 |
+
self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
|
| 377 |
+
|
| 378 |
+
inv_mask_borders = []
|
| 379 |
+
for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
|
| 380 |
+
if face_upsampler is not None:
|
| 381 |
+
restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
|
| 382 |
+
inverse_affine /= self.upscale_factor
|
| 383 |
+
inverse_affine[:, 2] *= self.upscale_factor
|
| 384 |
+
face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor)
|
| 385 |
+
else:
|
| 386 |
+
# Add an offset to inverse affine matrix, for more precise back alignment
|
| 387 |
+
if self.upscale_factor > 1:
|
| 388 |
+
extra_offset = 0.5 * self.upscale_factor
|
| 389 |
+
else:
|
| 390 |
+
extra_offset = 0
|
| 391 |
+
inverse_affine[:, 2] += extra_offset
|
| 392 |
+
face_size = self.face_size
|
| 393 |
+
inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
|
| 394 |
+
|
| 395 |
+
# if draw_box or not self.use_parse: # use square parse maps
|
| 396 |
+
# mask = np.ones(face_size, dtype=np.float32)
|
| 397 |
+
# inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
| 398 |
+
# # remove the black borders
|
| 399 |
+
# inv_mask_erosion = cv2.erode(
|
| 400 |
+
# inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
|
| 401 |
+
# pasted_face = inv_mask_erosion[:, :, None] * inv_restored
|
| 402 |
+
# total_face_area = np.sum(inv_mask_erosion) # // 3
|
| 403 |
+
# # add border
|
| 404 |
+
# if draw_box:
|
| 405 |
+
# h, w = face_size
|
| 406 |
+
# mask_border = np.ones((h, w, 3), dtype=np.float32)
|
| 407 |
+
# border = int(1400/np.sqrt(total_face_area))
|
| 408 |
+
# mask_border[border:h-border, border:w-border,:] = 0
|
| 409 |
+
# inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
|
| 410 |
+
# inv_mask_borders.append(inv_mask_border)
|
| 411 |
+
# if not self.use_parse:
|
| 412 |
+
# # compute the fusion edge based on the area of face
|
| 413 |
+
# w_edge = int(total_face_area**0.5) // 20
|
| 414 |
+
# erosion_radius = w_edge * 2
|
| 415 |
+
# inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
| 416 |
+
# blur_size = w_edge * 2
|
| 417 |
+
# inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
| 418 |
+
# if len(upsample_img.shape) == 2: # upsample_img is gray image
|
| 419 |
+
# upsample_img = upsample_img[:, :, None]
|
| 420 |
+
# inv_soft_mask = inv_soft_mask[:, :, None]
|
| 421 |
+
|
| 422 |
+
# always use square mask
|
| 423 |
+
mask = np.ones(face_size, dtype=np.float32)
|
| 424 |
+
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
|
| 425 |
+
# remove the black borders
|
| 426 |
+
inv_mask_erosion = cv2.erode(
|
| 427 |
+
inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
|
| 428 |
+
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
|
| 429 |
+
total_face_area = np.sum(inv_mask_erosion) # // 3
|
| 430 |
+
# add border
|
| 431 |
+
if draw_box:
|
| 432 |
+
h, w = face_size
|
| 433 |
+
mask_border = np.ones((h, w, 3), dtype=np.float32)
|
| 434 |
+
border = int(1400/np.sqrt(total_face_area))
|
| 435 |
+
mask_border[border:h-border, border:w-border,:] = 0
|
| 436 |
+
inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
|
| 437 |
+
inv_mask_borders.append(inv_mask_border)
|
| 438 |
+
# compute the fusion edge based on the area of face
|
| 439 |
+
w_edge = int(total_face_area**0.5) // 20
|
| 440 |
+
erosion_radius = w_edge * 2
|
| 441 |
+
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
| 442 |
+
blur_size = w_edge * 2
|
| 443 |
+
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
|
| 444 |
+
if len(upsample_img.shape) == 2: # upsample_img is gray image
|
| 445 |
+
upsample_img = upsample_img[:, :, None]
|
| 446 |
+
inv_soft_mask = inv_soft_mask[:, :, None]
|
| 447 |
+
|
| 448 |
+
# parse mask
|
| 449 |
+
if self.use_parse:
|
| 450 |
+
# inference
|
| 451 |
+
face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
|
| 452 |
+
face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
|
| 453 |
+
normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
| 454 |
+
face_input = torch.unsqueeze(face_input, 0).to(self.device)
|
| 455 |
+
with torch.no_grad():
|
| 456 |
+
out = self.face_parse(face_input)[0]
|
| 457 |
+
out = out.argmax(dim=1).squeeze().cpu().numpy()
|
| 458 |
+
|
| 459 |
+
parse_mask = np.zeros(out.shape)
|
| 460 |
+
MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
|
| 461 |
+
for idx, color in enumerate(MASK_COLORMAP):
|
| 462 |
+
parse_mask[out == idx] = color
|
| 463 |
+
# blur the mask
|
| 464 |
+
parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
|
| 465 |
+
parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
|
| 466 |
+
# remove the black borders
|
| 467 |
+
thres = 10
|
| 468 |
+
parse_mask[:thres, :] = 0
|
| 469 |
+
parse_mask[-thres:, :] = 0
|
| 470 |
+
parse_mask[:, :thres] = 0
|
| 471 |
+
parse_mask[:, -thres:] = 0
|
| 472 |
+
parse_mask = parse_mask / 255.
|
| 473 |
+
|
| 474 |
+
parse_mask = cv2.resize(parse_mask, face_size)
|
| 475 |
+
parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
|
| 476 |
+
inv_soft_parse_mask = parse_mask[:, :, None]
|
| 477 |
+
# pasted_face = inv_restored
|
| 478 |
+
fuse_mask = (inv_soft_parse_mask<inv_soft_mask).astype('int')
|
| 479 |
+
inv_soft_mask = inv_soft_parse_mask*fuse_mask + inv_soft_mask*(1-fuse_mask)
|
| 480 |
+
|
| 481 |
+
if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel
|
| 482 |
+
alpha = upsample_img[:, :, 3:]
|
| 483 |
+
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3]
|
| 484 |
+
upsample_img = np.concatenate((upsample_img, alpha), axis=2)
|
| 485 |
+
else:
|
| 486 |
+
upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img
|
| 487 |
+
|
| 488 |
+
if np.max(upsample_img) > 256: # 16-bit image
|
| 489 |
+
upsample_img = upsample_img.astype(np.uint16)
|
| 490 |
+
else:
|
| 491 |
+
upsample_img = upsample_img.astype(np.uint8)
|
| 492 |
+
|
| 493 |
+
# draw bounding box
|
| 494 |
+
if draw_box:
|
| 495 |
+
# upsample_input_img = cv2.resize(input_img, (w_up, h_up))
|
| 496 |
+
img_color = np.ones([*upsample_img.shape], dtype=np.float32)
|
| 497 |
+
img_color[:,:,0] = 0
|
| 498 |
+
img_color[:,:,1] = 255
|
| 499 |
+
img_color[:,:,2] = 0
|
| 500 |
+
for inv_mask_border in inv_mask_borders:
|
| 501 |
+
upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
|
| 502 |
+
# upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
|
| 503 |
+
|
| 504 |
+
if save_path is not None:
|
| 505 |
+
path = os.path.splitext(save_path)[0]
|
| 506 |
+
save_path = f'{path}.{self.save_ext}'
|
| 507 |
+
imwrite(upsample_img, save_path)
|
| 508 |
+
return upsample_img
|
| 509 |
+
|
| 510 |
+
def clean_all(self):
|
| 511 |
+
self.all_landmarks_5 = []
|
| 512 |
+
self.restored_faces = []
|
| 513 |
+
self.affine_matrices = []
|
| 514 |
+
self.cropped_faces = []
|
| 515 |
+
self.inverse_affine_matrices = []
|
| 516 |
+
self.det_faces = []
|
| 517 |
+
self.pad_input_imgs = []
|
utils/helpers.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import overload, Tuple, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
|
| 10 |
+
from model.cldm import ControlLDM
|
| 11 |
+
from model.gaussian_diffusion import Diffusion
|
| 12 |
+
from model.bsrnet import RRDBNet
|
| 13 |
+
from model.swinir import SwinIR
|
| 14 |
+
from model.scunet import SCUNet
|
| 15 |
+
from utils.sampler import SpacedSampler
|
| 16 |
+
from utils.cond_fn import Guidance
|
| 17 |
+
from utils.common import wavelet_decomposition, wavelet_reconstruction, count_vram_usage
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def bicubic_resize(img: np.ndarray, scale: float) -> np.ndarray:
|
| 21 |
+
pil = Image.fromarray(img)
|
| 22 |
+
res = pil.resize(tuple(int(x * scale) for x in pil.size), Image.BICUBIC)
|
| 23 |
+
return np.array(res)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def resize_short_edge_to(imgs: torch.Tensor, size: int) -> torch.Tensor:
|
| 27 |
+
_, _, h, w = imgs.size()
|
| 28 |
+
if h == w:
|
| 29 |
+
new_h, new_w = size, size
|
| 30 |
+
elif h < w:
|
| 31 |
+
new_h, new_w = size, int(w * (size / h))
|
| 32 |
+
else:
|
| 33 |
+
new_h, new_w = int(h * (size / w)), size
|
| 34 |
+
return F.interpolate(imgs, size=(new_h, new_w), mode="bicubic", antialias=True)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def pad_to_multiples_of(imgs: torch.Tensor, multiple: int) -> torch.Tensor:
|
| 38 |
+
_, _, h, w = imgs.size()
|
| 39 |
+
if h % multiple == 0 and w % multiple == 0:
|
| 40 |
+
return imgs.clone()
|
| 41 |
+
# get_pad = lambda x: (x // multiple + 1) * multiple - x
|
| 42 |
+
get_pad = lambda x: (x // multiple + int(x % multiple != 0)) * multiple - x
|
| 43 |
+
ph, pw = get_pad(h), get_pad(w)
|
| 44 |
+
return F.pad(imgs, pad=(0, pw, 0, ph), mode="constant", value=0)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class Pipeline:
|
| 48 |
+
|
| 49 |
+
def __init__(self, stage1_model: nn.Module, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None:
|
| 50 |
+
self.stage1_model = stage1_model
|
| 51 |
+
self.cldm = cldm
|
| 52 |
+
self.diffusion = diffusion
|
| 53 |
+
self.cond_fn = cond_fn
|
| 54 |
+
self.device = device
|
| 55 |
+
self.final_size: Tuple[int] = None
|
| 56 |
+
|
| 57 |
+
def set_final_size(self, lq: torch.Tensor) -> None:
|
| 58 |
+
h, w = lq.shape[2:]
|
| 59 |
+
self.final_size = (h, w)
|
| 60 |
+
|
| 61 |
+
@overload
|
| 62 |
+
def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
|
| 63 |
+
...
|
| 64 |
+
|
| 65 |
+
@count_vram_usage
|
| 66 |
+
def run_stage2(
|
| 67 |
+
self,
|
| 68 |
+
clean: torch.Tensor,
|
| 69 |
+
steps: int,
|
| 70 |
+
strength: float,
|
| 71 |
+
tiled: bool,
|
| 72 |
+
tile_size: int,
|
| 73 |
+
tile_stride: int,
|
| 74 |
+
pos_prompt: str,
|
| 75 |
+
neg_prompt: str,
|
| 76 |
+
cfg_scale: float,
|
| 77 |
+
better_start: float
|
| 78 |
+
) -> torch.Tensor:
|
| 79 |
+
### preprocess
|
| 80 |
+
bs, _, ori_h, ori_w = clean.shape
|
| 81 |
+
# pad: ensure that height & width are multiples of 64
|
| 82 |
+
pad_clean = pad_to_multiples_of(clean, multiple=64)
|
| 83 |
+
h, w = pad_clean.shape[2:]
|
| 84 |
+
# prepare conditon
|
| 85 |
+
if not tiled:
|
| 86 |
+
cond = self.cldm.prepare_condition(pad_clean, [pos_prompt] * bs)
|
| 87 |
+
uncond = self.cldm.prepare_condition(pad_clean, [neg_prompt] * bs)
|
| 88 |
+
else:
|
| 89 |
+
cond = self.cldm.prepare_condition_tiled(pad_clean, [pos_prompt] * bs, tile_size, tile_stride)
|
| 90 |
+
uncond = self.cldm.prepare_condition_tiled(pad_clean, [neg_prompt] * bs, tile_size, tile_stride)
|
| 91 |
+
if self.cond_fn:
|
| 92 |
+
self.cond_fn.load_target(pad_clean * 2 - 1)
|
| 93 |
+
old_control_scales = self.cldm.control_scales
|
| 94 |
+
self.cldm.control_scales = [strength] * 13
|
| 95 |
+
if better_start:
|
| 96 |
+
# using noised low frequency part of condition as a better start point of
|
| 97 |
+
# reverse sampling, which can prevent our model from generating noise in
|
| 98 |
+
# image background.
|
| 99 |
+
_, low_freq = wavelet_decomposition(pad_clean)
|
| 100 |
+
if not tiled:
|
| 101 |
+
x_0 = self.cldm.vae_encode(low_freq)
|
| 102 |
+
else:
|
| 103 |
+
x_0 = self.cldm.vae_encode_tiled(low_freq, tile_size, tile_stride)
|
| 104 |
+
x_T = self.diffusion.q_sample(
|
| 105 |
+
x_0,
|
| 106 |
+
torch.full((bs, ), self.diffusion.num_timesteps - 1, dtype=torch.long, device=self.device),
|
| 107 |
+
torch.randn(x_0.shape, dtype=torch.float32, device=self.device)
|
| 108 |
+
)
|
| 109 |
+
# print(f"diffusion sqrt_alphas_cumprod: {self.diffusion.sqrt_alphas_cumprod[-1]}")
|
| 110 |
+
else:
|
| 111 |
+
x_T = torch.randn((bs, 4, h // 8, w // 8), dtype=torch.float32, device=self.device)
|
| 112 |
+
### run sampler
|
| 113 |
+
sampler = SpacedSampler(self.diffusion.betas)
|
| 114 |
+
z = sampler.sample(
|
| 115 |
+
model=self.cldm, device=self.device, steps=steps, batch_size=bs, x_size=(4, h // 8, w // 8),
|
| 116 |
+
cond=cond, uncond=uncond, cfg_scale=cfg_scale, x_T=x_T, progress=True,
|
| 117 |
+
progress_leave=True, cond_fn=self.cond_fn, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
| 118 |
+
)
|
| 119 |
+
if not tiled:
|
| 120 |
+
x = self.cldm.vae_decode(z)
|
| 121 |
+
else:
|
| 122 |
+
x = self.cldm.vae_decode_tiled(z, tile_size // 8, tile_stride // 8)
|
| 123 |
+
### postprocess
|
| 124 |
+
self.cldm.control_scales = old_control_scales
|
| 125 |
+
sample = x[:, :, :ori_h, :ori_w]
|
| 126 |
+
return sample
|
| 127 |
+
|
| 128 |
+
@torch.no_grad()
|
| 129 |
+
def run(
|
| 130 |
+
self,
|
| 131 |
+
lq: np.ndarray,
|
| 132 |
+
steps: int,
|
| 133 |
+
strength: float,
|
| 134 |
+
tiled: bool,
|
| 135 |
+
tile_size: int,
|
| 136 |
+
tile_stride: int,
|
| 137 |
+
pos_prompt: str,
|
| 138 |
+
neg_prompt: str,
|
| 139 |
+
cfg_scale: float,
|
| 140 |
+
better_start: bool
|
| 141 |
+
) -> np.ndarray:
|
| 142 |
+
# image to tensor
|
| 143 |
+
lq = torch.tensor((lq / 255.).clip(0, 1), dtype=torch.float32, device=self.device)
|
| 144 |
+
lq = rearrange(lq, "n h w c -> n c h w").contiguous()
|
| 145 |
+
# set pipeline output size
|
| 146 |
+
self.set_final_size(lq)
|
| 147 |
+
clean = self.run_stage1(lq)
|
| 148 |
+
sample = self.run_stage2(
|
| 149 |
+
clean, steps, strength, tiled, tile_size, tile_stride,
|
| 150 |
+
pos_prompt, neg_prompt, cfg_scale, better_start
|
| 151 |
+
)
|
| 152 |
+
# colorfix (borrowed from StableSR, thanks for their work)
|
| 153 |
+
sample = (sample + 1) / 2
|
| 154 |
+
sample = wavelet_reconstruction(sample, clean)
|
| 155 |
+
# resize to desired output size
|
| 156 |
+
sample = F.interpolate(sample, size=self.final_size, mode="bicubic", antialias=True)
|
| 157 |
+
# tensor to image
|
| 158 |
+
sample = rearrange(sample * 255., "n c h w -> n h w c")
|
| 159 |
+
sample = sample.contiguous().clamp(0, 255).to(torch.uint8).cpu().numpy()
|
| 160 |
+
return sample
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class BSRNetPipeline(Pipeline):
|
| 164 |
+
|
| 165 |
+
def __init__(self, bsrnet: RRDBNet, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str, upscale: float) -> None:
|
| 166 |
+
super().__init__(bsrnet, cldm, diffusion, cond_fn, device)
|
| 167 |
+
self.upscale = upscale
|
| 168 |
+
|
| 169 |
+
def set_final_size(self, lq: torch.Tensor) -> None:
|
| 170 |
+
h, w = lq.shape[2:]
|
| 171 |
+
self.final_size = (int(h * self.upscale), int(w * self.upscale))
|
| 172 |
+
|
| 173 |
+
@count_vram_usage
|
| 174 |
+
def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
|
| 175 |
+
# NOTE: upscale is always set to 4 in our experiments
|
| 176 |
+
clean = self.stage1_model(lq)
|
| 177 |
+
# if self.final_size[0] < 512 and self.final_size[1] < 512:
|
| 178 |
+
if min(self.final_size) < 512:
|
| 179 |
+
clean = resize_short_edge_to(clean, size=512)
|
| 180 |
+
else:
|
| 181 |
+
clean = F.interpolate(clean, size=self.final_size, mode="bicubic", antialias=True)
|
| 182 |
+
return clean
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class SwinIRPipeline(Pipeline):
|
| 186 |
+
|
| 187 |
+
def __init__(self, swinir: SwinIR, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None:
|
| 188 |
+
super().__init__(swinir, cldm, diffusion, cond_fn, device)
|
| 189 |
+
|
| 190 |
+
@count_vram_usage
|
| 191 |
+
def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
|
| 192 |
+
# NOTE: lq size is always equal to 512 in our experiments
|
| 193 |
+
# resize: ensure the input lq size is as least 512, since SwinIR is trained on 512 resolution
|
| 194 |
+
if min(lq.shape[2:]) < 512:
|
| 195 |
+
lq = resize_short_edge_to(lq, size=512)
|
| 196 |
+
ori_h, ori_w = lq.shape[2:]
|
| 197 |
+
# pad: ensure that height & width are multiples of 64
|
| 198 |
+
pad_lq = pad_to_multiples_of(lq, multiple=64)
|
| 199 |
+
# run
|
| 200 |
+
clean = self.stage1_model(pad_lq)
|
| 201 |
+
# remove padding
|
| 202 |
+
clean = clean[:, :, :ori_h, :ori_w]
|
| 203 |
+
return clean
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class SCUNetPipeline(Pipeline):
|
| 207 |
+
|
| 208 |
+
def __init__(self, scunet: SCUNet, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None:
|
| 209 |
+
super().__init__(scunet, cldm, diffusion, cond_fn, device)
|
| 210 |
+
|
| 211 |
+
@count_vram_usage
|
| 212 |
+
def run_stage1(self, lq: torch.Tensor) -> torch.Tensor:
|
| 213 |
+
clean = self.stage1_model(lq)
|
| 214 |
+
if min(clean.shape[2:]) < 512:
|
| 215 |
+
clean = resize_short_edge_to(clean, size=512)
|
| 216 |
+
return clean
|
utils/inference.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import overload, Generator, Dict
|
| 3 |
+
from argparse import Namespace
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from omegaconf import OmegaConf
|
| 9 |
+
|
| 10 |
+
from model.cldm import ControlLDM
|
| 11 |
+
from model.gaussian_diffusion import Diffusion
|
| 12 |
+
from model.bsrnet import RRDBNet
|
| 13 |
+
from model.scunet import SCUNet
|
| 14 |
+
from model.swinir import SwinIR
|
| 15 |
+
from utils.common import instantiate_from_config, load_file_from_url, count_vram_usage
|
| 16 |
+
from utils.face_restoration_helper import FaceRestoreHelper
|
| 17 |
+
from utils.helpers import (
|
| 18 |
+
Pipeline,
|
| 19 |
+
BSRNetPipeline, SwinIRPipeline, SCUNetPipeline,
|
| 20 |
+
bicubic_resize
|
| 21 |
+
)
|
| 22 |
+
from utils.cond_fn import MSEGuidance, WeightedMSEGuidance
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
MODELS = {
|
| 26 |
+
### stage_1 model weights
|
| 27 |
+
"bsrnet": "https://github.com/cszn/KAIR/releases/download/v1.0/BSRNet.pth",
|
| 28 |
+
# the following checkpoint is up-to-date, but we use the old version in our paper
|
| 29 |
+
# "swinir_face": "https://github.com/zsyOAOA/DifFace/releases/download/V1.0/General_Face_ffhq512.pth",
|
| 30 |
+
"swinir_face": "https://huggingface.co/lxq007/DiffBIR/resolve/main/face_swinir_v1.ckpt",
|
| 31 |
+
"scunet_psnr": "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth",
|
| 32 |
+
"swinir_general": "https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt",
|
| 33 |
+
### stage_2 model weights
|
| 34 |
+
"sd_v21": "https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt",
|
| 35 |
+
"v1_face": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_face.pth",
|
| 36 |
+
"v1_general": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v1_general.pth",
|
| 37 |
+
"v2": "https://huggingface.co/lxq007/DiffBIR-v2/resolve/main/v2.pth"
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def load_model_from_url(url: str) -> Dict[str, torch.Tensor]:
|
| 42 |
+
sd_path = load_file_from_url(url, model_dir="weights")
|
| 43 |
+
sd = torch.load(sd_path, map_location="cpu")
|
| 44 |
+
if "state_dict" in sd:
|
| 45 |
+
sd = sd["state_dict"]
|
| 46 |
+
if list(sd.keys())[0].startswith("module"):
|
| 47 |
+
sd = {k[len("module."):]: v for k, v in sd.items()}
|
| 48 |
+
return sd
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class InferenceLoop:
|
| 52 |
+
|
| 53 |
+
def __init__(self, args: Namespace) -> "InferenceLoop":
|
| 54 |
+
self.args = args
|
| 55 |
+
self.loop_ctx = {}
|
| 56 |
+
self.pipeline: Pipeline = None
|
| 57 |
+
self.init_stage1_model()
|
| 58 |
+
self.init_stage2_model()
|
| 59 |
+
self.init_cond_fn()
|
| 60 |
+
self.init_pipeline()
|
| 61 |
+
|
| 62 |
+
@overload
|
| 63 |
+
def init_stage1_model(self) -> None:
|
| 64 |
+
...
|
| 65 |
+
|
| 66 |
+
@count_vram_usage
|
| 67 |
+
def init_stage2_model(self) -> None:
|
| 68 |
+
### load uent, vae, clip
|
| 69 |
+
self.cldm: ControlLDM = instantiate_from_config(OmegaConf.load("configs/inference/cldm.yaml"))
|
| 70 |
+
sd = load_model_from_url(MODELS["sd_v21"])
|
| 71 |
+
unused = self.cldm.load_pretrained_sd(sd)
|
| 72 |
+
print(f"strictly load pretrained sd_v2.1, unused weights: {unused}")
|
| 73 |
+
### load controlnet
|
| 74 |
+
if self.args.version == "v1":
|
| 75 |
+
if self.args.task == "fr":
|
| 76 |
+
control_sd = load_model_from_url(MODELS["v1_face"])
|
| 77 |
+
elif self.args.task == "sr":
|
| 78 |
+
control_sd = load_model_from_url(MODELS["v1_general"])
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError(f"DiffBIR v1 doesn't support task: {self.args.task}, please use v2 by passsing '--version v2'")
|
| 81 |
+
else:
|
| 82 |
+
control_sd = load_model_from_url(MODELS["v2"])
|
| 83 |
+
self.cldm.load_controlnet_from_ckpt(control_sd)
|
| 84 |
+
print(f"strictly load controlnet weight")
|
| 85 |
+
self.cldm.eval().to(self.args.device)
|
| 86 |
+
### load diffusion
|
| 87 |
+
self.diffusion: Diffusion = instantiate_from_config(OmegaConf.load("configs/inference/diffusion.yaml"))
|
| 88 |
+
self.diffusion.to(self.args.device)
|
| 89 |
+
|
| 90 |
+
def init_cond_fn(self) -> None:
|
| 91 |
+
if not self.args.guidance:
|
| 92 |
+
self.cond_fn = None
|
| 93 |
+
return
|
| 94 |
+
if self.args.g_loss == "mse":
|
| 95 |
+
cond_fn_cls = MSEGuidance
|
| 96 |
+
elif self.args.g_loss == "w_mse":
|
| 97 |
+
cond_fn_cls = WeightedMSEGuidance
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError(self.args.g_loss)
|
| 100 |
+
self.cond_fn = cond_fn_cls(
|
| 101 |
+
scale=self.args.g_scale, t_start=self.args.g_start, t_stop=self.args.g_stop,
|
| 102 |
+
space=self.args.g_space, repeat=self.args.g_repeat
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
@overload
|
| 106 |
+
def init_pipeline(self) -> None:
|
| 107 |
+
...
|
| 108 |
+
|
| 109 |
+
def setup(self) -> None:
|
| 110 |
+
self.output_dir = self.args.output
|
| 111 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
| 112 |
+
|
| 113 |
+
def lq_loader(self) -> Generator[np.ndarray, None, None]:
|
| 114 |
+
img_exts = [".png", ".jpg", ".jpeg"]
|
| 115 |
+
if os.path.isdir(self.args.input):
|
| 116 |
+
file_names = sorted([
|
| 117 |
+
file_name for file_name in os.listdir(self.args.input) if os.path.splitext(file_name)[-1] in img_exts
|
| 118 |
+
])
|
| 119 |
+
file_paths = [os.path.join(self.args.input, file_name) for file_name in file_names]
|
| 120 |
+
else:
|
| 121 |
+
assert os.path.splitext(self.args.input)[-1] in img_exts
|
| 122 |
+
file_paths = [self.args.input]
|
| 123 |
+
|
| 124 |
+
def _loader() -> Generator[np.ndarray, None, None]:
|
| 125 |
+
for file_path in file_paths:
|
| 126 |
+
### load lq
|
| 127 |
+
lq = np.array(Image.open(file_path).convert("RGB"))
|
| 128 |
+
print(f"load lq: {file_path}")
|
| 129 |
+
### set context for saving results
|
| 130 |
+
self.loop_ctx["file_stem"] = os.path.splitext(os.path.basename(file_path))[0]
|
| 131 |
+
for i in range(self.args.n_samples):
|
| 132 |
+
self.loop_ctx["repeat_idx"] = i
|
| 133 |
+
yield lq
|
| 134 |
+
|
| 135 |
+
return _loader
|
| 136 |
+
|
| 137 |
+
def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
|
| 138 |
+
return lq
|
| 139 |
+
|
| 140 |
+
@torch.no_grad()
|
| 141 |
+
def run(self) -> None:
|
| 142 |
+
self.setup()
|
| 143 |
+
# We don't support batch processing since input images may have different size
|
| 144 |
+
loader = self.lq_loader()
|
| 145 |
+
for lq in loader():
|
| 146 |
+
lq = self.after_load_lq(lq)
|
| 147 |
+
sample = self.pipeline.run(
|
| 148 |
+
lq[None], self.args.steps, 1.0, self.args.tiled,
|
| 149 |
+
self.args.tile_size, self.args.tile_stride,
|
| 150 |
+
self.args.pos_prompt, self.args.neg_prompt, self.args.cfg_scale,
|
| 151 |
+
self.args.better_start
|
| 152 |
+
)[0]
|
| 153 |
+
self.save(sample)
|
| 154 |
+
|
| 155 |
+
def save(self, sample: np.ndarray) -> None:
|
| 156 |
+
file_stem, repeat_idx = self.loop_ctx["file_stem"], self.loop_ctx["repeat_idx"]
|
| 157 |
+
file_name = f"{file_stem}_{repeat_idx}.png" if self.args.n_samples > 1 else f"{file_stem}.png"
|
| 158 |
+
save_path = os.path.join(self.args.output, file_name)
|
| 159 |
+
Image.fromarray(sample).save(save_path)
|
| 160 |
+
print(f"save result to {save_path}")
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class BSRInferenceLoop(InferenceLoop):
|
| 164 |
+
|
| 165 |
+
@count_vram_usage
|
| 166 |
+
def init_stage1_model(self) -> None:
|
| 167 |
+
self.bsrnet: RRDBNet = instantiate_from_config(OmegaConf.load("configs/inference/bsrnet.yaml"))
|
| 168 |
+
sd = load_model_from_url(MODELS["bsrnet"])
|
| 169 |
+
self.bsrnet.load_state_dict(sd, strict=True)
|
| 170 |
+
self.bsrnet.eval().to(self.args.device)
|
| 171 |
+
|
| 172 |
+
def init_pipeline(self) -> None:
|
| 173 |
+
self.pipeline = BSRNetPipeline(self.bsrnet, self.cldm, self.diffusion, self.cond_fn, self.args.device, self.args.upscale)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class BFRInferenceLoop(InferenceLoop):
|
| 177 |
+
|
| 178 |
+
@count_vram_usage
|
| 179 |
+
def init_stage1_model(self) -> None:
|
| 180 |
+
self.swinir_face: SwinIR = instantiate_from_config(OmegaConf.load("configs/inference/swinir.yaml"))
|
| 181 |
+
sd = load_model_from_url(MODELS["swinir_face"])
|
| 182 |
+
self.swinir_face.load_state_dict(sd, strict=True)
|
| 183 |
+
self.swinir_face.eval().to(self.args.device)
|
| 184 |
+
|
| 185 |
+
def init_pipeline(self) -> None:
|
| 186 |
+
self.pipeline = SwinIRPipeline(self.swinir_face, self.cldm, self.diffusion, self.cond_fn, self.args.device)
|
| 187 |
+
|
| 188 |
+
def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
|
| 189 |
+
# For BFR task, super resolution is achieved by directly upscaling lq
|
| 190 |
+
return bicubic_resize(lq, self.args.upscale)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class BIDInferenceLoop(InferenceLoop):
|
| 194 |
+
|
| 195 |
+
@count_vram_usage
|
| 196 |
+
def init_stage1_model(self) -> None:
|
| 197 |
+
self.scunet_psnr: SCUNet = instantiate_from_config(OmegaConf.load("configs/inference/scunet.yaml"))
|
| 198 |
+
sd = load_model_from_url(MODELS["scunet_psnr"])
|
| 199 |
+
self.scunet_psnr.load_state_dict(sd, strict=True)
|
| 200 |
+
self.scunet_psnr.eval().to(self.args.device)
|
| 201 |
+
|
| 202 |
+
def init_pipeline(self) -> None:
|
| 203 |
+
self.pipeline = SCUNetPipeline(self.scunet_psnr, self.cldm, self.diffusion, self.cond_fn, self.args.device)
|
| 204 |
+
|
| 205 |
+
def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
|
| 206 |
+
# For BID task, super resolution is achieved by directly upscaling lq
|
| 207 |
+
return bicubic_resize(lq, self.args.upscale)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class V1InferenceLoop(InferenceLoop):
|
| 211 |
+
|
| 212 |
+
@count_vram_usage
|
| 213 |
+
def init_stage1_model(self) -> None:
|
| 214 |
+
self.swinir: SwinIR = instantiate_from_config(OmegaConf.load("configs/inference/swinir.yaml"))
|
| 215 |
+
if self.args.task == "fr":
|
| 216 |
+
sd = load_model_from_url(MODELS["swinir_face"])
|
| 217 |
+
elif self.args.task == "sr":
|
| 218 |
+
sd = load_model_from_url(MODELS["swinir_general"])
|
| 219 |
+
else:
|
| 220 |
+
raise ValueError(f"DiffBIR v1 doesn't support task: {self.args.task}, please use v2 by passsing '--version v2'")
|
| 221 |
+
self.swinir.load_state_dict(sd, strict=True)
|
| 222 |
+
self.swinir.eval().to(self.args.device)
|
| 223 |
+
|
| 224 |
+
def init_pipeline(self) -> None:
|
| 225 |
+
self.pipeline = SwinIRPipeline(self.swinir, self.cldm, self.diffusion, self.cond_fn, self.args.device)
|
| 226 |
+
|
| 227 |
+
def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
|
| 228 |
+
# For BFR task, super resolution is achieved by directly upscaling lq
|
| 229 |
+
return bicubic_resize(lq, self.args.upscale)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class UnAlignedBFRInferenceLoop(InferenceLoop):
|
| 233 |
+
|
| 234 |
+
@count_vram_usage
|
| 235 |
+
def init_stage1_model(self) -> None:
|
| 236 |
+
self.bsrnet: RRDBNet = instantiate_from_config(OmegaConf.load("configs/inference/bsrnet.yaml"))
|
| 237 |
+
sd = load_model_from_url(MODELS["bsrnet"])
|
| 238 |
+
self.bsrnet.load_state_dict(sd, strict=True)
|
| 239 |
+
self.bsrnet.eval().to(self.args.device)
|
| 240 |
+
|
| 241 |
+
self.swinir_face: SwinIR = instantiate_from_config(OmegaConf.load("configs/inference/swinir.yaml"))
|
| 242 |
+
sd = load_model_from_url(MODELS["swinir_face"])
|
| 243 |
+
self.swinir_face.load_state_dict(sd, strict=True)
|
| 244 |
+
self.swinir_face.eval().to(self.args.device)
|
| 245 |
+
|
| 246 |
+
def init_pipeline(self) -> None:
|
| 247 |
+
self.pipes = {
|
| 248 |
+
"bg": BSRNetPipeline(self.bsrnet, self.cldm, self.diffusion, self.cond_fn, self.args.device, self.args.upscale),
|
| 249 |
+
"face": SwinIRPipeline(self.swinir_face, self.cldm, self.diffusion, self.cond_fn, self.args.device)
|
| 250 |
+
}
|
| 251 |
+
self.pipeline = self.pipes["face"]
|
| 252 |
+
|
| 253 |
+
def setup(self) -> None:
|
| 254 |
+
super().setup()
|
| 255 |
+
self.cropped_face_dir = os.path.join(self.args.output, "cropped_faces")
|
| 256 |
+
os.makedirs(self.cropped_face_dir, exist_ok=True)
|
| 257 |
+
self.restored_face_dir = os.path.join(self.args.output, "restored_faces")
|
| 258 |
+
os.makedirs(self.restored_face_dir, exist_ok=True)
|
| 259 |
+
self.restored_bg_dir = os.path.join(self.args.output, "restored_backgrounds")
|
| 260 |
+
os.makedirs(self.restored_bg_dir, exist_ok=True)
|
| 261 |
+
|
| 262 |
+
def lq_loader(self) -> Generator[np.ndarray, None, None]:
|
| 263 |
+
base_loader = super().lq_loader()
|
| 264 |
+
self.face_helper = FaceRestoreHelper(
|
| 265 |
+
device=self.args.device,
|
| 266 |
+
upscale_factor=1,
|
| 267 |
+
face_size=512,
|
| 268 |
+
use_parse=True,
|
| 269 |
+
det_model="retinaface_resnet50"
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
def _loader() -> Generator[np.ndarray, None, None]:
|
| 273 |
+
for lq in base_loader():
|
| 274 |
+
### set input image
|
| 275 |
+
self.face_helper.clean_all()
|
| 276 |
+
upscaled_bg = bicubic_resize(lq, self.args.upscale)
|
| 277 |
+
self.face_helper.read_image(upscaled_bg)
|
| 278 |
+
### get face landmarks for each face
|
| 279 |
+
self.face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
|
| 280 |
+
self.face_helper.align_warp_face()
|
| 281 |
+
print(f"detect {len(self.face_helper.cropped_faces)} faces")
|
| 282 |
+
### restore each face (has been upscaeled)
|
| 283 |
+
for i, lq_face in enumerate(self.face_helper.cropped_faces):
|
| 284 |
+
self.loop_ctx["is_face"] = True
|
| 285 |
+
self.loop_ctx["face_idx"] = i
|
| 286 |
+
self.loop_ctx["cropped_face"] = lq_face
|
| 287 |
+
yield lq_face
|
| 288 |
+
### restore background (hasn't been upscaled)
|
| 289 |
+
self.loop_ctx["is_face"] = False
|
| 290 |
+
yield lq
|
| 291 |
+
|
| 292 |
+
return _loader
|
| 293 |
+
|
| 294 |
+
def after_load_lq(self, lq: np.ndarray) -> np.ndarray:
|
| 295 |
+
if self.loop_ctx["is_face"]:
|
| 296 |
+
self.pipeline = self.pipes["face"]
|
| 297 |
+
else:
|
| 298 |
+
self.pipeline = self.pipes["bg"]
|
| 299 |
+
return lq
|
| 300 |
+
|
| 301 |
+
def save(self, sample: np.ndarray) -> None:
|
| 302 |
+
file_stem, repeat_idx = self.loop_ctx["file_stem"], self.loop_ctx["repeat_idx"]
|
| 303 |
+
if self.loop_ctx["is_face"]:
|
| 304 |
+
face_idx = self.loop_ctx["face_idx"]
|
| 305 |
+
file_name = f"{file_stem}_{repeat_idx}_face_{face_idx}.png"
|
| 306 |
+
Image.fromarray(sample).save(os.path.join(self.restored_face_dir, file_name))
|
| 307 |
+
|
| 308 |
+
cropped_face = self.loop_ctx["cropped_face"]
|
| 309 |
+
Image.fromarray(cropped_face).save(os.path.join(self.cropped_face_dir, file_name))
|
| 310 |
+
|
| 311 |
+
self.face_helper.add_restored_face(sample)
|
| 312 |
+
else:
|
| 313 |
+
self.face_helper.get_inverse_affine()
|
| 314 |
+
# paste each restored face to the input image
|
| 315 |
+
restored_img = self.face_helper.paste_faces_to_input_image(
|
| 316 |
+
upsample_img=sample
|
| 317 |
+
)
|
| 318 |
+
file_name = f"{file_stem}_{repeat_idx}.png"
|
| 319 |
+
Image.fromarray(sample).save(os.path.join(self.restored_bg_dir, file_name))
|
| 320 |
+
Image.fromarray(restored_img).save(os.path.join(self.output_dir, file_name))
|
utils/sampler.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple, Dict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
from model.gaussian_diffusion import extract_into_tensor
|
| 9 |
+
from model.cldm import ControlLDM
|
| 10 |
+
from utils.cond_fn import Guidance
|
| 11 |
+
from utils.common import sliding_windows, gaussian_weights
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
|
| 15 |
+
def space_timesteps(num_timesteps, section_counts):
|
| 16 |
+
"""
|
| 17 |
+
Create a list of timesteps to use from an original diffusion process,
|
| 18 |
+
given the number of timesteps we want to take from equally-sized portions
|
| 19 |
+
of the original process.
|
| 20 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
| 21 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
| 22 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
| 23 |
+
If the stride is a string starting with "ddim", then the fixed striding
|
| 24 |
+
from the DDIM paper is used, and only one section is allowed.
|
| 25 |
+
:param num_timesteps: the number of diffusion steps in the original
|
| 26 |
+
process to divide up.
|
| 27 |
+
:param section_counts: either a list of numbers, or a string containing
|
| 28 |
+
comma-separated numbers, indicating the step count
|
| 29 |
+
per section. As a special case, use "ddimN" where N
|
| 30 |
+
is a number of steps to use the striding from the
|
| 31 |
+
DDIM paper.
|
| 32 |
+
:return: a set of diffusion steps from the original process to use.
|
| 33 |
+
"""
|
| 34 |
+
if isinstance(section_counts, str):
|
| 35 |
+
if section_counts.startswith("ddim"):
|
| 36 |
+
desired_count = int(section_counts[len("ddim") :])
|
| 37 |
+
for i in range(1, num_timesteps):
|
| 38 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
| 39 |
+
return set(range(0, num_timesteps, i))
|
| 40 |
+
raise ValueError(
|
| 41 |
+
f"cannot create exactly {num_timesteps} steps with an integer stride"
|
| 42 |
+
)
|
| 43 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
| 44 |
+
size_per = num_timesteps // len(section_counts)
|
| 45 |
+
extra = num_timesteps % len(section_counts)
|
| 46 |
+
start_idx = 0
|
| 47 |
+
all_steps = []
|
| 48 |
+
for i, section_count in enumerate(section_counts):
|
| 49 |
+
size = size_per + (1 if i < extra else 0)
|
| 50 |
+
if size < section_count:
|
| 51 |
+
raise ValueError(
|
| 52 |
+
f"cannot divide section of {size} steps into {section_count}"
|
| 53 |
+
)
|
| 54 |
+
if section_count <= 1:
|
| 55 |
+
frac_stride = 1
|
| 56 |
+
else:
|
| 57 |
+
frac_stride = (size - 1) / (section_count - 1)
|
| 58 |
+
cur_idx = 0.0
|
| 59 |
+
taken_steps = []
|
| 60 |
+
for _ in range(section_count):
|
| 61 |
+
taken_steps.append(start_idx + round(cur_idx))
|
| 62 |
+
cur_idx += frac_stride
|
| 63 |
+
all_steps += taken_steps
|
| 64 |
+
start_idx += size
|
| 65 |
+
return set(all_steps)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class SpacedSampler(nn.Module):
|
| 69 |
+
"""
|
| 70 |
+
Implementation for spaced sampling schedule proposed in IDDPM. This class is designed
|
| 71 |
+
for sampling ControlLDM.
|
| 72 |
+
|
| 73 |
+
https://arxiv.org/pdf/2102.09672.pdf
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(self, betas: np.ndarray) -> "SpacedSampler":
|
| 77 |
+
super().__init__()
|
| 78 |
+
self.num_timesteps = len(betas)
|
| 79 |
+
self.original_betas = betas
|
| 80 |
+
self.original_alphas_cumprod = np.cumprod(1.0 - betas, axis=0)
|
| 81 |
+
self.context = {}
|
| 82 |
+
|
| 83 |
+
def register(self, name: str, value: np.ndarray) -> None:
|
| 84 |
+
self.register_buffer(name, torch.tensor(value, dtype=torch.float32))
|
| 85 |
+
|
| 86 |
+
def make_schedule(self, num_steps: int) -> None:
|
| 87 |
+
# calcualte betas for spaced sampling
|
| 88 |
+
# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
|
| 89 |
+
used_timesteps = space_timesteps(self.num_timesteps, str(num_steps))
|
| 90 |
+
betas = []
|
| 91 |
+
last_alpha_cumprod = 1.0
|
| 92 |
+
for i, alpha_cumprod in enumerate(self.original_alphas_cumprod):
|
| 93 |
+
if i in used_timesteps:
|
| 94 |
+
# marginal distribution is the same as q(x_{S_t}|x_0)
|
| 95 |
+
betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
| 96 |
+
last_alpha_cumprod = alpha_cumprod
|
| 97 |
+
assert len(betas) == num_steps
|
| 98 |
+
self.timesteps = np.array(sorted(list(used_timesteps)), dtype=np.int32) # e.g. [0, 10, 20, ...]
|
| 99 |
+
|
| 100 |
+
betas = np.array(betas, dtype=np.float64)
|
| 101 |
+
alphas = 1.0 - betas
|
| 102 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
| 103 |
+
# print(f"sampler sqrt_alphas_cumprod: {np.sqrt(alphas_cumprod)[-1]}")
|
| 104 |
+
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
|
| 105 |
+
sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod)
|
| 106 |
+
sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1)
|
| 107 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 108 |
+
posterior_variance = (
|
| 109 |
+
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
| 110 |
+
)
|
| 111 |
+
# log calculation clipped because the posterior variance is 0 at the
|
| 112 |
+
# beginning of the diffusion chain.
|
| 113 |
+
posterior_log_variance_clipped = np.log(
|
| 114 |
+
np.append(posterior_variance[1], posterior_variance[1:])
|
| 115 |
+
)
|
| 116 |
+
posterior_mean_coef1 = (
|
| 117 |
+
betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
| 118 |
+
)
|
| 119 |
+
posterior_mean_coef2 = (
|
| 120 |
+
(1.0 - alphas_cumprod_prev)
|
| 121 |
+
* np.sqrt(alphas)
|
| 122 |
+
/ (1.0 - alphas_cumprod)
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
self.register("sqrt_recip_alphas_cumprod", sqrt_recip_alphas_cumprod)
|
| 126 |
+
self.register("sqrt_recipm1_alphas_cumprod", sqrt_recipm1_alphas_cumprod)
|
| 127 |
+
self.register("posterior_variance", posterior_variance)
|
| 128 |
+
self.register("posterior_log_variance_clipped", posterior_log_variance_clipped)
|
| 129 |
+
self.register("posterior_mean_coef1", posterior_mean_coef1)
|
| 130 |
+
self.register("posterior_mean_coef2", posterior_mean_coef2)
|
| 131 |
+
|
| 132 |
+
def q_posterior_mean_variance(self, x_start: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor]:
|
| 133 |
+
"""
|
| 134 |
+
Implement the posterior distribution q(x_{t-1}|x_t, x_0).
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
x_start (torch.Tensor): The predicted images (NCHW) in timestep `t`.
|
| 138 |
+
x_t (torch.Tensor): The sampled intermediate variables (NCHW) of timestep `t`.
|
| 139 |
+
t (torch.Tensor): Timestep (N) of `x_t`. `t` serves as an index to get
|
| 140 |
+
parameters for each timestep.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
posterior_mean (torch.Tensor): Mean of the posterior distribution.
|
| 144 |
+
posterior_variance (torch.Tensor): Variance of the posterior distribution.
|
| 145 |
+
posterior_log_variance_clipped (torch.Tensor): Log variance of the posterior distribution.
|
| 146 |
+
"""
|
| 147 |
+
posterior_mean = (
|
| 148 |
+
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
| 149 |
+
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 150 |
+
)
|
| 151 |
+
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
| 152 |
+
posterior_log_variance_clipped = extract_into_tensor(
|
| 153 |
+
self.posterior_log_variance_clipped, t, x_t.shape
|
| 154 |
+
)
|
| 155 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 156 |
+
|
| 157 |
+
def _predict_xstart_from_eps(self, x_t: torch.Tensor, t: torch.Tensor, eps: torch.Tensor) -> torch.Tensor:
|
| 158 |
+
return (
|
| 159 |
+
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
| 160 |
+
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
def apply_cond_fn(
|
| 164 |
+
self,
|
| 165 |
+
model: ControlLDM,
|
| 166 |
+
pred_x0: torch.Tensor,
|
| 167 |
+
t: torch.Tensor,
|
| 168 |
+
index: torch.Tensor,
|
| 169 |
+
cond_fn: Guidance
|
| 170 |
+
) -> torch.Tensor:
|
| 171 |
+
t_now = int(t[0].item()) + 1
|
| 172 |
+
if not (cond_fn.t_stop < t_now and t_now < cond_fn.t_start):
|
| 173 |
+
# stop guidance
|
| 174 |
+
self.context["g_apply"] = False
|
| 175 |
+
return pred_x0
|
| 176 |
+
grad_rescale = 1 / extract_into_tensor(self.posterior_mean_coef1, index, pred_x0.shape)
|
| 177 |
+
# apply guidance for multiple times
|
| 178 |
+
loss_vals = []
|
| 179 |
+
for _ in range(cond_fn.repeat):
|
| 180 |
+
# set target and pred for gradient computation
|
| 181 |
+
target, pred = None, None
|
| 182 |
+
if cond_fn.space == "latent":
|
| 183 |
+
target = model.vae_encode(cond_fn.target)
|
| 184 |
+
pred = pred_x0
|
| 185 |
+
elif cond_fn.space == "rgb":
|
| 186 |
+
# We need to backward gradient to x0 in latent space, so it's required
|
| 187 |
+
# to trace the computation graph while decoding the latent.
|
| 188 |
+
with torch.enable_grad():
|
| 189 |
+
target = cond_fn.target
|
| 190 |
+
pred_x0_rg = pred_x0.detach().clone().requires_grad_(True)
|
| 191 |
+
pred = model.vae_decode(pred_x0_rg)
|
| 192 |
+
assert pred.requires_grad
|
| 193 |
+
else:
|
| 194 |
+
raise NotImplementedError(cond_fn.space)
|
| 195 |
+
# compute gradient
|
| 196 |
+
delta_pred, loss_val = cond_fn(target, pred, t_now)
|
| 197 |
+
loss_vals.append(loss_val)
|
| 198 |
+
# update pred_x0 w.r.t gradient
|
| 199 |
+
if cond_fn.space == "latent":
|
| 200 |
+
delta_pred_x0 = delta_pred
|
| 201 |
+
pred_x0 = pred_x0 + delta_pred_x0 * grad_rescale
|
| 202 |
+
elif cond_fn.space == "rgb":
|
| 203 |
+
pred.backward(delta_pred)
|
| 204 |
+
delta_pred_x0 = pred_x0_rg.grad
|
| 205 |
+
pred_x0 = pred_x0 + delta_pred_x0 * grad_rescale
|
| 206 |
+
else:
|
| 207 |
+
raise NotImplementedError(cond_fn.space)
|
| 208 |
+
self.context["g_apply"] = True
|
| 209 |
+
self.context["g_loss"] = float(np.mean(loss_vals))
|
| 210 |
+
return pred_x0
|
| 211 |
+
|
| 212 |
+
def predict_noise(
|
| 213 |
+
self,
|
| 214 |
+
model: ControlLDM,
|
| 215 |
+
x: torch.Tensor,
|
| 216 |
+
t: torch.Tensor,
|
| 217 |
+
cond: Dict[str, torch.Tensor],
|
| 218 |
+
uncond: Optional[Dict[str, torch.Tensor]],
|
| 219 |
+
cfg_scale: float
|
| 220 |
+
) -> torch.Tensor:
|
| 221 |
+
if uncond is None or cfg_scale == 1.:
|
| 222 |
+
model_output = model(x, t, cond)
|
| 223 |
+
else:
|
| 224 |
+
# apply classifier-free guidance
|
| 225 |
+
model_cond = model(x, t, cond)
|
| 226 |
+
model_uncond = model(x, t, uncond)
|
| 227 |
+
model_output = model_uncond + cfg_scale * (model_cond - model_uncond)
|
| 228 |
+
return model_output
|
| 229 |
+
|
| 230 |
+
@torch.no_grad()
|
| 231 |
+
def predict_noise_tiled(
|
| 232 |
+
self,
|
| 233 |
+
model: ControlLDM,
|
| 234 |
+
x: torch.Tensor,
|
| 235 |
+
t: torch.Tensor,
|
| 236 |
+
cond: Dict[str, torch.Tensor],
|
| 237 |
+
uncond: Optional[Dict[str, torch.Tensor]],
|
| 238 |
+
cfg_scale: float,
|
| 239 |
+
tile_size: int,
|
| 240 |
+
tile_stride: int
|
| 241 |
+
):
|
| 242 |
+
_, _, h, w = x.shape
|
| 243 |
+
tiles = tqdm(sliding_windows(h, w, tile_size // 8, tile_stride // 8), unit="tile", leave=False)
|
| 244 |
+
eps = torch.zeros_like(x)
|
| 245 |
+
count = torch.zeros_like(x, dtype=torch.float32)
|
| 246 |
+
weights = gaussian_weights(tile_size // 8, tile_size // 8)[None, None]
|
| 247 |
+
weights = torch.tensor(weights, dtype=torch.float32, device=x.device)
|
| 248 |
+
for hi, hi_end, wi, wi_end in tiles:
|
| 249 |
+
tiles.set_description(f"Process tile ({hi} {hi_end}), ({wi} {wi_end})")
|
| 250 |
+
tile_x = x[:, :, hi:hi_end, wi:wi_end]
|
| 251 |
+
tile_cond = {
|
| 252 |
+
"c_img": cond["c_img"][:, :, hi:hi_end, wi:wi_end],
|
| 253 |
+
"c_txt": cond["c_txt"]
|
| 254 |
+
}
|
| 255 |
+
if uncond:
|
| 256 |
+
tile_uncond = {
|
| 257 |
+
"c_img": uncond["c_img"][:, :, hi:hi_end, wi:wi_end],
|
| 258 |
+
"c_txt": uncond["c_txt"]
|
| 259 |
+
}
|
| 260 |
+
tile_eps = self.predict_noise(model, tile_x, t, tile_cond, tile_uncond, cfg_scale)
|
| 261 |
+
# accumulate noise
|
| 262 |
+
eps[:, :, hi:hi_end, wi:wi_end] += tile_eps * weights
|
| 263 |
+
count[:, :, hi:hi_end, wi:wi_end] += weights
|
| 264 |
+
# average on noise (score)
|
| 265 |
+
eps.div_(count)
|
| 266 |
+
return eps
|
| 267 |
+
|
| 268 |
+
@torch.no_grad()
|
| 269 |
+
def p_sample(
|
| 270 |
+
self,
|
| 271 |
+
model: ControlLDM,
|
| 272 |
+
x: torch.Tensor,
|
| 273 |
+
t: torch.Tensor,
|
| 274 |
+
index: torch.Tensor,
|
| 275 |
+
cond: Dict[str, torch.Tensor],
|
| 276 |
+
uncond: Optional[Dict[str, torch.Tensor]],
|
| 277 |
+
cfg_scale: float,
|
| 278 |
+
cond_fn: Optional[Guidance],
|
| 279 |
+
tiled: bool,
|
| 280 |
+
tile_size: int,
|
| 281 |
+
tile_stride: int
|
| 282 |
+
) -> torch.Tensor:
|
| 283 |
+
if tiled:
|
| 284 |
+
eps = self.predict_noise_tiled(model, x, t, cond, uncond, cfg_scale, tile_size, tile_stride)
|
| 285 |
+
else:
|
| 286 |
+
eps = self.predict_noise(model, x, t, cond, uncond, cfg_scale)
|
| 287 |
+
pred_x0 = self._predict_xstart_from_eps(x, index, eps)
|
| 288 |
+
if cond_fn:
|
| 289 |
+
assert not tiled, f"tiled sampling currently doesn't support guidance"
|
| 290 |
+
pred_x0 = self.apply_cond_fn(model, pred_x0, t, index, cond_fn)
|
| 291 |
+
model_mean, model_variance, _ = self.q_posterior_mean_variance(pred_x0, x, index)
|
| 292 |
+
noise = torch.randn_like(x)
|
| 293 |
+
nonzero_mask = (
|
| 294 |
+
(index != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
|
| 295 |
+
)
|
| 296 |
+
x_prev = model_mean + nonzero_mask * torch.sqrt(model_variance) * noise
|
| 297 |
+
return x_prev
|
| 298 |
+
|
| 299 |
+
@torch.no_grad()
|
| 300 |
+
def sample(
|
| 301 |
+
self,
|
| 302 |
+
model: ControlLDM,
|
| 303 |
+
device: str,
|
| 304 |
+
steps: int,
|
| 305 |
+
batch_size: int,
|
| 306 |
+
x_size: Tuple[int],
|
| 307 |
+
cond: Dict[str, torch.Tensor],
|
| 308 |
+
uncond: Dict[str, torch.Tensor],
|
| 309 |
+
cfg_scale: float,
|
| 310 |
+
cond_fn: Optional[Guidance]=None,
|
| 311 |
+
tiled: bool=False,
|
| 312 |
+
tile_size: int=-1,
|
| 313 |
+
tile_stride: int=-1,
|
| 314 |
+
x_T: Optional[torch.Tensor]=None,
|
| 315 |
+
progress: bool=True,
|
| 316 |
+
progress_leave: bool=True,
|
| 317 |
+
) -> torch.Tensor:
|
| 318 |
+
self.make_schedule(steps)
|
| 319 |
+
self.to(device)
|
| 320 |
+
if x_T is None:
|
| 321 |
+
# TODO: not convert to float32, may trigger an error
|
| 322 |
+
img = torch.randn((batch_size, *x_size), device=device)
|
| 323 |
+
else:
|
| 324 |
+
img = x_T
|
| 325 |
+
timesteps = np.flip(self.timesteps) # [1000, 950, 900, ...]
|
| 326 |
+
total_steps = len(self.timesteps)
|
| 327 |
+
iterator = tqdm(timesteps, total=total_steps, leave=progress_leave, disable=not progress)
|
| 328 |
+
for i, step in enumerate(iterator):
|
| 329 |
+
ts = torch.full((batch_size,), step, device=device, dtype=torch.long)
|
| 330 |
+
index = torch.full_like(ts, fill_value=total_steps - i - 1)
|
| 331 |
+
img = self.p_sample(
|
| 332 |
+
model, img, ts, index, cond, uncond, cfg_scale, cond_fn,
|
| 333 |
+
tiled, tile_size, tile_stride
|
| 334 |
+
)
|
| 335 |
+
if cond_fn and self.context["g_apply"]:
|
| 336 |
+
loss_val = self.context["g_loss"]
|
| 337 |
+
desc = f"Spaced Sampler With Guidance, Loss: {loss_val:.6f}"
|
| 338 |
+
else:
|
| 339 |
+
desc = "Spaced Sampler"
|
| 340 |
+
iterator.set_description(desc)
|
| 341 |
+
return img
|