JamesL404 commited on
Commit
0c50ae6
·
verified ·
1 Parent(s): 67348e9

Upload 8 files

Browse files
Files changed (8) hide show
  1. LICENSE +21 -0
  2. README.md +60 -12
  3. enhence_reinhard.py +213 -0
  4. function.py +67 -0
  5. net.py +152 -0
  6. requirements.txt +11 -0
  7. sampler.py +26 -0
  8. torch_to_pytorch.py +322 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2018 Naoto Inoue
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,60 @@
1
- ---
2
- title: Color Transfer
3
- emoji: 🌍
4
- colorFrom: pink
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 4.19.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch-AdaIN
2
+
3
+ This is an unofficial pytorch implementation of a paper, Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization [Huang+, ICCV2017].
4
+ I'm really grateful to the [original implementation](https://github.com/xunhuang1995/AdaIN-style) in Torch by the authors, which is very useful.
5
+
6
+ ![Results](results.png)
7
+
8
+ ## Requirements
9
+ Please install requirements by `pip install -r requirements.txt`
10
+
11
+ - Python 3.5+
12
+ - PyTorch 0.4+
13
+ - TorchVision
14
+ - Pillow
15
+
16
+ (optional, for training)
17
+ - tqdm
18
+ - TensorboardX
19
+
20
+ ## Usage
21
+
22
+ ### Download models
23
+ Download [decoder.pth](https://drive.google.com/file/d/1bMfhMMwPeXnYSQI6cDWElSZxOxc6aVyr/view?usp=sharing)/[vgg_normalized.pth](https://drive.google.com/file/d/1EpkBA2K2eYILDSyPTt0fztz59UjAIpZU/view?usp=sharing) and put them under `models/`.
24
+
25
+ ### Test
26
+ Use `--content` and `--style` to provide the respective path to the content and style image.
27
+ ```
28
+ CUDA_VISIBLE_DEVICES=<gpu_id> python test.py --content input/content/cornell.jpg --style input/style/woman_with_hat_matisse.jpg
29
+ ```
30
+
31
+ You can also run the code on directories of content and style images using `--content_dir` and `--style_dir`. It will save every possible combination of content and styles to the output directory.
32
+ ```
33
+ CUDA_VISIBLE_DEVICES=<gpu_id> python test.py --content_dir input/content --style_dir input/style
34
+ ```
35
+
36
+ This is an example of mixing four styles by specifying `--style` and `--style_interpolation_weights` option.
37
+ ```
38
+ CUDA_VISIBLE_DEVICES=<gpu_id> python test.py --content input/content/avril.jpg --style input/style/picasso_self_portrait.jpg,input/style/impronte_d_artista.jpg,input/style/trial.jpg,input/style/antimonocromatismo.jpg --style_interpolation_weights 1,1,1,1 --content_size 512 --style_size 512 --crop
39
+ ```
40
+
41
+ Some other options:
42
+ * `--content_size`: New (minimum) size for the content image. Keeping the original size if set to 0.
43
+ * `--style_size`: New (minimum) size for the style image. Keeping the original size if set to 0.
44
+ * `--alpha`: Adjust the degree of stylization. It should be a value between 0.0 and 1.0 (default).
45
+ * `--preserve_color`: Preserve the color of the content image.
46
+
47
+
48
+ ### Train
49
+ Use `--content_dir` and `--style_dir` to provide the respective directory to the content and style images.
50
+ ```
51
+ CUDA_VISIBLE_DEVICES=<gpu_id> python train.py --content_dir <content_dir> --style_dir <style_dir>
52
+ ```
53
+
54
+ For more details and parameters, please refer to --help option.
55
+
56
+ I share the model trained by this code [here](https://drive.google.com/file/d/1YIBRdgGBoVllLhmz_N7PwfeP5V9Vz2Nr/view?usp=sharing)
57
+
58
+ ## References
59
+ - [1]: X. Huang and S. Belongie. "Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization.", in ICCV, 2017.
60
+ - [2]: [Original implementation in Torch](https://github.com/xunhuang1995/AdaIN-style)
enhence_reinhard.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ import numpy as np
4
+ import os
5
+ from PIL import Image
6
+ import cv2
7
+ import skimage
8
+
9
+ METHODS = ('lhm', 'pccm', 'reinhard')
10
+
11
+ img_dir = 'C:/Users/joshua.lin/Desktop/AdaIN/pytorch-AdaIN-master/input/'
12
+
13
+
14
+ def transfer_lhm(content, reference):
15
+ """Transfers colors from a reference image to a content image using the
16
+ Linear Histogram Matching.
17
+
18
+ content: NumPy array (HxWxC)
19
+ reference: NumPy array (HxWxC)
20
+ """
21
+ # Convert HxWxC image to a (H*W)xC matrix.
22
+ shape = content.shape
23
+ assert len(shape) == 3
24
+ content = content.reshape(-1, shape[-1]).astype(np.float32)
25
+ reference = reference.reshape(-1, shape[-1]).astype(np.float32)
26
+
27
+ def matrix_sqrt(X):
28
+ eig_val, eig_vec = np.linalg.eig(X)
29
+ return eig_vec.dot(np.diag(np.sqrt(eig_val))).dot(eig_vec.T)
30
+
31
+ #
32
+ mu_content = np.mean(content, axis=0)
33
+ #
34
+ mu_reference = np.mean(reference, axis=0)
35
+
36
+ cov_content = np.cov(content, rowvar=False)
37
+ cov_reference = np.cov(reference, rowvar=False)
38
+
39
+ #
40
+ result = matrix_sqrt(cov_reference)
41
+
42
+ #
43
+ result = result.dot(np.linalg.inv(matrix_sqrt(cov_content)))
44
+
45
+ #
46
+ result = result.dot((content - mu_content).T).T
47
+ #result = result.dot((content*1 - mu_content*0.5).T).T*3
48
+
49
+ #
50
+ result = result + mu_reference
51
+
52
+
53
+ # Restore image dimensions.
54
+ result = result.reshape(shape).clip(0, 255).round().astype(np.uint8)
55
+
56
+ return result
57
+
58
+
59
+ def transfer_pccm(content, reference):
60
+ """Transfers colors from a reference image to a content image using
61
+ Principal Component Color Matching.
62
+
63
+ content: NumPy array (HxWxC)
64
+ reference: NumPy array (HxWxC)
65
+ """
66
+ # Convert HxWxC image to a (H*W)xC matrix.
67
+ shape = content.shape
68
+ assert len(shape) == 3
69
+ content = content.reshape(-1, shape[-1]).astype(np.float32)
70
+ reference = reference.reshape(-1, shape[-1]).astype(np.float32)
71
+
72
+ mu_content = np.mean(content, axis=0)
73
+ mu_reference = np.mean(reference, axis=0)
74
+
75
+ cov_content = np.cov(content, rowvar=False)
76
+ cov_reference = np.cov(reference, rowvar=False)
77
+
78
+ eigval_content, eigvec_content = np.linalg.eig(cov_content)
79
+ eigval_reference, eigvec_reference = np.linalg.eig(cov_reference)
80
+
81
+ scaling = np.diag(np.sqrt(eigval_reference / eigval_content))
82
+ transform = eigvec_reference.dot(scaling).dot(eigvec_content.T)
83
+ result = (content - mu_content).dot(transform.T) + mu_reference
84
+ # Restore image dimensions.
85
+ result = result.reshape(shape).clip(0, 255).round().astype(np.uint8)
86
+
87
+ return result
88
+
89
+
90
+ def transfer_reinhard(content, reference):
91
+ """Transfers colors from a reference image to a content image using the
92
+ technique from Reinhard et al.
93
+
94
+ content: NumPy array (HxWxC)
95
+ reference: NumPy array (HxWxC)
96
+ """
97
+ # Convert HxWxC image to a (H*W)xC matrix.
98
+ shape = content.shape
99
+ assert len(shape) == 3
100
+ content = content.reshape(-1, shape[-1]).astype(np.float32)
101
+ reference = reference.reshape(-1, shape[-1]).astype(np.float32)
102
+
103
+ m1 = np.array([
104
+ [0.3811, 0.1967, 0.0241],
105
+ [0.5783, 0.7244, 0.1288],
106
+ [0.0402, 0.0782, 0.8444],
107
+ ])
108
+
109
+ m2 = np.array([
110
+ [0.5774, 0.4082, 0.7071],
111
+ [0.5774, 0.4082, -0.7071],
112
+ [0.5774, -0.8165, 0.0000],
113
+ ])
114
+
115
+ m3 = np.array([
116
+ [0.5774, 0.5774, 0.5774],
117
+ [0.4082, 0.4082, -0.8165],
118
+ [0.7071, -0.7071, 0.0000],
119
+ ])
120
+
121
+ m4 = np.array([
122
+ [4.4679, -1.2186, 0.0497],
123
+ [-3.5873, 2.3809, -0.2439],
124
+ [0.1193, -0.1624, 1.2045],
125
+ ])
126
+
127
+ # Avoid log of 0. Clipping is used instead of adding epsilon, to avoid
128
+ # taking a log of a small number whose very low output distorts the results.
129
+ # WARN: This differs from the Reinhard paper, where no adjustment is made.
130
+ lab_content = np.log10(np.maximum(1.0, content.dot(m1))).dot(m2)
131
+ lab_reference = np.log10(np.maximum(1.0, reference.dot(m1))).dot(m2)
132
+
133
+ mu_content = lab_content.mean(axis=0) # shape=3
134
+ mu_reference = lab_reference.mean(axis=0)
135
+
136
+ std_source = np.std(content, axis=0)
137
+ std_target = np.std(reference, axis=0)
138
+ #variable percentage for mu and std
139
+ result = lab_content - mu_content
140
+ result *= std_target
141
+ result /= std_source
142
+ result += mu_reference
143
+ result = (10 ** result.dot(m3)).dot(m4)
144
+ # Restore image dimensions.
145
+ result = result.reshape(shape).clip(0, 255).round().astype(np.uint8)
146
+
147
+ return result
148
+
149
+ # ===================================================================================
150
+
151
+
152
+ def parse_args(argv):
153
+ parser = argparse.ArgumentParser(
154
+ prog='colortrans',
155
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
156
+ )
157
+
158
+ # Optional arguments
159
+ parser.add_argument(
160
+ '--method', default='lhm', choices=METHODS,
161
+ help='Algorithm to use for color transfer.')
162
+
163
+ # Required arguments
164
+ parser.add_argument('content', help='Path to content image (qualitative appearance).')
165
+ parser.add_argument('reference', help='Path to reference image (desired colors).')
166
+ parser.add_argument('output', help='Path to output image.')
167
+
168
+ args = parser.parse_args(argv[1:])
169
+
170
+ return args
171
+
172
+
173
+ def main(argv=sys.argv):
174
+ args = parse_args(argv)
175
+ content_img = Image.open(args.content).convert('RGB')
176
+ # The slicing is to remove transparency channels if they exist.
177
+ content = np.array(content_img)[:, :, :3]
178
+ reference_img = Image.open(args.reference).convert('RGB')
179
+ reference = np.array(reference_img)[:, :, :3]
180
+ transfer = globals()[f'transfer_{args.method}']
181
+ output = transfer(content, reference)
182
+ Image.fromarray(output).save(args.output)
183
+
184
+
185
+ # ==================================================================================
186
+
187
+
188
+ def test_reinhard():
189
+ content_path = img_dir + 'content/brad_pitt.jpg'
190
+ style_path = 'output/brad_pitt_stylized_Neon_City.jpg'
191
+ content_img = Image.open(content_path).convert('RGB')
192
+ content = np.array(content_img)[:, :, :3]
193
+ style_img = Image.open(style_path).convert('RGB')
194
+ style = np.array(style_img)[:, :, :3]
195
+ output = transfer_lhm(content, style)
196
+ Image.fromarray(output).save('output/processed.jpg')
197
+
198
+
199
+ def test1():
200
+ img_path = img_dir + '2.jpg'
201
+ img = skimage.io.imread(img_path)
202
+ sk_imgf = skimage.util.img_as_float32(img)
203
+ cv_img = skimage.img_as_ubyte(img)
204
+
205
+ print('')
206
+
207
+
208
+ # ==============================================================
209
+ if __name__ == '__main__':
210
+ test_reinhard()
211
+ # test1()
212
+
213
+
function.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def calc_mean_std(feat, eps=1e-5):
5
+ # eps is a small value added to the variance to avoid divide-by-zero.
6
+ size = feat.size()
7
+ assert (len(size) == 4)
8
+ N, C = size[:2]
9
+ feat_var = feat.view(N, C, -1).var(dim=2) + eps
10
+ feat_std = feat_var.sqrt().view(N, C, 1, 1)
11
+ feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
12
+ return feat_mean, feat_std
13
+
14
+
15
+ def adaptive_instance_normalization(content_feat, style_feat):
16
+ assert (content_feat.size()[:2] == style_feat.size()[:2])
17
+ size = content_feat.size()
18
+ style_mean, style_std = calc_mean_std(style_feat)
19
+ content_mean, content_std = calc_mean_std(content_feat)
20
+
21
+ normalized_feat = (content_feat - content_mean.expand(
22
+ size)) / content_std.expand(size)
23
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
24
+
25
+
26
+ def _calc_feat_flatten_mean_std(feat):
27
+ # takes 3D feat (C, H, W), return mean and std of array within channels
28
+ assert (feat.size()[0] == 3)
29
+ assert (isinstance(feat, torch.FloatTensor))
30
+ feat_flatten = feat.view(3, -1)
31
+ mean = feat_flatten.mean(dim=-1, keepdim=True)
32
+ std = feat_flatten.std(dim=-1, keepdim=True)
33
+ return feat_flatten, mean, std
34
+
35
+
36
+ def _mat_sqrt(x):
37
+ U, D, V = torch.svd(x)
38
+ return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())
39
+
40
+
41
+ def coral(source, target):
42
+ # assume both source and target are 3D array (C, H, W)
43
+ # Note: flatten -> f
44
+
45
+ source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
46
+ source_f_norm = (source_f - source_f_mean.expand_as(
47
+ source_f)) / source_f_std.expand_as(source_f)
48
+ source_f_cov_eye = \
49
+ torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)
50
+
51
+ target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
52
+ target_f_norm = (target_f - target_f_mean.expand_as(
53
+ target_f)) / target_f_std.expand_as(target_f)
54
+ target_f_cov_eye = \
55
+ torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)
56
+
57
+ source_f_norm_transfer = torch.mm(
58
+ _mat_sqrt(target_f_cov_eye),
59
+ torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
60
+ source_f_norm)
61
+ )
62
+
63
+ source_f_transfer = source_f_norm_transfer * \
64
+ target_f_std.expand_as(source_f_norm) + \
65
+ target_f_mean.expand_as(source_f_norm)
66
+
67
+ return source_f_transfer.view(source.size())
net.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from function import adaptive_instance_normalization as adain
4
+ from function import calc_mean_std
5
+
6
+ decoder = nn.Sequential(
7
+ nn.ReflectionPad2d((1, 1, 1, 1)),
8
+ nn.Conv2d(512, 256, (3, 3)),
9
+ nn.ReLU(),
10
+ nn.Upsample(scale_factor=2, mode='nearest'),
11
+ nn.ReflectionPad2d((1, 1, 1, 1)),
12
+ nn.Conv2d(256, 256, (3, 3)),
13
+ nn.ReLU(),
14
+ nn.ReflectionPad2d((1, 1, 1, 1)),
15
+ nn.Conv2d(256, 256, (3, 3)),
16
+ nn.ReLU(),
17
+ nn.ReflectionPad2d((1, 1, 1, 1)),
18
+ nn.Conv2d(256, 256, (3, 3)),
19
+ nn.ReLU(),
20
+ nn.ReflectionPad2d((1, 1, 1, 1)),
21
+ nn.Conv2d(256, 128, (3, 3)),
22
+ nn.ReLU(),
23
+ nn.Upsample(scale_factor=2, mode='nearest'),
24
+ nn.ReflectionPad2d((1, 1, 1, 1)),
25
+ nn.Conv2d(128, 128, (3, 3)),
26
+ nn.ReLU(),
27
+ nn.ReflectionPad2d((1, 1, 1, 1)),
28
+ nn.Conv2d(128, 64, (3, 3)),
29
+ nn.ReLU(),
30
+ nn.Upsample(scale_factor=2, mode='nearest'),
31
+ nn.ReflectionPad2d((1, 1, 1, 1)),
32
+ nn.Conv2d(64, 64, (3, 3)),
33
+ nn.ReLU(),
34
+ nn.ReflectionPad2d((1, 1, 1, 1)),
35
+ nn.Conv2d(64, 3, (3, 3)),
36
+ )
37
+
38
+ vgg = nn.Sequential(
39
+ nn.Conv2d(3, 3, (1, 1)),
40
+ nn.ReflectionPad2d((1, 1, 1, 1)),
41
+ nn.Conv2d(3, 64, (3, 3)),
42
+ nn.ReLU(), # relu1-1
43
+ nn.ReflectionPad2d((1, 1, 1, 1)),
44
+ nn.Conv2d(64, 64, (3, 3)),
45
+ nn.ReLU(), # relu1-2
46
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
47
+ nn.ReflectionPad2d((1, 1, 1, 1)),
48
+ nn.Conv2d(64, 128, (3, 3)),
49
+ nn.ReLU(), # relu2-1
50
+ nn.ReflectionPad2d((1, 1, 1, 1)),
51
+ nn.Conv2d(128, 128, (3, 3)),
52
+ nn.ReLU(), # relu2-2
53
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
54
+ nn.ReflectionPad2d((1, 1, 1, 1)),
55
+ nn.Conv2d(128, 256, (3, 3)),
56
+ nn.ReLU(), # relu3-1
57
+ nn.ReflectionPad2d((1, 1, 1, 1)),
58
+ nn.Conv2d(256, 256, (3, 3)),
59
+ nn.ReLU(), # relu3-2
60
+ nn.ReflectionPad2d((1, 1, 1, 1)),
61
+ nn.Conv2d(256, 256, (3, 3)),
62
+ nn.ReLU(), # relu3-3
63
+ nn.ReflectionPad2d((1, 1, 1, 1)),
64
+ nn.Conv2d(256, 256, (3, 3)),
65
+ nn.ReLU(), # relu3-4
66
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
67
+ nn.ReflectionPad2d((1, 1, 1, 1)),
68
+ nn.Conv2d(256, 512, (3, 3)),
69
+ nn.ReLU(), # relu4-1, this is the last layer used
70
+ nn.ReflectionPad2d((1, 1, 1, 1)),
71
+ nn.Conv2d(512, 512, (3, 3)),
72
+ nn.ReLU(), # relu4-2
73
+ nn.ReflectionPad2d((1, 1, 1, 1)),
74
+ nn.Conv2d(512, 512, (3, 3)),
75
+ nn.ReLU(), # relu4-3
76
+ nn.ReflectionPad2d((1, 1, 1, 1)),
77
+ nn.Conv2d(512, 512, (3, 3)),
78
+ nn.ReLU(), # relu4-4
79
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
80
+ nn.ReflectionPad2d((1, 1, 1, 1)),
81
+ nn.Conv2d(512, 512, (3, 3)),
82
+ nn.ReLU(), # relu5-1
83
+ nn.ReflectionPad2d((1, 1, 1, 1)),
84
+ nn.Conv2d(512, 512, (3, 3)),
85
+ nn.ReLU(), # relu5-2
86
+ nn.ReflectionPad2d((1, 1, 1, 1)),
87
+ nn.Conv2d(512, 512, (3, 3)),
88
+ nn.ReLU(), # relu5-3
89
+ nn.ReflectionPad2d((1, 1, 1, 1)),
90
+ nn.Conv2d(512, 512, (3, 3)),
91
+ nn.ReLU() # relu5-4
92
+ )
93
+
94
+
95
+ class Net(nn.Module):
96
+ def __init__(self, encoder, decoder):
97
+ super(Net, self).__init__()
98
+ enc_layers = list(encoder.children())
99
+ self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1
100
+ self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1
101
+ self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1
102
+ self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1
103
+ self.decoder = decoder
104
+ self.mse_loss = nn.MSELoss()
105
+
106
+ # fix the encoder
107
+ for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
108
+ for param in getattr(self, name).parameters():
109
+ param.requires_grad = False
110
+
111
+ # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
112
+ def encode_with_intermediate(self, input):
113
+ results = [input]
114
+ for i in range(4):
115
+ func = getattr(self, 'enc_{:d}'.format(i + 1))
116
+ results.append(func(results[-1]))
117
+ return results[1:]
118
+
119
+ # extract relu4_1 from input image
120
+ def encode(self, input):
121
+ for i in range(4):
122
+ input = getattr(self, 'enc_{:d}'.format(i + 1))(input)
123
+ return input
124
+
125
+ def calc_content_loss(self, input, target):
126
+ assert (input.size() == target.size())
127
+ assert (target.requires_grad is False)
128
+ return self.mse_loss(input, target)
129
+
130
+ def calc_style_loss(self, input, target):
131
+ assert (input.size() == target.size())
132
+ assert (target.requires_grad is False)
133
+ input_mean, input_std = calc_mean_std(input)
134
+ target_mean, target_std = calc_mean_std(target)
135
+ return self.mse_loss(input_mean, target_mean) + \
136
+ self.mse_loss(input_std, target_std)
137
+
138
+ def forward(self, content, style, alpha=1.0):
139
+ assert 0 <= alpha <= 1
140
+ style_feats = self.encode_with_intermediate(style)
141
+ content_feat = self.encode(content)
142
+ t = adain(content_feat, style_feats[-1])
143
+ t = alpha * t + (1 - alpha) * content_feat
144
+
145
+ g_t = self.decoder(t)
146
+ g_t_feats = self.encode_with_intermediate(g_t)
147
+
148
+ loss_c = self.calc_content_loss(g_t_feats[-1], t)
149
+ loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
150
+ for i in range(1, 4):
151
+ loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
152
+ return loss_c, loss_s
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Pillow==9.0.1
3
+ pkg-resources==0.0.0
4
+ protobuf==3.15.0
5
+ six==1.12.0
6
+ tensorboardX==1.8
7
+ torch==1.2.0
8
+ torchvision==0.4.0
9
+ tqdm==4.35.0
10
+ opencv-python==4.4.0.46
11
+ imageio==2.9.0
sampler.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torch.utils import data
3
+
4
+
5
+ def InfiniteSampler(n):
6
+ # i = 0
7
+ i = n - 1
8
+ order = np.random.permutation(n)
9
+ while True:
10
+ yield order[i]
11
+ i += 1
12
+ if i >= n:
13
+ np.random.seed()
14
+ order = np.random.permutation(n)
15
+ i = 0
16
+
17
+
18
+ class InfiniteSamplerWrapper(data.sampler.Sampler):
19
+ def __init__(self, data_source):
20
+ self.num_samples = len(data_source)
21
+
22
+ def __iter__(self):
23
+ return iter(InfiniteSampler(self.num_samples))
24
+
25
+ def __len__(self):
26
+ return 2 ** 31
torch_to_pytorch.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+
3
+ import argparse
4
+ from functools import reduce
5
+
6
+ import torch
7
+ assert torch.__version__.split('.')[0] == '0', 'Only working on PyTorch 0.x.x'
8
+ import torch.nn as nn
9
+ from torch.autograd import Variable
10
+ from torch.utils.serialization import load_lua
11
+
12
+
13
+ class LambdaBase(nn.Sequential):
14
+ def __init__(self, fn, *args):
15
+ super(LambdaBase, self).__init__(*args)
16
+ self.lambda_func = fn
17
+
18
+ def forward_prepare(self, input):
19
+ output = []
20
+ for module in self._modules.values():
21
+ output.append(module(input))
22
+ return output if output else input
23
+
24
+
25
+ class Lambda(LambdaBase):
26
+ def forward(self, input):
27
+ return self.lambda_func(self.forward_prepare(input))
28
+
29
+
30
+ class LambdaMap(LambdaBase):
31
+ def forward(self, input):
32
+ # result is Variables list [Variable1, Variable2, ...]
33
+ return list(map(self.lambda_func, self.forward_prepare(input)))
34
+
35
+
36
+ class LambdaReduce(LambdaBase):
37
+ def forward(self, input):
38
+ # result is a Variable
39
+ return reduce(self.lambda_func, self.forward_prepare(input))
40
+
41
+
42
+ def copy_param(m, n):
43
+ if m.weight is not None: n.weight.data.copy_(m.weight)
44
+ if m.bias is not None: n.bias.data.copy_(m.bias)
45
+ if hasattr(n, 'running_mean'): n.running_mean.copy_(m.running_mean)
46
+ if hasattr(n, 'running_var'): n.running_var.copy_(m.running_var)
47
+
48
+
49
+ def add_submodule(seq, *args):
50
+ for n in args:
51
+ seq.add_module(str(len(seq._modules)), n)
52
+
53
+
54
+ def lua_recursive_model(module, seq):
55
+ for m in module.modules:
56
+ name = type(m).__name__
57
+ real = m
58
+ if name == 'TorchObject':
59
+ name = m._typename.replace('cudnn.', '')
60
+ m = m._obj
61
+
62
+ if name == 'SpatialConvolution':
63
+ if not hasattr(m, 'groups'): m.groups = 1
64
+ n = nn.Conv2d(m.nInputPlane, m.nOutputPlane, (m.kW, m.kH),
65
+ (m.dW, m.dH), (m.padW, m.padH), 1, m.groups,
66
+ bias=(m.bias is not None))
67
+ copy_param(m, n)
68
+ add_submodule(seq, n)
69
+ elif name == 'SpatialBatchNormalization':
70
+ n = nn.BatchNorm2d(m.running_mean.size(0), m.eps, m.momentum,
71
+ m.affine)
72
+ copy_param(m, n)
73
+ add_submodule(seq, n)
74
+ elif name == 'ReLU':
75
+ n = nn.ReLU()
76
+ add_submodule(seq, n)
77
+ elif name == 'SpatialMaxPooling':
78
+ n = nn.MaxPool2d((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH),
79
+ ceil_mode=m.ceil_mode)
80
+ add_submodule(seq, n)
81
+ elif name == 'SpatialAveragePooling':
82
+ n = nn.AvgPool2d((m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH),
83
+ ceil_mode=m.ceil_mode)
84
+ add_submodule(seq, n)
85
+ elif name == 'SpatialUpSamplingNearest':
86
+ n = nn.UpsamplingNearest2d(scale_factor=m.scale_factor)
87
+ add_submodule(seq, n)
88
+ elif name == 'View':
89
+ n = Lambda(lambda x: x.view(x.size(0), -1))
90
+ add_submodule(seq, n)
91
+ elif name == 'Linear':
92
+ # Linear in pytorch only accept 2D input
93
+ n1 = Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x)
94
+ n2 = nn.Linear(m.weight.size(1), m.weight.size(0),
95
+ bias=(m.bias is not None))
96
+ copy_param(m, n2)
97
+ n = nn.Sequential(n1, n2)
98
+ add_submodule(seq, n)
99
+ elif name == 'Dropout':
100
+ m.inplace = False
101
+ n = nn.Dropout(m.p)
102
+ add_submodule(seq, n)
103
+ elif name == 'SoftMax':
104
+ n = nn.Softmax()
105
+ add_submodule(seq, n)
106
+ elif name == 'Identity':
107
+ n = Lambda(lambda x: x) # do nothing
108
+ add_submodule(seq, n)
109
+ elif name == 'SpatialFullConvolution':
110
+ n = nn.ConvTranspose2d(m.nInputPlane, m.nOutputPlane, (m.kW, m.kH),
111
+ (m.dW, m.dH), (m.padW, m.padH))
112
+ add_submodule(seq, n)
113
+ elif name == 'SpatialReplicationPadding':
114
+ n = nn.ReplicationPad2d((m.pad_l, m.pad_r, m.pad_t, m.pad_b))
115
+ add_submodule(seq, n)
116
+ elif name == 'SpatialReflectionPadding':
117
+ n = nn.ReflectionPad2d((m.pad_l, m.pad_r, m.pad_t, m.pad_b))
118
+ add_submodule(seq, n)
119
+ elif name == 'Copy':
120
+ n = Lambda(lambda x: x) # do nothing
121
+ add_submodule(seq, n)
122
+ elif name == 'Narrow':
123
+ n = Lambda(
124
+ lambda x, a=(m.dimension, m.index, m.length): x.narrow(*a))
125
+ add_submodule(seq, n)
126
+ elif name == 'SpatialCrossMapLRN':
127
+ lrn = torch.legacy.nn.SpatialCrossMapLRN(m.size, m.alpha, m.beta,
128
+ m.k)
129
+ n = Lambda(lambda x, lrn=lrn: lrn.forward(x))
130
+ add_submodule(seq, n)
131
+ elif name == 'Sequential':
132
+ n = nn.Sequential()
133
+ lua_recursive_model(m, n)
134
+ add_submodule(seq, n)
135
+ elif name == 'ConcatTable': # output is list
136
+ n = LambdaMap(lambda x: x)
137
+ lua_recursive_model(m, n)
138
+ add_submodule(seq, n)
139
+ elif name == 'CAddTable': # input is list
140
+ n = LambdaReduce(lambda x, y: x + y)
141
+ add_submodule(seq, n)
142
+ elif name == 'Concat':
143
+ dim = m.dimension
144
+ n = LambdaReduce(lambda x, y, dim=dim: torch.cat((x, y), dim))
145
+ lua_recursive_model(m, n)
146
+ add_submodule(seq, n)
147
+ elif name == 'TorchObject':
148
+ print('Not Implement', name, real._typename)
149
+ else:
150
+ print('Not Implement', name)
151
+
152
+
153
+ def lua_recursive_source(module):
154
+ s = []
155
+ for m in module.modules:
156
+ name = type(m).__name__
157
+ real = m
158
+ if name == 'TorchObject':
159
+ name = m._typename.replace('cudnn.', '')
160
+ m = m._obj
161
+
162
+ if name == 'SpatialConvolution':
163
+ if not hasattr(m, 'groups'): m.groups = 1
164
+ s += ['nn.Conv2d({},{},{},{},{},{},{},bias={}),#Conv2d'.format(
165
+ m.nInputPlane,
166
+ m.nOutputPlane, (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH),
167
+ 1, m.groups, m.bias is not None)]
168
+ elif name == 'SpatialBatchNormalization':
169
+ s += ['nn.BatchNorm2d({},{},{},{}),#BatchNorm2d'.format(
170
+ m.running_mean.size(0), m.eps, m.momentum, m.affine)]
171
+ elif name == 'ReLU':
172
+ s += ['nn.ReLU()']
173
+ elif name == 'SpatialMaxPooling':
174
+ s += ['nn.MaxPool2d({},{},{},ceil_mode={}),#MaxPool2d'.format(
175
+ (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)]
176
+ elif name == 'SpatialAveragePooling':
177
+ s += ['nn.AvgPool2d({},{},{},ceil_mode={}),#AvgPool2d'.format(
178
+ (m.kW, m.kH), (m.dW, m.dH), (m.padW, m.padH), m.ceil_mode)]
179
+ elif name == 'SpatialUpSamplingNearest':
180
+ s += ['nn.UpsamplingNearest2d(scale_factor={})'.format(
181
+ m.scale_factor)]
182
+ elif name == 'View':
183
+ s += ['Lambda(lambda x: x.view(x.size(0),-1)), # View']
184
+ elif name == 'Linear':
185
+ s1 = 'Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x )'
186
+ s2 = 'nn.Linear({},{},bias={})'.format(m.weight.size(1),
187
+ m.weight.size(0),
188
+ (m.bias is not None))
189
+ s += ['nn.Sequential({},{}),#Linear'.format(s1, s2)]
190
+ elif name == 'Dropout':
191
+ s += ['nn.Dropout({})'.format(m.p)]
192
+ elif name == 'SoftMax':
193
+ s += ['nn.Softmax()']
194
+ elif name == 'Identity':
195
+ s += ['Lambda(lambda x: x), # Identity']
196
+ elif name == 'SpatialFullConvolution':
197
+ s += ['nn.ConvTranspose2d({},{},{},{},{})'.format(m.nInputPlane,
198
+ m.nOutputPlane,
199
+ (m.kW, m.kH),
200
+ (m.dW, m.dH), (
201
+ m.padW, m.padH))]
202
+ elif name == 'SpatialReplicationPadding':
203
+ s += ['nn.ReplicationPad2d({})'.format(
204
+ (m.pad_l, m.pad_r, m.pad_t, m.pad_b))]
205
+ elif name == 'SpatialReflectionPadding':
206
+ s += ['nn.ReflectionPad2d({})'.format(
207
+ (m.pad_l, m.pad_r, m.pad_t, m.pad_b))]
208
+ elif name == 'Copy':
209
+ s += ['Lambda(lambda x: x), # Copy']
210
+ elif name == 'Narrow':
211
+ s += ['Lambda(lambda x,a={}: x.narrow(*a))'.format(
212
+ (m.dimension, m.index, m.length))]
213
+ elif name == 'SpatialCrossMapLRN':
214
+ lrn = 'torch.legacy.nn.SpatialCrossMapLRN(*{})'.format(
215
+ (m.size, m.alpha, m.beta, m.k))
216
+ s += [
217
+ 'Lambda(lambda x,lrn={}: Variable(lrn.forward(x)))'.format(
218
+ lrn)]
219
+
220
+ elif name == 'Sequential':
221
+ s += ['nn.Sequential( # Sequential']
222
+ s += lua_recursive_source(m)
223
+ s += [')']
224
+ elif name == 'ConcatTable':
225
+ s += ['LambdaMap(lambda x: x, # ConcatTable']
226
+ s += lua_recursive_source(m)
227
+ s += [')']
228
+ elif name == 'CAddTable':
229
+ s += ['LambdaReduce(lambda x,y: x+y), # CAddTable']
230
+ elif name == 'Concat':
231
+ dim = m.dimension
232
+ s += [
233
+ 'LambdaReduce(lambda x,y,dim={}: torch.cat((x,y),dim), # Concat'.format(
234
+ m.dimension)]
235
+ s += lua_recursive_source(m)
236
+ s += [')']
237
+ else:
238
+ s += '# ' + name + ' Not Implement,\n'
239
+ s = map(lambda x: '\t{}'.format(x), s)
240
+ return s
241
+
242
+
243
+ def simplify_source(s):
244
+ s = map(lambda x: x.replace(',(1, 1),(0, 0),1,1,bias=True),#Conv2d', ')'),
245
+ s)
246
+ s = map(lambda x: x.replace(',(0, 0),1,1,bias=True),#Conv2d', ')'), s)
247
+ s = map(lambda x: x.replace(',1,1,bias=True),#Conv2d', ')'), s)
248
+ s = map(lambda x: x.replace(',bias=True),#Conv2d', ')'), s)
249
+ s = map(lambda x: x.replace('),#Conv2d', ')'), s)
250
+ s = map(lambda x: x.replace(',1e-05,0.1,True),#BatchNorm2d', ')'), s)
251
+ s = map(lambda x: x.replace('),#BatchNorm2d', ')'), s)
252
+ s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#MaxPool2d', ')'), s)
253
+ s = map(lambda x: x.replace(',ceil_mode=False),#MaxPool2d', ')'), s)
254
+ s = map(lambda x: x.replace('),#MaxPool2d', ')'), s)
255
+ s = map(lambda x: x.replace(',(0, 0),ceil_mode=False),#AvgPool2d', ')'), s)
256
+ s = map(lambda x: x.replace(',ceil_mode=False),#AvgPool2d', ')'), s)
257
+ s = map(lambda x: x.replace(',bias=True)),#Linear', ')), # Linear'), s)
258
+ s = map(lambda x: x.replace(')),#Linear', ')), # Linear'), s)
259
+
260
+ s = map(lambda x: '{},\n'.format(x), s)
261
+ s = map(lambda x: x[1:], s)
262
+ s = reduce(lambda x, y: x + y, s)
263
+ return s
264
+
265
+
266
+ def torch_to_pytorch(t7_filename, outputname=None):
267
+ model = load_lua(t7_filename, unknown_classes=True)
268
+ if type(model).__name__ == 'hashable_uniq_dict': model = model.model
269
+ model.gradInput = None
270
+ slist = lua_recursive_source(torch.legacy.nn.Sequential().add(model))
271
+ s = simplify_source(slist)
272
+ header = '''
273
+ import torch
274
+ import torch.nn as nn
275
+ from torch.autograd import Variable
276
+ from functools import reduce
277
+
278
+ class LambdaBase(nn.Sequential):
279
+ def __init__(self, fn, *args):
280
+ super(LambdaBase, self).__init__(*args)
281
+ self.lambda_func = fn
282
+
283
+ def forward_prepare(self, input):
284
+ output = []
285
+ for module in self._modules.values():
286
+ output.append(module(input))
287
+ return output if output else input
288
+
289
+ class Lambda(LambdaBase):
290
+ def forward(self, input):
291
+ return self.lambda_func(self.forward_prepare(input))
292
+
293
+ class LambdaMap(LambdaBase):
294
+ def forward(self, input):
295
+ return list(map(self.lambda_func,self.forward_prepare(input)))
296
+
297
+ class LambdaReduce(LambdaBase):
298
+ def forward(self, input):
299
+ return reduce(self.lambda_func,self.forward_prepare(input))
300
+ '''
301
+ varname = t7_filename.replace('.t7', '').replace('.', '_').replace('-',
302
+ '_')
303
+ s = '{}\n\n{} = {}'.format(header, varname, s[:-2])
304
+
305
+ if outputname is None: outputname = varname
306
+ with open(outputname + '.py', "w") as pyfile:
307
+ pyfile.write(s)
308
+
309
+ n = nn.Sequential()
310
+ lua_recursive_model(model, n)
311
+ torch.save(n.state_dict(), outputname + '.pth')
312
+
313
+
314
+ parser = argparse.ArgumentParser(
315
+ description='Convert torch t7 model to pytorch')
316
+ parser.add_argument('--model', '-m', type=str, required=True,
317
+ help='torch model file in t7 format')
318
+ parser.add_argument('--output', '-o', type=str, default=None,
319
+ help='output file name prefix, xxx.py xxx.pth')
320
+ args = parser.parse_args()
321
+
322
+ torch_to_pytorch(args.model, args.output)