Luisgust commited on
Commit
20b919d
·
verified ·
1 Parent(s): abee8a9

Create vtoonify/model/stylegan/op_gpu/upfirdn2d.py

Browse files
vtoonify/model/stylegan/op_gpu/upfirdn2d.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import abc
2
+ import os
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+
9
+
10
+ module_path = os.path.dirname(__file__)
11
+ upfirdn2d_op = load(
12
+ "upfirdn2d",
13
+ sources=[
14
+ os.path.join(module_path, "upfirdn2d.cpp"),
15
+ os.path.join(module_path, "upfirdn2d_kernel.cu"),
16
+ ],
17
+ )
18
+
19
+
20
+ class UpFirDn2dBackward(Function):
21
+ @staticmethod
22
+ def forward(
23
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
24
+ ):
25
+
26
+ up_x, up_y = up
27
+ down_x, down_y = down
28
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
29
+
30
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
31
+
32
+ grad_input = upfirdn2d_op.upfirdn2d(
33
+ grad_output,
34
+ grad_kernel,
35
+ down_x,
36
+ down_y,
37
+ up_x,
38
+ up_y,
39
+ g_pad_x0,
40
+ g_pad_x1,
41
+ g_pad_y0,
42
+ g_pad_y1,
43
+ )
44
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
45
+
46
+ ctx.save_for_backward(kernel)
47
+
48
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
49
+
50
+ ctx.up_x = up_x
51
+ ctx.up_y = up_y
52
+ ctx.down_x = down_x
53
+ ctx.down_y = down_y
54
+ ctx.pad_x0 = pad_x0
55
+ ctx.pad_x1 = pad_x1
56
+ ctx.pad_y0 = pad_y0
57
+ ctx.pad_y1 = pad_y1
58
+ ctx.in_size = in_size
59
+ ctx.out_size = out_size
60
+
61
+ return grad_input
62
+
63
+ @staticmethod
64
+ def backward(ctx, gradgrad_input):
65
+ kernel, = ctx.saved_tensors
66
+
67
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
68
+
69
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
70
+ gradgrad_input,
71
+ kernel,
72
+ ctx.up_x,
73
+ ctx.up_y,
74
+ ctx.down_x,
75
+ ctx.down_y,
76
+ ctx.pad_x0,
77
+ ctx.pad_x1,
78
+ ctx.pad_y0,
79
+ ctx.pad_y1,
80
+ )
81
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
82
+ gradgrad_out = gradgrad_out.view(
83
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
84
+ )
85
+
86
+ return gradgrad_out, None, None, None, None, None, None, None, None
87
+
88
+
89
+ class UpFirDn2d(Function):
90
+ @staticmethod
91
+ def forward(ctx, input, kernel, up, down, pad):
92
+ up_x, up_y = up
93
+ down_x, down_y = down
94
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
95
+
96
+ kernel_h, kernel_w = kernel.shape
97
+ batch, channel, in_h, in_w = input.shape
98
+ ctx.in_size = input.shape
99
+
100
+ input = input.reshape(-1, in_h, in_w, 1)
101
+
102
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
103
+
104
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
105
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
106
+ ctx.out_size = (out_h, out_w)
107
+
108
+ ctx.up = (up_x, up_y)
109
+ ctx.down = (down_x, down_y)
110
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
111
+
112
+ g_pad_x0 = kernel_w - pad_x0 - 1
113
+ g_pad_y0 = kernel_h - pad_y0 - 1
114
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
115
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
116
+
117
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
118
+
119
+ out = upfirdn2d_op.upfirdn2d(
120
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
121
+ )
122
+ # out = out.view(major, out_h, out_w, minor)
123
+ out = out.view(-1, channel, out_h, out_w)
124
+
125
+ return out
126
+
127
+ @staticmethod
128
+ def backward(ctx, grad_output):
129
+ kernel, grad_kernel = ctx.saved_tensors
130
+
131
+ grad_input = None
132
+
133
+ if ctx.needs_input_grad[0]:
134
+ grad_input = UpFirDn2dBackward.apply(
135
+ grad_output,
136
+ kernel,
137
+ grad_kernel,
138
+ ctx.up,
139
+ ctx.down,
140
+ ctx.pad,
141
+ ctx.g_pad,
142
+ ctx.in_size,
143
+ ctx.out_size,
144
+ )
145
+
146
+ return grad_input, None, None, None, None
147
+
148
+
149
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
150
+ if not isinstance(up, abc.Iterable):
151
+ up = (up, up)
152
+
153
+ if not isinstance(down, abc.Iterable):
154
+ down = (down, down)
155
+
156
+ if len(pad) == 2:
157
+ pad = (pad[0], pad[1], pad[0], pad[1])
158
+
159
+ if input.device.type == "cpu":
160
+ out = upfirdn2d_native(input, kernel, *up, *down, *pad)
161
+
162
+ else:
163
+ out = UpFirDn2d.apply(input, kernel, up, down, pad)
164
+
165
+ return out
166
+
167
+
168
+ def upfirdn2d_native(
169
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
170
+ ):
171
+ _, channel, in_h, in_w = input.shape
172
+ input = input.reshape(-1, in_h, in_w, 1)
173
+
174
+ _, in_h, in_w, minor = input.shape
175
+ kernel_h, kernel_w = kernel.shape
176
+
177
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
178
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
179
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
180
+
181
+ out = F.pad(
182
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
183
+ )
184
+ out = out[
185
+ :,
186
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
187
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
188
+ :,
189
+ ]
190
+
191
+ out = out.permute(0, 3, 1, 2)
192
+ out = out.reshape(
193
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
194
+ )
195
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
196
+ out = F.conv2d(out, w)
197
+ out = out.reshape(
198
+ -1,
199
+ minor,
200
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
201
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
202
+ )
203
+ out = out.permute(0, 2, 3, 1)
204
+ out = out[:, ::down_y, ::down_x, :]
205
+
206
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
207
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
208
+
209
+ return out.view(-1, channel, out_h, out_w)