Spaces:
Running
Running
Upload 8 files
Browse files- LICENSE +21 -0
- README.md +60 -12
- enhence_reinhard.py +213 -0
- function.py +67 -0
- net.py +152 -0
- requirements.txt +11 -0
- sampler.py +26 -0
- torch_to_pytorch.py +322 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2018 Naoto Inoue
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,60 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pytorch-AdaIN
|
2 |
+
|
3 |
+
This is an unofficial pytorch implementation of a paper, Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization [Huang+, ICCV2017].
|
4 |
+
I'm really grateful to the [original implementation](https://github.com/xunhuang1995/AdaIN-style) in Torch by the authors, which is very useful.
|
5 |
+
|
6 |
+

|
7 |
+
|
8 |
+
## Requirements
|
9 |
+
Please install requirements by `pip install -r requirements.txt`
|
10 |
+
|
11 |
+
- Python 3.5+
|
12 |
+
- PyTorch 0.4+
|
13 |
+
- TorchVision
|
14 |
+
- Pillow
|
15 |
+
|
16 |
+
(optional, for training)
|
17 |
+
- tqdm
|
18 |
+
- TensorboardX
|
19 |
+
|
20 |
+
## Usage
|
21 |
+
|
22 |
+
### Download models
|
23 |
+
Download [decoder.pth](https://drive.google.com/file/d/1bMfhMMwPeXnYSQI6cDWElSZxOxc6aVyr/view?usp=sharing)/[vgg_normalized.pth](https://drive.google.com/file/d/1EpkBA2K2eYILDSyPTt0fztz59UjAIpZU/view?usp=sharing) and put them under `models/`.
|
24 |
+
|
25 |
+
### Test
|
26 |
+
Use `--content` and `--style` to provide the respective path to the content and style image.
|
27 |
+
```
|
28 |
+
CUDA_VISIBLE_DEVICES=<gpu_id> python test.py --content input/content/cornell.jpg --style input/style/woman_with_hat_matisse.jpg
|
29 |
+
```
|
30 |
+
|
31 |
+
You can also run the code on directories of content and style images using `--content_dir` and `--style_dir`. It will save every possible combination of content and styles to the output directory.
|
32 |
+
```
|
33 |
+
CUDA_VISIBLE_DEVICES=<gpu_id> python test.py --content_dir input/content --style_dir input/style
|
34 |
+
```
|
35 |
+
|
36 |
+
This is an example of mixing four styles by specifying `--style` and `--style_interpolation_weights` option.
|
37 |
+
```
|
38 |
+
CUDA_VISIBLE_DEVICES=<gpu_id> python test.py --content input/content/avril.jpg --style input/style/picasso_self_portrait.jpg,input/style/impronte_d_artista.jpg,input/style/trial.jpg,input/style/antimonocromatismo.jpg --style_interpolation_weights 1,1,1,1 --content_size 512 --style_size 512 --crop
|
39 |
+
```
|
40 |
+
|
41 |
+
Some other options:
|
42 |
+
* `--content_size`: New (minimum) size for the content image. Keeping the original size if set to 0.
|
43 |
+
* `--style_size`: New (minimum) size for the style image. Keeping the original size if set to 0.
|
44 |
+
* `--alpha`: Adjust the degree of stylization. It should be a value between 0.0 and 1.0 (default).
|
45 |
+
* `--preserve_color`: Preserve the color of the content image.
|
46 |
+
|
47 |
+
|
48 |
+
### Train
|
49 |
+
Use `--content_dir` and `--style_dir` to provide the respective directory to the content and style images.
|
50 |
+
```
|
51 |
+
CUDA_VISIBLE_DEVICES=<gpu_id> python train.py --content_dir <content_dir> --style_dir <style_dir>
|
52 |
+
```
|
53 |
+
|
54 |
+
For more details and parameters, please refer to --help option.
|
55 |
+
|
56 |
+
I share the model trained by this code [here](https://drive.google.com/file/d/1YIBRdgGBoVllLhmz_N7PwfeP5V9Vz2Nr/view?usp=sharing)
|
57 |
+
|
58 |
+
## References
|
59 |
+
- [1]: X. Huang and S. Belongie. "Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization.", in ICCV, 2017.
|
60 |
+
- [2]: [Original implementation in Torch](https://github.com/xunhuang1995/AdaIN-style)
|
enhence_reinhard.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import sys
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from PIL import Image
|
6 |
+
import cv2
|
7 |
+
import skimage
|
8 |
+
|
9 |
+
METHODS = ('lhm', 'pccm', 'reinhard')
|
10 |
+
|
11 |
+
img_dir = 'C:/Users/joshua.lin/Desktop/AdaIN/pytorch-AdaIN-master/input/'
|
12 |
+
|
13 |
+
|
14 |
+
def transfer_lhm(content, reference):
|
15 |
+
"""Transfers colors from a reference image to a content image using the
|
16 |
+
Linear Histogram Matching.
|
17 |
+
|
18 |
+
content: NumPy array (HxWxC)
|
19 |
+
reference: NumPy array (HxWxC)
|
20 |
+
"""
|
21 |
+
# Convert HxWxC image to a (H*W)xC matrix.
|
22 |
+
shape = content.shape
|
23 |
+
assert len(shape) == 3
|
24 |
+
content = content.reshape(-1, shape[-1]).astype(np.float32)
|
25 |
+
reference = reference.reshape(-1, shape[-1]).astype(np.float32)
|
26 |
+
|
27 |
+
def matrix_sqrt(X):
|
28 |
+
eig_val, eig_vec = np.linalg.eig(X)
|
29 |
+
return eig_vec.dot(np.diag(np.sqrt(eig_val))).dot(eig_vec.T)
|
30 |
+
|
31 |
+
#
|
32 |
+
mu_content = np.mean(content, axis=0)
|
33 |
+
#
|
34 |
+
mu_reference = np.mean(reference, axis=0)
|
35 |
+
|
36 |
+
cov_content = np.cov(content, rowvar=False)
|
37 |
+
cov_reference = np.cov(reference, rowvar=False)
|
38 |
+
|
39 |
+
#
|
40 |
+
result = matrix_sqrt(cov_reference)
|
41 |
+
|
42 |
+
#
|
43 |
+
result = result.dot(np.linalg.inv(matrix_sqrt(cov_content)))
|
44 |
+
|
45 |
+
#
|
46 |
+
result = result.dot((content - mu_content).T).T
|
47 |
+
#result = result.dot((content*1 - mu_content*0.5).T).T*3
|
48 |
+
|
49 |
+
#
|
50 |
+
result = result + mu_reference
|
51 |
+
|
52 |
+
|
53 |
+
# Restore image dimensions.
|
54 |
+
result = result.reshape(shape).clip(0, 255).round().astype(np.uint8)
|
55 |
+
|
56 |
+
return result
|
57 |
+
|
58 |
+
|
59 |
+
def transfer_pccm(content, reference):
|
60 |
+
"""Transfers colors from a reference image to a content image using
|
61 |
+
Principal Component Color Matching.
|
62 |
+
|
63 |
+
content: NumPy array (HxWxC)
|
64 |
+
reference: NumPy array (HxWxC)
|
65 |
+
"""
|
66 |
+
# Convert HxWxC image to a (H*W)xC matrix.
|
67 |
+
shape = content.shape
|
68 |
+
assert len(shape) == 3
|
69 |
+
content = content.reshape(-1, shape[-1]).astype(np.float32)
|
70 |
+
reference = reference.reshape(-1, shape[-1]).astype(np.float32)
|
71 |
+
|
72 |
+
mu_content = np.mean(content, axis=0)
|
73 |
+
mu_reference = np.mean(reference, axis=0)
|
74 |
+
|
75 |
+
cov_content = np.cov(content, rowvar=False)
|
76 |
+
cov_reference = np.cov(reference, rowvar=False)
|
77 |
+
|
78 |
+
eigval_content, eigvec_content = np.linalg.eig(cov_content)
|
79 |
+
eigval_reference, eigvec_reference = np.linalg.eig(cov_reference)
|
80 |
+
|
81 |
+
scaling = np.diag(np.sqrt(eigval_reference / eigval_content))
|
82 |
+
transform = eigvec_reference.dot(scaling).dot(eigvec_content.T)
|
83 |
+
result = (content - mu_content).dot(transform.T) + mu_reference
|
84 |
+
# Restore image dimensions.
|
85 |
+
result = result.reshape(shape).clip(0, 255).round().astype(np.uint8)
|
86 |
+
|
87 |
+
return result
|
88 |
+
|
89 |
+
|
90 |
+
def transfer_reinhard(content, reference):
|
91 |
+
"""Transfers colors from a reference image to a content image using the
|
92 |
+
technique from Reinhard et al.
|
93 |
+
|
94 |
+
content: NumPy array (HxWxC)
|
95 |
+
reference: NumPy array (HxWxC)
|
96 |
+
"""
|
97 |
+
# Convert HxWxC image to a (H*W)xC matrix.
|
98 |
+
shape = content.shape
|
99 |
+
assert len(shape) == 3
|
100 |
+
content = content.reshape(-1, shape[-1]).astype(np.float32)
|
101 |
+
reference = reference.reshape(-1, shape[-1]).astype(np.float32)
|
102 |
+
|
103 |
+
m1 = np.array([
|
104 |
+
[0.3811, 0.1967, 0.0241],
|
105 |
+
[0.5783, 0.7244, 0.1288],
|
106 |
+
[0.0402, 0.0782, 0.8444],
|
107 |
+
])
|
108 |
+
|
109 |
+
m2 = np.array([
|
110 |
+
[0.5774, 0.4082, 0.7071],
|
111 |
+
[0.5774, 0.4082, -0.7071],
|
112 |
+
[0.5774, -0.8165, 0.0000],
|
113 |
+
])
|
114 |
+
|
115 |
+
m3 = np.array([
|
116 |
+
[0.5774, 0.5774, 0.5774],
|
117 |
+
[0.4082, 0.4082, -0.8165],
|
118 |
+
[0.7071, -0.7071, 0.0000],
|
119 |
+
])
|
120 |
+
|
121 |
+
m4 = np.array([
|
122 |
+
[4.4679, -1.2186, 0.0497],
|
123 |
+
[-3.5873, 2.3809, -0.2439],
|
124 |
+
[0.1193, -0.1624, 1.2045],
|
125 |
+
])
|
126 |
+
|
127 |
+
# Avoid log of 0. Clipping is used instead of adding epsilon, to avoid
|
128 |
+
# taking a log of a small number whose very low output distorts the results.
|
129 |
+
# WARN: This differs from the Reinhard paper, where no adjustment is made.
|
130 |
+
lab_content = np.log10(np.maximum(1.0, content.dot(m1))).dot(m2)
|
131 |
+
lab_reference = np.log10(np.maximum(1.0, reference.dot(m1))).dot(m2)
|
132 |
+
|
133 |
+
mu_content = lab_content.mean(axis=0) # shape=3
|
134 |
+
mu_reference = lab_reference.mean(axis=0)
|
135 |
+
|
136 |
+
std_source = np.std(content, axis=0)
|
137 |
+
std_target = np.std(reference, axis=0)
|
138 |
+
#variable percentage for mu and std
|
139 |
+
result = lab_content - mu_content
|
140 |
+
result *= std_target
|
141 |
+
result /= std_source
|
142 |
+
result += mu_reference
|
143 |
+
result = (10 ** result.dot(m3)).dot(m4)
|
144 |
+
# Restore image dimensions.
|
145 |
+
result = result.reshape(shape).clip(0, 255).round().astype(np.uint8)
|
146 |
+
|
147 |
+
return result
|
148 |
+
|
149 |
+
# ===================================================================================
|
150 |
+
|
151 |
+
|
152 |
+
def parse_args(argv):
|
153 |
+
parser = argparse.ArgumentParser(
|
154 |
+
prog='colortrans',
|
155 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
156 |
+
)
|
157 |
+
|
158 |
+
# Optional arguments
|
159 |
+
parser.add_argument(
|
160 |
+
'--method', default='lhm', choices=METHODS,
|
161 |
+
help='Algorithm to use for color transfer.')
|
162 |
+
|
163 |
+
# Required arguments
|
164 |
+
parser.add_argument('content', help='Path to content image (qualitative appearance).')
|
165 |
+
parser.add_argument('reference', help='Path to reference image (desired colors).')
|
166 |
+
parser.add_argument('output', help='Path to output image.')
|
167 |
+
|
168 |
+
args = parser.parse_args(argv[1:])
|
169 |
+
|
170 |
+
return args
|
171 |
+
|
172 |
+
|
173 |
+
def main(argv=sys.argv):
|
174 |
+
args = parse_args(argv)
|
175 |
+
content_img = Image.open(args.content).convert('RGB')
|
176 |
+
# The slicing is to remove transparency channels if they exist.
|
177 |
+
content = np.array(content_img)[:, :, :3]
|
178 |
+
reference_img = Image.open(args.reference).convert('RGB')
|
179 |
+
reference = np.array(reference_img)[:, :, :3]
|
180 |
+
transfer = globals()[f'transfer_{args.method}']
|
181 |
+
output = transfer(content, reference)
|
182 |
+
Image.fromarray(output).save(args.output)
|
183 |
+
|
184 |
+
|
185 |
+
# ==================================================================================
|
186 |
+
|
187 |
+
|
188 |
+
def test_reinhard():
|
189 |
+
content_path = img_dir + 'content/brad_pitt.jpg'
|
190 |
+
style_path = 'output/brad_pitt_stylized_Neon_City.jpg'
|
191 |
+
content_img = Image.open(content_path).convert('RGB')
|
192 |
+
content = np.array(content_img)[:, :, :3]
|
193 |
+
style_img = Image.open(style_path).convert('RGB')
|
194 |
+
style = np.array(style_img)[:, :, :3]
|
195 |
+
output = transfer_lhm(content, style)
|
196 |
+
Image.fromarray(output).save('output/processed.jpg')
|
197 |
+
|
198 |
+
|
199 |
+
def test1():
|
200 |
+
img_path = img_dir + '2.jpg'
|
201 |
+
img = skimage.io.imread(img_path)
|
202 |
+
sk_imgf = skimage.util.img_as_float32(img)
|
203 |
+
cv_img = skimage.img_as_ubyte(img)
|
204 |
+
|
205 |
+
print('')
|
206 |
+
|
207 |
+
|
208 |
+
# ==============================================================
|
209 |
+
if __name__ == '__main__':
|
210 |
+
test_reinhard()
|
211 |
+
# test1()
|
212 |
+
|
213 |
+
|
function.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def calc_mean_std(feat, eps=1e-5):
|
5 |
+
# eps is a small value added to the variance to avoid divide-by-zero.
|
6 |
+
size = feat.size()
|
7 |
+
assert (len(size) == 4)
|
8 |
+
N, C = size[:2]
|
9 |
+
feat_var = feat.view(N, C, -1).var(dim=2) + eps
|
10 |
+
feat_std = feat_var.sqrt().view(N, C, 1, 1)
|
11 |
+
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
12 |
+
return feat_mean, feat_std
|
13 |
+
|
14 |
+
|
15 |
+
def adaptive_instance_normalization(content_feat, style_feat):
|
16 |
+
assert (content_feat.size()[:2] == style_feat.size()[:2])
|
17 |
+
size = content_feat.size()
|
18 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
19 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
20 |
+
|
21 |
+
normalized_feat = (content_feat - content_mean.expand(
|
22 |
+
size)) / content_std.expand(size)
|
23 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
24 |
+
|
25 |
+
|
26 |
+
def _calc_feat_flatten_mean_std(feat):
|
27 |
+
# takes 3D feat (C, H, W), return mean and std of array within channels
|
28 |
+
assert (feat.size()[0] == 3)
|
29 |
+
assert (isinstance(feat, torch.FloatTensor))
|
30 |
+
feat_flatten = feat.view(3, -1)
|
31 |
+
mean = feat_flatten.mean(dim=-1, keepdim=True)
|
32 |
+
std = feat_flatten.std(dim=-1, keepdim=True)
|
33 |
+
return feat_flatten, mean, std
|
34 |
+
|
35 |
+
|
36 |
+
def _mat_sqrt(x):
|
37 |
+
U, D, V = torch.svd(x)
|
38 |
+
return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())
|
39 |
+
|
40 |
+
|
41 |
+
def coral(source, target):
|
42 |
+
# assume both source and target are 3D array (C, H, W)
|
43 |
+
# Note: flatten -> f
|
44 |
+
|
45 |
+
source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
|
46 |
+
source_f_norm = (source_f - source_f_mean.expand_as(
|
47 |
+
source_f)) / source_f_std.expand_as(source_f)
|
48 |
+
source_f_cov_eye = \
|
49 |
+
torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)
|
50 |
+
|
51 |
+
target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
|
52 |
+
target_f_norm = (target_f - target_f_mean.expand_as(
|
53 |
+
target_f)) / target_f_std.expand_as(target_f)
|
54 |
+
target_f_cov_eye = \
|
55 |
+
torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)
|
56 |
+
|
57 |
+
source_f_norm_transfer = torch.mm(
|
58 |
+
_mat_sqrt(target_f_cov_eye),
|
59 |
+
torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
|
60 |
+
source_f_norm)
|
61 |
+
)
|
62 |
+
|
63 |
+
source_f_transfer = source_f_norm_transfer * \
|
64 |
+
target_f_std.expand_as(source_f_norm) + \
|
65 |
+
target_f_mean.expand_as(source_f_norm)
|
66 |
+
|
67 |
+
return source_f_transfer.view(source.size())
|
net.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from function import adaptive_instance_normalization as adain
|
4 |
+
from function import calc_mean_std
|
5 |
+
|
6 |
+
decoder = nn.Sequential(
|
7 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
8 |
+
nn.Conv2d(512, 256, (3, 3)),
|
9 |
+
nn.ReLU(),
|
10 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
11 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
12 |
+
nn.Conv2d(256, 256, (3, 3)),
|
13 |
+
nn.ReLU(),
|
14 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
15 |
+
nn.Conv2d(256, 256, (3, 3)),
|
16 |
+
nn.ReLU(),
|
17 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
18 |
+
nn.Conv2d(256, 256, (3, 3)),
|
19 |
+
nn.ReLU(),
|
20 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
21 |
+
nn.Conv2d(256, 128, (3, 3)),
|
22 |
+
nn.ReLU(),
|
23 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
24 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
25 |
+
nn.Conv2d(128, 128, (3, 3)),
|
26 |
+
nn.ReLU(),
|
27 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
28 |
+
nn.Conv2d(128, 64, (3, 3)),
|
29 |
+
nn.ReLU(),
|
30 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
31 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
32 |
+
nn.Conv2d(64, 64, (3, 3)),
|
33 |
+
nn.ReLU(),
|
34 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
35 |
+
nn.Conv2d(64, 3, (3, 3)),
|
36 |
+
)
|
37 |
+
|
38 |
+
vgg = nn.Sequential(
|
39 |
+
nn.Conv2d(3, 3, (1, 1)),
|
40 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
41 |
+
nn.Conv2d(3, 64, (3, 3)),
|
42 |
+
nn.ReLU(), # relu1-1
|
43 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
44 |
+
nn.Conv2d(64, 64, (3, 3)),
|
45 |
+
nn.ReLU(), # relu1-2
|
46 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
47 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
48 |
+
nn.Conv2d(64, 128, (3, 3)),
|
49 |
+
nn.ReLU(), # relu2-1
|
50 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
51 |
+
nn.Conv2d(128, 128, (3, 3)),
|
52 |
+
nn.ReLU(), # relu2-2
|
53 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
54 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
55 |
+
nn.Conv2d(128, 256, (3, 3)),
|
56 |
+
nn.ReLU(), # relu3-1
|
57 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
58 |
+
nn.Conv2d(256, 256, (3, 3)),
|
59 |
+
nn.ReLU(), # relu3-2
|
60 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
61 |
+
nn.Conv2d(256, 256, (3, 3)),
|
62 |
+
nn.ReLU(), # relu3-3
|
63 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
64 |
+
nn.Conv2d(256, 256, (3, 3)),
|
65 |
+
nn.ReLU(), # relu3-4
|
66 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
67 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
68 |
+
nn.Conv2d(256, 512, (3, 3)),
|
69 |
+
nn.ReLU(), # relu4-1, this is the last layer used
|
70 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
71 |
+
nn.Conv2d(512, 512, (3, 3)),
|
72 |
+
nn.ReLU(), # relu4-2
|
73 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
74 |
+
nn.Conv2d(512, 512, (3, 3)),
|
75 |
+
nn.ReLU(), # relu4-3
|
76 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
77 |
+
nn.Conv2d(512, 512, (3, 3)),
|
78 |
+
nn.ReLU(), # relu4-4
|
79 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
80 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
81 |
+
nn.Conv2d(512, 512, (3, 3)),
|
82 |
+
nn.ReLU(), # relu5-1
|
83 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
84 |
+
nn.Conv2d(512, 512, (3, 3)),
|
85 |
+
nn.ReLU(), # relu5-2
|
86 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
87 |
+
nn.Conv2d(512, 512, (3, 3)),
|
88 |
+
nn.ReLU(), # relu5-3
|
89 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
90 |
+
nn.Conv2d(512, 512, (3, 3)),
|
91 |
+
nn.ReLU() # relu5-4
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
+
class Net(nn.Module):
|
96 |
+
def __init__(self, encoder, decoder):
|
97 |
+
super(Net, self).__init__()
|
98 |
+
enc_layers = list(encoder.children())
|
99 |
+
self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1
|
100 |
+
self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1
|
101 |
+
self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1
|
102 |
+
self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1
|
103 |
+
self.decoder = decoder
|
104 |
+
self.mse_loss = nn.MSELoss()
|
105 |
+
|
106 |
+
# fix the encoder
|
107 |
+
for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
|
108 |
+
for param in getattr(self, name).parameters():
|
109 |
+
param.requires_grad = False
|
110 |
+
|
111 |
+
# extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
|
112 |
+
def encode_with_intermediate(self, input):
|
113 |
+
results = [input]
|
114 |
+
for i in range(4):
|
115 |
+
func = getattr(self, 'enc_{:d}'.format(i + 1))
|
116 |
+
results.append(func(results[-1]))
|
117 |
+
return results[1:]
|
118 |
+
|
119 |
+
# extract relu4_1 from input image
|
120 |
+
def encode(self, input):
|
121 |
+
for i in range(4):
|
122 |
+
input = getattr(self, 'enc_{:d}'.format(i + 1))(input)
|
123 |
+
return input
|
124 |
+
|
125 |
+
def calc_content_loss(self, input, target):
|
126 |
+
assert (input.size() == target.size())
|
127 |
+
assert (target.requires_grad is False)
|
128 |
+
return self.mse_loss(input, target)
|
129 |
+
|
130 |
+
def calc_style_loss(self, input, target):
|
131 |
+
assert (input.size() == target.size())
|
132 |
+
assert (target.requires_grad is False)
|
133 |
+
input_mean, input_std = calc_mean_std(input)
|
134 |
+
target_mean, target_std = calc_mean_std(target)
|
135 |
+
return self.mse_loss(input_mean, target_mean) + \
|
136 |
+
self.mse_loss(input_std, target_std)
|
137 |
+
|
138 |
+
def forward(self, content, style, alpha=1.0):
|
139 |
+
assert 0 <= alpha <= 1
|
140 |
+
style_feats = self.encode_with_intermediate(style)
|
141 |
+
content_feat = self.encode(content)
|
142 |
+
t = adain(content_feat, style_feats[-1])
|
143 |
+
t = alpha * t + (1 - alpha) * content_feat
|
144 |
+
|
145 |
+
g_t = self.decoder(t)
|
146 |
+
g_t_feats = self.encode_with_intermediate(g_t)
|
147 |
+
|
148 |
+
loss_c = self.calc_content_loss(g_t_feats[-1], t)
|
149 |
+
loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
|
150 |
+
for i in range(1, 4):
|
151 |
+
loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
|
152 |
+
return loss_c, loss_s
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Pillow==9.0.1
|
3 |
+
pkg-resources==0.0.0
|
4 |
+
protobuf==3.15.0
|
5 |
+
six==1.12.0
|
6 |
+
tensorboardX==1.8
|
7 |
+
torch==1.2.0
|
8 |
+
torchvision==0.4.0
|
9 |
+
tqdm==4.35.0
|
10 |
+
opencv-python==4.4.0.46
|
11 |
+
imageio==2.9.0
|
sampler.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from torch.utils import data
|
3 |
+
|
4 |
+
|
5 |
+
def InfiniteSampler(n):
|
6 |
+
# i = 0
|
7 |
+
i = n - 1
|
8 |
+
order = np.random.permutation(n)
|
9 |
+
while True:
|
10 |
+
yield order[i]
|
11 |
+
i += 1
|
12 |
+
if i >= n:
|
13 |
+
np.random.seed()
|
14 |
+
order = np.random.permutation(n)
|
15 |
+
i = 0
|
16 |
+
|
17 |
+
|
18 |
+
class InfiniteSamplerWrapper(data.sampler.Sampler):
|
19 |
+
def __init__(self, data_source):
|
20 |
+
self.num_samples = len(data_source)
|
21 |
+
|
22 |
+
def __iter__(self):
|
23 |
+
return iter(InfiniteSampler(self.num_samples))
|
24 |
+
|
25 |
+
def __len__(self):
|
26 |
+
return 2 ** 31
|
torch_to_pytorch.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
from functools import reduce
|
5 |
+
|
6 |
+
import torch
|
7 |
+
assert torch.__version__.split('.')[0] == '0', 'Only working on PyTorch 0.x.x'
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.autograd import Variable
|
10 |
+
from torch.utils.serialization import load_lua
|
11 |
+
|
12 |
+
|
13 |
+
class LambdaBase(nn.Sequential):
|
14 |
+
def __init__(self, fn, *args):
|
15 |
+
super(LambdaBase, self).__init__(*args)
|
16 |
+
self.lambda_func = fn
|
17 |
+
|
18 |
+
def forward_prepare(self, input):
|
19 |
+
output = []
|
20 |
+
for module in self._modules.values():
|
21 |
+
output.append(module(input))
|
22 |
+
return output if output else input
|
23 |
+
|
24 |
+
|
25 |
+
class Lambda(LambdaBase):
|
26 |
+
def forward(self, input):
|
27 |
+
return self.lambda_func(self.forward_prepare(input))
|
28 |
+
|
29 |
+
|
30 |
+
class LambdaMap(LambdaBase):
|
31 |
+
def forward(self, input):
|
32 |
+
# result is Variables list [Variable1, Variable2, ...]
|
33 |
+
return list(map(self.lambda_func, self.forward_prepare(input)))
|
34 |
+
|
35 |
+
|
36 |
+
class LambdaReduce(LambdaBase):
|
37 |
+
def forward(self, input):
|
38 |
+
# result is a Variable
|
39 |
+
return reduce(self.lambda_func, self.forward_prepare(input))
|
40 |
+
|
41 |
+
|
42 |
+
def copy_param(m, n):
|
43 |
+
if m.weight is not None: n.weight.data.copy_(m.weight)
|
44 |
+
if m.bias is not None: n.bias.data.copy_(m.bias)
|
45 |
+
if hasattr(n, 'running_mean'): n.running_mean.copy_(m.running_mean)
|
46 |
+
if hasattr(n, 'running_var'): n.running_var.copy_(m.running_var)
|
47 |
+
|
48 |
+
|
49 |
+
def add_submodule(seq, *args):
|
50 |
+
for n in args:
|
51 |
+
seq.add_module(str(len(seq._modules)), n)
|
52 |
+
|
53 |
+
|
54 |
+
def lua_recursive_model(module, seq):
|
55 |
+
for m in module.modules:
|
56 |
+
name = type(m).__name__
|
57 |
+
real = m
|
58 |
+
if name == 'TorchObject':
|
59 |
+
name = m._typename.replace('cudnn.', '')
|
60 |
+
m = m._obj
|
61 |
+
|
62 |
+
if name == 'SpatialConvolution':
|
63 |
+
if not hasattr(m, 'groups'): m.groups = 1
|
64 |
+
n = nn.Conv2d(m.nInputPlane, m.nOutputPlane, (m.kW, m.kH),
|
65 |
+
(m.dW, m.dH), (m.padW, m.padH), 1, m.groups,
|
66 |
+
bias=(m.bias is not None))
|
67 |
+
copy_param(m, n)
|
68 |
+
add_submodule(seq, n)
|
69 |
+
elif name == 'SpatialBatchNormalization':
|
70 |
+
n = nn.BatchNorm2d(m.running_mean.size(0), m.eps, m.momentum,
|
71 |
+
m.affine)
|
72 |
+
copy_param(m, n)
|
73 |
+
add_submodule(seq, n)
|
74 |
+
elif name == 'ReLU':
|
75 |
+
n = nn.ReLU()
|
76 |
+
add_submodule(seq, n)
|
77 |
+
elif name == 'SpatialMaxPooling':
|
78 |
+
n = nn.MaxPool2d((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH),
|
79 |
+
ceil_mode=m.ceil_mode)
|
80 |
+
add_submodule(seq, n)
|
81 |
+
elif name == 'SpatialAveragePooling':
|
82 |
+
n = nn.AvgPool2d((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH),
|
83 |
+
ceil_mode=m.ceil_mode)
|
84 |
+
add_submodule(seq, n)
|
85 |
+
elif name == 'SpatialUpSamplingNearest':
|
86 |
+
n = nn.UpsamplingNearest2d(scale_factor=m.scale_factor)
|
87 |
+
add_submodule(seq, n)
|
88 |
+
elif name == 'View':
|
89 |
+
n = Lambda(lambda x: x.view(x.size(0), -1))
|
90 |
+
add_submodule(seq, n)
|
91 |
+
elif name == 'Linear':
|
92 |
+
# Linear in pytorch only accept 2D input
|
93 |
+
n1 = Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x)
|
94 |
+
n2 = nn.Linear(m.weight.size(1), m.weight.size(0),
|
95 |
+
bias=(m.bias is not None))
|
96 |
+
copy_param(m, n2)
|
97 |
+
n = nn.Sequential(n1, n2)
|
98 |
+
add_submodule(seq, n)
|
99 |
+
elif name == 'Dropout':
|
100 |
+
m.inplace = False
|
101 |
+
n = nn.Dropout(m.p)
|
102 |
+
add_submodule(seq, n)
|
103 |
+
elif name == 'SoftMax':
|
104 |
+
n = nn.Softmax()
|
105 |
+
add_submodule(seq, n)
|
106 |
+
elif name == 'Identity':
|
107 |
+
n = Lambda(lambda x: x) # do nothing
|
108 |
+
add_submodule(seq, n)
|
109 |
+
elif name == 'SpatialFullConvolution':
|
110 |
+
n = nn.ConvTranspose2d(m.nInputPlane, m.nOutputPlane, (m.kW, m.kH),
|
111 |
+
(m.dW, m.dH), (m.padW, m.padH))
|
112 |
+
add_submodule(seq, n)
|
113 |
+
elif name == 'SpatialReplicationPadding':
|
114 |
+
n = nn.ReplicationPad2d((m.pad_l, m.pad_r, m.pad_t, m.pad_b))
|
115 |
+
add_submodule(seq, n)
|
116 |
+
elif name == 'SpatialReflectionPadding':
|
117 |
+
n = nn.ReflectionPad2d((m.pad_l, m.pad_r, m.pad_t, m.pad_b))
|
118 |
+
add_submodule(seq, n)
|
119 |
+
elif name == 'Copy':
|
120 |
+
n = Lambda(lambda x: x) # do nothing
|
121 |
+
add_submodule(seq, n)
|
122 |
+
elif name == 'Narrow':
|
123 |
+
n = Lambda(
|
124 |
+
lambda x, a=(m.dimension, m.index, m.length): x.narrow(*a))
|
125 |
+
add_submodule(seq, n)
|
126 |
+
elif name == 'SpatialCrossMapLRN':
|
127 |
+
lrn = torch.legacy.nn.SpatialCrossMapLRN(m.size, m.alpha, m.beta,
|
128 |
+
m.k)
|
129 |
+
n = Lambda(lambda x, lrn=lrn: lrn.forward(x))
|
130 |
+
add_submodule(seq, n)
|
131 |
+
elif name == 'Sequential':
|
132 |
+
n = nn.Sequential()
|
133 |
+
lua_recursive_model(m, n)
|
134 |
+
add_submodule(seq, n)
|
135 |
+
elif name == 'ConcatTable': # output is list
|
136 |
+
n = LambdaMap(lambda x: x)
|
137 |
+
lua_recursive_model(m, n)
|
138 |
+
add_submodule(seq, n)
|
139 |
+
elif name == 'CAddTable': # input is list
|
140 |
+
n = LambdaReduce(lambda x, y: x + y)
|
141 |
+
add_submodule(seq, n)
|
142 |
+
elif name == 'Concat':
|
143 |
+
dim = m.dimension
|
144 |
+
n = LambdaReduce(lambda x, y, dim=dim: torch.cat((x, y), dim))
|
145 |
+
lua_recursive_model(m, n)
|
146 |
+
add_submodule(seq, n)
|
147 |
+
elif name == 'TorchObject':
|
148 |
+
print('Not Implement', name, real._typename)
|
149 |
+
else:
|
150 |
+
print('Not Implement', name)
|
151 |
+
|
152 |
+
|
153 |
+
def lua_recursive_source(module):
|
154 |
+
s = []
|
155 |
+
for m in module.modules:
|
156 |
+
name = type(m).__name__
|
157 |
+
real = m
|
158 |
+
if name == 'TorchObject':
|
159 |
+
name = m._typename.replace('cudnn.', '')
|
160 |
+
m = m._obj
|
161 |
+
|
162 |
+
if name == 'SpatialConvolution':
|
163 |
+
if not hasattr(m, 'groups'): m.groups = 1
|
164 |
+
s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d'.format(
|
165 |
+
m.nInputPlane,
|
166 |
+
m.nOutputPlane, (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH),
|
167 |
+
1, m.groups, m.bias is not None)]
|
168 |
+
elif name == 'SpatialBatchNormalization':
|
169 |
+
s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d'.format(
|
170 |
+
m.running_mean.size(0), m.eps, m.momentum, m.affine)]
|
171 |
+
elif name == 'ReLU':
|
172 |
+
s += ['nn.ReLU()']
|
173 |
+
elif name == 'SpatialMaxPooling':
|
174 |
+
s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format(
|
175 |
+
(m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)]
|
176 |
+
elif name == 'SpatialAveragePooling':
|
177 |
+
s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d'.format(
|
178 |
+
(m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)]
|
179 |
+
elif name == 'SpatialUpSamplingNearest':
|
180 |
+
s += ['nn.UpsamplingNearest2d(scale_factor={})'.format(
|
181 |
+
m.scale_factor)]
|
182 |
+
elif name == 'View':
|
183 |
+
s += ['Lambda(lambda x: x.view(x.size(0),-1)), # View']
|
184 |
+
elif name == 'Linear':
|
185 |
+
s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )'
|
186 |
+
s2 = 'nn.Linear({},{},bias={})'.format(m.weight.size(1),
|
187 |
+
m.weight.size(0),
|
188 |
+
(m.bias is not None))
|
189 |
+
s += ['nn.Sequential({},{}),#Linear'.format(s1, s2)]
|
190 |
+
elif name == 'Dropout':
|
191 |
+
s += ['nn.Dropout({})'.format(m.p)]
|
192 |
+
elif name == 'SoftMax':
|
193 |
+
s += ['nn.Softmax()']
|
194 |
+
elif name == 'Identity':
|
195 |
+
s += ['Lambda(lambda x: x), # Identity']
|
196 |
+
elif name == 'SpatialFullConvolution':
|
197 |
+
s += ['nn.ConvTranspose2d({},{},{},{},{})'.format(m.nInputPlane,
|
198 |
+
m.nOutputPlane,
|
199 |
+
(m.kW, m.kH),
|
200 |
+
(m.dW, m.dH), (
|
201 |
+
m.padW, m.padH))]
|
202 |
+
elif name == 'SpatialReplicationPadding':
|
203 |
+
s += ['nn.ReplicationPad2d({})'.format(
|
204 |
+
(m.pad_l, m.pad_r, m.pad_t, m.pad_b))]
|
205 |
+
elif name == 'SpatialReflectionPadding':
|
206 |
+
s += ['nn.ReflectionPad2d({})'.format(
|
207 |
+
(m.pad_l, m.pad_r, m.pad_t, m.pad_b))]
|
208 |
+
elif name == 'Copy':
|
209 |
+
s += ['Lambda(lambda x: x), # Copy']
|
210 |
+
elif name == 'Narrow':
|
211 |
+
s += ['Lambda(lambda x,a={}: x.narrow(*a))'.format(
|
212 |
+
(m.dimension, m.index, m.length))]
|
213 |
+
elif name == 'SpatialCrossMapLRN':
|
214 |
+
lrn = 'torch.legacy.nn.SpatialCrossMapLRN(*{})'.format(
|
215 |
+
(m.size, m.alpha, m.beta, m.k))
|
216 |
+
s += [
|
217 |
+
'Lambda(lambda x,lrn={}: Variable(lrn.forward(x)))'.format(
|
218 |
+
lrn)]
|
219 |
+
|
220 |
+
elif name == 'Sequential':
|
221 |
+
s += ['nn.Sequential( # Sequential']
|
222 |
+
s += lua_recursive_source(m)
|
223 |
+
s += [')']
|
224 |
+
elif name == 'ConcatTable':
|
225 |
+
s += ['LambdaMap(lambda x: x, # ConcatTable']
|
226 |
+
s += lua_recursive_source(m)
|
227 |
+
s += [')']
|
228 |
+
elif name == 'CAddTable':
|
229 |
+
s += ['LambdaReduce(lambda x,y: x+y), # CAddTable']
|
230 |
+
elif name == 'Concat':
|
231 |
+
dim = m.dimension
|
232 |
+
s += [
|
233 |
+
'LambdaReduce(lambda x,y,dim={}: torch.cat((x,y),dim), # Concat'.format(
|
234 |
+
m.dimension)]
|
235 |
+
s += lua_recursive_source(m)
|
236 |
+
s += [')']
|
237 |
+
else:
|
238 |
+
s += '# ' + name + ' Not Implement,\n'
|
239 |
+
s = map(lambda x: '\t{}'.format(x), s)
|
240 |
+
return s
|
241 |
+
|
242 |
+
|
243 |
+
def simplify_source(s):
|
244 |
+
s = map(lambda x: x.replace(',(1, 1),(0, 0),1,1,bias=True),#Conv2d', ')'),
|
245 |
+
s)
|
246 |
+
s = map(lambda x: x.replace(',(0, 0),1,1,bias=True),#Conv2d', ')'), s)
|
247 |
+
s = map(lambda x: x.replace(',1,1,bias=True),#Conv2d', ')'), s)
|
248 |
+
s = map(lambda x: x.replace(',bias=True),#Conv2d', ')'), s)
|
249 |
+
s = map(lambda x: x.replace('),#Conv2d', ')'), s)
|
250 |
+
s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm2d', ')'), s)
|
251 |
+
s = map(lambda x: x.replace('),#BatchNorm2d', ')'), s)
|
252 |
+
s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d', ')'), s)
|
253 |
+
s = map(lambda x: x.replace(',ceil_mode=False),#MaxPool2d', ')'), s)
|
254 |
+
s = map(lambda x: x.replace('),#MaxPool2d', ')'), s)
|
255 |
+
s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#AvgPool2d', ')'), s)
|
256 |
+
s = map(lambda x: x.replace(',ceil_mode=False),#AvgPool2d', ')'), s)
|
257 |
+
s = map(lambda x: x.replace(',bias=True)),#Linear', ')), # Linear'), s)
|
258 |
+
s = map(lambda x: x.replace(')),#Linear', ')), # Linear'), s)
|
259 |
+
|
260 |
+
s = map(lambda x: '{},\n'.format(x), s)
|
261 |
+
s = map(lambda x: x[1:], s)
|
262 |
+
s = reduce(lambda x, y: x + y, s)
|
263 |
+
return s
|
264 |
+
|
265 |
+
|
266 |
+
def torch_to_pytorch(t7_filename, outputname=None):
|
267 |
+
model = load_lua(t7_filename, unknown_classes=True)
|
268 |
+
if type(model).__name__ == 'hashable_uniq_dict': model = model.model
|
269 |
+
model.gradInput = None
|
270 |
+
slist = lua_recursive_source(torch.legacy.nn.Sequential().add(model))
|
271 |
+
s = simplify_source(slist)
|
272 |
+
header = '''
|
273 |
+
import torch
|
274 |
+
import torch.nn as nn
|
275 |
+
from torch.autograd import Variable
|
276 |
+
from functools import reduce
|
277 |
+
|
278 |
+
class LambdaBase(nn.Sequential):
|
279 |
+
def __init__(self, fn, *args):
|
280 |
+
super(LambdaBase, self).__init__(*args)
|
281 |
+
self.lambda_func = fn
|
282 |
+
|
283 |
+
def forward_prepare(self, input):
|
284 |
+
output = []
|
285 |
+
for module in self._modules.values():
|
286 |
+
output.append(module(input))
|
287 |
+
return output if output else input
|
288 |
+
|
289 |
+
class Lambda(LambdaBase):
|
290 |
+
def forward(self, input):
|
291 |
+
return self.lambda_func(self.forward_prepare(input))
|
292 |
+
|
293 |
+
class LambdaMap(LambdaBase):
|
294 |
+
def forward(self, input):
|
295 |
+
return list(map(self.lambda_func,self.forward_prepare(input)))
|
296 |
+
|
297 |
+
class LambdaReduce(LambdaBase):
|
298 |
+
def forward(self, input):
|
299 |
+
return reduce(self.lambda_func,self.forward_prepare(input))
|
300 |
+
'''
|
301 |
+
varname = t7_filename.replace('.t7', '').replace('.', '_').replace('-',
|
302 |
+
'_')
|
303 |
+
s = '{}\n\n{} = {}'.format(header, varname, s[:-2])
|
304 |
+
|
305 |
+
if outputname is None: outputname = varname
|
306 |
+
with open(outputname + '.py', "w") as pyfile:
|
307 |
+
pyfile.write(s)
|
308 |
+
|
309 |
+
n = nn.Sequential()
|
310 |
+
lua_recursive_model(model, n)
|
311 |
+
torch.save(n.state_dict(), outputname + '.pth')
|
312 |
+
|
313 |
+
|
314 |
+
parser = argparse.ArgumentParser(
|
315 |
+
description='Convert torch t7 model to pytorch')
|
316 |
+
parser.add_argument('--model', '-m', type=str, required=True,
|
317 |
+
help='torch model file in t7 format')
|
318 |
+
parser.add_argument('--output', '-o', type=str, default=None,
|
319 |
+
help='output file name prefix, xxx.py xxx.pth')
|
320 |
+
args = parser.parse_args()
|
321 |
+
|
322 |
+
torch_to_pytorch(args.model, args.output)
|