Spaces:
Runtime error
Runtime error
Update Time_TravelRephotography/models/encoder4editing/models/stylegan2/op/upfirdn2d.py
Browse files
Time_TravelRephotography/models/encoder4editing/models/stylegan2/op/upfirdn2d.py
CHANGED
|
@@ -81,44 +81,84 @@ class UpFirDn2dBackward(Function):
|
|
| 81 |
|
| 82 |
return gradgrad_out, None, None, None, None, None, None, None, None
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
@staticmethod
|
| 124 |
def backward(ctx, grad_output):
|
|
|
|
| 81 |
|
| 82 |
return gradgrad_out, None, None, None, None, None, None, None, None
|
| 83 |
|
| 84 |
+
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
| 85 |
+
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
|
| 86 |
+
"""
|
| 87 |
+
# Validate arguments.
|
| 88 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
| 89 |
+
if f is None:
|
| 90 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
| 91 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
| 92 |
+
assert f.dtype == torch.float32 and not f.requires_grad
|
| 93 |
+
batch_size, num_channels, in_height, in_width = x.shape
|
| 94 |
+
upx, upy = _parse_scaling(up)
|
| 95 |
+
downx, downy = _parse_scaling(down)
|
| 96 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
| 97 |
+
|
| 98 |
+
# Upsample by inserting zeros.
|
| 99 |
+
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
| 100 |
+
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
| 101 |
+
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
| 102 |
+
|
| 103 |
+
# Pad or crop.
|
| 104 |
+
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
|
| 105 |
+
x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
|
| 106 |
+
|
| 107 |
+
# Setup filter.
|
| 108 |
+
f = f * (gain ** (f.ndim / 2))
|
| 109 |
+
f = f.to(x.dtype)
|
| 110 |
+
if not flip_filter:
|
| 111 |
+
f = f.flip(list(range(f.ndim)))
|
| 112 |
+
|
| 113 |
+
# Convolve with the filter.
|
| 114 |
+
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
| 115 |
+
if f.ndim == 4:
|
| 116 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
|
| 117 |
+
else:
|
| 118 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
| 119 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
| 120 |
+
|
| 121 |
+
# Downsample by throwing away pixels.
|
| 122 |
+
x = x[:, :, ::downy, ::downx]
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
| 126 |
+
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
| 127 |
+
Performs the following sequence of operations for each channel:
|
| 128 |
+
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
| 129 |
+
2. Pad the image with the specified number of zeros on each side (`padding`).
|
| 130 |
+
Negative padding corresponds to cropping the image.
|
| 131 |
+
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
|
| 132 |
+
so that the footprint of all output pixels lies within the input image.
|
| 133 |
+
4. Downsample the image by keeping every Nth pixel (`down`).
|
| 134 |
+
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
| 135 |
+
The fused op is considerably more efficient than performing the same calculation
|
| 136 |
+
using standard PyTorch ops. It supports gradients of arbitrary order.
|
| 137 |
+
Args:
|
| 138 |
+
x: Float32/float64/float16 input tensor of the shape
|
| 139 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
| 140 |
+
f: Float32 FIR filter of the shape
|
| 141 |
+
`[filter_height, filter_width]` (non-separable),
|
| 142 |
+
`[filter_taps]` (separable), or
|
| 143 |
+
`None` (identity).
|
| 144 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
| 145 |
+
`[x, y]` (default: 1).
|
| 146 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
| 147 |
+
`[x, y]` (default: 1).
|
| 148 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
| 149 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
| 150 |
+
(default: 0).
|
| 151 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
| 152 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
| 153 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
| 154 |
+
Returns:
|
| 155 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
| 156 |
+
"""
|
| 157 |
+
assert isinstance(x, torch.Tensor)
|
| 158 |
+
assert impl in ['ref', 'cuda']
|
| 159 |
+
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
| 160 |
+
return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
|
| 161 |
+
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
| 162 |
|
| 163 |
@staticmethod
|
| 164 |
def backward(ctx, grad_output):
|