Yi Xie commited on
Commit
321f459
·
1 Parent(s): d7ffaa3

Add MangaScaleV3 on ESRGAN+ arch

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +26 -0
  2. converter.py +4 -1
  3. esrgan_plus/LICENSE +201 -0
  4. esrgan_plus/README.md +48 -0
  5. esrgan_plus/codes/auto_test.py +32 -0
  6. esrgan_plus/codes/data/LRHR_dataset.py +128 -0
  7. esrgan_plus/codes/data/LRHR_seg_bg_dataset.py +149 -0
  8. esrgan_plus/codes/data/LR_dataset.py +40 -0
  9. esrgan_plus/codes/data/__init__.py +37 -0
  10. esrgan_plus/codes/data/util.py +434 -0
  11. esrgan_plus/codes/models/SFTGAN_ACD_model.py +261 -0
  12. esrgan_plus/codes/models/SRGAN_model.py +240 -0
  13. esrgan_plus/codes/models/SRRaGAN_model.py +251 -0
  14. esrgan_plus/codes/models/SR_model.py +151 -0
  15. esrgan_plus/codes/models/__init__.py +20 -0
  16. esrgan_plus/codes/models/__pycache__/__init__.cpython-310.pyc +0 -0
  17. esrgan_plus/codes/models/base_model.py +85 -0
  18. esrgan_plus/codes/models/modules/__pycache__/architecture.cpython-310.pyc +0 -0
  19. esrgan_plus/codes/models/modules/__pycache__/block.cpython-310.pyc +0 -0
  20. esrgan_plus/codes/models/modules/__pycache__/spectral_norm.cpython-310.pyc +0 -0
  21. esrgan_plus/codes/models/modules/architecture.py +394 -0
  22. esrgan_plus/codes/models/modules/block.py +322 -0
  23. esrgan_plus/codes/models/modules/loss.py +60 -0
  24. esrgan_plus/codes/models/modules/seg_arch.py +70 -0
  25. esrgan_plus/codes/models/modules/sft_arch.py +226 -0
  26. esrgan_plus/codes/models/modules/spectral_norm.py +149 -0
  27. esrgan_plus/codes/models/networks.py +155 -0
  28. esrgan_plus/codes/options/options.py +120 -0
  29. esrgan_plus/codes/options/test/test_ESRGANplus.json +40 -0
  30. esrgan_plus/codes/options/test/test_SRGAN.json +37 -0
  31. esrgan_plus/codes/options/test/test_SRResNet.json +40 -0
  32. esrgan_plus/codes/options/test/test_sr.json +40 -0
  33. esrgan_plus/codes/options/train/train_ESRGANplus.json +83 -0
  34. esrgan_plus/codes/options/train/train_SRGAN.json +87 -0
  35. esrgan_plus/codes/options/train/train_SRResNet.json +66 -0
  36. esrgan_plus/codes/options/train/train_sftgan.json +76 -0
  37. esrgan_plus/codes/options/train/train_sr.json +66 -0
  38. esrgan_plus/codes/scripts/README.md +8 -0
  39. esrgan_plus/codes/scripts/back_projection/backprojection.m +20 -0
  40. esrgan_plus/codes/scripts/back_projection/main_bp.m +22 -0
  41. esrgan_plus/codes/scripts/back_projection/main_reverse_filter.m +25 -0
  42. esrgan_plus/codes/scripts/color2gray.py +63 -0
  43. esrgan_plus/codes/scripts/create_lmdb.py +66 -0
  44. esrgan_plus/codes/scripts/extract_enlarge_patches.py +64 -0
  45. esrgan_plus/codes/scripts/extract_subimgs_single.py +88 -0
  46. esrgan_plus/codes/scripts/generate_mod_LR_bic.m +82 -0
  47. esrgan_plus/codes/scripts/generate_mod_LR_bic.py +74 -0
  48. esrgan_plus/codes/scripts/make_gif_video.py +106 -0
  49. esrgan_plus/codes/scripts/net_interp.py +20 -0
  50. esrgan_plus/codes/scripts/rename.py +25 -0
.gitignore ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # General
2
+ .DS_Store
3
+ .AppleDouble
4
+ .LSOverride
5
+
6
+ # Icon must end with two \r
7
+ Icon
8
+
9
+ # Thumbnails
10
+ ._*
11
+
12
+ # Files that might appear in the root of a volume
13
+ .DocumentRevisions-V100
14
+ .fseventsd
15
+ .Spotlight-V100
16
+ .TemporaryItems
17
+ .Trashes
18
+ .VolumeIcon.icns
19
+ .com.apple.timemachine.donotpresent
20
+
21
+ # Directories potentially created on remote AFP share
22
+ .AppleDB
23
+ .AppleDesktop
24
+ Network Trash Folder
25
+ Temporary Items
26
+ .apdisk
converter.py CHANGED
@@ -28,7 +28,7 @@ parser = argparse.ArgumentParser(
28
  )
29
  parser.add_argument('filename')
30
  required_args = parser.add_argument_group('required')
31
- required_args.add_argument('--type', choices=['esrgan_old', 'esrgan_old_lite', 'real_esrgan', 'real_esrgan_compact'], required=True, help='Type of the model')
32
  required_args.add_argument('--name', type=str, required=True, help='Name of the model')
33
  required_args.add_argument('--scale', type=int, required=True, help='Scale factor of the model')
34
  required_args.add_argument('--out-dir', type=str, required=True, help='Output directory')
@@ -118,6 +118,9 @@ elif args.type == 'real_esrgan':
118
  elif args.type == 'real_esrgan_compact':
119
  from basicsr.archs.srvgg_arch import SRVGGNetCompact
120
  torch_model = SRVGGNetCompact(num_in_ch=channels, num_out_ch=channels, num_feat=num_features, num_conv=num_convs, upscale=args.scale, act_type='prelu')
 
 
 
121
  else:
122
  logger.fatal('Unknown model type: %s', args.type)
123
  sys.exit(-1)
 
28
  )
29
  parser.add_argument('filename')
30
  required_args = parser.add_argument_group('required')
31
+ required_args.add_argument('--type', choices=['esrgan_old', 'esrgan_old_lite', 'real_esrgan', 'real_esrgan_compact', 'esrgan_plus'], required=True, help='Type of the model')
32
  required_args.add_argument('--name', type=str, required=True, help='Name of the model')
33
  required_args.add_argument('--scale', type=int, required=True, help='Scale factor of the model')
34
  required_args.add_argument('--out-dir', type=str, required=True, help='Output directory')
 
118
  elif args.type == 'real_esrgan_compact':
119
  from basicsr.archs.srvgg_arch import SRVGGNetCompact
120
  torch_model = SRVGGNetCompact(num_in_ch=channels, num_out_ch=channels, num_feat=num_features, num_conv=num_convs, upscale=args.scale, act_type='prelu')
121
+ elif args.type == 'esrgan_plus':
122
+ from esrgan_plus.codes.models.modules.architecture import RRDBNet
123
+ torch_model = RRDBNet(in_nc=channels, out_nc=channels, nf=num_features, nb=num_blocks, gc=32, upscale=args.scale)
124
  else:
125
  logger.fatal('Unknown model type: %s', args.type)
126
  sys.exit(-1)
esrgan_plus/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
esrgan_plus/README.md ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ESRGAN+ nESRGAN+ Tarsier
2
+ ## ICASSP 2020 - ESRGAN+ : Further Improving Enhanced Super-Resolution Generative Adversarial Network
3
+ ### [Paper arXiv](https://arxiv.org/abs/2001.08073)
4
+ ### [Paper IEEE Xplore](https://ieeexplore.ieee.org/document/9054071)
5
+ ## ICPR 2020 - Tarsier: Evolving Noise Injection in Super-Resolution GANs
6
+ ### [Paper arXiv](https://arxiv.org/abs/2009.12177)
7
+
8
+ <p align="center">
9
+ <img height="250" src="./figures/noise_per_residual_dense_block.PNG">
10
+ </p>
11
+
12
+ <p align="center">
13
+ <img src="./figures/qualitative_result.PNG">
14
+ </p>
15
+
16
+ ### Dependencies
17
+
18
+ - Python 3 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux))
19
+ - [PyTorch >= 1.0.0](https://pytorch.org/)
20
+ - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
21
+ - Python packages: `pip install numpy opencv-python lmdb tensorboardX`
22
+
23
+ ## How to test
24
+ 1. Place your low-resolution images in `test_image/LR` folder.
25
+ 2. Download pretrained models from [Google Drive](https://drive.google.com/drive/folders/1lNky9afqEP-qdxrAwDFPJ1g0ui4x7Sin?usp=sharing) and place them in `test_image/pretrained_models`.
26
+ 2. Run the command: `python test_image/test.py test_image/pretrained_models/nESRGANplus.pth` (or any other models).
27
+ 3. The results are in `test_image/results` folder.
28
+
29
+
30
+ ## How to train
31
+ 1. Prepare the datasets which can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1pRmhEmmY-tPF7uH8DuVthfHoApZWJ1QU).
32
+ 2. Prepare the PSNR-oriented pretrained model (all pretrained models can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1lNky9afqEP-qdxrAwDFPJ1g0ui4x7Sin?usp=sharing)).
33
+ 2. Modify the configuration file `codes/options/train/train_ESRGANplus.json`.
34
+ 3. Run the command `python train.py -opt codes/options/train/train_ESRGANplus.json`.
35
+
36
+ ## Acknowledgement
37
+ - This code is based on [BasicSR](https://github.com/xinntao/BasicSR).
38
+
39
+ ## Citation
40
+
41
+ @INPROCEEDINGS{9054071,
42
+ author = {N. C. {Rakotonirina} and A. {Rasoanaivo}},
43
+ booktitle={ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
44
+ title={ESRGAN+ : Further Improving Enhanced Super-Resolution Generative Adversarial Network},
45
+ year={2020},
46
+ volume={},
47
+ number={},
48
+ pages={3637-3641},}
esrgan_plus/codes/auto_test.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''auto test several models.'''
2
+
3
+ import json
4
+ import os
5
+
6
+ test_json_path = 'options/test/test_esrgan_auto.json'
7
+
8
+
9
+ def modify_json(json_path, model_name, iteration):
10
+ with open(json_path, 'r+') as json_file:
11
+ config = json.load(json_file)
12
+
13
+ config['name'] = model_name
14
+ config['datasets']['test_1']['name'] = 'pirm_test_{:d}k'.format(iteration)
15
+ # config['datasets']['test_1']['dataroot_LR'] = \
16
+ # '/home/carraz/datasets/PIRM/PIRM_Test_set/LR'
17
+ config['path']['pretrain_model_G'] = \
18
+ '../experiments/{:s}/models/{:d}_G.pth'.format(model_name, iteration*1000)
19
+ json_file.seek(0) # rewind
20
+ json.dump(config, json_file)
21
+ json_file.truncate() # if the new data is smaller than the previous
22
+
23
+
24
+ model_iter_dict = {}
25
+ model_iter_dict['100_ESRGAN_SRResNet_pristine_pixel10_minc'] = [80, 85, 90, 95]
26
+
27
+ for model_name, iter_list in model_iter_dict.items():
28
+ for iteration in iter_list:
29
+ modify_json(test_json_path, model_name, iteration)
30
+ # run test scripts
31
+ print('\n\nTesting {:s} {:d}k...'.format(model_name, iteration))
32
+ os.system('python test.py -opt ' + test_json_path)
esrgan_plus/codes/data/LRHR_dataset.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import random
3
+ import numpy as np
4
+ import cv2
5
+ import torch
6
+ import torch.utils.data as data
7
+ import data.util as util
8
+
9
+
10
+ class LRHRDataset(data.Dataset):
11
+ '''
12
+ Read LR and HR image pairs.
13
+ If only HR image is provided, generate LR image on-the-fly.
14
+ The pair is ensured by 'sorted' function, so please check the name convention.
15
+ '''
16
+
17
+ def __init__(self, opt):
18
+ super(LRHRDataset, self).__init__()
19
+ self.opt = opt
20
+ self.paths_LR = None
21
+ self.paths_HR = None
22
+ self.LR_env = None # environment for lmdb
23
+ self.HR_env = None
24
+
25
+ # read image list from subset list txt
26
+ if opt['subset_file'] is not None and opt['phase'] == 'train':
27
+ with open(opt['subset_file']) as f:
28
+ self.paths_HR = sorted([os.path.join(opt['dataroot_HR'], line.rstrip('\n')) \
29
+ for line in f])
30
+ if opt['dataroot_LR'] is not None:
31
+ raise NotImplementedError('Now subset only supports generating LR on-the-fly.')
32
+ else: # read image list from lmdb or image files
33
+ self.HR_env, self.paths_HR = util.get_image_paths(opt['data_type'], opt['dataroot_HR'])
34
+ self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR'])
35
+
36
+ assert self.paths_HR, 'Error: HR path is empty.'
37
+ if self.paths_LR and self.paths_HR:
38
+ assert len(self.paths_LR) == len(self.paths_HR), \
39
+ 'HR and LR datasets have different number of images - {}, {}.'.format(\
40
+ len(self.paths_LR), len(self.paths_HR))
41
+
42
+ self.random_scale_list = [1]
43
+
44
+ def __getitem__(self, index):
45
+ HR_path, LR_path = None, None
46
+ scale = self.opt['scale']
47
+ HR_size = self.opt['HR_size']
48
+
49
+ # get HR image
50
+ HR_path = self.paths_HR[index]
51
+ img_HR = util.read_img(self.HR_env, HR_path)
52
+ # modcrop in the validation / test phase
53
+ if self.opt['phase'] != 'train':
54
+ img_HR = util.modcrop(img_HR, scale)
55
+ # change color space if necessary
56
+ if self.opt['color']:
57
+ img_HR = util.channel_convert(img_HR.shape[2], self.opt['color'], [img_HR])[0]
58
+
59
+ # get LR image
60
+ if self.paths_LR:
61
+ LR_path = self.paths_LR[index]
62
+ img_LR = util.read_img(self.LR_env, LR_path)
63
+ else: # down-sampling on-the-fly
64
+ # randomly scale during training
65
+ if self.opt['phase'] == 'train':
66
+ random_scale = random.choice(self.random_scale_list)
67
+ H_s, W_s, _ = img_HR.shape
68
+
69
+ def _mod(n, random_scale, scale, thres):
70
+ rlt = int(n * random_scale)
71
+ rlt = (rlt // scale) * scale
72
+ return thres if rlt < thres else rlt
73
+
74
+ H_s = _mod(H_s, random_scale, scale, HR_size)
75
+ W_s = _mod(W_s, random_scale, scale, HR_size)
76
+ img_HR = cv2.resize(np.copy(img_HR), (W_s, H_s), interpolation=cv2.INTER_LINEAR)
77
+ # force to 3 channels
78
+ if img_HR.ndim == 2:
79
+ img_HR = cv2.cvtColor(img_HR, cv2.COLOR_GRAY2BGR)
80
+
81
+ H, W, _ = img_HR.shape
82
+ # using matlab imresize
83
+ img_LR = util.imresize_np(img_HR, 1 / scale, True)
84
+ if img_LR.ndim == 2:
85
+ img_LR = np.expand_dims(img_LR, axis=2)
86
+
87
+ if self.opt['phase'] == 'train':
88
+ # if the image size is too small
89
+ H, W, _ = img_HR.shape
90
+ if H < HR_size or W < HR_size:
91
+ img_HR = cv2.resize(
92
+ np.copy(img_HR), (HR_size, HR_size), interpolation=cv2.INTER_LINEAR)
93
+ # using matlab imresize
94
+ img_LR = util.imresize_np(img_HR, 1 / scale, True)
95
+ if img_LR.ndim == 2:
96
+ img_LR = np.expand_dims(img_LR, axis=2)
97
+
98
+ H, W, C = img_LR.shape
99
+ LR_size = HR_size // scale
100
+
101
+ # randomly crop
102
+ rnd_h = random.randint(0, max(0, H - LR_size))
103
+ rnd_w = random.randint(0, max(0, W - LR_size))
104
+ img_LR = img_LR[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :]
105
+ rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
106
+ img_HR = img_HR[rnd_h_HR:rnd_h_HR + HR_size, rnd_w_HR:rnd_w_HR + HR_size, :]
107
+
108
+ # augmentation - flip, rotate
109
+ img_LR, img_HR = util.augment([img_LR, img_HR], self.opt['use_flip'], \
110
+ self.opt['use_rot'])
111
+
112
+ # change color space if necessary
113
+ if self.opt['color']:
114
+ img_LR = util.channel_convert(C, self.opt['color'], [img_LR])[0] # TODO during val no definetion
115
+
116
+ # BGR to RGB, HWC to CHW, numpy to tensor
117
+ if img_HR.shape[2] == 3:
118
+ img_HR = img_HR[:, :, [2, 1, 0]]
119
+ img_LR = img_LR[:, :, [2, 1, 0]]
120
+ img_HR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HR, (2, 0, 1)))).float()
121
+ img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float()
122
+
123
+ if LR_path is None:
124
+ LR_path = HR_path
125
+ return {'LR': img_LR, 'HR': img_HR, 'LR_path': LR_path, 'HR_path': HR_path}
126
+
127
+ def __len__(self):
128
+ return len(self.paths_HR)
esrgan_plus/codes/data/LRHR_seg_bg_dataset.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import random
3
+ import numpy as np
4
+ import cv2
5
+ import torch
6
+ import torch.utils.data as data
7
+ import data.util as util
8
+
9
+
10
+ class LRHRSeg_BG_Dataset(data.Dataset):
11
+ '''
12
+ Read HR image, segmentation probability map; generate LR image, category for SFTGAN
13
+ also sample general scenes for background
14
+ need to generate LR images on-the-fly
15
+ '''
16
+
17
+ def __init__(self, opt):
18
+ super(LRHRSeg_BG_Dataset, self).__init__()
19
+ self.opt = opt
20
+ self.paths_LR = None
21
+ self.paths_HR = None
22
+ self.paths_HR_bg = None # HR images for background scenes
23
+ self.LR_env = None # environment for lmdb
24
+ self.HR_env = None
25
+ self.HR_env_bg = None
26
+
27
+ # read image list from lmdb or image files
28
+ self.HR_env, self.paths_HR = util.get_image_paths(opt['data_type'], opt['dataroot_HR'])
29
+ self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR'])
30
+ self.HR_env_bg, self.paths_HR_bg = util.get_image_paths(opt['data_type'], \
31
+ opt['dataroot_HR_bg'])
32
+
33
+ assert self.paths_HR, 'Error: HR path is empty.'
34
+ if self.paths_LR and self.paths_HR:
35
+ assert len(self.paths_LR) == len(self.paths_HR), \
36
+ 'HR and LR datasets have different number of images - {}, {}.'.format(\
37
+ len(self.paths_LR), len(self.paths_HR))
38
+
39
+ self.random_scale_list = [1, 0.9, 0.8, 0.7, 0.6, 0.5]
40
+ self.ratio = 10 # 10 OST data samples and 1 DIV2K general data samples(background)
41
+
42
+ def __getitem__(self, index):
43
+ HR_path, LR_path = None, None
44
+ scale = self.opt['scale']
45
+ HR_size = self.opt['HR_size']
46
+
47
+ # get HR image
48
+ if self.opt['phase'] == 'train' and \
49
+ random.choice(list(range(self.ratio))) == 0: # read background images
50
+ bg_index = random.randint(0, len(self.paths_HR_bg) - 1)
51
+ HR_path = self.paths_HR_bg[bg_index]
52
+ img_HR = util.read_img(self.HR_env_bg, HR_path)
53
+ seg = torch.FloatTensor(8, img_HR.shape[0], img_HR.shape[1]).fill_(0)
54
+ seg[0, :, :] = 1 # background
55
+ else:
56
+ HR_path = self.paths_HR[index]
57
+ img_HR = util.read_img(self.HR_env, HR_path)
58
+ seg = torch.load(HR_path.replace('/img/', '/bicseg/').replace('.png', '.pth'))
59
+ # read segmentatin files, you should change it to your settings.
60
+
61
+ # modcrop in the validation / test phase
62
+ if self.opt['phase'] != 'train':
63
+ img_HR = util.modcrop(img_HR, 8)
64
+
65
+ seg = np.transpose(seg.numpy(), (1, 2, 0))
66
+
67
+ # get LR image
68
+ if self.paths_LR:
69
+ LR_path = self.paths_LR[index]
70
+ img_LR = util.read_img(self.LR_env, LR_path)
71
+ else: # down-sampling on-the-fly
72
+ # randomly scale during training
73
+ if self.opt['phase'] == 'train':
74
+ random_scale = random.choice(self.random_scale_list)
75
+ H_s, W_s, _ = seg.shape
76
+
77
+ def _mod(n, random_scale, scale, thres):
78
+ rlt = int(n * random_scale)
79
+ rlt = (rlt // scale) * scale
80
+ return thres if rlt < thres else rlt
81
+
82
+ H_s = _mod(H_s, random_scale, scale, HR_size)
83
+ W_s = _mod(W_s, random_scale, scale, HR_size)
84
+ img_HR = cv2.resize(np.copy(img_HR), (W_s, H_s), interpolation=cv2.INTER_LINEAR)
85
+ seg = cv2.resize(np.copy(seg), (W_s, H_s), interpolation=cv2.INTER_NEAREST)
86
+
87
+ H, W, _ = img_HR.shape
88
+ # using matlab imresize
89
+ img_LR = util.imresize_np(img_HR, 1 / scale, True)
90
+ if img_LR.ndim == 2:
91
+ img_LR = np.expand_dims(img_LR, axis=2)
92
+
93
+ H, W, C = img_LR.shape
94
+ if self.opt['phase'] == 'train':
95
+ LR_size = HR_size // scale
96
+
97
+ # randomly crop
98
+ rnd_h = random.randint(0, max(0, H - LR_size))
99
+ rnd_w = random.randint(0, max(0, W - LR_size))
100
+ img_LR = img_LR[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :]
101
+ rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
102
+ img_HR = img_HR[rnd_h_HR:rnd_h_HR + HR_size, rnd_w_HR:rnd_w_HR + HR_size, :]
103
+ seg = seg[rnd_h_HR:rnd_h_HR + HR_size, rnd_w_HR:rnd_w_HR + HR_size, :]
104
+
105
+ # augmentation - flip, rotate
106
+ img_LR, img_HR, seg = util.augment([img_LR, img_HR, seg], self.opt['use_flip'],
107
+ self.opt['use_rot'])
108
+
109
+ # category
110
+ if 'building' in HR_path:
111
+ category = 1
112
+ elif 'plant' in HR_path:
113
+ category = 2
114
+ elif 'mountain' in HR_path:
115
+ category = 3
116
+ elif 'water' in HR_path:
117
+ category = 4
118
+ elif 'sky' in HR_path:
119
+ category = 5
120
+ elif 'grass' in HR_path:
121
+ category = 6
122
+ elif 'animal' in HR_path:
123
+ category = 7
124
+ else:
125
+ category = 0 # background
126
+ else:
127
+ category = -1 # during val, useless
128
+
129
+ # BGR to RGB, HWC to CHW, numpy to tensor
130
+ if img_HR.shape[2] == 3:
131
+ img_HR = img_HR[:, :, [2, 1, 0]]
132
+ img_LR = img_LR[:, :, [2, 1, 0]]
133
+ img_HR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HR, (2, 0, 1)))).float()
134
+ img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float()
135
+ seg = torch.from_numpy(np.ascontiguousarray(np.transpose(seg, (2, 0, 1)))).float()
136
+
137
+ if LR_path is None:
138
+ LR_path = HR_path
139
+ return {
140
+ 'LR': img_LR,
141
+ 'HR': img_HR,
142
+ 'seg': seg,
143
+ 'category': category,
144
+ 'LR_path': LR_path,
145
+ 'HR_path': HR_path
146
+ }
147
+
148
+ def __len__(self):
149
+ return len(self.paths_HR)
esrgan_plus/codes/data/LR_dataset.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data as data
4
+ import data.util as util
5
+
6
+
7
+ class LRDataset(data.Dataset):
8
+ '''Read LR images only in the test phase.'''
9
+
10
+ def __init__(self, opt):
11
+ super(LRDataset, self).__init__()
12
+ self.opt = opt
13
+ self.paths_LR = None
14
+ self.LR_env = None # environment for lmdb
15
+
16
+ # read image list from lmdb or image files
17
+ self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR'])
18
+ assert self.paths_LR, 'Error: LR paths are empty.'
19
+
20
+ def __getitem__(self, index):
21
+ LR_path = None
22
+
23
+ # get LR image
24
+ LR_path = self.paths_LR[index]
25
+ img_LR = util.read_img(self.LR_env, LR_path)
26
+ H, W, C = img_LR.shape
27
+
28
+ # change color space if necessary
29
+ if self.opt['color']:
30
+ img_LR = util.channel_convert(C, self.opt['color'], [img_LR])[0]
31
+
32
+ # BGR to RGB, HWC to CHW, numpy to tensor
33
+ if img_LR.shape[2] == 3:
34
+ img_LR = img_LR[:, :, [2, 1, 0]]
35
+ img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float()
36
+
37
+ return {'LR': img_LR, 'LR_path': LR_path}
38
+
39
+ def __len__(self):
40
+ return len(self.paths_LR)
esrgan_plus/codes/data/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''create dataset and dataloader'''
2
+ import logging
3
+ import torch.utils.data
4
+
5
+
6
+ def create_dataloader(dataset, dataset_opt):
7
+ '''create dataloader '''
8
+ phase = dataset_opt['phase']
9
+ if phase == 'train':
10
+ return torch.utils.data.DataLoader(
11
+ dataset,
12
+ batch_size=dataset_opt['batch_size'],
13
+ shuffle=dataset_opt['use_shuffle'],
14
+ num_workers=dataset_opt['n_workers'],
15
+ drop_last=True,
16
+ pin_memory=True)
17
+ else:
18
+ return torch.utils.data.DataLoader(
19
+ dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)
20
+
21
+
22
+ def create_dataset(dataset_opt):
23
+ '''create dataset'''
24
+ mode = dataset_opt['mode']
25
+ if mode == 'LR':
26
+ from data.LR_dataset import LRDataset as D
27
+ elif mode == 'LRHR':
28
+ from data.LRHR_dataset import LRHRDataset as D
29
+ elif mode == 'LRHRseg_bg':
30
+ from data.LRHR_seg_bg_dataset import LRHRSeg_BG_Dataset as D
31
+ else:
32
+ raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
33
+ dataset = D(dataset_opt)
34
+ logger = logging.getLogger('base')
35
+ logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
36
+ dataset_opt['name']))
37
+ return dataset
esrgan_plus/codes/data/util.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import pickle
4
+ import random
5
+ import numpy as np
6
+ import lmdb
7
+ import torch
8
+ import cv2
9
+ import logging
10
+
11
+ IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
12
+
13
+ ####################
14
+ # Files & IO
15
+ ####################
16
+
17
+
18
+ def is_image_file(filename):
19
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
20
+
21
+
22
+ def _get_paths_from_images(path):
23
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
24
+ images = []
25
+ for dirpath, _, fnames in sorted(os.walk(path)):
26
+ for fname in sorted(fnames):
27
+ if is_image_file(fname):
28
+ img_path = os.path.join(dirpath, fname)
29
+ images.append(img_path)
30
+ assert images, '{:s} has no valid image file'.format(path)
31
+ return images
32
+
33
+
34
+ def _get_paths_from_lmdb(dataroot):
35
+ env = lmdb.open(dataroot, readonly=True, lock=False, readahead=False, meminit=False)
36
+ keys_cache_file = os.path.join(dataroot, '_keys_cache.p')
37
+ logger = logging.getLogger('base')
38
+ if os.path.isfile(keys_cache_file):
39
+ logger.info('Read lmdb keys from cache: {}'.format(keys_cache_file))
40
+ keys = pickle.load(open(keys_cache_file, "rb"))
41
+ else:
42
+ with env.begin(write=False) as txn:
43
+ logger.info('Creating lmdb keys cache: {}'.format(keys_cache_file))
44
+ keys = [key.decode('ascii') for key, _ in txn.cursor()]
45
+ pickle.dump(keys, open(keys_cache_file, 'wb'))
46
+ paths = sorted([key for key in keys if not key.endswith('.meta')])
47
+ return env, paths
48
+
49
+
50
+ def get_image_paths(data_type, dataroot):
51
+ env, paths = None, None
52
+ if dataroot is not None:
53
+ if data_type == 'lmdb':
54
+ env, paths = _get_paths_from_lmdb(dataroot)
55
+ elif data_type == 'img':
56
+ paths = sorted(_get_paths_from_images(dataroot))
57
+ else:
58
+ raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type))
59
+ return env, paths
60
+
61
+
62
+ def _read_lmdb_img(env, path):
63
+ with env.begin(write=False) as txn:
64
+ buf = txn.get(path.encode('ascii'))
65
+ buf_meta = txn.get((path + '.meta').encode('ascii')).decode('ascii')
66
+ img_flat = np.frombuffer(buf, dtype=np.uint8)
67
+ H, W, C = [int(s) for s in buf_meta.split(',')]
68
+ img = img_flat.reshape(H, W, C)
69
+ return img
70
+
71
+
72
+ def read_img(env, path):
73
+ # read image by cv2 or from lmdb
74
+ # return: Numpy float32, HWC, BGR, [0,1]
75
+ if env is None: # img
76
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
77
+ else:
78
+ img = _read_lmdb_img(env, path)
79
+ img = img.astype(np.float32) / 255.
80
+ if img.ndim == 2:
81
+ img = np.expand_dims(img, axis=2)
82
+ # some images have 4 channels
83
+ if img.shape[2] > 3:
84
+ img = img[:, :, :3]
85
+ return img
86
+
87
+
88
+ ####################
89
+ # image processing
90
+ # process on numpy image
91
+ ####################
92
+
93
+
94
+ def augment(img_list, hflip=True, rot=True):
95
+ # horizontal flip OR rotate
96
+ hflip = hflip and random.random() < 0.5
97
+ vflip = rot and random.random() < 0.5
98
+ rot90 = rot and random.random() < 0.5
99
+
100
+ def _augment(img):
101
+ if hflip: img = img[:, ::-1, :]
102
+ if vflip: img = img[::-1, :, :]
103
+ if rot90: img = img.transpose(1, 0, 2)
104
+ return img
105
+
106
+ return [_augment(img) for img in img_list]
107
+
108
+
109
+ def channel_convert(in_c, tar_type, img_list):
110
+ # conversion among BGR, gray and y
111
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
112
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
113
+ return [np.expand_dims(img, axis=2) for img in gray_list]
114
+ elif in_c == 3 and tar_type == 'y': # BGR to y
115
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
116
+ return [np.expand_dims(img, axis=2) for img in y_list]
117
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
118
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
119
+ else:
120
+ return img_list
121
+
122
+
123
+ def rgb2ycbcr(img, only_y=True):
124
+ '''same as matlab rgb2ycbcr
125
+ only_y: only return Y channel
126
+ Input:
127
+ uint8, [0, 255]
128
+ float, [0, 1]
129
+ '''
130
+ in_img_type = img.dtype
131
+ img.astype(np.float32)
132
+ if in_img_type != np.uint8:
133
+ img *= 255.
134
+ # convert
135
+ if only_y:
136
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
137
+ else:
138
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
139
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
140
+ if in_img_type == np.uint8:
141
+ rlt = rlt.round()
142
+ else:
143
+ rlt /= 255.
144
+ return rlt.astype(in_img_type)
145
+
146
+
147
+ def bgr2ycbcr(img, only_y=True):
148
+ '''bgr version of rgb2ycbcr
149
+ only_y: only return Y channel
150
+ Input:
151
+ uint8, [0, 255]
152
+ float, [0, 1]
153
+ '''
154
+ in_img_type = img.dtype
155
+ img.astype(np.float32)
156
+ if in_img_type != np.uint8:
157
+ img *= 255.
158
+ # convert
159
+ if only_y:
160
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
161
+ else:
162
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
163
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
164
+ if in_img_type == np.uint8:
165
+ rlt = rlt.round()
166
+ else:
167
+ rlt /= 255.
168
+ return rlt.astype(in_img_type)
169
+
170
+
171
+ def ycbcr2rgb(img):
172
+ '''same as matlab ycbcr2rgb
173
+ Input:
174
+ uint8, [0, 255]
175
+ float, [0, 1]
176
+ '''
177
+ in_img_type = img.dtype
178
+ img.astype(np.float32)
179
+ if in_img_type != np.uint8:
180
+ img *= 255.
181
+ # convert
182
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
183
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
184
+ if in_img_type == np.uint8:
185
+ rlt = rlt.round()
186
+ else:
187
+ rlt /= 255.
188
+ return rlt.astype(in_img_type)
189
+
190
+
191
+ def modcrop(img_in, scale):
192
+ # img_in: Numpy, HWC or HW
193
+ img = np.copy(img_in)
194
+ if img.ndim == 2:
195
+ H, W = img.shape
196
+ H_r, W_r = H % scale, W % scale
197
+ img = img[:H - H_r, :W - W_r]
198
+ elif img.ndim == 3:
199
+ H, W, C = img.shape
200
+ H_r, W_r = H % scale, W % scale
201
+ img = img[:H - H_r, :W - W_r, :]
202
+ else:
203
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
204
+ return img
205
+
206
+
207
+ ####################
208
+ # Functions
209
+ ####################
210
+
211
+
212
+ # matlab 'imresize' function, now only support 'bicubic'
213
+ def cubic(x):
214
+ absx = torch.abs(x)
215
+ absx2 = absx**2
216
+ absx3 = absx**3
217
+ return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
218
+ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
219
+
220
+
221
+ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
222
+ if (scale < 1) and (antialiasing):
223
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
224
+ kernel_width = kernel_width / scale
225
+
226
+ # Output-space coordinates
227
+ x = torch.linspace(1, out_length, out_length)
228
+
229
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
230
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
231
+ # space maps to 1.5 in input space.
232
+ u = x / scale + 0.5 * (1 - 1 / scale)
233
+
234
+ # What is the left-most pixel that can be involved in the computation?
235
+ left = torch.floor(u - kernel_width / 2)
236
+
237
+ # What is the maximum number of pixels that can be involved in the
238
+ # computation? Note: it's OK to use an extra pixel here; if the
239
+ # corresponding weights are all zero, it will be eliminated at the end
240
+ # of this function.
241
+ P = math.ceil(kernel_width) + 2
242
+
243
+ # The indices of the input pixels involved in computing the k-th output
244
+ # pixel are in row k of the indices matrix.
245
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
246
+ 1, P).expand(out_length, P)
247
+
248
+ # The weights used to compute the k-th output pixel are in row k of the
249
+ # weights matrix.
250
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
251
+ # apply cubic kernel
252
+ if (scale < 1) and (antialiasing):
253
+ weights = scale * cubic(distance_to_center * scale)
254
+ else:
255
+ weights = cubic(distance_to_center)
256
+ # Normalize the weights matrix so that each row sums to 1.
257
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
258
+ weights = weights / weights_sum.expand(out_length, P)
259
+
260
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
261
+ weights_zero_tmp = torch.sum((weights == 0), 0)
262
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
263
+ indices = indices.narrow(1, 1, P - 2)
264
+ weights = weights.narrow(1, 1, P - 2)
265
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
266
+ indices = indices.narrow(1, 0, P - 2)
267
+ weights = weights.narrow(1, 0, P - 2)
268
+ weights = weights.contiguous()
269
+ indices = indices.contiguous()
270
+ sym_len_s = -indices.min() + 1
271
+ sym_len_e = indices.max() - in_length
272
+ indices = indices + sym_len_s - 1
273
+ return weights, indices, int(sym_len_s), int(sym_len_e)
274
+
275
+
276
+ def imresize(img, scale, antialiasing=True):
277
+ # Now the scale should be the same for H and W
278
+ # input: img: CHW RGB [0,1]
279
+ # output: CHW RGB [0,1] w/o round
280
+
281
+ in_C, in_H, in_W = img.size()
282
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
283
+ kernel_width = 4
284
+ kernel = 'cubic'
285
+
286
+ # Return the desired dimension order for performing the resize. The
287
+ # strategy is to perform the resize first along the dimension with the
288
+ # smallest scale factor.
289
+ # Now we do not support this.
290
+
291
+ # get weights and indices
292
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
293
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
294
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
295
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
296
+ # process H dimension
297
+ # symmetric copying
298
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
299
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
300
+
301
+ sym_patch = img[:, :sym_len_Hs, :]
302
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
303
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
304
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
305
+
306
+ sym_patch = img[:, -sym_len_He:, :]
307
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
308
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
309
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
310
+
311
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
312
+ kernel_width = weights_H.size(1)
313
+ for i in range(out_H):
314
+ idx = int(indices_H[i][0])
315
+ out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
316
+ out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
317
+ out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
318
+
319
+ # process W dimension
320
+ # symmetric copying
321
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
322
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
323
+
324
+ sym_patch = out_1[:, :, :sym_len_Ws]
325
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
326
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
327
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
328
+
329
+ sym_patch = out_1[:, :, -sym_len_We:]
330
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
331
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
332
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
333
+
334
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
335
+ kernel_width = weights_W.size(1)
336
+ for i in range(out_W):
337
+ idx = int(indices_W[i][0])
338
+ out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i])
339
+ out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i])
340
+ out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i])
341
+
342
+ return out_2
343
+
344
+
345
+ def imresize_np(img, scale, antialiasing=True):
346
+ # Now the scale should be the same for H and W
347
+ # input: img: Numpy, HWC BGR [0,1]
348
+ # output: HWC BGR [0,1] w/o round
349
+ img = torch.from_numpy(img)
350
+
351
+ in_H, in_W, in_C = img.size()
352
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
353
+ kernel_width = 4
354
+ kernel = 'cubic'
355
+
356
+ # Return the desired dimension order for performing the resize. The
357
+ # strategy is to perform the resize first along the dimension with the
358
+ # smallest scale factor.
359
+ # Now we do not support this.
360
+
361
+ # get weights and indices
362
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
363
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
364
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
365
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
366
+ # process H dimension
367
+ # symmetric copying
368
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
369
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
370
+
371
+ sym_patch = img[:sym_len_Hs, :, :]
372
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
373
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
374
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
375
+
376
+ sym_patch = img[-sym_len_He:, :, :]
377
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
378
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
379
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
380
+
381
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
382
+ kernel_width = weights_H.size(1)
383
+ for i in range(out_H):
384
+ idx = int(indices_H[i][0])
385
+ out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i])
386
+ out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i])
387
+ out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i])
388
+
389
+ # process W dimension
390
+ # symmetric copying
391
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
392
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
393
+
394
+ sym_patch = out_1[:, :sym_len_Ws, :]
395
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
396
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
397
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
398
+
399
+ sym_patch = out_1[:, -sym_len_We:, :]
400
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
401
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
402
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
403
+
404
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
405
+ kernel_width = weights_W.size(1)
406
+ for i in range(out_W):
407
+ idx = int(indices_W[i][0])
408
+ out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i])
409
+ out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i])
410
+ out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i])
411
+
412
+ return out_2.numpy()
413
+
414
+
415
+ if __name__ == '__main__':
416
+ # test imresize function
417
+ # read images
418
+ img = cv2.imread('test.png')
419
+ img = img * 1.0 / 255
420
+ img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
421
+ # imresize
422
+ scale = 1 / 4
423
+ import time
424
+ total_time = 0
425
+ for i in range(10):
426
+ start_time = time.time()
427
+ rlt = imresize(img, scale, antialiasing=True)
428
+ use_time = time.time() - start_time
429
+ total_time += use_time
430
+ print('average time: {}'.format(total_time / 10))
431
+
432
+ import torchvision.utils
433
+ torchvision.utils.save_image(
434
+ (rlt * 255).round() / 255, 'rlt.png', nrow=1, padding=0, normalize=False)
esrgan_plus/codes/models/SFTGAN_ACD_model.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from collections import OrderedDict
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.optim import lr_scheduler
8
+
9
+ import models.networks as networks
10
+ from .base_model import BaseModel
11
+ from models.modules.loss import GANLoss, GradientPenaltyLoss
12
+
13
+ logger = logging.getLogger('base')
14
+
15
+
16
+ class SFTGAN_ACD_Model(BaseModel):
17
+ def __init__(self, opt):
18
+ super(SFTGAN_ACD_Model, self).__init__(opt)
19
+ train_opt = opt['train']
20
+
21
+ # define networks and load pretrained models
22
+ self.netG = networks.define_G(opt).to(self.device) # G
23
+ if self.is_train:
24
+ self.netD = networks.define_D(opt).to(self.device) # D
25
+ self.netG.train()
26
+ self.netD.train()
27
+ self.load() # load G and D if needed
28
+
29
+ # define losses, optimizer and scheduler
30
+ if self.is_train:
31
+ # G pixel loss
32
+ if train_opt['pixel_weight'] > 0:
33
+ l_pix_type = train_opt['pixel_criterion']
34
+ if l_pix_type == 'l1':
35
+ self.cri_pix = nn.L1Loss().to(self.device)
36
+ elif l_pix_type == 'l2':
37
+ self.cri_pix = nn.MSELoss().to(self.device)
38
+ else:
39
+ raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
40
+ self.l_pix_w = train_opt['pixel_weight']
41
+ else:
42
+ logging.info('Remove pixel loss.')
43
+ self.cri_pix = None
44
+
45
+ # G feature loss
46
+ if train_opt['feature_weight'] > 0:
47
+ l_fea_type = train_opt['feature_criterion']
48
+ if l_fea_type == 'l1':
49
+ self.cri_fea = nn.L1Loss().to(self.device)
50
+ elif l_fea_type == 'l2':
51
+ self.cri_fea = nn.MSELoss().to(self.device)
52
+ else:
53
+ raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
54
+ self.l_fea_w = train_opt['feature_weight']
55
+ else:
56
+ logging.info('Remove feature loss.')
57
+ self.cri_fea = None
58
+ if self.cri_fea: # load VGG perceptual loss
59
+ self.netF = networks.define_F(opt, use_bn=False).to(self.device)
60
+
61
+ # GD gan loss
62
+ self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
63
+ self.l_gan_w = train_opt['gan_weight']
64
+ # D_update_ratio and D_init_iters are for WGAN
65
+ self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
66
+ self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
67
+
68
+ if train_opt['gan_type'] == 'wgan-gp':
69
+ self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
70
+ # gradient penalty loss
71
+ self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device)
72
+ self.l_gp_w = train_opt['gp_weigth']
73
+
74
+ # D cls loss
75
+ self.cri_ce = nn.CrossEntropyLoss(ignore_index=0).to(self.device)
76
+ # ignore background, since bg images may conflict with other classes
77
+
78
+ # optimizers
79
+ # G
80
+ wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
81
+ optim_params_SFT = []
82
+ optim_params_other = []
83
+ for k, v in self.netG.named_parameters(): # can optimize for a part of the model
84
+ if 'SFT' in k or 'Cond' in k:
85
+ optim_params_SFT.append(v)
86
+ else:
87
+ optim_params_other.append(v)
88
+ self.optimizer_G_SFT = torch.optim.Adam(optim_params_SFT, lr=train_opt['lr_G']*5, \
89
+ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
90
+ self.optimizer_G_other = torch.optim.Adam(optim_params_other, lr=train_opt['lr_G'], \
91
+ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
92
+ self.optimizers.append(self.optimizer_G_SFT)
93
+ self.optimizers.append(self.optimizer_G_other)
94
+ # D
95
+ wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
96
+ self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
97
+ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
98
+ self.optimizers.append(self.optimizer_D)
99
+
100
+ # schedulers
101
+ if train_opt['lr_scheme'] == 'MultiStepLR':
102
+ for optimizer in self.optimizers:
103
+ self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
104
+ train_opt['lr_steps'], train_opt['lr_gamma']))
105
+ else:
106
+ raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
107
+
108
+ self.log_dict = OrderedDict()
109
+ # print network
110
+ self.print_network()
111
+
112
+ def feed_data(self, data, need_HR=True):
113
+ # LR
114
+ self.var_L = data['LR'].to(self.device)
115
+ # seg
116
+ self.var_seg = data['seg'].to(self.device)
117
+ # category
118
+ self.var_cat = data['category'].long().to(self.device)
119
+
120
+ if need_HR: # train or val
121
+ self.var_H = data['HR'].to(self.device)
122
+
123
+ def optimize_parameters(self, step):
124
+ # G
125
+ self.optimizer_G_SFT.zero_grad()
126
+ self.optimizer_G_other.zero_grad()
127
+ self.fake_H = self.netG((self.var_L, self.var_seg))
128
+
129
+ l_g_total = 0
130
+ if step % self.D_update_ratio == 0 and step > self.D_init_iters:
131
+ if self.cri_pix: # pixel loss
132
+ l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
133
+ l_g_total += l_g_pix
134
+ if self.cri_fea: # feature loss
135
+ real_fea = self.netF(self.var_H).detach()
136
+ fake_fea = self.netF(self.fake_H)
137
+ l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
138
+ l_g_total += l_g_fea
139
+ # G gan + cls loss
140
+ pred_g_fake, cls_g_fake = self.netD(self.fake_H)
141
+ l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
142
+ l_g_cls = self.l_gan_w * self.cri_ce(cls_g_fake, self.var_cat)
143
+ l_g_total += l_g_gan
144
+ l_g_total += l_g_cls
145
+
146
+ l_g_total.backward()
147
+ self.optimizer_G_SFT.step()
148
+ if step > 20000:
149
+ self.optimizer_G_other.step()
150
+
151
+ # D
152
+ self.optimizer_D.zero_grad()
153
+ l_d_total = 0
154
+ # real data
155
+ pred_d_real, cls_d_real = self.netD(self.var_H)
156
+ l_d_real = self.cri_gan(pred_d_real, True)
157
+ l_d_cls_real = self.cri_ce(cls_d_real, self.var_cat)
158
+ # fake data
159
+ pred_d_fake, cls_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
160
+ l_d_fake = self.cri_gan(pred_d_fake, False)
161
+ l_d_cls_fake = self.cri_ce(cls_d_fake, self.var_cat)
162
+
163
+ l_d_total = l_d_real + l_d_cls_real + l_d_fake + l_d_cls_fake
164
+
165
+ if self.opt['train']['gan_type'] == 'wgan-gp':
166
+ batch_size = self.var_H.size(0)
167
+ if self.random_pt.size(0) != batch_size:
168
+ self.random_pt.resize_(batch_size, 1, 1, 1)
169
+ self.random_pt.uniform_() # Draw random interpolation points
170
+ interp = self.random_pt * self.fake_H.detach() + (1 - self.random_pt) * self.var_H
171
+ interp.requires_grad = True
172
+ interp_crit, _ = self.netD(interp)
173
+ l_d_gp = self.l_gp_w * self.cri_gp(interp, interp_crit) # maybe wrong in cls?
174
+ l_d_total += l_d_gp
175
+
176
+ l_d_total.backward()
177
+ self.optimizer_D.step()
178
+
179
+ # set log
180
+ if step % self.D_update_ratio == 0 and step > self.D_init_iters:
181
+ # G
182
+ if self.cri_pix:
183
+ self.log_dict['l_g_pix'] = l_g_pix.item()
184
+ if self.cri_fea:
185
+ self.log_dict['l_g_fea'] = l_g_fea.item()
186
+ self.log_dict['l_g_gan'] = l_g_gan.item()
187
+ # D
188
+ self.log_dict['l_d_real'] = l_d_real.item()
189
+ self.log_dict['l_d_fake'] = l_d_fake.item()
190
+ self.log_dict['l_d_cls_real'] = l_d_cls_real.item()
191
+ self.log_dict['l_d_cls_fake'] = l_d_cls_fake.item()
192
+ if self.opt['train']['gan_type'] == 'wgan-gp':
193
+ self.log_dict['l_d_gp'] = l_d_gp.item()
194
+ # D outputs
195
+ self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
196
+ self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
197
+
198
+ def test(self):
199
+ self.netG.eval()
200
+ with torch.no_grad():
201
+ self.fake_H = self.netG((self.var_L, self.var_seg))
202
+ self.netG.train()
203
+
204
+ def get_current_log(self):
205
+ return self.log_dict
206
+
207
+ def get_current_visuals(self, need_HR=True):
208
+ out_dict = OrderedDict()
209
+ out_dict['LR'] = self.var_L.detach()[0].float().cpu()
210
+ out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
211
+ if need_HR:
212
+ out_dict['HR'] = self.var_H.detach()[0].float().cpu()
213
+ return out_dict
214
+
215
+ def print_network(self):
216
+ # G
217
+ s, n = self.get_network_description(self.netG)
218
+ if isinstance(self.netG, nn.DataParallel):
219
+ net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
220
+ self.netG.module.__class__.__name__)
221
+ else:
222
+ net_struc_str = '{}'.format(self.netG.__class__.__name__)
223
+
224
+ logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
225
+ logger.info(s)
226
+ if self.is_train:
227
+ # D
228
+ s, n = self.get_network_description(self.netD)
229
+ if isinstance(self.netD, nn.DataParallel):
230
+ net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
231
+ self.netD.module.__class__.__name__)
232
+ else:
233
+ net_struc_str = '{}'.format(self.netD.__class__.__name__)
234
+
235
+ logger.info('Network D structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
236
+ logger.info(s)
237
+
238
+ if self.cri_fea: # F, Perceptual Network
239
+ s, n = self.get_network_description(self.netF)
240
+ if isinstance(self.netF, nn.DataParallel):
241
+ net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
242
+ self.netF.module.__class__.__name__)
243
+ else:
244
+ net_struc_str = '{}'.format(self.netF.__class__.__name__)
245
+
246
+ logger.info('Network F structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
247
+ logger.info(s)
248
+
249
+ def load(self):
250
+ load_path_G = self.opt['path']['pretrain_model_G']
251
+ if load_path_G is not None:
252
+ logger.info('Loading pretrained model for G [{:s}] ...'.format(load_path_G))
253
+ self.load_network(load_path_G, self.netG)
254
+ load_path_D = self.opt['path']['pretrain_model_D']
255
+ if self.opt['is_train'] and load_path_D is not None:
256
+ logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D))
257
+ self.load_network(load_path_D, self.netD)
258
+
259
+ def save(self, iter_step):
260
+ self.save_network(self.netG, 'G', iter_step)
261
+ self.save_network(self.netD, 'D', iter_step)
esrgan_plus/codes/models/SRGAN_model.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from collections import OrderedDict
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.optim import lr_scheduler
8
+
9
+ import models.networks as networks
10
+ from .base_model import BaseModel
11
+ from models.modules.loss import GANLoss, GradientPenaltyLoss
12
+
13
+ logger = logging.getLogger('base')
14
+
15
+
16
+ class SRGANModel(BaseModel):
17
+ def __init__(self, opt):
18
+ super(SRGANModel, self).__init__(opt)
19
+ train_opt = opt['train']
20
+
21
+ # define networks and load pretrained models
22
+ self.netG = networks.define_G(opt).to(self.device) # G
23
+ if self.is_train:
24
+ self.netD = networks.define_D(opt).to(self.device) # D
25
+ self.netG.train()
26
+ self.netD.train()
27
+ self.load() # load G and D if needed
28
+
29
+ # define losses, optimizer and scheduler
30
+ if self.is_train:
31
+ # G pixel loss
32
+ if train_opt['pixel_weight'] > 0:
33
+ l_pix_type = train_opt['pixel_criterion']
34
+ if l_pix_type == 'l1':
35
+ self.cri_pix = nn.L1Loss().to(self.device)
36
+ elif l_pix_type == 'l2':
37
+ self.cri_pix = nn.MSELoss().to(self.device)
38
+ else:
39
+ raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
40
+ self.l_pix_w = train_opt['pixel_weight']
41
+ else:
42
+ logger.info('Remove pixel loss.')
43
+ self.cri_pix = None
44
+
45
+ # G feature loss
46
+ if train_opt['feature_weight'] > 0:
47
+ l_fea_type = train_opt['feature_criterion']
48
+ if l_fea_type == 'l1':
49
+ self.cri_fea = nn.L1Loss().to(self.device)
50
+ elif l_fea_type == 'l2':
51
+ self.cri_fea = nn.MSELoss().to(self.device)
52
+ else:
53
+ raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
54
+ self.l_fea_w = train_opt['feature_weight']
55
+ else:
56
+ logger.info('Remove feature loss.')
57
+ self.cri_fea = None
58
+ if self.cri_fea: # load VGG perceptual loss
59
+ self.netF = networks.define_F(opt, use_bn=False).to(self.device)
60
+
61
+ # GD gan loss
62
+ self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
63
+ self.l_gan_w = train_opt['gan_weight']
64
+ # D_update_ratio and D_init_iters are for WGAN
65
+ self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
66
+ self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
67
+
68
+ if train_opt['gan_type'] == 'wgan-gp':
69
+ self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
70
+ # gradient penalty loss
71
+ self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device)
72
+ self.l_gp_w = train_opt['gp_weigth']
73
+
74
+ # optimizers
75
+ # G
76
+ wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
77
+ optim_params = []
78
+ for k, v in self.netG.named_parameters(): # can optimize for a part of the model
79
+ if v.requires_grad:
80
+ optim_params.append(v)
81
+ else:
82
+ logger.warning('Params [{:s}] will not optimize.'.format(k))
83
+ self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
84
+ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
85
+ self.optimizers.append(self.optimizer_G)
86
+ # D
87
+ wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
88
+ self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
89
+ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
90
+ self.optimizers.append(self.optimizer_D)
91
+
92
+ # schedulers
93
+ if train_opt['lr_scheme'] == 'MultiStepLR':
94
+ for optimizer in self.optimizers:
95
+ self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
96
+ train_opt['lr_steps'], train_opt['lr_gamma']))
97
+ else:
98
+ raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
99
+
100
+ self.log_dict = OrderedDict()
101
+ # print network
102
+ self.print_network()
103
+
104
+ def feed_data(self, data, need_HR=True):
105
+ # LR
106
+ self.var_L = data['LR'].to(self.device)
107
+ if need_HR: # train or val
108
+ self.var_H = data['HR'].to(self.device)
109
+
110
+ input_ref = data['ref'] if 'ref' in data else data['HR']
111
+ self.var_ref = input_ref.to(self.device)
112
+
113
+ def optimize_parameters(self, step):
114
+ # G
115
+ self.optimizer_G.zero_grad()
116
+ self.fake_H = self.netG(self.var_L)
117
+
118
+ l_g_total = 0
119
+ if step % self.D_update_ratio == 0 and step > self.D_init_iters:
120
+ if self.cri_pix: # pixel loss
121
+ l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
122
+ l_g_total += l_g_pix
123
+ if self.cri_fea: # feature loss
124
+ real_fea = self.netF(self.var_H).detach()
125
+ fake_fea = self.netF(self.fake_H)
126
+ l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
127
+ l_g_total += l_g_fea
128
+ # G gan + cls loss
129
+ pred_g_fake = self.netD(self.fake_H)
130
+ l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
131
+ l_g_total += l_g_gan
132
+
133
+ l_g_total.backward()
134
+ self.optimizer_G.step()
135
+
136
+ # D
137
+ self.optimizer_D.zero_grad()
138
+ l_d_total = 0
139
+ # real data
140
+ pred_d_real = self.netD(self.var_ref)
141
+ l_d_real = self.cri_gan(pred_d_real, True)
142
+ # fake data
143
+ pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
144
+ l_d_fake = self.cri_gan(pred_d_fake, False)
145
+
146
+ l_d_total = l_d_real + l_d_fake
147
+
148
+ if self.opt['train']['gan_type'] == 'wgan-gp':
149
+ batch_size = self.var_ref.size(0)
150
+ if self.random_pt.size(0) != batch_size:
151
+ self.random_pt.resize_(batch_size, 1, 1, 1)
152
+ self.random_pt.uniform_() # Draw random interpolation points
153
+ interp = self.random_pt * self.fake_H.detach() + (1 - self.random_pt) * self.var_ref
154
+ interp.requires_grad = True
155
+ interp_crit = self.netD(interp)
156
+ l_d_gp = self.l_gp_w * self.cri_gp(interp, interp_crit)
157
+ l_d_total += l_d_gp
158
+
159
+ l_d_total.backward()
160
+ self.optimizer_D.step()
161
+
162
+ # set log
163
+ if step % self.D_update_ratio == 0 and step > self.D_init_iters:
164
+ # G
165
+ if self.cri_pix:
166
+ self.log_dict['l_g_pix'] = l_g_pix.item()
167
+ if self.cri_fea:
168
+ self.log_dict['l_g_fea'] = l_g_fea.item()
169
+ self.log_dict['l_g_gan'] = l_g_gan.item()
170
+ # D
171
+ self.log_dict['l_d_real'] = l_d_real.item()
172
+ self.log_dict['l_d_fake'] = l_d_fake.item()
173
+
174
+ if self.opt['train']['gan_type'] == 'wgan-gp':
175
+ self.log_dict['l_d_gp'] = l_d_gp.item()
176
+ # D outputs
177
+ self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
178
+ self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
179
+
180
+ def test(self):
181
+ self.netG.eval()
182
+ with torch.no_grad():
183
+ self.fake_H = self.netG(self.var_L)
184
+ self.netG.train()
185
+
186
+ def get_current_log(self):
187
+ return self.log_dict
188
+
189
+ def get_current_visuals(self, need_HR=True):
190
+ out_dict = OrderedDict()
191
+ out_dict['LR'] = self.var_L.detach()[0].float().cpu()
192
+ out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
193
+ if need_HR:
194
+ out_dict['HR'] = self.var_H.detach()[0].float().cpu()
195
+ return out_dict
196
+
197
+ def print_network(self):
198
+ # Generator
199
+ s, n = self.get_network_description(self.netG)
200
+ if isinstance(self.netG, nn.DataParallel):
201
+ net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
202
+ self.netG.module.__class__.__name__)
203
+ else:
204
+ net_struc_str = '{}'.format(self.netG.__class__.__name__)
205
+ logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
206
+ logger.info(s)
207
+ if self.is_train:
208
+ # Discriminator
209
+ s, n = self.get_network_description(self.netD)
210
+ if isinstance(self.netD, nn.DataParallel):
211
+ net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
212
+ self.netD.module.__class__.__name__)
213
+ else:
214
+ net_struc_str = '{}'.format(self.netD.__class__.__name__)
215
+ logger.info('Network D structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
216
+ logger.info(s)
217
+
218
+ if self.cri_fea: # F, Perceptual Network
219
+ s, n = self.get_network_description(self.netF)
220
+ if isinstance(self.netF, nn.DataParallel):
221
+ net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
222
+ self.netF.module.__class__.__name__)
223
+ else:
224
+ net_struc_str = '{}'.format(self.netF.__class__.__name__)
225
+ logger.info('Network F structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
226
+ logger.info(s)
227
+
228
+ def load(self):
229
+ load_path_G = self.opt['path']['pretrain_model_G']
230
+ if load_path_G is not None:
231
+ logger.info('Loading pretrained model for G [{:s}] ...'.format(load_path_G))
232
+ self.load_network(load_path_G, self.netG)
233
+ load_path_D = self.opt['path']['pretrain_model_D']
234
+ if self.opt['is_train'] and load_path_D is not None:
235
+ logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D))
236
+ self.load_network(load_path_D, self.netD)
237
+
238
+ def save(self, iter_step):
239
+ self.save_network(self.netG, 'G', iter_step)
240
+ self.save_network(self.netD, 'D', iter_step)
esrgan_plus/codes/models/SRRaGAN_model.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from collections import OrderedDict
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.optim import lr_scheduler
8
+
9
+ import models.networks as networks
10
+ from .base_model import BaseModel
11
+ from models.modules.loss import GANLoss, GradientPenaltyLoss
12
+ logger = logging.getLogger('base')
13
+
14
+
15
+ class SRRaGANModel(BaseModel):
16
+ def __init__(self, opt):
17
+ super(SRRaGANModel, self).__init__(opt)
18
+ train_opt = opt['train']
19
+
20
+ # define networks and load pretrained models
21
+ self.netG = networks.define_G(opt).to(self.device) # G
22
+ if self.is_train:
23
+ self.netD = networks.define_D(opt).to(self.device) # D
24
+ self.netG.train()
25
+ self.netD.train()
26
+ self.load() # load G and D if needed
27
+
28
+ # define losses, optimizer and scheduler
29
+ if self.is_train:
30
+ # G pixel loss
31
+ if train_opt['pixel_weight'] > 0:
32
+ l_pix_type = train_opt['pixel_criterion']
33
+ if l_pix_type == 'l1':
34
+ self.cri_pix = nn.L1Loss().to(self.device)
35
+ elif l_pix_type == 'l2':
36
+ self.cri_pix = nn.MSELoss().to(self.device)
37
+ else:
38
+ raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
39
+ self.l_pix_w = train_opt['pixel_weight']
40
+ else:
41
+ logger.info('Remove pixel loss.')
42
+ self.cri_pix = None
43
+
44
+ # G feature loss
45
+ if train_opt['feature_weight'] > 0:
46
+ l_fea_type = train_opt['feature_criterion']
47
+ if l_fea_type == 'l1':
48
+ self.cri_fea = nn.L1Loss().to(self.device)
49
+ elif l_fea_type == 'l2':
50
+ self.cri_fea = nn.MSELoss().to(self.device)
51
+ else:
52
+ raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
53
+ self.l_fea_w = train_opt['feature_weight']
54
+ else:
55
+ logger.info('Remove feature loss.')
56
+ self.cri_fea = None
57
+ if self.cri_fea: # load VGG perceptual loss
58
+ self.netF = networks.define_F(opt, use_bn=False).to(self.device)
59
+
60
+ # GD gan loss
61
+ self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
62
+ self.l_gan_w = train_opt['gan_weight']
63
+ # D_update_ratio and D_init_iters are for WGAN
64
+ self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
65
+ self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
66
+
67
+ if train_opt['gan_type'] == 'wgan-gp':
68
+ self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
69
+ # gradient penalty loss
70
+ self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device)
71
+ self.l_gp_w = train_opt['gp_weigth']
72
+
73
+ # optimizers
74
+ # G
75
+ wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
76
+ optim_params = []
77
+ for k, v in self.netG.named_parameters(): # can optimize for a part of the model
78
+ if v.requires_grad:
79
+ optim_params.append(v)
80
+ else:
81
+ logger.warning('Params [{:s}] will not optimize.'.format(k))
82
+ self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
83
+ weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
84
+ self.optimizers.append(self.optimizer_G)
85
+ # D
86
+ wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
87
+ self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
88
+ weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
89
+ self.optimizers.append(self.optimizer_D)
90
+
91
+ # schedulers
92
+ if train_opt['lr_scheme'] == 'MultiStepLR':
93
+ for optimizer in self.optimizers:
94
+ self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
95
+ train_opt['lr_steps'], train_opt['lr_gamma']))
96
+ else:
97
+ raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
98
+
99
+ self.log_dict = OrderedDict()
100
+ # print network
101
+ self.print_network()
102
+
103
+ def feed_data(self, data, need_HR=True):
104
+ # LR
105
+ self.var_L = data['LR'].to(self.device)
106
+
107
+ if need_HR: # train or val
108
+ self.var_H = data['HR'].to(self.device)
109
+
110
+ input_ref = data['ref'] if 'ref' in data else data['HR']
111
+ self.var_ref = input_ref.to(self.device)
112
+
113
+ def optimize_parameters(self, step):
114
+ # G
115
+ for p in self.netD.parameters():
116
+ p.requires_grad = False
117
+
118
+ self.optimizer_G.zero_grad()
119
+
120
+ self.fake_H = self.netG(self.var_L)
121
+
122
+ l_g_total = 0
123
+ if step % self.D_update_ratio == 0 and step > self.D_init_iters:
124
+ if self.cri_pix: # pixel loss
125
+ l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
126
+ l_g_total += l_g_pix
127
+ if self.cri_fea: # feature loss
128
+ real_fea = self.netF(self.var_H).detach()
129
+ fake_fea = self.netF(self.fake_H)
130
+ l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
131
+ l_g_total += l_g_fea
132
+ # G gan + cls loss
133
+ pred_g_fake = self.netD(self.fake_H)
134
+ pred_d_real = self.netD(self.var_ref).detach()
135
+
136
+ l_g_gan = self.l_gan_w * (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
137
+ self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
138
+ l_g_total += l_g_gan
139
+
140
+ l_g_total.backward()
141
+ self.optimizer_G.step()
142
+
143
+ # D
144
+ for p in self.netD.parameters():
145
+ p.requires_grad = True
146
+
147
+ self.optimizer_D.zero_grad()
148
+ l_d_total = 0
149
+ pred_d_real = self.netD(self.var_ref)
150
+ pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
151
+ l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
152
+ l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
153
+
154
+ l_d_total = (l_d_real + l_d_fake) / 2
155
+
156
+ if self.opt['train']['gan_type'] == 'wgan-gp':
157
+ batch_size = self.var_ref.size(0)
158
+ if self.random_pt.size(0) != batch_size:
159
+ self.random_pt.resize_(batch_size, 1, 1, 1)
160
+ self.random_pt.uniform_() # Draw random interpolation points
161
+ interp = self.random_pt * self.fake_H.detach() + (1 - self.random_pt) * self.var_ref
162
+ interp.requires_grad = True
163
+ interp_crit = self.netD(interp)
164
+ l_d_gp = self.l_gp_w * self.cri_gp(interp, interp_crit)
165
+ l_d_total += l_d_gp
166
+
167
+ l_d_total.backward()
168
+ self.optimizer_D.step()
169
+
170
+ # set log
171
+ if step % self.D_update_ratio == 0 and step > self.D_init_iters:
172
+ # G
173
+ if self.cri_pix:
174
+ self.log_dict['l_g_pix'] = l_g_pix.item()
175
+ if self.cri_fea:
176
+ self.log_dict['l_g_fea'] = l_g_fea.item()
177
+ self.log_dict['l_g_gan'] = l_g_gan.item()
178
+ # D
179
+ self.log_dict['l_d_real'] = l_d_real.item()
180
+ self.log_dict['l_d_fake'] = l_d_fake.item()
181
+
182
+ if self.opt['train']['gan_type'] == 'wgan-gp':
183
+ self.log_dict['l_d_gp'] = l_d_gp.item()
184
+ # D outputs
185
+ self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
186
+ self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
187
+
188
+ def test(self):
189
+ self.netG.eval()
190
+ with torch.no_grad():
191
+ self.fake_H = self.netG(self.var_L)
192
+ self.netG.train()
193
+
194
+ def get_current_log(self):
195
+ return self.log_dict
196
+
197
+ def get_current_visuals(self, need_HR=True):
198
+ out_dict = OrderedDict()
199
+ out_dict['LR'] = self.var_L.detach()[0].float().cpu()
200
+ out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
201
+ if need_HR:
202
+ out_dict['HR'] = self.var_H.detach()[0].float().cpu()
203
+ return out_dict
204
+
205
+ def print_network(self):
206
+ # Generator
207
+ s, n = self.get_network_description(self.netG)
208
+ if isinstance(self.netG, nn.DataParallel):
209
+ net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
210
+ self.netG.module.__class__.__name__)
211
+ else:
212
+ net_struc_str = '{}'.format(self.netG.__class__.__name__)
213
+
214
+ logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
215
+ logger.info(s)
216
+ if self.is_train:
217
+ # Discriminator
218
+ s, n = self.get_network_description(self.netD)
219
+ if isinstance(self.netD, nn.DataParallel):
220
+ net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
221
+ self.netD.module.__class__.__name__)
222
+ else:
223
+ net_struc_str = '{}'.format(self.netD.__class__.__name__)
224
+
225
+ logger.info('Network D structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
226
+ logger.info(s)
227
+
228
+ if self.cri_fea: # F, Perceptual Network
229
+ s, n = self.get_network_description(self.netF)
230
+ if isinstance(self.netF, nn.DataParallel):
231
+ net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
232
+ self.netF.module.__class__.__name__)
233
+ else:
234
+ net_struc_str = '{}'.format(self.netF.__class__.__name__)
235
+
236
+ logger.info('Network F structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
237
+ logger.info(s)
238
+
239
+ def load(self):
240
+ load_path_G = self.opt['path']['pretrain_model_G']
241
+ if load_path_G is not None:
242
+ logger.info('Loading pretrained model for G [{:s}] ...'.format(load_path_G))
243
+ self.load_network(load_path_G, self.netG)
244
+ load_path_D = self.opt['path']['pretrain_model_D']
245
+ if self.opt['is_train'] and load_path_D is not None:
246
+ logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D))
247
+ self.load_network(load_path_D, self.netD)
248
+
249
+ def save(self, iter_step):
250
+ self.save_network(self.netG, 'G', iter_step)
251
+ self.save_network(self.netD, 'D', iter_step)
esrgan_plus/codes/models/SR_model.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from collections import OrderedDict
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.optim import lr_scheduler
8
+
9
+ import models.networks as networks
10
+ from .base_model import BaseModel
11
+
12
+ logger = logging.getLogger('base')
13
+
14
+
15
+ class SRModel(BaseModel):
16
+ def __init__(self, opt):
17
+ super(SRModel, self).__init__(opt)
18
+ train_opt = opt['train']
19
+
20
+ # define network and load pretrained models
21
+ self.netG = networks.define_G(opt).to(self.device)
22
+ self.load()
23
+
24
+ if self.is_train:
25
+ self.netG.train()
26
+
27
+ # loss
28
+ loss_type = train_opt['pixel_criterion']
29
+ if loss_type == 'l1':
30
+ self.cri_pix = nn.L1Loss().to(self.device)
31
+ elif loss_type == 'l2':
32
+ self.cri_pix = nn.MSELoss().to(self.device)
33
+ else:
34
+ raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
35
+ self.l_pix_w = train_opt['pixel_weight']
36
+
37
+ # optimizers
38
+ wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
39
+ optim_params = []
40
+ for k, v in self.netG.named_parameters(): # can optimize for a part of the model
41
+ if v.requires_grad:
42
+ optim_params.append(v)
43
+ else:
44
+ logger.warning('Params [{:s}] will not optimize.'.format(k))
45
+ self.optimizer_G = torch.optim.Adam(
46
+ optim_params, lr=train_opt['lr_G'], weight_decay=wd_G)
47
+ self.optimizers.append(self.optimizer_G)
48
+
49
+ # schedulers
50
+ if train_opt['lr_scheme'] == 'MultiStepLR':
51
+ for optimizer in self.optimizers:
52
+ self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
53
+ train_opt['lr_steps'], train_opt['lr_gamma']))
54
+ else:
55
+ raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
56
+
57
+ self.log_dict = OrderedDict()
58
+ # print network
59
+ self.print_network()
60
+
61
+ def feed_data(self, data, need_HR=True):
62
+ self.var_L = data['LR'].to(self.device) # LR
63
+ if need_HR:
64
+ self.real_H = data['HR'].to(self.device) # HR
65
+
66
+ def optimize_parameters(self, step):
67
+ self.optimizer_G.zero_grad()
68
+ self.fake_H = self.netG(self.var_L)
69
+ l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
70
+ l_pix.backward()
71
+ self.optimizer_G.step()
72
+
73
+ # set log
74
+ self.log_dict['l_pix'] = l_pix.item()
75
+
76
+ def test(self):
77
+ self.netG.eval()
78
+ with torch.no_grad():
79
+ self.fake_H = self.netG(self.var_L)
80
+ self.netG.train()
81
+
82
+ def test_x8(self):
83
+ # from https://github.com/thstkdgus35/EDSR-PyTorch
84
+ self.netG.eval()
85
+ for k, v in self.netG.named_parameters():
86
+ v.requires_grad = False
87
+
88
+ def _transform(v, op):
89
+ # if self.precision != 'single': v = v.float()
90
+ v2np = v.data.cpu().numpy()
91
+ if op == 'v':
92
+ tfnp = v2np[:, :, :, ::-1].copy()
93
+ elif op == 'h':
94
+ tfnp = v2np[:, :, ::-1, :].copy()
95
+ elif op == 't':
96
+ tfnp = v2np.transpose((0, 1, 3, 2)).copy()
97
+
98
+ ret = torch.Tensor(tfnp).to(self.device)
99
+ # if self.precision == 'half': ret = ret.half()
100
+
101
+ return ret
102
+
103
+ lr_list = [self.var_L]
104
+ for tf in 'v', 'h', 't':
105
+ lr_list.extend([_transform(t, tf) for t in lr_list])
106
+ sr_list = [self.netG(aug) for aug in lr_list]
107
+ for i in range(len(sr_list)):
108
+ if i > 3:
109
+ sr_list[i] = _transform(sr_list[i], 't')
110
+ if i % 4 > 1:
111
+ sr_list[i] = _transform(sr_list[i], 'h')
112
+ if (i % 4) % 2 == 1:
113
+ sr_list[i] = _transform(sr_list[i], 'v')
114
+
115
+ output_cat = torch.cat(sr_list, dim=0)
116
+ self.fake_H = output_cat.mean(dim=0, keepdim=True)
117
+
118
+ for k, v in self.netG.named_parameters():
119
+ v.requires_grad = True
120
+ self.netG.train()
121
+
122
+ def get_current_log(self):
123
+ return self.log_dict
124
+
125
+ def get_current_visuals(self, need_HR=True):
126
+ out_dict = OrderedDict()
127
+ out_dict['LR'] = self.var_L.detach()[0].float().cpu()
128
+ out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
129
+ if need_HR:
130
+ out_dict['HR'] = self.real_H.detach()[0].float().cpu()
131
+ return out_dict
132
+
133
+ def print_network(self):
134
+ s, n = self.get_network_description(self.netG)
135
+ if isinstance(self.netG, nn.DataParallel):
136
+ net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
137
+ self.netG.module.__class__.__name__)
138
+ else:
139
+ net_struc_str = '{}'.format(self.netG.__class__.__name__)
140
+
141
+ logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
142
+ logger.info(s)
143
+
144
+ def load(self):
145
+ load_path_G = self.opt['path']['pretrain_model_G']
146
+ if load_path_G is not None:
147
+ logger.info('Loading pretrained model for G [{:s}] ...'.format(load_path_G))
148
+ self.load_network(load_path_G, self.netG)
149
+
150
+ def save(self, iter_step):
151
+ self.save_network(self.netG, 'G', iter_step)
esrgan_plus/codes/models/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ logger = logging.getLogger('base')
3
+
4
+
5
+ def create_model(opt):
6
+ model = opt['model']
7
+
8
+ if model == 'sr':
9
+ from .SR_model import SRModel as M
10
+ elif model == 'srgan':
11
+ from .SRGAN_model import SRGANModel as M
12
+ elif model == 'srragan':
13
+ from .SRRaGAN_model import SRRaGANModel as M
14
+ elif model == 'sftgan':
15
+ from .SFTGAN_ACD_model import SFTGAN_ACD_Model as M
16
+ else:
17
+ raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
18
+ m = M(opt)
19
+ logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
20
+ return m
esrgan_plus/codes/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (807 Bytes). View file
 
esrgan_plus/codes/models/base_model.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class BaseModel():
7
+ def __init__(self, opt):
8
+ self.opt = opt
9
+ self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
10
+ self.is_train = opt['is_train']
11
+ self.schedulers = []
12
+ self.optimizers = []
13
+
14
+ def feed_data(self, data):
15
+ pass
16
+
17
+ def optimize_parameters(self):
18
+ pass
19
+
20
+ def get_current_visuals(self):
21
+ pass
22
+
23
+ def get_current_losses(self):
24
+ pass
25
+
26
+ def print_network(self):
27
+ pass
28
+
29
+ def save(self, label):
30
+ pass
31
+
32
+ def load(self):
33
+ pass
34
+
35
+ def update_learning_rate(self):
36
+ for scheduler in self.schedulers:
37
+ scheduler.step()
38
+
39
+ def get_current_learning_rate(self):
40
+ return self.schedulers[0].get_lr()[0]
41
+
42
+ def get_network_description(self, network):
43
+ '''Get the string and total parameters of the network'''
44
+ if isinstance(network, nn.DataParallel):
45
+ network = network.module
46
+ s = str(network)
47
+ n = sum(map(lambda x: x.numel(), network.parameters()))
48
+ return s, n
49
+
50
+ def save_network(self, network, network_label, iter_step):
51
+ save_filename = '{}_{}.pth'.format(iter_step, network_label)
52
+ save_path = os.path.join(self.opt['path']['models'], save_filename)
53
+ if isinstance(network, nn.DataParallel):
54
+ network = network.module
55
+ state_dict = network.state_dict()
56
+ for key, param in state_dict.items():
57
+ state_dict[key] = param.cpu()
58
+ torch.save(state_dict, save_path)
59
+
60
+ def load_network(self, load_path, network, strict=True):
61
+ if isinstance(network, nn.DataParallel):
62
+ network = network.module
63
+ network.load_state_dict(torch.load(load_path), strict=strict)
64
+
65
+ def save_training_state(self, epoch, iter_step):
66
+ '''Saves training state during training, which will be used for resuming'''
67
+ state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []}
68
+ for s in self.schedulers:
69
+ state['schedulers'].append(s.state_dict())
70
+ for o in self.optimizers:
71
+ state['optimizers'].append(o.state_dict())
72
+ save_filename = '{}.state'.format(iter_step)
73
+ save_path = os.path.join(self.opt['path']['training_state'], save_filename)
74
+ torch.save(state, save_path)
75
+
76
+ def resume_training(self, resume_state):
77
+ '''Resume the optimizers and schedulers for training'''
78
+ resume_optimizers = resume_state['optimizers']
79
+ resume_schedulers = resume_state['schedulers']
80
+ assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
81
+ assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
82
+ for i, o in enumerate(resume_optimizers):
83
+ self.optimizers[i].load_state_dict(o)
84
+ for i, s in enumerate(resume_schedulers):
85
+ self.schedulers[i].load_state_dict(s)
esrgan_plus/codes/models/modules/__pycache__/architecture.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
esrgan_plus/codes/models/modules/__pycache__/block.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
esrgan_plus/codes/models/modules/__pycache__/spectral_norm.cpython-310.pyc ADDED
Binary file (5.46 kB). View file
 
esrgan_plus/codes/models/modules/architecture.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision
5
+ from . import block as B
6
+ from . import spectral_norm as SN
7
+
8
+ ####################
9
+ # Generator
10
+ ####################
11
+
12
+
13
+ class SRResNet(nn.Module):
14
+ def __init__(self, in_nc, out_nc, nf, nb, upscale=4, norm_type='batch', act_type='relu', \
15
+ mode='NAC', res_scale=1, upsample_mode='upconv'):
16
+ super(SRResNet, self).__init__()
17
+ n_upscale = int(math.log(upscale, 2))
18
+ if upscale == 3:
19
+ n_upscale = 1
20
+
21
+ fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
22
+ resnet_blocks = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type,\
23
+ mode=mode, res_scale=res_scale) for _ in range(nb)]
24
+ LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
25
+
26
+ if upsample_mode == 'upconv':
27
+ upsample_block = B.upconv_blcok
28
+ elif upsample_mode == 'pixelshuffle':
29
+ upsample_block = B.pixelshuffle_block
30
+ else:
31
+ raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
32
+ if upscale == 3:
33
+ upsampler = upsample_block(nf, nf, 3, act_type=act_type)
34
+ else:
35
+ upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
36
+ HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
37
+ HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
38
+
39
+ self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*resnet_blocks, LR_conv)),\
40
+ *upsampler, HR_conv0, HR_conv1)
41
+
42
+ def forward(self, x):
43
+ x = self.model(x)
44
+ return x
45
+
46
+
47
+ class RRDBNet(nn.Module):
48
+ def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
49
+ act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
50
+ super(RRDBNet, self).__init__()
51
+ n_upscale = int(math.log(upscale, 2))
52
+ if upscale == 3:
53
+ n_upscale = 1
54
+
55
+ fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
56
+ rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
57
+ norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
58
+ LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
59
+
60
+ if upsample_mode == 'upconv':
61
+ upsample_block = B.upconv_blcok
62
+ elif upsample_mode == 'pixelshuffle':
63
+ upsample_block = B.pixelshuffle_block
64
+ else:
65
+ raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
66
+ if upscale == 3:
67
+ upsampler = upsample_block(nf, nf, 3, act_type=act_type)
68
+ else:
69
+ upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
70
+ HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
71
+ HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
72
+
73
+ self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
74
+ *upsampler, HR_conv0, HR_conv1)
75
+
76
+ def forward(self, x):
77
+ x = self.model(x)
78
+ return x
79
+
80
+
81
+ ####################
82
+ # Discriminator
83
+ ####################
84
+
85
+
86
+ # VGG style Discriminator with input size 128*128
87
+ class Discriminator_VGG_128(nn.Module):
88
+ def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
89
+ super(Discriminator_VGG_128, self).__init__()
90
+ # features
91
+ # hxw, c
92
+ # 128, 64
93
+ conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
94
+ mode=mode)
95
+ conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
96
+ act_type=act_type, mode=mode)
97
+ # 64, 64
98
+ conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
99
+ act_type=act_type, mode=mode)
100
+ conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
101
+ act_type=act_type, mode=mode)
102
+ # 32, 128
103
+ conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
104
+ act_type=act_type, mode=mode)
105
+ conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
106
+ act_type=act_type, mode=mode)
107
+ # 16, 256
108
+ conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
109
+ act_type=act_type, mode=mode)
110
+ conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
111
+ act_type=act_type, mode=mode)
112
+ # 8, 512
113
+ conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
114
+ act_type=act_type, mode=mode)
115
+ conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
116
+ act_type=act_type, mode=mode)
117
+ # 4, 512
118
+ self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
119
+ conv9)
120
+
121
+ # classifier
122
+ self.classifier = nn.Sequential(
123
+ nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
124
+
125
+ def forward(self, x):
126
+ x = self.features(x)
127
+ x = x.view(x.size(0), -1)
128
+ x = self.classifier(x)
129
+ return x
130
+
131
+
132
+ # VGG style Discriminator with input size 128*128, Spectral Normalization
133
+ class Discriminator_VGG_128_SN(nn.Module):
134
+ def __init__(self):
135
+ super(Discriminator_VGG_128_SN, self).__init__()
136
+ # features
137
+ # hxw, c
138
+ # 128, 64
139
+ self.lrelu = nn.LeakyReLU(0.2, True)
140
+
141
+ self.conv0 = SN.spectral_norm(nn.Conv2d(3, 64, 3, 1, 1))
142
+ self.conv1 = SN.spectral_norm(nn.Conv2d(64, 64, 4, 2, 1))
143
+ # 64, 64
144
+ self.conv2 = SN.spectral_norm(nn.Conv2d(64, 128, 3, 1, 1))
145
+ self.conv3 = SN.spectral_norm(nn.Conv2d(128, 128, 4, 2, 1))
146
+ # 32, 128
147
+ self.conv4 = SN.spectral_norm(nn.Conv2d(128, 256, 3, 1, 1))
148
+ self.conv5 = SN.spectral_norm(nn.Conv2d(256, 256, 4, 2, 1))
149
+ # 16, 256
150
+ self.conv6 = SN.spectral_norm(nn.Conv2d(256, 512, 3, 1, 1))
151
+ self.conv7 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
152
+ # 8, 512
153
+ self.conv8 = SN.spectral_norm(nn.Conv2d(512, 512, 3, 1, 1))
154
+ self.conv9 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
155
+ # 4, 512
156
+
157
+ # classifier
158
+ self.linear0 = SN.spectral_norm(nn.Linear(512 * 4 * 4, 100))
159
+ self.linear1 = SN.spectral_norm(nn.Linear(100, 1))
160
+
161
+ def forward(self, x):
162
+ x = self.lrelu(self.conv0(x))
163
+ x = self.lrelu(self.conv1(x))
164
+ x = self.lrelu(self.conv2(x))
165
+ x = self.lrelu(self.conv3(x))
166
+ x = self.lrelu(self.conv4(x))
167
+ x = self.lrelu(self.conv5(x))
168
+ x = self.lrelu(self.conv6(x))
169
+ x = self.lrelu(self.conv7(x))
170
+ x = self.lrelu(self.conv8(x))
171
+ x = self.lrelu(self.conv9(x))
172
+ x = x.view(x.size(0), -1)
173
+ x = self.lrelu(self.linear0(x))
174
+ x = self.linear1(x)
175
+ return x
176
+
177
+
178
+ class Discriminator_VGG_96(nn.Module):
179
+ def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
180
+ super(Discriminator_VGG_96, self).__init__()
181
+ # features
182
+ # hxw, c
183
+ # 96, 64
184
+ conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
185
+ mode=mode)
186
+ conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
187
+ act_type=act_type, mode=mode)
188
+ # 48, 64
189
+ conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
190
+ act_type=act_type, mode=mode)
191
+ conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
192
+ act_type=act_type, mode=mode)
193
+ # 24, 128
194
+ conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
195
+ act_type=act_type, mode=mode)
196
+ conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
197
+ act_type=act_type, mode=mode)
198
+ # 12, 256
199
+ conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
200
+ act_type=act_type, mode=mode)
201
+ conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
202
+ act_type=act_type, mode=mode)
203
+ # 6, 512
204
+ conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
205
+ act_type=act_type, mode=mode)
206
+ conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
207
+ act_type=act_type, mode=mode)
208
+ # 3, 512
209
+ self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
210
+ conv9)
211
+
212
+ # classifier
213
+ self.classifier = nn.Sequential(
214
+ nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
215
+
216
+ def forward(self, x):
217
+ x = self.features(x)
218
+ x = x.view(x.size(0), -1)
219
+ x = self.classifier(x)
220
+ return x
221
+
222
+
223
+ class Discriminator_VGG_192(nn.Module):
224
+ def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
225
+ super(Discriminator_VGG_192, self).__init__()
226
+ # features
227
+ # hxw, c
228
+ # 192, 64
229
+ conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
230
+ mode=mode)
231
+ conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
232
+ act_type=act_type, mode=mode)
233
+ # 96, 64
234
+ conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
235
+ act_type=act_type, mode=mode)
236
+ conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
237
+ act_type=act_type, mode=mode)
238
+ # 48, 128
239
+ conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
240
+ act_type=act_type, mode=mode)
241
+ conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
242
+ act_type=act_type, mode=mode)
243
+ # 24, 256
244
+ conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
245
+ act_type=act_type, mode=mode)
246
+ conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
247
+ act_type=act_type, mode=mode)
248
+ # 12, 512
249
+ conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
250
+ act_type=act_type, mode=mode)
251
+ conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
252
+ act_type=act_type, mode=mode)
253
+ # 6, 512
254
+ conv10 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
255
+ act_type=act_type, mode=mode)
256
+ conv11 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
257
+ act_type=act_type, mode=mode)
258
+ # 3, 512
259
+ self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
260
+ conv9, conv10, conv11)
261
+
262
+ # classifier
263
+ self.classifier = nn.Sequential(
264
+ nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
265
+
266
+ def forward(self, x):
267
+ x = self.features(x)
268
+ x = x.view(x.size(0), -1)
269
+ x = self.classifier(x)
270
+ return x
271
+
272
+
273
+ ####################
274
+ # Perceptual Network
275
+ ####################
276
+
277
+
278
+ # Assume input range is [0, 1]
279
+ class VGGFeatureExtractor(nn.Module):
280
+ def __init__(self,
281
+ feature_layer=34,
282
+ use_bn=False,
283
+ use_input_norm=True,
284
+ device=torch.device('cpu')):
285
+ super(VGGFeatureExtractor, self).__init__()
286
+ if use_bn:
287
+ model = torchvision.models.vgg19_bn(pretrained=True)
288
+ else:
289
+ model = torchvision.models.vgg19(pretrained=True)
290
+ self.use_input_norm = use_input_norm
291
+ if self.use_input_norm:
292
+ mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
293
+ # [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
294
+ std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
295
+ # [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
296
+ self.register_buffer('mean', mean)
297
+ self.register_buffer('std', std)
298
+ self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
299
+ # No need to BP to variable
300
+ for k, v in self.features.named_parameters():
301
+ v.requires_grad = False
302
+
303
+ def forward(self, x):
304
+ if self.use_input_norm:
305
+ x = (x - self.mean) / self.std
306
+ output = self.features(x)
307
+ return output
308
+
309
+
310
+ # Assume input range is [0, 1]
311
+ class ResNet101FeatureExtractor(nn.Module):
312
+ def __init__(self, use_input_norm=True, device=torch.device('cpu')):
313
+ super(ResNet101FeatureExtractor, self).__init__()
314
+ model = torchvision.models.resnet101(pretrained=True)
315
+ self.use_input_norm = use_input_norm
316
+ if self.use_input_norm:
317
+ mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
318
+ # [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
319
+ std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
320
+ # [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
321
+ self.register_buffer('mean', mean)
322
+ self.register_buffer('std', std)
323
+ self.features = nn.Sequential(*list(model.children())[:8])
324
+ # No need to BP to variable
325
+ for k, v in self.features.named_parameters():
326
+ v.requires_grad = False
327
+
328
+ def forward(self, x):
329
+ if self.use_input_norm:
330
+ x = (x - self.mean) / self.std
331
+ output = self.features(x)
332
+ return output
333
+
334
+
335
+ class MINCNet(nn.Module):
336
+ def __init__(self):
337
+ super(MINCNet, self).__init__()
338
+ self.ReLU = nn.ReLU(True)
339
+ self.conv11 = nn.Conv2d(3, 64, 3, 1, 1)
340
+ self.conv12 = nn.Conv2d(64, 64, 3, 1, 1)
341
+ self.maxpool1 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
342
+ self.conv21 = nn.Conv2d(64, 128, 3, 1, 1)
343
+ self.conv22 = nn.Conv2d(128, 128, 3, 1, 1)
344
+ self.maxpool2 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
345
+ self.conv31 = nn.Conv2d(128, 256, 3, 1, 1)
346
+ self.conv32 = nn.Conv2d(256, 256, 3, 1, 1)
347
+ self.conv33 = nn.Conv2d(256, 256, 3, 1, 1)
348
+ self.maxpool3 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
349
+ self.conv41 = nn.Conv2d(256, 512, 3, 1, 1)
350
+ self.conv42 = nn.Conv2d(512, 512, 3, 1, 1)
351
+ self.conv43 = nn.Conv2d(512, 512, 3, 1, 1)
352
+ self.maxpool4 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
353
+ self.conv51 = nn.Conv2d(512, 512, 3, 1, 1)
354
+ self.conv52 = nn.Conv2d(512, 512, 3, 1, 1)
355
+ self.conv53 = nn.Conv2d(512, 512, 3, 1, 1)
356
+
357
+ def forward(self, x):
358
+ out = self.ReLU(self.conv11(x))
359
+ out = self.ReLU(self.conv12(out))
360
+ out = self.maxpool1(out)
361
+ out = self.ReLU(self.conv21(out))
362
+ out = self.ReLU(self.conv22(out))
363
+ out = self.maxpool2(out)
364
+ out = self.ReLU(self.conv31(out))
365
+ out = self.ReLU(self.conv32(out))
366
+ out = self.ReLU(self.conv33(out))
367
+ out = self.maxpool3(out)
368
+ out = self.ReLU(self.conv41(out))
369
+ out = self.ReLU(self.conv42(out))
370
+ out = self.ReLU(self.conv43(out))
371
+ out = self.maxpool4(out)
372
+ out = self.ReLU(self.conv51(out))
373
+ out = self.ReLU(self.conv52(out))
374
+ out = self.conv53(out)
375
+ return out
376
+
377
+
378
+ # Assume input range is [0, 1]
379
+ class MINCFeatureExtractor(nn.Module):
380
+ def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, \
381
+ device=torch.device('cpu')):
382
+ super(MINCFeatureExtractor, self).__init__()
383
+
384
+ self.features = MINCNet()
385
+ self.features.load_state_dict(
386
+ torch.load('../experiments/pretrained_models/VGG16minc_53.pth'), strict=True)
387
+ self.features.eval()
388
+ # No need to BP to variable
389
+ for k, v in self.features.named_parameters():
390
+ v.requires_grad = False
391
+
392
+ def forward(self, x):
393
+ output = self.features(x)
394
+ return output
esrgan_plus/codes/models/modules/block.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import copy
6
+
7
+ ####################
8
+ # Basic blocks
9
+ ####################
10
+
11
+
12
+ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
13
+ # helper selecting activation
14
+ # neg_slope: for leakyrelu and init of prelu
15
+ # n_prelu: for p_relu num_parameters
16
+ act_type = act_type.lower()
17
+ if act_type == 'relu':
18
+ layer = nn.ReLU(inplace)
19
+ elif act_type == 'leakyrelu':
20
+ layer = nn.LeakyReLU(neg_slope, inplace)
21
+ elif act_type == 'prelu':
22
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
23
+ else:
24
+ raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
25
+ return layer
26
+
27
+
28
+ def norm(norm_type, nc):
29
+ # helper selecting normalization layer
30
+ norm_type = norm_type.lower()
31
+ if norm_type == 'batch':
32
+ layer = nn.BatchNorm2d(nc, affine=True)
33
+ elif norm_type == 'instance':
34
+ layer = nn.InstanceNorm2d(nc, affine=False)
35
+ else:
36
+ raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
37
+ return layer
38
+
39
+
40
+ def pad(pad_type, padding):
41
+ # helper selecting padding layer
42
+ # if padding is 'zero', do by conv layers
43
+ pad_type = pad_type.lower()
44
+ if padding == 0:
45
+ return None
46
+ if pad_type == 'reflect':
47
+ layer = nn.ReflectionPad2d(padding)
48
+ elif pad_type == 'replicate':
49
+ layer = nn.ReplicationPad2d(padding)
50
+ else:
51
+ raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
52
+ return layer
53
+
54
+
55
+ def get_valid_padding(kernel_size, dilation):
56
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
57
+ padding = (kernel_size - 1) // 2
58
+ return padding
59
+
60
+
61
+ class ConcatBlock(nn.Module):
62
+ # Concat the output of a submodule to its input
63
+ def __init__(self, submodule):
64
+ super(ConcatBlock, self).__init__()
65
+ self.sub = submodule
66
+
67
+ def forward(self, x):
68
+ output = torch.cat((x, self.sub(x)), dim=1)
69
+ return output
70
+
71
+ def __repr__(self):
72
+ tmpstr = 'Identity .. \n|'
73
+ modstr = self.sub.__repr__().replace('\n', '\n|')
74
+ tmpstr = tmpstr + modstr
75
+ return tmpstr
76
+
77
+
78
+ class ShortcutBlock(nn.Module):
79
+ #Elementwise sum the output of a submodule to its input
80
+ def __init__(self, submodule):
81
+ super(ShortcutBlock, self).__init__()
82
+ self.sub = submodule
83
+
84
+ def forward(self, x):
85
+ output = x + self.sub(x)
86
+ return output
87
+
88
+ def __repr__(self):
89
+ tmpstr = 'Identity + \n|'
90
+ modstr = self.sub.__repr__().replace('\n', '\n|')
91
+ tmpstr = tmpstr + modstr
92
+ return tmpstr
93
+
94
+
95
+ def sequential(*args):
96
+ # Flatten Sequential. It unwraps nn.Sequential.
97
+ if len(args) == 1:
98
+ if isinstance(args[0], OrderedDict):
99
+ raise NotImplementedError('sequential does not support OrderedDict input.')
100
+ return args[0] # No sequential is needed.
101
+ modules = []
102
+ for module in args:
103
+ if isinstance(module, nn.Sequential):
104
+ for submodule in module.children():
105
+ modules.append(submodule)
106
+ elif isinstance(module, nn.Module):
107
+ modules.append(module)
108
+ return nn.Sequential(*modules)
109
+
110
+ class GaussianNoise(nn.Module):
111
+ def __init__(self, sigma=0.1, is_relative_detach=False):
112
+ super().__init__()
113
+ self.sigma = sigma
114
+ self.is_relative_detach = is_relative_detach
115
+ self.noise = torch.tensor(0, dtype=torch.float)
116
+
117
+ def forward(self, x):
118
+ if self.training and self.sigma != 0:
119
+ scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
120
+ sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
121
+ x = x + sampled_noise
122
+ return x
123
+
124
+
125
+ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, \
126
+ pad_type='zero', norm_type=None, act_type='relu', mode='CNA'):
127
+ '''
128
+ Conv layer with padding, normalization, activation
129
+ mode: CNA --> Conv -> Norm -> Act
130
+ NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
131
+ '''
132
+ assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
133
+ padding = get_valid_padding(kernel_size, dilation)
134
+ p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
135
+ padding = padding if pad_type == 'zero' else 0
136
+
137
+ c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
138
+ dilation=dilation, bias=bias, groups=groups)
139
+ a = act(act_type) if act_type else None
140
+ if 'CNA' in mode:
141
+ n = norm(norm_type, out_nc) if norm_type else None
142
+ return sequential(p, c, n, a)
143
+ elif mode == 'NAC':
144
+ if norm_type is None and act_type is not None:
145
+ a = act(act_type, inplace=False)
146
+ # Important!
147
+ # input----ReLU(inplace)----Conv--+----output
148
+ # |________________________|
149
+ # inplace ReLU will modify the input, therefore wrong output
150
+ n = norm(norm_type, in_nc) if norm_type else None
151
+ return sequential(n, a, p, c)
152
+
153
+ def conv1x1(in_planes, out_planes, stride=1):
154
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
155
+
156
+
157
+ # https://github.com/github-pengge/PyTorch-progressive_growing_of_gans/blob/master/models/base_model.py
158
+ class minibatch_std_concat_layer(nn.Module):
159
+ def __init__(self, averaging='all'):
160
+ super(minibatch_std_concat_layer, self).__init__()
161
+ self.averaging = averaging.lower()
162
+ if 'group' in self.averaging:
163
+ self.n = int(self.averaging[5:])
164
+ else:
165
+ assert self.averaging in ['all', 'flat', 'spatial', 'none', 'gpool'], 'Invalid averaging mode'%self.averaging
166
+ self.adjusted_std = lambda x, **kwargs: torch.sqrt(torch.mean((x - torch.mean(x, **kwargs)) ** 2, **kwargs) + 1e-8)
167
+
168
+ def forward(self, x):
169
+ shape = list(x.size())
170
+ target_shape = copy.deepcopy(shape)
171
+ vals = self.adjusted_std(x, dim=0, keepdim=True)
172
+ if self.averaging == 'all':
173
+ target_shape[1] = 1
174
+ vals = torch.mean(vals, dim=1, keepdim=True)
175
+ elif self.averaging == 'spatial':
176
+ if len(shape) == 4:
177
+ vals = mean(vals, axis=[2,3], keepdim=True) # torch.mean(torch.mean(vals, 2, keepdim=True), 3, keepdim=True)
178
+ elif self.averaging == 'none':
179
+ target_shape = [target_shape[0]] + [s for s in target_shape[1:]]
180
+ elif self.averaging == 'gpool':
181
+ if len(shape) == 4:
182
+ vals = mean(x, [0,2,3], keepdim=True) # torch.mean(torch.mean(torch.mean(x, 2, keepdim=True), 3, keepdim=True), 0, keepdim=True)
183
+ elif self.averaging == 'flat':
184
+ target_shape[1] = 1
185
+ vals = torch.FloatTensor([self.adjusted_std(x)])
186
+ else: # self.averaging == 'group'
187
+ target_shape[1] = self.n
188
+ vals = vals.view(self.n, self.shape[1]/self.n, self.shape[2], self.shape[3])
189
+ vals = mean(vals, axis=0, keepdim=True).view(1, self.n, 1, 1)
190
+ vals = vals.expand(*target_shape)
191
+ return torch.cat([x, vals], 1)
192
+
193
+
194
+ ####################
195
+ # Useful blocks
196
+ ####################
197
+
198
+
199
+ class ResNetBlock(nn.Module):
200
+ '''
201
+ ResNet Block, 3-3 style
202
+ with extra residual scaling used in EDSR
203
+ (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
204
+ '''
205
+
206
+ def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, stride=1, dilation=1, groups=1, \
207
+ bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA', res_scale=1):
208
+ super(ResNetBlock, self).__init__()
209
+ conv0 = conv_block(in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
210
+ norm_type, act_type, mode)
211
+ if mode == 'CNA':
212
+ act_type = None
213
+ if mode == 'CNAC': # Residual path: |-CNAC-|
214
+ act_type = None
215
+ norm_type = None
216
+ conv1 = conv_block(mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
217
+ norm_type, act_type, mode)
218
+ # if in_nc != out_nc:
219
+ # self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
220
+ # None, None)
221
+ # print('Need a projecter in ResNetBlock.')
222
+ # else:
223
+ # self.project = lambda x:x
224
+ self.res = sequential(conv0, conv1)
225
+ self.res_scale = res_scale
226
+
227
+ def forward(self, x):
228
+ res = self.res(x).mul(self.res_scale)
229
+ return x + res
230
+
231
+
232
+ class ResidualDenseBlock_5C(nn.Module):
233
+ '''
234
+ Residual Dense Block
235
+ style: 5 convs
236
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
237
+ '''
238
+
239
+ def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
240
+ norm_type=None, act_type='leakyrelu', mode='CNA', gaussian_noise=True):
241
+ super(ResidualDenseBlock_5C, self).__init__()
242
+ # gc: growth channel, i.e. intermediate channels
243
+ self.noise = GaussianNoise() if gaussian_noise else None
244
+ self.conv1x1 = conv1x1(nc, gc)
245
+ self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
246
+ norm_type=norm_type, act_type=act_type, mode=mode)
247
+ self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
248
+ norm_type=norm_type, act_type=act_type, mode=mode)
249
+ self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
250
+ norm_type=norm_type, act_type=act_type, mode=mode)
251
+ self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
252
+ norm_type=norm_type, act_type=act_type, mode=mode)
253
+ if mode == 'CNA':
254
+ last_act = None
255
+ else:
256
+ last_act = act_type
257
+ self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
258
+ norm_type=norm_type, act_type=last_act, mode=mode)
259
+
260
+ def forward(self, x):
261
+ x1 = self.conv1(x)
262
+ x2 = self.conv2(torch.cat((x, x1), 1))
263
+ x2 = x2 + self.conv1x1(x)
264
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
265
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
266
+ x4 = x4 + x2
267
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
268
+ return self.noise(x5.mul(0.2) + x)
269
+
270
+
271
+ class RRDB(nn.Module):
272
+ '''
273
+ Residual in Residual Dense Block
274
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
275
+ '''
276
+
277
+ def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
278
+ norm_type=None, act_type='leakyrelu', mode='CNA'):
279
+ super(RRDB, self).__init__()
280
+ self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
281
+ norm_type, act_type, mode)
282
+ self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
283
+ norm_type, act_type, mode)
284
+ self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
285
+ norm_type, act_type, mode)
286
+
287
+ def forward(self, x):
288
+ out = self.RDB1(x)
289
+ out = self.RDB2(out)
290
+ out = self.RDB3(out)
291
+ return out.mul(0.2) + x
292
+
293
+
294
+ ####################
295
+ # Upsampler
296
+ ####################
297
+
298
+
299
+ def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
300
+ pad_type='zero', norm_type=None, act_type='relu'):
301
+ '''
302
+ Pixel shuffle layer
303
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
304
+ Neural Network, CVPR17)
305
+ '''
306
+ conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, \
307
+ pad_type=pad_type, norm_type=None, act_type=None)
308
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
309
+
310
+ n = norm(norm_type, out_nc) if norm_type else None
311
+ a = act(act_type) if act_type else None
312
+ return sequential(conv, pixel_shuffle, n, a)
313
+
314
+
315
+ def upconv_blcok(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
316
+ pad_type='zero', norm_type=None, act_type='relu', mode='nearest'):
317
+ # Up conv
318
+ # described in https://distill.pub/2016/deconv-checkerboard/
319
+ upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
320
+ conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, \
321
+ pad_type=pad_type, norm_type=norm_type, act_type=act_type)
322
+ return sequential(upsample, conv)
esrgan_plus/codes/models/modules/loss.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ # Define GAN loss: [vanilla | lsgan | wgan-gp]
6
+ class GANLoss(nn.Module):
7
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
8
+ super(GANLoss, self).__init__()
9
+ self.gan_type = gan_type.lower()
10
+ self.real_label_val = real_label_val
11
+ self.fake_label_val = fake_label_val
12
+
13
+ if self.gan_type == 'vanilla':
14
+ self.loss = nn.BCEWithLogitsLoss()
15
+ elif self.gan_type == 'lsgan':
16
+ self.loss = nn.MSELoss()
17
+ elif self.gan_type == 'wgan-gp':
18
+
19
+ def wgan_loss(input, target):
20
+ # target is boolean
21
+ return -1 * input.mean() if target else input.mean()
22
+
23
+ self.loss = wgan_loss
24
+ else:
25
+ raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
26
+
27
+ def get_target_label(self, input, target_is_real):
28
+ if self.gan_type == 'wgan-gp':
29
+ return target_is_real
30
+ if target_is_real:
31
+ return torch.empty_like(input).fill_(self.real_label_val)
32
+ else:
33
+ return torch.empty_like(input).fill_(self.fake_label_val)
34
+
35
+ def forward(self, input, target_is_real):
36
+ target_label = self.get_target_label(input, target_is_real)
37
+ loss = self.loss(input, target_label)
38
+ return loss
39
+
40
+
41
+ class GradientPenaltyLoss(nn.Module):
42
+ def __init__(self, device=torch.device('cpu')):
43
+ super(GradientPenaltyLoss, self).__init__()
44
+ self.register_buffer('grad_outputs', torch.Tensor())
45
+ self.grad_outputs = self.grad_outputs.to(device)
46
+
47
+ def get_grad_outputs(self, input):
48
+ if self.grad_outputs.size() != input.size():
49
+ self.grad_outputs.resize_(input.size()).fill_(1.0)
50
+ return self.grad_outputs
51
+
52
+ def forward(self, interp, interp_crit):
53
+ grad_outputs = self.get_grad_outputs(interp_crit)
54
+ grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, \
55
+ grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True)[0]
56
+ grad_interp = grad_interp.view(grad_interp.size(0), -1)
57
+ grad_interp_norm = grad_interp.norm(2, dim=1)
58
+
59
+ loss = ((grad_interp_norm - 1)**2).mean()
60
+ return loss
esrgan_plus/codes/models/modules/seg_arch.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ architecture for segmentation
3
+ '''
4
+ import torch.nn as nn
5
+ from . import block as B
6
+
7
+
8
+ class Res131(nn.Module):
9
+ def __init__(self, in_nc, mid_nc, out_nc, dilation=1, stride=1):
10
+ super(Res131, self).__init__()
11
+ conv0 = B.conv_block(in_nc, mid_nc, 1, 1, 1, 1, False, 'zero', 'batch')
12
+ conv1 = B.conv_block(mid_nc, mid_nc, 3, stride, dilation, 1, False, 'zero', 'batch')
13
+ conv2 = B.conv_block(mid_nc, out_nc, 1, 1, 1, 1, False, 'zero', 'batch', None) # No ReLU
14
+ self.res = B.sequential(conv0, conv1, conv2)
15
+ if in_nc == out_nc:
16
+ self.has_proj = False
17
+ else:
18
+ self.has_proj = True
19
+ self.proj = B.conv_block(in_nc, out_nc, 1, stride, 1, 1, False, 'zero', 'batch', None)
20
+ # No ReLU
21
+
22
+ def forward(self, x):
23
+ res = self.res(x)
24
+ if self.has_proj:
25
+ x = self.proj(x)
26
+ return nn.functional.relu(x + res, inplace=True)
27
+
28
+
29
+ class OutdoorSceneSeg(nn.Module):
30
+ def __init__(self):
31
+ super(OutdoorSceneSeg, self).__init__()
32
+ # conv1
33
+ blocks = []
34
+ conv1_1 = B.conv_block(3, 64, 3, 2, 1, 1, False, 'zero', 'batch') # /2
35
+ conv1_2 = B.conv_block(64, 64, 3, 1, 1, 1, False, 'zero', 'batch')
36
+ conv1_3 = B.conv_block(64, 128, 3, 1, 1, 1, False, 'zero', 'batch')
37
+ max_pool = nn.MaxPool2d(3, stride=2, padding=0, ceil_mode=True) # /2
38
+ blocks = [conv1_1, conv1_2, conv1_3, max_pool]
39
+ # conv2, 3 blocks
40
+ blocks.append(Res131(128, 64, 256))
41
+ for i in range(2):
42
+ blocks.append(Res131(256, 64, 256))
43
+ # conv3, 4 blocks
44
+ blocks.append(Res131(256, 128, 512, 1, 2)) # /2
45
+ for i in range(3):
46
+ blocks.append(Res131(512, 128, 512))
47
+ # conv4, 23 blocks
48
+ blocks.append(Res131(512, 256, 1024, 2))
49
+ for i in range(22):
50
+ blocks.append(Res131(1024, 256, 1024, 2))
51
+ # conv5
52
+ blocks.append(Res131(1024, 512, 2048, 4))
53
+ blocks.append(Res131(2048, 512, 2048, 4))
54
+ blocks.append(Res131(2048, 512, 2048, 4))
55
+ blocks.append(B.conv_block(2048, 512, 3, 1, 1, 1, False, 'zero', 'batch'))
56
+ blocks.append(nn.Dropout(0.1))
57
+ # # conv6
58
+ blocks.append(nn.Conv2d(512, 8, 1, 1))
59
+
60
+ self.feature = B.sequential(*blocks)
61
+ # deconv
62
+ self.deconv = nn.ConvTranspose2d(8, 8, 16, 8, 4, 0, 8, False, 1)
63
+ # softmax
64
+ self.softmax = nn.Softmax(1)
65
+
66
+ def forward(self, x):
67
+ x = self.feature(x)
68
+ x = self.deconv(x)
69
+ x = self.softmax(x)
70
+ return x
esrgan_plus/codes/models/modules/sft_arch.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ architecture for sft
3
+ '''
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class SFTLayer(nn.Module):
9
+ def __init__(self):
10
+ super(SFTLayer, self).__init__()
11
+ self.SFT_scale_conv0 = nn.Conv2d(32, 32, 1)
12
+ self.SFT_scale_conv1 = nn.Conv2d(32, 64, 1)
13
+ self.SFT_shift_conv0 = nn.Conv2d(32, 32, 1)
14
+ self.SFT_shift_conv1 = nn.Conv2d(32, 64, 1)
15
+
16
+ def forward(self, x):
17
+ # x[0]: fea; x[1]: cond
18
+ scale = self.SFT_scale_conv1(F.leaky_relu(self.SFT_scale_conv0(x[1]), 0.1, inplace=True))
19
+ shift = self.SFT_shift_conv1(F.leaky_relu(self.SFT_shift_conv0(x[1]), 0.1, inplace=True))
20
+ return x[0] * (scale + 1) + shift
21
+
22
+
23
+ class ResBlock_SFT(nn.Module):
24
+ def __init__(self):
25
+ super(ResBlock_SFT, self).__init__()
26
+ self.sft0 = SFTLayer()
27
+ self.conv0 = nn.Conv2d(64, 64, 3, 1, 1)
28
+ self.sft1 = SFTLayer()
29
+ self.conv1 = nn.Conv2d(64, 64, 3, 1, 1)
30
+
31
+ def forward(self, x):
32
+ # x[0]: fea; x[1]: cond
33
+ fea = self.sft0(x)
34
+ fea = F.relu(self.conv0(fea), inplace=True)
35
+ fea = self.sft1((fea, x[1]))
36
+ fea = self.conv1(fea)
37
+ return (x[0] + fea, x[1]) # return a tuple containing features and conditions
38
+
39
+
40
+ class SFT_Net(nn.Module):
41
+ def __init__(self):
42
+ super(SFT_Net, self).__init__()
43
+ self.conv0 = nn.Conv2d(3, 64, 3, 1, 1)
44
+
45
+ sft_branch = []
46
+ for i in range(16):
47
+ sft_branch.append(ResBlock_SFT())
48
+ sft_branch.append(SFTLayer())
49
+ sft_branch.append(nn.Conv2d(64, 64, 3, 1, 1))
50
+ self.sft_branch = nn.Sequential(*sft_branch)
51
+
52
+ self.HR_branch = nn.Sequential(
53
+ nn.Conv2d(64, 256, 3, 1, 1),
54
+ nn.PixelShuffle(2),
55
+ nn.ReLU(True),
56
+ nn.Conv2d(64, 256, 3, 1, 1),
57
+ nn.PixelShuffle(2),
58
+ nn.ReLU(True),
59
+ nn.Conv2d(64, 64, 3, 1, 1),
60
+ nn.ReLU(True),
61
+ nn.Conv2d(64, 3, 3, 1, 1)
62
+ )
63
+
64
+ self.CondNet = nn.Sequential(
65
+ nn.Conv2d(8, 128, 4, 4),
66
+ nn.LeakyReLU(0.1, True),
67
+ nn.Conv2d(128, 128, 1),
68
+ nn.LeakyReLU(0.1, True),
69
+ nn.Conv2d(128, 128, 1),
70
+ nn.LeakyReLU(0.1, True),
71
+ nn.Conv2d(128, 128, 1),
72
+ nn.LeakyReLU(0.1, True),
73
+ nn.Conv2d(128, 32, 1)
74
+ )
75
+
76
+ def forward(self, x):
77
+ # x[0]: img; x[1]: seg
78
+ cond = self.CondNet(x[1])
79
+ fea = self.conv0(x[0])
80
+ res = self.sft_branch((fea, cond))
81
+ fea = fea + res
82
+ out = self.HR_branch(fea)
83
+ return out
84
+
85
+
86
+ # Auxiliary Classifier Discriminator
87
+ class ACD_VGG_BN_96(nn.Module):
88
+ def __init__(self):
89
+ super(ACD_VGG_BN_96, self).__init__()
90
+
91
+ self.feature = nn.Sequential(
92
+ nn.Conv2d(3, 64, 3, 1, 1),
93
+ nn.LeakyReLU(0.1, True),
94
+
95
+ nn.Conv2d(64, 64, 4, 2, 1),
96
+ nn.BatchNorm2d(64, affine=True),
97
+ nn.LeakyReLU(0.1, True),
98
+
99
+ nn.Conv2d(64, 128, 3, 1, 1),
100
+ nn.BatchNorm2d(128, affine=True),
101
+ nn.LeakyReLU(0.1, True),
102
+
103
+ nn.Conv2d(128, 128, 4, 2, 1),
104
+ nn.BatchNorm2d(128, affine=True),
105
+ nn.LeakyReLU(0.1, True),
106
+
107
+ nn.Conv2d(128, 256, 3, 1, 1),
108
+ nn.BatchNorm2d(256, affine=True),
109
+ nn.LeakyReLU(0.1, True),
110
+
111
+ nn.Conv2d(256, 256, 4, 2, 1),
112
+ nn.BatchNorm2d(256, affine=True),
113
+ nn.LeakyReLU(0.1, True),
114
+
115
+ nn.Conv2d(256, 512, 3, 1, 1),
116
+ nn.BatchNorm2d(512, affine=True),
117
+ nn.LeakyReLU(0.1, True),
118
+
119
+ nn.Conv2d(512, 512, 4, 2, 1),
120
+ nn.BatchNorm2d(512, affine=True),
121
+ nn.LeakyReLU(0.1, True),
122
+ )
123
+
124
+ # gan
125
+ self.gan = nn.Sequential(
126
+ nn.Linear(512*6*6, 100),
127
+ nn.LeakyReLU(0.1, True),
128
+ nn.Linear(100, 1)
129
+ )
130
+
131
+ self.cls = nn.Sequential(
132
+ nn.Linear(512*6*6, 100),
133
+ nn.LeakyReLU(0.1, True),
134
+ nn.Linear(100, 8)
135
+ )
136
+
137
+ def forward(self, x):
138
+ fea = self.feature(x)
139
+ fea = fea.view(fea.size(0), -1)
140
+ gan = self.gan(fea)
141
+ cls = self.cls(fea)
142
+ return [gan, cls]
143
+
144
+
145
+ #############################################
146
+ # below is the sft arch for the torch version
147
+ #############################################
148
+
149
+
150
+ class SFTLayer_torch(nn.Module):
151
+ def __init__(self):
152
+ super(SFTLayer_torch, self).__init__()
153
+ self.SFT_scale_conv0 = nn.Conv2d(32, 32, 1)
154
+ self.SFT_scale_conv1 = nn.Conv2d(32, 64, 1)
155
+ self.SFT_shift_conv0 = nn.Conv2d(32, 32, 1)
156
+ self.SFT_shift_conv1 = nn.Conv2d(32, 64, 1)
157
+
158
+ def forward(self, x):
159
+ # x[0]: fea; x[1]: cond
160
+ scale = self.SFT_scale_conv1(F.leaky_relu(self.SFT_scale_conv0(x[1]), 0.01, inplace=True))
161
+ shift = self.SFT_shift_conv1(F.leaky_relu(self.SFT_shift_conv0(x[1]), 0.01, inplace=True))
162
+ return x[0] * scale + shift
163
+
164
+
165
+ class ResBlock_SFT_torch(nn.Module):
166
+ def __init__(self):
167
+ super(ResBlock_SFT_torch, self).__init__()
168
+ self.sft0 = SFTLayer_torch()
169
+ self.conv0 = nn.Conv2d(64, 64, 3, 1, 1)
170
+ self.sft1 = SFTLayer_torch()
171
+ self.conv1 = nn.Conv2d(64, 64, 3, 1, 1)
172
+
173
+ def forward(self, x):
174
+ # x[0]: fea; x[1]: cond
175
+ fea = F.relu(self.sft0(x), inplace=True)
176
+ fea = self.conv0(fea)
177
+ fea = F.relu(self.sft1((fea, x[1])), inplace=True)
178
+ fea = self.conv1(fea)
179
+ return (x[0] + fea, x[1]) # return a tuple containing features and conditions
180
+
181
+
182
+ class SFT_Net_torch(nn.Module):
183
+ def __init__(self):
184
+ super(SFT_Net_torch, self).__init__()
185
+ self.conv0 = nn.Conv2d(3, 64, 3, 1, 1)
186
+
187
+ sft_branch = []
188
+ for i in range(16):
189
+ sft_branch.append(ResBlock_SFT_torch())
190
+ sft_branch.append(SFTLayer_torch())
191
+ sft_branch.append(nn.Conv2d(64, 64, 3, 1, 1))
192
+ self.sft_branch = nn.Sequential(*sft_branch)
193
+
194
+ self.HR_branch = nn.Sequential(
195
+ nn.Upsample(scale_factor=2, mode='nearest'),
196
+ nn.Conv2d(64, 64, 3, 1, 1),
197
+ nn.ReLU(True),
198
+ nn.Upsample(scale_factor=2, mode='nearest'),
199
+ nn.Conv2d(64, 64, 3, 1, 1),
200
+ nn.ReLU(True),
201
+ nn.Conv2d(64, 64, 3, 1, 1),
202
+ nn.ReLU(True),
203
+ nn.Conv2d(64, 3, 3, 1, 1)
204
+ )
205
+
206
+ # Condtion network
207
+ self.CondNet = nn.Sequential(
208
+ nn.Conv2d(8, 128, 4, 4),
209
+ nn.LeakyReLU(0.1, True),
210
+ nn.Conv2d(128, 128, 1),
211
+ nn.LeakyReLU(0.1, True),
212
+ nn.Conv2d(128, 128, 1),
213
+ nn.LeakyReLU(0.1, True),
214
+ nn.Conv2d(128, 128, 1),
215
+ nn.LeakyReLU(0.1, True),
216
+ nn.Conv2d(128, 32, 1)
217
+ )
218
+
219
+ def forward(self, x):
220
+ # x[0]: img; x[1]: seg
221
+ cond = self.CondNet(x[1])
222
+ fea = self.conv0(x[0])
223
+ res = self.sft_branch((fea, cond))
224
+ fea = fea + res
225
+ out = self.HR_branch(fea)
226
+ return out
esrgan_plus/codes/models/modules/spectral_norm.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Copy from pytorch github repo
3
+ Spectral Normalization from https://arxiv.org/abs/1802.05957
4
+ '''
5
+ import torch
6
+ from torch.nn.functional import normalize
7
+ from torch.nn.parameter import Parameter
8
+
9
+
10
+ class SpectralNorm(object):
11
+ def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
12
+ self.name = name
13
+ self.dim = dim
14
+ if n_power_iterations <= 0:
15
+ raise ValueError('Expected n_power_iterations to be positive, but '
16
+ 'got n_power_iterations={}'.format(n_power_iterations))
17
+ self.n_power_iterations = n_power_iterations
18
+ self.eps = eps
19
+
20
+ def compute_weight(self, module):
21
+ weight = getattr(module, self.name + '_orig')
22
+ u = getattr(module, self.name + '_u')
23
+ weight_mat = weight
24
+ if self.dim != 0:
25
+ # permute dim to front
26
+ weight_mat = weight_mat.permute(self.dim,
27
+ *[d for d in range(weight_mat.dim()) if d != self.dim])
28
+ height = weight_mat.size(0)
29
+ weight_mat = weight_mat.reshape(height, -1)
30
+ with torch.no_grad():
31
+ for _ in range(self.n_power_iterations):
32
+ # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
33
+ # are the first left and right singular vectors.
34
+ # This power iteration produces approximations of `u` and `v`.
35
+ v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
36
+ u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)
37
+
38
+ sigma = torch.dot(u, torch.matmul(weight_mat, v))
39
+ weight = weight / sigma
40
+ return weight, u
41
+
42
+ def remove(self, module):
43
+ weight = getattr(module, self.name)
44
+ delattr(module, self.name)
45
+ delattr(module, self.name + '_u')
46
+ delattr(module, self.name + '_orig')
47
+ module.register_parameter(self.name, torch.nn.Parameter(weight))
48
+
49
+ def __call__(self, module, inputs):
50
+ if module.training:
51
+ weight, u = self.compute_weight(module)
52
+ setattr(module, self.name, weight)
53
+ setattr(module, self.name + '_u', u)
54
+ else:
55
+ r_g = getattr(module, self.name + '_orig').requires_grad
56
+ getattr(module, self.name).detach_().requires_grad_(r_g)
57
+
58
+ @staticmethod
59
+ def apply(module, name, n_power_iterations, dim, eps):
60
+ fn = SpectralNorm(name, n_power_iterations, dim, eps)
61
+ weight = module._parameters[name]
62
+ height = weight.size(dim)
63
+
64
+ u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps)
65
+ delattr(module, fn.name)
66
+ module.register_parameter(fn.name + "_orig", weight)
67
+ # We still need to assign weight back as fn.name because all sorts of
68
+ # things may assume that it exists, e.g., when initializing weights.
69
+ # However, we can't directly assign as it could be an nn.Parameter and
70
+ # gets added as a parameter. Instead, we register weight.data as a
71
+ # buffer, which will cause weight to be included in the state dict
72
+ # and also supports nn.init due to shared storage.
73
+ module.register_buffer(fn.name, weight.data)
74
+ module.register_buffer(fn.name + "_u", u)
75
+
76
+ module.register_forward_pre_hook(fn)
77
+ return fn
78
+
79
+
80
+ def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None):
81
+ r"""Applies spectral normalization to a parameter in the given module.
82
+
83
+ .. math::
84
+ \mathbf{W} &= \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
85
+ \sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
86
+
87
+ Spectral normalization stabilizes the training of discriminators (critics)
88
+ in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
89
+ with spectral norm :math:`\sigma` of the weight matrix calculated using
90
+ power iteration method. If the dimension of the weight tensor is greater
91
+ than 2, it is reshaped to 2D in power iteration method to get spectral
92
+ norm. This is implemented via a hook that calculates spectral norm and
93
+ rescales weight before every :meth:`~Module.forward` call.
94
+
95
+ See `Spectral Normalization for Generative Adversarial Networks`_ .
96
+
97
+ .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
98
+
99
+ Args:
100
+ module (nn.Module): containing module
101
+ name (str, optional): name of weight parameter
102
+ n_power_iterations (int, optional): number of power iterations to
103
+ calculate spectal norm
104
+ eps (float, optional): epsilon for numerical stability in
105
+ calculating norms
106
+ dim (int, optional): dimension corresponding to number of outputs,
107
+ the default is 0, except for modules that are instances of
108
+ ConvTranspose1/2/3d, when it is 1
109
+
110
+ Returns:
111
+ The original module with the spectal norm hook
112
+
113
+ Example::
114
+
115
+ >>> m = spectral_norm(nn.Linear(20, 40))
116
+ Linear (20 -> 40)
117
+ >>> m.weight_u.size()
118
+ torch.Size([20])
119
+
120
+ """
121
+ if dim is None:
122
+ if isinstance(
123
+ module,
124
+ (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)):
125
+ dim = 1
126
+ else:
127
+ dim = 0
128
+ SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
129
+ return module
130
+
131
+
132
+ def remove_spectral_norm(module, name='weight'):
133
+ r"""Removes the spectral normalization reparameterization from a module.
134
+
135
+ Args:
136
+ module (nn.Module): containing module
137
+ name (str, optional): name of weight parameter
138
+
139
+ Example:
140
+ >>> m = spectral_norm(nn.Linear(40, 10))
141
+ >>> remove_spectral_norm(m)
142
+ """
143
+ for k, hook in module._forward_pre_hooks.items():
144
+ if isinstance(hook, SpectralNorm) and hook.name == name:
145
+ hook.remove(module)
146
+ del module._forward_pre_hooks[k]
147
+ return module
148
+
149
+ raise ValueError("spectral_norm of '{}' not found in {}".format(name, module))
esrgan_plus/codes/models/networks.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import logging
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import init
6
+
7
+ import models.modules.architecture as arch
8
+ import models.modules.sft_arch as sft_arch
9
+ logger = logging.getLogger('base')
10
+ ####################
11
+ # initialize
12
+ ####################
13
+
14
+
15
+ def weights_init_normal(m, std=0.02):
16
+ classname = m.__class__.__name__
17
+ if classname.find('Conv') != -1:
18
+ init.normal_(m.weight.data, 0.0, std)
19
+ if m.bias is not None:
20
+ m.bias.data.zero_()
21
+ elif classname.find('Linear') != -1:
22
+ init.normal_(m.weight.data, 0.0, std)
23
+ if m.bias is not None:
24
+ m.bias.data.zero_()
25
+ elif classname.find('BatchNorm2d') != -1:
26
+ init.normal_(m.weight.data, 1.0, std) # BN also uses norm
27
+ init.constant_(m.bias.data, 0.0)
28
+
29
+
30
+ def weights_init_kaiming(m, scale=1):
31
+ classname = m.__class__.__name__
32
+ if classname.find('Conv') != -1:
33
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
34
+ m.weight.data *= scale
35
+ if m.bias is not None:
36
+ m.bias.data.zero_()
37
+ elif classname.find('Linear') != -1:
38
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
39
+ m.weight.data *= scale
40
+ if m.bias is not None:
41
+ m.bias.data.zero_()
42
+ elif classname.find('BatchNorm2d') != -1:
43
+ init.constant_(m.weight.data, 1.0)
44
+ init.constant_(m.bias.data, 0.0)
45
+
46
+
47
+ def weights_init_orthogonal(m):
48
+ classname = m.__class__.__name__
49
+ if classname.find('Conv') != -1:
50
+ init.orthogonal_(m.weight.data, gain=1)
51
+ if m.bias is not None:
52
+ m.bias.data.zero_()
53
+ elif classname.find('Linear') != -1:
54
+ init.orthogonal_(m.weight.data, gain=1)
55
+ if m.bias is not None:
56
+ m.bias.data.zero_()
57
+ elif classname.find('BatchNorm2d') != -1:
58
+ init.constant_(m.weight.data, 1.0)
59
+ init.constant_(m.bias.data, 0.0)
60
+
61
+
62
+ def init_weights(net, init_type='kaiming', scale=1, std=0.02):
63
+ # scale for 'kaiming', std for 'normal'.
64
+ logger.info('Initialization method [{:s}]'.format(init_type))
65
+ if init_type == 'normal':
66
+ weights_init_normal_ = functools.partial(weights_init_normal, std=std)
67
+ net.apply(weights_init_normal_)
68
+ elif init_type == 'kaiming':
69
+ weights_init_kaiming_ = functools.partial(weights_init_kaiming, scale=scale)
70
+ net.apply(weights_init_kaiming_)
71
+ elif init_type == 'orthogonal':
72
+ net.apply(weights_init_orthogonal)
73
+ else:
74
+ raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type))
75
+
76
+
77
+ ####################
78
+ # define network
79
+ ####################
80
+
81
+
82
+ # Generator
83
+ def define_G(opt):
84
+ gpu_ids = opt['gpu_ids']
85
+ opt_net = opt['network_G']
86
+ which_model = opt_net['which_model_G']
87
+
88
+ if which_model == 'sr_resnet': # SRResNet
89
+ netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
90
+ nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
91
+ act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')
92
+
93
+ elif which_model == 'sft_arch': # SFT-GAN
94
+ netG = sft_arch.SFT_Net()
95
+
96
+ elif which_model == 'RRDB_net': # RRDB
97
+ netG = arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
98
+ nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],
99
+ act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')
100
+ else:
101
+ raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
102
+
103
+ if opt['is_train']:
104
+ init_weights(netG, init_type='kaiming', scale=0.1)
105
+ if gpu_ids:
106
+ assert torch.cuda.is_available()
107
+ netG = nn.DataParallel(netG)
108
+ return netG
109
+
110
+
111
+ # Discriminator
112
+ def define_D(opt):
113
+ gpu_ids = opt['gpu_ids']
114
+ opt_net = opt['network_D']
115
+ which_model = opt_net['which_model_D']
116
+
117
+ if which_model == 'discriminator_vgg_128':
118
+ netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
119
+ norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
120
+
121
+ elif which_model == 'dis_acd': # sft-gan, Auxiliary Classifier Discriminator
122
+ netD = sft_arch.ACD_VGG_BN_96()
123
+
124
+ elif which_model == 'discriminator_vgg_96':
125
+ netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
126
+ norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
127
+ elif which_model == 'discriminator_vgg_192':
128
+ netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
129
+ norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
130
+ elif which_model == 'discriminator_vgg_128_SN':
131
+ netD = arch.Discriminator_VGG_128_SN()
132
+ else:
133
+ raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
134
+
135
+ init_weights(netD, init_type='kaiming', scale=1)
136
+ if gpu_ids:
137
+ netD = nn.DataParallel(netD)
138
+ return netD
139
+
140
+
141
+ def define_F(opt, use_bn=False):
142
+ gpu_ids = opt['gpu_ids']
143
+ device = torch.device('cuda' if gpu_ids else 'cpu')
144
+ # pytorch pretrained VGG19-54, before ReLU.
145
+ if use_bn:
146
+ feature_layer = 49
147
+ else:
148
+ feature_layer = 34
149
+ netF = arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, \
150
+ use_input_norm=True, device=device)
151
+ # netF = arch.ResNet101FeatureExtractor(use_input_norm=True, device=device)
152
+ if gpu_ids:
153
+ netF = nn.DataParallel(netF)
154
+ netF.eval() # No need to train
155
+ return netF
esrgan_plus/codes/options/options.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import logging
4
+ from collections import OrderedDict
5
+ import json
6
+
7
+
8
+ def parse(opt_path, is_train=True):
9
+ # remove comments starting with '//'
10
+ json_str = ''
11
+ with open(opt_path, 'r') as f:
12
+ for line in f:
13
+ line = line.split('//')[0] + '\n'
14
+ json_str += line
15
+ opt = json.loads(json_str, object_pairs_hook=OrderedDict)
16
+
17
+ opt['is_train'] = is_train
18
+ scale = opt['scale']
19
+
20
+ # datasets
21
+ for phase, dataset in opt['datasets'].items():
22
+ phase = phase.split('_')[0]
23
+ dataset['phase'] = phase
24
+ dataset['scale'] = scale
25
+ is_lmdb = False
26
+ if 'dataroot_HR' in dataset and dataset['dataroot_HR'] is not None:
27
+ dataset['dataroot_HR'] = os.path.expanduser(dataset['dataroot_HR'])
28
+ if dataset['dataroot_HR'].endswith('lmdb'):
29
+ is_lmdb = True
30
+ if 'dataroot_HR_bg' in dataset and dataset['dataroot_HR_bg'] is not None:
31
+ dataset['dataroot_HR_bg'] = os.path.expanduser(dataset['dataroot_HR_bg'])
32
+ if 'dataroot_LR' in dataset and dataset['dataroot_LR'] is not None:
33
+ dataset['dataroot_LR'] = os.path.expanduser(dataset['dataroot_LR'])
34
+ if dataset['dataroot_LR'].endswith('lmdb'):
35
+ is_lmdb = True
36
+ dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
37
+
38
+ if phase == 'train' and 'subset_file' in dataset and dataset['subset_file'] is not None:
39
+ dataset['subset_file'] = os.path.expanduser(dataset['subset_file'])
40
+
41
+ # path
42
+ for key, path in opt['path'].items():
43
+ if path and key in opt['path']:
44
+ opt['path'][key] = os.path.expanduser(path)
45
+ if is_train:
46
+ experiments_root = os.path.join(opt['path']['root'], 'experiments', opt['name'])
47
+ opt['path']['experiments_root'] = experiments_root
48
+ opt['path']['models'] = os.path.join(experiments_root, 'models')
49
+ opt['path']['training_state'] = os.path.join(experiments_root, 'training_state')
50
+ opt['path']['log'] = experiments_root
51
+ opt['path']['val_images'] = os.path.join(experiments_root, 'val_images')
52
+
53
+ # change some options for debug mode
54
+ if 'debug' in opt['name']:
55
+ opt['train']['val_freq'] = 8
56
+ opt['logger']['print_freq'] = 2
57
+ opt['logger']['save_checkpoint_freq'] = 8
58
+ opt['train']['lr_decay_iter'] = 10
59
+ else: # test
60
+ results_root = os.path.join(opt['path']['root'], 'results', opt['name'])
61
+ opt['path']['results_root'] = results_root
62
+ opt['path']['log'] = results_root
63
+
64
+ # network
65
+ opt['network_G']['scale'] = scale
66
+
67
+ # export CUDA_VISIBLE_DEVICES
68
+ gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
69
+ os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
70
+ print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
71
+
72
+ return opt
73
+
74
+
75
+ class NoneDict(dict):
76
+ def __missing__(self, key):
77
+ return None
78
+
79
+
80
+ # convert to NoneDict, which return None for missing key.
81
+ def dict_to_nonedict(opt):
82
+ if isinstance(opt, dict):
83
+ new_opt = dict()
84
+ for key, sub_opt in opt.items():
85
+ new_opt[key] = dict_to_nonedict(sub_opt)
86
+ return NoneDict(**new_opt)
87
+ elif isinstance(opt, list):
88
+ return [dict_to_nonedict(sub_opt) for sub_opt in opt]
89
+ else:
90
+ return opt
91
+
92
+
93
+ def dict2str(opt, indent_l=1):
94
+ '''dict to string for logger'''
95
+ msg = ''
96
+ for k, v in opt.items():
97
+ if isinstance(v, dict):
98
+ msg += ' ' * (indent_l * 2) + k + ':[\n'
99
+ msg += dict2str(v, indent_l + 1)
100
+ msg += ' ' * (indent_l * 2) + ']\n'
101
+ else:
102
+ msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
103
+ return msg
104
+
105
+
106
+ def check_resume(opt):
107
+ '''Check resume states and pretrain_model paths'''
108
+ logger = logging.getLogger('base')
109
+ if opt['path']['resume_state']:
110
+ if opt['path']['pretrain_model_G'] or opt['path']['pretrain_model_D']:
111
+ logger.warning('pretrain_model path will be ignored when resuming training.')
112
+
113
+ state_idx = osp.basename(opt['path']['resume_state']).split('.')[0]
114
+ opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
115
+ '{}_G.pth'.format(state_idx))
116
+ logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
117
+ if 'gan' in opt['model']:
118
+ opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
119
+ '{}_D.pth'.format(state_idx))
120
+ logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
esrgan_plus/codes/options/test/test_ESRGANplus.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "nESRGAN+_x4"
3
+ , "suffix": "_ESRGAN"
4
+ , "model": "srragan"
5
+ , "scale": 4
6
+ , "gpu_ids": [0]
7
+
8
+ , "datasets": {
9
+ "test_1": { // the 1st test dataset
10
+ "name": "set5"
11
+ , "mode": "LRHR"
12
+ , "dataroot_HR": "/home/carraz/datasets/val_set5/Set5"
13
+ , "dataroot_LR": "/home/carraz/datasets/val_set5/Set5_bicLRx4"
14
+ }
15
+ , "test_2": { // the 2nd test dataset
16
+ "name": "set14"
17
+ , "mode": "LRHR"
18
+ , "dataroot_HR": "/home/carraz/datasets/val_set14/Set14"
19
+ , "dataroot_LR": "/home/carraz/datasets/val_set14/Set14_bicLRx4"
20
+ }
21
+ }
22
+
23
+ , "path": {
24
+ "root": "/home/carraz/nESRGANplus"
25
+ , "pretrain_model_G": "../experiments/pretrained_models/RRDB_ESRGAN_x4.pth"
26
+ }
27
+
28
+ , "network_G": {
29
+ "which_model_G": "RRDB_net" // RRDB_net | sr_resnet
30
+ , "norm_type": null
31
+ , "mode": "CNA"
32
+ , "nf": 64
33
+ , "nb": 23
34
+ , "in_nc": 3
35
+ , "out_nc": 3
36
+
37
+ , "gc": 32
38
+ , "group": 1
39
+ }
40
+ }
esrgan_plus/codes/options/test/test_SRGAN.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "SRGAN"
3
+ , "suffix": "_SRGAN"
4
+ , "model": "srgan"
5
+ , "scale": 4
6
+ , "gpu_ids": [0]
7
+
8
+ , "datasets": {
9
+ "test_1": { // the 1st test dataset
10
+ "name": "set5"
11
+ , "mode": "LRHR"
12
+ , "dataroot_HR": "/home/carraz/datasets/val_set5/Set5"
13
+ , "dataroot_LR": "/home/carraz/datasets/val_set5/Set5_bicLRx4"
14
+ }
15
+ , "test_2": { // the 2nd test dataset
16
+ "name": "set14"
17
+ , "mode": "LRHR"
18
+ , "dataroot_HR": "/home/carraz/datasets/val_set14/Set14"
19
+ , "dataroot_LR": "/home/carraz/datasets/val_set14/Set14_bicLRx4"
20
+ }
21
+ }
22
+
23
+ , "path": {
24
+ "root": "/home/carraz/nESRGANplus"
25
+ , "pretrain_model_G": "../experiments/pretrained_models/SRGAN_bicx4_303_505.pth"
26
+ }
27
+
28
+ , "network_G": {
29
+ "which_model_G": "sr_resnet"
30
+ , "norm_type": null
31
+ , "mode": "CNA"
32
+ , "nf": 64
33
+ , "nb": 16
34
+ , "in_nc": 3
35
+ , "out_nc": 3
36
+ }
37
+ }
esrgan_plus/codes/options/test/test_SRResNet.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "SRResNet_bicx4_in3nf64nb16"
3
+ , "suffix": null
4
+ , "model": "sr"
5
+ , "scale": 4
6
+ , "gpu_ids": [0]
7
+
8
+ , "datasets": {
9
+ "test_1": { // the 1st test dataset
10
+ "name": "set5"
11
+ , "mode": "LRHR"
12
+ , "dataroot_HR": "/home/carraz/datasets/val_set5/Set5"
13
+ , "dataroot_LR": "/home/carraz/datasets/val_set5/Set5_bicLRx4"
14
+ }
15
+ , "test_2": { // the 2nd test dataset
16
+ "name": "set14"
17
+ , "mode": "LRHR"
18
+ , "dataroot_HR": "/home/carraz/datasets/val_set14/Set14"
19
+ , "dataroot_LR": "/home/carraz/datasets/val_set14/Set14_bicLRx4"
20
+ }
21
+ }
22
+
23
+ , "path": {
24
+ "root": "/home/carraz/nESRGANplus"
25
+ , "pretrain_model_G": "../experiments/pretrained_models/SRResNet_bicx4_in3nf64nb16.pth"
26
+ }
27
+
28
+ , "network_G": {
29
+ "which_model_G": "sr_resnet" // RRDB_net | sr_resnet
30
+ , "norm_type": null
31
+ , "mode": "CNA"
32
+ , "nf": 64
33
+ , "nb": 16
34
+ , "in_nc": 3
35
+ , "out_nc": 3
36
+
37
+ , "gc": 32
38
+ , "group": 1
39
+ }
40
+ }
esrgan_plus/codes/options/test/test_sr.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "RRDB_PSNR_x4"
3
+ , "suffix": null
4
+ , "model": "sr"
5
+ , "scale": 4
6
+ , "gpu_ids": [0]
7
+
8
+ , "datasets": {
9
+ "test_1": { // the 1st test dataset
10
+ "name": "set5"
11
+ , "mode": "LRHR"
12
+ , "dataroot_HR": "/home/carraz/datasets/val_set5/Set5"
13
+ , "dataroot_LR": "/home/carraz/datasets/val_set5/Set5_bicLRx4"
14
+ }
15
+ , "test_2": { // the 2nd test dataset
16
+ "name": "set14"
17
+ , "mode": "LRHR"
18
+ , "dataroot_HR": "/home/carraz/datasets/val_set14/Set14"
19
+ , "dataroot_LR": "/home/carraz/datasets/val_set14/Set14_bicLRx4"
20
+ }
21
+ }
22
+
23
+ , "path": {
24
+ "root": "/home/carraz/nESRGANplus"
25
+ , "pretrain_model_G": "../experiments/pretrained_models/RRDB_PSNR_x4.pth"
26
+ }
27
+
28
+ , "network_G": {
29
+ "which_model_G": "RRDB_net" // RRDB_net | sr_resnet
30
+ , "norm_type": null
31
+ , "mode": "CNA"
32
+ , "nf": 64
33
+ , "nb": 23
34
+ , "in_nc": 3
35
+ , "out_nc": 3
36
+
37
+ , "gc": 32
38
+ , "group": 1
39
+ }
40
+ }
esrgan_plus/codes/options/train/train_ESRGANplus.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "nESRGANplus_x4_DIV2K",
3
+ "use_tb_logger": true,
4
+ "model": "srragan",
5
+ "scale": 4,
6
+ "gpu_ids": [
7
+ 0
8
+ ],
9
+ "datasets": {
10
+ "train": {
11
+ "name": "DIV2K",
12
+ "mode": "LRHR",
13
+ "dataroot_HR": "/content/gdrive/My Drive/DIV2K_train_HR_sub",
14
+ "dataroot_LR": "/content/gdrive/My Drive/DIV2K_train_LR_sub_bicLRx4",
15
+ "subset_file": null,
16
+ "use_shuffle": true,
17
+ "n_workers": 8,
18
+ "batch_size": 16,
19
+ "HR_size": 128,
20
+ "use_flip": true,
21
+ "use_rot": true
22
+ },
23
+ "val": {
24
+ "name": "val_set14_part",
25
+ "mode": "LRHR",
26
+ "dataroot_HR": "/content/gdrive/My Drive/ESRGAN/Set14",
27
+ "dataroot_LR": "/content/gdrive/My Drive/ESRGAN/Set14_LR_sub_bicLRx4"
28
+ }
29
+ },
30
+ "path": {
31
+ "root": "/content/gdrive/My Drive/ESRGAN/BasicSR",
32
+ "resume_state": "/content/gdrive/My Drive/ESRGAN/BasicSR/experiments/002_RRDB_ESRGAN_x4_DIV2K/training_state/495000.state",
33
+ "pretrain_model_G": "/content/gdrive/My Drive/ESRGAN/RRDB_PSNR_x4.pth"
34
+ },
35
+ "network_G": {
36
+ "which_model_G": "RRDB_net",
37
+ "norm_type": null,
38
+ "mode": "CNA",
39
+ "nf": 64,
40
+ "nb": 23,
41
+ "in_nc": 3,
42
+ "out_nc": 3,
43
+ "gc": 32,
44
+ "group": 1
45
+ },
46
+ "network_D": {
47
+ "which_model_D": "discriminator_vgg_128",
48
+ "norm_type": "batch",
49
+ "act_type": "leakyrelu",
50
+ "mode": "CNA",
51
+ "nf": 64,
52
+ "in_nc": 3
53
+ },
54
+ "train": {
55
+ "lr_G": 0.0001,
56
+ "weight_decay_G": 0,
57
+ "beta1_G": 0.9,
58
+ "lr_D": 0.0001,
59
+ "weight_decay_D": 0,
60
+ "beta1_D": 0.9,
61
+ "lr_scheme": "MultiStepLR",
62
+ "lr_steps": [
63
+ 50000,
64
+ 100000,
65
+ 200000,
66
+ 300000
67
+ ],
68
+ "lr_gamma": 0.5,
69
+ "pixel_criterion": "l1",
70
+ "pixel_weight": 0.01,
71
+ "feature_criterion": "l1",
72
+ "feature_weight": 1,
73
+ "gan_type": "vanilla",
74
+ "gan_weight": 0.005,
75
+ "manual_seed": 0,
76
+ "niter": 500000.0,
77
+ "val_freq": 500.0
78
+ },
79
+ "logger": {
80
+ "print_freq": 50,
81
+ "save_checkpoint_freq": 500.0
82
+ }
83
+ }
esrgan_plus/codes/options/train/train_SRGAN.json ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Not total the same as SRGAN in <Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network>
2
+ {
3
+ "name": "debug_002_SRGAN_x4_DIV2K" // please remove "debug_" during training
4
+ , "use_tb_logger": true
5
+ , "model":"srgan"
6
+ , "scale": 4
7
+ , "gpu_ids": [0]
8
+
9
+ , "datasets": {
10
+ "train": {
11
+ "name": "DIV2K"
12
+ , "mode": "LRHR"
13
+ , "dataroot_HR": "/home/carraz/datasets/DIV2K800/DIV2K800_sub.lmdb"
14
+ , "dataroot_LR": "/home/carraz/datasets/DIV2K800/DIV2K800_sub_bicLRx4.lmdb"
15
+ , "subset_file": null
16
+ , "use_shuffle": true
17
+ , "n_workers": 8
18
+ , "batch_size": 16
19
+ , "HR_size": 128
20
+ , "use_flip": true
21
+ , "use_rot": true
22
+ }
23
+ , "val": {
24
+ "name": "val_set14_part"
25
+ , "mode": "LRHR"
26
+ , "dataroot_HR": "/home/carraz/datasets/val_set14_part/Set14"
27
+ , "dataroot_LR": "/home/carraz/datasets/val_set14_part/Set14_bicLRx4"
28
+ }
29
+ }
30
+
31
+ , "path": {
32
+ "root": "/home/carraz/nESRGANplus"
33
+ // , "resume_state": "../experiments/debug_002_SRGAN_x4_DIV2K/training_state/16.state"
34
+ , "pretrain_model_G": "../experiments/pretrained_models/SRResNet_bicx4_in3nf64nb16.pth"
35
+ }
36
+
37
+ , "network_G": {
38
+ "which_model_G": "sr_resnet" // RRDB_net | sr_resnet
39
+ , "norm_type": null
40
+ , "mode": "CNA"
41
+ , "nf": 64
42
+ , "nb": 16
43
+ , "in_nc": 3
44
+ , "out_nc": 3
45
+ }
46
+ , "network_D": {
47
+ "which_model_D": "discriminator_vgg_128"
48
+ , "norm_type": "batch"
49
+ , "act_type": "leakyrelu"
50
+ , "mode": "CNA"
51
+ , "nf": 64
52
+ , "in_nc": 3
53
+ }
54
+
55
+ , "train": {
56
+ "lr_G": 1e-4
57
+ , "weight_decay_G": 0
58
+ , "beta1_G": 0.9
59
+ , "lr_D": 1e-4
60
+ , "weight_decay_D": 0
61
+ , "beta1_D": 0.9
62
+ , "lr_scheme": "MultiStepLR"
63
+ , "lr_steps": [50000, 100000, 200000, 300000]
64
+ , "lr_gamma": 0.5
65
+
66
+ , "pixel_criterion": "l1"
67
+ , "pixel_weight": 1e-2
68
+ , "feature_criterion": "l1"
69
+ , "feature_weight": 1
70
+ , "gan_type": "vanilla"
71
+ , "gan_weight": 5e-3
72
+
73
+ //for wgan-gp
74
+ // , "D_update_ratio": 1
75
+ // , "D_init_iters": 0
76
+ // , "gp_weigth": 10
77
+
78
+ , "manual_seed": 0
79
+ , "niter": 5e5
80
+ , "val_freq": 5e3
81
+ }
82
+
83
+ , "logger": {
84
+ "print_freq": 200
85
+ , "save_checkpoint_freq": 5e3
86
+ }
87
+ }
esrgan_plus/codes/options/train/train_SRResNet.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Not total the same as SRResNet in <Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network>
2
+ // With 16 Residual blocks w/o BN
3
+ {
4
+ "name": "debug_001_SRResNet_PSNR_x4_DIV2K" // please remove "debug_" during training
5
+ , "use_tb_logger": true
6
+ , "model":"sr"
7
+ , "scale": 4
8
+ , "gpu_ids": [0]
9
+
10
+ , "datasets": {
11
+ "train": {
12
+ "name": "DIV2K"
13
+ , "mode": "LRHR"
14
+ , "dataroot_HR": "/home/carraz/datasets/DIV2K800/DIV2K800_sub.lmdb"
15
+ , "dataroot_LR": "/home/carraz/datasets/DIV2K800/DIV2K800_sub_bicLRx4.lmdb"
16
+ , "subset_file": null
17
+ , "use_shuffle": true
18
+ , "n_workers": 8
19
+ , "batch_size": 16
20
+ , "HR_size": 128 // 128 | 192
21
+ , "use_flip": true
22
+ , "use_rot": true
23
+ }
24
+ , "val": {
25
+ "name": "val_set5"
26
+ , "mode": "LRHR"
27
+ , "dataroot_HR": "/home/carraz/datasets/val_set5/Set5"
28
+ , "dataroot_LR": "/home/carraz/datasets/val_set5/Set5_bicLRx4"
29
+ }
30
+ }
31
+
32
+ , "path": {
33
+ "root": "/home/carraz/nESRGANplus"
34
+ // , "resume_state": "../experiments/debug_001_RRDB_PSNR_x4_DIV2K/training_state/200.state"
35
+ , "pretrain_model_G": null
36
+ }
37
+
38
+ , "network_G": {
39
+ "which_model_G": "sr_resnet" // RRDB_net | sr_resnet
40
+ , "norm_type": null
41
+ , "mode": "CNA"
42
+ , "nf": 64
43
+ , "nb": 16
44
+ , "in_nc": 3
45
+ , "out_nc": 3
46
+ }
47
+
48
+ , "train": {
49
+ "lr_G": 2e-4
50
+ , "lr_scheme": "MultiStepLR"
51
+ , "lr_steps": [200000, 400000, 600000, 800000]
52
+ , "lr_gamma": 0.5
53
+
54
+ , "pixel_criterion": "l1"
55
+ , "pixel_weight": 1.0
56
+ , "val_freq": 5e3
57
+
58
+ , "manual_seed": 0
59
+ , "niter": 1e6
60
+ }
61
+
62
+ , "logger": {
63
+ "print_freq": 200
64
+ , "save_checkpoint_freq": 5e3
65
+ }
66
+ }
esrgan_plus/codes/options/train/train_sftgan.json ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "debug_003_SFTGANx4_OST" // please remove "debug_" during training
3
+ , "use_tb_logger": false
4
+ , "model": "sftgan"
5
+ , "scale": 4
6
+ , "gpu_ids": [0]
7
+
8
+ , "datasets": {
9
+ "train": {
10
+ "name": "OST"
11
+ , "mode": "LRHRseg_bg"
12
+ , "dataroot_HR": "/home/carraz/datasets/OST/train/img"
13
+ , "dataroot_HR_bg": "/home/carraz/datasets/DIV2K800/DIV2K800_sub"
14
+ , "dataroot_LR": null
15
+ , "subset_file": null
16
+ , "use_shuffle": true
17
+ , "n_workers": 8
18
+ , "batch_size": 16
19
+ , "HR_size": 96
20
+ , "use_flip": true
21
+ , "use_rot": false
22
+ }
23
+ , "val": {
24
+ "name": "val_OST300_part"
25
+ , "mode": "LRHRseg_bg"
26
+ , "dataroot_HR": "/home/carraz/datasets/OST/val/img"
27
+ , "dataroot_LR": null
28
+ }
29
+ }
30
+
31
+ , "path": {
32
+ "root": "/home/carraz/nESRGANplus"
33
+ , "resume_state": null
34
+ , "pretrain_model_G": "../experiments/pretrained_models/sft_net_ini.pth"
35
+ }
36
+
37
+ , "network_G": {
38
+ "which_model_G": "sft_arch"
39
+ }
40
+ , "network_D": {
41
+ "which_model_D": "dis_acd"
42
+ }
43
+
44
+ , "train": {
45
+ "lr_G": 1e-4
46
+ , "weight_decay_G": 0
47
+ , "beta1_G": 0.9
48
+ , "lr_D": 1e-4
49
+ , "weight_decay_D": 0
50
+ , "beta1_D": 0.9
51
+ , "lr_scheme": "MultiStepLR"
52
+ , "lr_steps": [50000, 100000, 150000, 200000]
53
+ , "lr_gamma": 0.5
54
+
55
+ , "pixel_criterion": "l1"
56
+ , "pixel_weight": 0
57
+ , "feature_criterion": "l1"
58
+ , "feature_weight": 1
59
+ , "gan_type": "vanilla"
60
+ , "gan_weight": 5e-3
61
+
62
+ //for wgan-gp
63
+ // , "D_update_ratio": 1
64
+ // , "D_init_iters": 0
65
+ // , "gp_weigth": 10
66
+
67
+ , "manual_seed": 0
68
+ , "niter": 6e5
69
+ , "val_freq": 2e3
70
+ }
71
+
72
+ , "logger": {
73
+ "print_freq": 200
74
+ , "save_checkpoint_freq": 2e3
75
+ }
76
+ }
esrgan_plus/codes/options/train/train_sr.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "debug_001_RRDB_PSNR_x4_DIV2K" // please remove "debug_" during training
3
+ , "use_tb_logger": true
4
+ , "model":"sr"
5
+ , "scale": 4
6
+ , "gpu_ids": [0]
7
+
8
+ , "datasets": {
9
+ "train": {
10
+ "name": "DIV2K"
11
+ , "mode": "LRHR"
12
+ , "dataroot_HR": "/home/carraz/datasets/DIV2K800/DIV2K800_sub.lmdb"
13
+ , "dataroot_LR": "/home/carraz/datasets/DIV2K800/DIV2K800_sub_bicLRx4.lmdb"
14
+ , "subset_file": null
15
+ , "use_shuffle": true
16
+ , "n_workers": 8
17
+ , "batch_size": 16
18
+ , "HR_size": 128 // 128 | 192
19
+ , "use_flip": true
20
+ , "use_rot": true
21
+ }
22
+ , "val": {
23
+ "name": "val_set5"
24
+ , "mode": "LRHR"
25
+ , "dataroot_HR": "/home/carraz/datasets/val_set5/Set5"
26
+ , "dataroot_LR": "/home/carraz/datasets/val_set5/Set5_bicLRx4"
27
+ }
28
+ }
29
+
30
+ , "path": {
31
+ "root": "/home/carraz/nESRGANplus"
32
+ // , "resume_state": "../experiments/debug_001_RRDB_PSNR_x4_DIV2K/training_state/200.state"
33
+ , "pretrain_model_G": null
34
+ }
35
+
36
+ , "network_G": {
37
+ "which_model_G": "RRDB_net" // RRDB_net | sr_resnet
38
+ , "norm_type": null
39
+ , "mode": "CNA"
40
+ , "nf": 64
41
+ , "nb": 23
42
+ , "in_nc": 3
43
+ , "out_nc": 3
44
+ , "gc": 32
45
+ , "group": 1
46
+ }
47
+
48
+ , "train": {
49
+ "lr_G": 2e-4
50
+ , "lr_scheme": "MultiStepLR"
51
+ , "lr_steps": [200000, 400000, 600000, 800000]
52
+ , "lr_gamma": 0.5
53
+
54
+ , "pixel_criterion": "l1"
55
+ , "pixel_weight": 1.0
56
+ , "val_freq": 5e3
57
+
58
+ , "manual_seed": 0
59
+ , "niter": 1e6
60
+ }
61
+
62
+ , "logger": {
63
+ "print_freq": 200
64
+ , "save_checkpoint_freq": 5e3
65
+ }
66
+ }
esrgan_plus/codes/scripts/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Scripts
2
+ We provide some useful scripts here.
3
+
4
+ ## List
5
+
6
+ | Name | Description |
7
+ |:---:|:---:|
8
+ | back projection | `Matlab` codes for back projection |
esrgan_plus/codes/scripts/back_projection/backprojection.m ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function [im_h] = backprojection(im_h, im_l, maxIter)
2
+
3
+ [row_l, col_l,~] = size(im_l);
4
+ [row_h, col_h,~] = size(im_h);
5
+
6
+ p = fspecial('gaussian', 5, 1);
7
+ p = p.^2;
8
+ p = p./sum(p(:));
9
+
10
+ im_l = double(im_l);
11
+ im_h = double(im_h);
12
+
13
+ for ii = 1:maxIter
14
+ im_l_s = imresize(im_h, [row_l, col_l], 'bicubic');
15
+ im_diff = im_l - im_l_s;
16
+ im_diff = imresize(im_diff, [row_h, col_h], 'bicubic');
17
+ im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same');
18
+ im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same');
19
+ im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same');
20
+ end
esrgan_plus/codes/scripts/back_projection/main_bp.m ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clear; close all; clc;
2
+
3
+ LR_folder = './LR'; % LR
4
+ preout_folder = './results'; % pre output
5
+ save_folder = './results_20bp';
6
+ filepaths = dir(fullfile(preout_folder, '*.png'));
7
+ max_iter = 20;
8
+
9
+ if ~ exist(save_folder, 'dir')
10
+ mkdir(save_folder);
11
+ end
12
+
13
+ for idx_im = 1:length(filepaths)
14
+ fprintf([num2str(idx_im) '\n']);
15
+ im_name = filepaths(idx_im).name;
16
+ im_LR = im2double(imread(fullfile(LR_folder, im_name)));
17
+ im_out = im2double(imread(fullfile(preout_folder, im_name)));
18
+ %tic
19
+ im_out = backprojection(im_out, im_LR, max_iter);
20
+ %toc
21
+ imwrite(im_out, fullfile(save_folder, im_name));
22
+ end
esrgan_plus/codes/scripts/back_projection/main_reverse_filter.m ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clear; close all; clc;
2
+
3
+ LR_folder = './LR'; % LR
4
+ preout_folder = './results'; % pre output
5
+ save_folder = './results_20if';
6
+ filepaths = dir(fullfile(preout_folder, '*.png'));
7
+ max_iter = 20;
8
+
9
+ if ~ exist(save_folder, 'dir')
10
+ mkdir(save_folder);
11
+ end
12
+
13
+ for idx_im = 1:length(filepaths)
14
+ fprintf([num2str(idx_im) '\n']);
15
+ im_name = filepaths(idx_im).name;
16
+ im_LR = im2double(imread(fullfile(LR_folder, im_name)));
17
+ im_out = im2double(imread(fullfile(preout_folder, im_name)));
18
+ J = imresize(im_LR,4,'bicubic');
19
+ %tic
20
+ for m = 1:max_iter
21
+ im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic'));
22
+ end
23
+ %toc
24
+ imwrite(im_out, fullfile(save_folder, im_name));
25
+ end
esrgan_plus/codes/scripts/color2gray.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path
3
+ import sys
4
+ from multiprocessing import Pool
5
+ import cv2
6
+
7
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
+ from data.util import bgr2ycbcr
9
+ from utils.progress_bar import ProgressBar
10
+
11
+
12
+ def main():
13
+ """A multi-thread tool for converting RGB images to gary/Y images."""
14
+
15
+ input_folder = '/home/carraz/datasets/DIV2K800/DIV2K800'
16
+ save_folder = '/home/carraz/datasets/DIV2K800/DIV2K800_gray'
17
+ mode = 'gray' # 'gray' | 'y': Y channel in YCbCr space
18
+ compression_level = 3 # 3 is the default value in cv2
19
+ # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
20
+ # compression time. If read raw images during training, use 0 for faster IO speed.
21
+ n_thread = 20 # thread number
22
+
23
+ if not os.path.exists(save_folder):
24
+ os.makedirs(save_folder)
25
+ print('mkdir [{:s}] ...'.format(save_folder))
26
+ else:
27
+ print('Folder [{:s}] already exists. Exit...'.format(save_folder))
28
+ sys.exit(1)
29
+ # print('Parent process {:d}.'.format(os.getpid()))
30
+
31
+ img_list = []
32
+ for root, _, file_list in sorted(os.walk(input_folder)):
33
+ path = [os.path.join(root, x) for x in file_list] # assume only images in the input_folder
34
+ img_list.extend(path)
35
+
36
+ def update(arg):
37
+ pbar.update(arg)
38
+
39
+ pbar = ProgressBar(len(img_list))
40
+
41
+ pool = Pool(n_thread)
42
+ for path in img_list:
43
+ pool.apply_async(worker, args=(path, save_folder, mode, compression_level), callback=update)
44
+ pool.close()
45
+ pool.join()
46
+ print('All subprocesses done.')
47
+
48
+
49
+ def worker(path, save_folder, mode, compression_level):
50
+ img_name = os.path.basename(path)
51
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR
52
+ if mode == 'gray':
53
+ img_y = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
54
+ else:
55
+ img_y = bgr2ycbcr(img, only_y=True)
56
+ cv2.imwrite(
57
+ os.path.join(save_folder, img_name), img_y,
58
+ [cv2.IMWRITE_PNG_COMPRESSION, compression_level])
59
+ return 'Processing {:s} ...'.format(img_name)
60
+
61
+
62
+ if __name__ == '__main__':
63
+ main()
esrgan_plus/codes/scripts/create_lmdb.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os.path
3
+ import glob
4
+ import pickle
5
+ import lmdb
6
+ import cv2
7
+
8
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9
+ from utils.progress_bar import ProgressBar
10
+
11
+ # configurations
12
+ img_folder = '/home/carraz/datasets/DIV2K800/DIV2K800/*' # glob matching pattern
13
+ lmdb_save_path = '/home/carraz/datasets/DIV2K800/DIV2K800.lmdb' # must end with .lmdb
14
+ mode = 1 # 1 for small data (more memory), 2 for large data (less memory)
15
+
16
+ img_list = sorted(glob.glob(img_folder))
17
+
18
+ print('Read images...')
19
+ # mode 1 small data, read all imgs
20
+ if mode == 1:
21
+ dataset = [cv2.imread(v, cv2.IMREAD_UNCHANGED) for v in img_list]
22
+ data_size = sum([img.nbytes for img in dataset])
23
+ # mode 2 large data, read imgs later
24
+ elif mode == 2:
25
+ data_size = sum(os.stat(v).st_size for v in img_list)
26
+ else:
27
+ raise ValueError('mode should be 1 or 2')
28
+
29
+ env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
30
+ print('Finish reading {} images.\nWrite lmdb...'.format(len(img_list)))
31
+
32
+ pbar = ProgressBar(len(img_list))
33
+ batch = 3000 # can be modified according to memory usage
34
+ txn = env.begin(write=True) # txn is a Transaction object
35
+ for i, v in enumerate(img_list):
36
+ pbar.update('Write {}'.format(v))
37
+ base_name = os.path.splitext(os.path.basename(v))[0]
38
+ key = base_name.encode('ascii')
39
+ data = dataset[i] if mode == 1 else cv2.imread(v, cv2.IMREAD_UNCHANGED)
40
+ if data.ndim == 2:
41
+ H, W = data.shape
42
+ C = 1
43
+ else:
44
+ H, W, C = data.shape
45
+ meta_key = (base_name + '.meta').encode('ascii')
46
+ meta = '{:d}, {:d}, {:d}'.format(H, W, C)
47
+ # The encode is only essential in Python 3
48
+ txn.put(key, data)
49
+ txn.put(meta_key, meta.encode('ascii'))
50
+ if mode == 2 and i % batch == batch - 1:
51
+ txn.commit()
52
+ txn = env.begin(write=True)
53
+
54
+ txn.commit()
55
+ env.close()
56
+
57
+ print('Finish writing lmdb.')
58
+
59
+ # create keys cache
60
+ keys_cache_file = os.path.join(lmdb_save_path, '_keys_cache.p')
61
+ env = lmdb.open(lmdb_save_path, readonly=True, lock=False, readahead=False, meminit=False)
62
+ with env.begin(write=False) as txn:
63
+ print('Create lmdb keys cache: {}'.format(keys_cache_file))
64
+ keys = [key.decode('ascii') for key, _ in txn.cursor()]
65
+ pickle.dump(keys, open(keys_cache_file, "wb"))
66
+ print('Finish creating lmdb keys cache.')
esrgan_plus/codes/scripts/extract_enlarge_patches.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import glob
3
+ import cv2
4
+
5
+ crt_path = os.path.dirname(os.path.realpath(__file__))
6
+
7
+ # configurations
8
+ h_start, h_len = 170, 64
9
+ w_start, w_len = 232, 100
10
+ enlarge_ratio = 3
11
+ line_width = 2
12
+ color = 'yellow'
13
+
14
+ folder = os.path.join(crt_path, './ori/*')
15
+ save_patch_folder = os.path.join(crt_path, './patch')
16
+ save_rect_folder = os.path.join(crt_path, './rect')
17
+
18
+ color_tb = {}
19
+ color_tb['yellow'] = (0, 255, 255)
20
+ color_tb['green'] = (0, 255, 0)
21
+ color_tb['red'] = (0, 0, 255)
22
+ color_tb['magenta'] = (255, 0, 255)
23
+ color_tb['matlab_blue'] = (189, 114, 0)
24
+ color_tb['matlab_orange'] = (25, 83, 217)
25
+ color_tb['matlab_yellow'] = (32, 177, 237)
26
+ color_tb['matlab_purple'] = (142, 47, 126)
27
+ color_tb['matlab_green'] = (48, 172, 119)
28
+ color_tb['matlab_liblue'] = (238, 190, 77)
29
+ color_tb['matlab_brown'] = (47, 20, 162)
30
+ color = color_tb[color]
31
+ img_list = glob.glob(folder)
32
+ images = []
33
+
34
+ # make temp folder
35
+ if not os.path.exists(save_patch_folder):
36
+ os.makedirs(save_patch_folder)
37
+ print('mkdir [{}] ...'.format(save_patch_folder))
38
+ if not os.path.exists(save_rect_folder):
39
+ os.makedirs(save_rect_folder)
40
+ print('mkdir [{}] ...'.format(save_rect_folder))
41
+
42
+ for i, path in enumerate(img_list):
43
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
44
+ base_name = os.path.splitext(os.path.basename(path))[0]
45
+ print(i, base_name)
46
+ # crop patch
47
+ if img.ndim == 2:
48
+ patch = img[h_start:h_start + h_len, w_start:w_start + w_len]
49
+ elif img.ndim == 3:
50
+ patch = img[h_start:h_start + h_len, w_start:w_start + w_len, :]
51
+ else:
52
+ raise ValueError('Wrong image dim [{:d}]'.format(img.ndim))
53
+
54
+ # enlarge patch if necessary
55
+ if enlarge_ratio > 1:
56
+ H, W, _ = patch.shape
57
+ patch = cv2.resize(patch, (W * enlarge_ratio, H * enlarge_ratio), \
58
+ interpolation=cv2.INTER_CUBIC)
59
+ cv2.imwrite(os.path.join(save_patch_folder, base_name + '_patch.png'), patch)
60
+
61
+ # draw rectangle
62
+ img_rect = cv2.rectangle(img, (w_start, h_start), (w_start + w_len, h_start + h_len),
63
+ color, line_width)
64
+ cv2.imwrite(os.path.join(save_rect_folder, base_name + '_rect.png'), img_rect)
esrgan_plus/codes/scripts/extract_subimgs_single.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path
3
+ import sys
4
+ from multiprocessing import Pool
5
+ import numpy as np
6
+ import cv2
7
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8
+ from utils.progress_bar import ProgressBar
9
+
10
+
11
+ def main():
12
+ """A multi-thread tool to crop sub imags."""
13
+ input_folder = '/home/carraz/datasets/DIV2K800/DIV2K800'
14
+ save_folder = '/home/carraz/datasets/DIV2K800/DIV2K800_sub'
15
+ n_thread = 20
16
+ crop_sz = 480
17
+ step = 240
18
+ thres_sz = 48
19
+ compression_level = 3 # 3 is the default value in cv2
20
+ # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
21
+ # compression time. If read raw images during training, use 0 for faster IO speed.
22
+
23
+ if not os.path.exists(save_folder):
24
+ os.makedirs(save_folder)
25
+ print('mkdir [{:s}] ...'.format(save_folder))
26
+ else:
27
+ print('Folder [{:s}] already exists. Exit...'.format(save_folder))
28
+ sys.exit(1)
29
+
30
+ img_list = []
31
+ for root, _, file_list in sorted(os.walk(input_folder)):
32
+ path = [os.path.join(root, x) for x in file_list] # assume only images in the input_folder
33
+ img_list.extend(path)
34
+
35
+ def update(arg):
36
+ pbar.update(arg)
37
+
38
+ pbar = ProgressBar(len(img_list))
39
+
40
+ pool = Pool(n_thread)
41
+ for path in img_list:
42
+ pool.apply_async(worker,
43
+ args=(path, save_folder, crop_sz, step, thres_sz, compression_level),
44
+ callback=update)
45
+ pool.close()
46
+ pool.join()
47
+ print('All subprocesses done.')
48
+
49
+
50
+ def worker(path, save_folder, crop_sz, step, thres_sz, compression_level):
51
+ img_name = os.path.basename(path)
52
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
53
+
54
+ n_channels = len(img.shape)
55
+ if n_channels == 2:
56
+ h, w = img.shape
57
+ elif n_channels == 3:
58
+ h, w, c = img.shape
59
+ else:
60
+ raise ValueError('Wrong image shape - {}'.format(n_channels))
61
+
62
+ h_space = np.arange(0, h - crop_sz + 1, step)
63
+ if h - (h_space[-1] + crop_sz) > thres_sz:
64
+ h_space = np.append(h_space, h - crop_sz)
65
+ w_space = np.arange(0, w - crop_sz + 1, step)
66
+ if w - (w_space[-1] + crop_sz) > thres_sz:
67
+ w_space = np.append(w_space, w - crop_sz)
68
+
69
+ index = 0
70
+ for x in h_space:
71
+ for y in w_space:
72
+ index += 1
73
+ if n_channels == 2:
74
+ crop_img = img[x:x + crop_sz, y:y + crop_sz]
75
+ else:
76
+ crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
77
+ crop_img = np.ascontiguousarray(crop_img)
78
+ # var = np.var(crop_img / 255)
79
+ # if var > 0.008:
80
+ # print(img_name, index_str, var)
81
+ cv2.imwrite(
82
+ os.path.join(save_folder, img_name.replace('.png', '_s{:03d}.png'.format(index))),
83
+ crop_img, [cv2.IMWRITE_PNG_COMPRESSION, compression_level])
84
+ return 'Processing {:s} ...'.format(img_name)
85
+
86
+
87
+ if __name__ == '__main__':
88
+ main()
esrgan_plus/codes/scripts/generate_mod_LR_bic.m ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function generate_mod_LR_bic()
2
+ %% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images.
3
+
4
+ %% set parameters
5
+ % comment the unnecessary line
6
+ input_folder = '/home/carraz/datasets/DIV2K800/DIV2K800_sub';
7
+ % save_mod_folder = '';
8
+ save_LR_folder = '/home/carraz/datasets/DIV2K800/DIV2K800_sub_bicLRx4';
9
+ % save_bic_folder = '';
10
+
11
+ up_scale = 4;
12
+ mod_scale = 4;
13
+
14
+ if exist('save_mod_folder', 'var')
15
+ if exist(save_mod_folder, 'dir')
16
+ disp(['It will cover ', save_mod_folder]);
17
+ else
18
+ mkdir(save_mod_folder);
19
+ end
20
+ end
21
+ if exist('save_LR_folder', 'var')
22
+ if exist(save_LR_folder, 'dir')
23
+ disp(['It will cover ', save_LR_folder]);
24
+ else
25
+ mkdir(save_LR_folder);
26
+ end
27
+ end
28
+ if exist('save_bic_folder', 'var')
29
+ if exist(save_bic_folder, 'dir')
30
+ disp(['It will cover ', save_bic_folder]);
31
+ else
32
+ mkdir(save_bic_folder);
33
+ end
34
+ end
35
+
36
+ idx = 0;
37
+ filepaths = dir(fullfile(input_folder,'*.*'));
38
+ for i = 1 : length(filepaths)
39
+ [paths,imname,ext] = fileparts(filepaths(i).name);
40
+ if isempty(imname)
41
+ disp('Ignore . folder.');
42
+ elseif strcmp(imname, '.')
43
+ disp('Ignore .. folder.');
44
+ else
45
+ idx = idx + 1;
46
+ str_rlt = sprintf('%d\t%s.\n', idx, imname);
47
+ fprintf(str_rlt);
48
+ % read image
49
+ img = imread(fullfile(input_folder, [imname, ext]));
50
+ img = im2double(img);
51
+ % modcrop
52
+ img = modcrop(img, mod_scale);
53
+ if exist('save_mod_folder', 'var')
54
+ imwrite(img, fullfile(save_mod_folder, [imname, '.png']));
55
+ end
56
+ % LR
57
+ im_LR = imresize(img, 1/up_scale, 'bicubic');
58
+ if exist('save_LR_folder', 'var')
59
+ imwrite(im_LR, fullfile(save_LR_folder, [imname, '_bicLRx4.png']));
60
+ end
61
+ % Bicubic
62
+ if exist('save_bic_folder', 'var')
63
+ im_B = imresize(im_LR, up_scale, 'bicubic');
64
+ imwrite(im_B, fullfile(save_bic_folder, [imname, '_bicx4.png']));
65
+ end
66
+ end
67
+ end
68
+ end
69
+
70
+ %% modcrop
71
+ function img = modcrop(img, modulo)
72
+ if size(img,3) == 1
73
+ sz = size(img);
74
+ sz = sz - mod(sz, modulo);
75
+ img = img(1:sz(1), 1:sz(2));
76
+ else
77
+ tmpsz = size(img);
78
+ sz = tmpsz(1:2);
79
+ sz = sz - mod(sz, modulo);
80
+ img = img(1:sz(1), 1:sz(2),:);
81
+ end
82
+ end
esrgan_plus/codes/scripts/generate_mod_LR_bic.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import cv2
3
+ import numpy as np
4
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
5
+ from data.util import imresize_np
6
+
7
+ def generate_mod_LR_bic():
8
+ # set parameters
9
+ up_scale = 4
10
+ mod_scale = 4
11
+ # set data dir
12
+ sourcedir = '/data/datasets/img'
13
+ savedir = '/data/datasets/mod'
14
+
15
+ saveHRpath = os.path.join(savedir, 'HR', 'x'+str(mod_scale))
16
+ saveLRpath = os.path.join(savedir, 'LR', 'x'+str(up_scale))
17
+ saveBicpath = os.path.join(savedir, 'Bic', 'x'+str(up_scale))
18
+
19
+ if not os.path.isdir(sourcedir):
20
+ print('Error: No source data found')
21
+ exit(0)
22
+ if not os.path.isdir(savedir):
23
+ os.mkdir(savedir)
24
+
25
+ if not os.path.isdir(os.path.join(savedir, 'HR')):
26
+ os.mkdir(os.path.join(savedir, 'HR'))
27
+ if not os.path.isdir(os.path.join(savedir, 'LR')):
28
+ os.mkdir(os.path.join(savedir, 'LR'))
29
+ if not os.path.isdir(os.path.join(savedir, 'Bic')):
30
+ os.mkdir(os.path.join(savedir, 'Bic'))
31
+
32
+ if not os.path.isdir(saveHRpath):
33
+ os.mkdir(saveHRpath)
34
+ else:
35
+ print('It will cover '+str(saveHRpath))
36
+
37
+ if not os.path.isdir(saveLRpath):
38
+ os.mkdir(saveLRpath)
39
+ else:
40
+ print('It will cover '+str(saveLRpath))
41
+
42
+ if not os.path.isdir(saveBicpath):
43
+ os.mkdir(saveBicpath)
44
+ else:
45
+ print('It will cover '+str(saveBicpath))
46
+
47
+ filepaths = [f for f in os.listdir(sourcedir) if f.endswith('.png')]
48
+ num_files = len(filepaths)
49
+
50
+ # prepare data with augementation
51
+ for i in range(num_files):
52
+ filename = filepaths[i]
53
+ print('No.{} -- Processing {}'.format(i, filename))
54
+ # read image
55
+ image = cv2.imread(os.path.join(sourcedir, filename))
56
+
57
+ width = int(np.floor(image.shape[1] / mod_scale))
58
+ height = int(np.floor(image.shape[0] / mod_scale))
59
+ # modcrop
60
+ if len(image.shape) == 3:
61
+ image_HR = image[0:mod_scale*height, 0:mod_scale*width,:]
62
+ else:
63
+ image_HR = image[0:mod_scale*height, 0:mod_scale*width]
64
+ # LR
65
+ image_LR = imresize_np(image_HR, 1/up_scale, True)
66
+ # bic
67
+ image_Bic = imresize_np(image_LR, up_scale, True)
68
+
69
+ cv2.imwrite(os.path.join(saveHRpath, filename), image_HR)
70
+ cv2.imwrite(os.path.join(saveLRpath, filename), image_LR)
71
+ cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic)
72
+
73
+ if __name__ == "__main__":
74
+ generate_mod_LR_bic()
esrgan_plus/codes/scripts/make_gif_video.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Add text to images, then make gif/video sequence from images.
3
+
4
+ Since the created gif has low quality with color issues, use this script to generate image with
5
+ text and then use `gifski`.
6
+
7
+ Call `ffmpeg` to make video.
8
+ """
9
+
10
+ import os.path
11
+ import numpy as np
12
+ import cv2
13
+
14
+
15
+ crt_path = os.path.dirname(os.path.realpath(__file__))
16
+
17
+ # configurations
18
+ img_name_list = ['x1', 'x2', 'x3', 'x4', 'x5']
19
+ ext = '.png'
20
+ text_list = ['1', '2', '3', '4', '5']
21
+ h_start, h_len = 0, 576
22
+ w_start, w_len = 10, 352
23
+ enlarge_ratio = 1
24
+ txt_pos = (10, 50) # w, h
25
+ font_size = 1.5
26
+ font_thickness = 4
27
+ color = 'red'
28
+ duration = 0.8 # second
29
+ use_imageio = False # use imageio to make gif
30
+ make_video = False # make video using ffmpeg
31
+
32
+ is_crop = True
33
+ if h_start == 0 or w_start == 0:
34
+ is_crop = False # do not crop
35
+
36
+ img_name_list = [x + ext for x in img_name_list]
37
+ input_folder = os.path.join(crt_path, './ori')
38
+ save_folder = os.path.join(crt_path, './ori')
39
+ color_tb = {}
40
+ color_tb['yellow'] = (0, 255, 255)
41
+ color_tb['green'] = (0, 255, 0)
42
+ color_tb['red'] = (0, 0, 255)
43
+ color_tb['magenta'] = (255, 0, 255)
44
+ color_tb['matlab_blue'] = (189, 114, 0)
45
+ color_tb['matlab_orange'] = (25, 83, 217)
46
+ color_tb['matlab_yellow'] = (32, 177, 237)
47
+ color_tb['matlab_purple'] = (142, 47, 126)
48
+ color_tb['matlab_green'] = (48, 172, 119)
49
+ color_tb['matlab_liblue'] = (238, 190, 77)
50
+ color_tb['matlab_brown'] = (47, 20, 162)
51
+ color = color_tb[color]
52
+
53
+ img_list = []
54
+
55
+ # make temp dir
56
+ if not os.path.exists(save_folder):
57
+ os.makedirs(save_folder)
58
+ print('mkdir [{}] ...'.format(save_folder))
59
+ if make_video:
60
+ # tmp folder to save images for video
61
+ tmp_video_folder = os.path.join(crt_path, '_tmp_video')
62
+ if not os.path.exists(tmp_video_folder):
63
+ os.makedirs(tmp_video_folder)
64
+
65
+ idx = 0
66
+ for img_name, write_txt in zip(img_name_list, text_list):
67
+ img = cv2.imread(os.path.join(input_folder, img_name), cv2.IMREAD_UNCHANGED)
68
+ base_name = os.path.splitext(img_name)[0]
69
+ print(base_name)
70
+ # crop image
71
+ if is_crop:
72
+ print('Crop image ...')
73
+ if img.ndim == 2:
74
+ img = img[h_start:h_start + h_len, w_start:w_start + w_len]
75
+ elif img.ndim == 3:
76
+ img = img[h_start:h_start + h_len, w_start:w_start + w_len, :]
77
+ else:
78
+ raise ValueError('Wrong image dim [{:d}]'.format(img.ndim))
79
+
80
+ # enlarge img if necessary
81
+ if enlarge_ratio > 1:
82
+ H, W, _ = img.shape
83
+ img = cv2.resize(img, (W * enlarge_ratio, H * enlarge_ratio), \
84
+ interpolation=cv2.INTER_CUBIC)
85
+
86
+ # add text
87
+ font = cv2.FONT_HERSHEY_COMPLEX
88
+ cv2.putText(img, write_txt, txt_pos, font, font_size, color, font_thickness, cv2.LINE_AA)
89
+ cv2.imwrite(os.path.join(save_folder, base_name + '_text.png'), img)
90
+ if make_video:
91
+ idx += 1
92
+ cv2.imwrite(os.path.join(tmp_video_folder, '{:05d}.png'.format(idx)), img)
93
+
94
+ img = np.ascontiguousarray(img[:, :, [2, 1, 0]])
95
+ img_list.append(img)
96
+
97
+ if use_imageio:
98
+ import imageio
99
+ imageio.mimsave(os.path.join(save_folder, 'out.gif'), img_list, format='GIF', duration=duration)
100
+
101
+ if make_video:
102
+ os.system('ffmpeg -r {:f} -i {:s}/%05d.png -vcodec mpeg4 -y {:s}/movie.mp4'.format(
103
+ 1 / duration, tmp_video_folder, save_folder))
104
+
105
+ if os.path.exists(tmp_video_folder):
106
+ os.system('rm -rf {}'.format(tmp_video_folder))
esrgan_plus/codes/scripts/net_interp.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from collections import OrderedDict
3
+
4
+ alpha = 0.8
5
+
6
+ net_PSNR_path = './models/RRDB_PSNR_x4.pth'
7
+ net_ESRGAN_path = './models/RRDB_ESRGAN_x4.pth'
8
+ net_interp_path = './models/interp_{:02d}.pth'.format(int(alpha*10))
9
+
10
+ net_PSNR = torch.load(net_PSNR_path)
11
+ net_ESRGAN = torch.load(net_ESRGAN_path)
12
+ net_interp = OrderedDict()
13
+
14
+ print('Interpolating with alpha = ', alpha)
15
+
16
+ for k, v_PSNR in net_PSNR.items():
17
+ v_ESRGAN = net_ESRGAN[k]
18
+ net_interp[k] = (1 - alpha) * v_PSNR + alpha * v_ESRGAN
19
+
20
+ torch.save(net_interp, net_interp_path)
esrgan_plus/codes/scripts/rename.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path
3
+ import glob
4
+
5
+
6
+ input_folder = '/home/xtwang/Projects/PIRM18/results/pirm_selfval_img06/*' # glob matching pattern
7
+ save_folder = '/home/xtwang/Projects/PIRM18/results/pirm_selfval_img'
8
+
9
+ mode = 'cp' # 'cp' | 'mv'
10
+ file_list = sorted(glob.glob(input_folder))
11
+
12
+ if not os.path.exists(save_folder):
13
+ os.makedirs(save_folder)
14
+ print('mkdir ... ' + save_folder)
15
+ else:
16
+ print('File [{}] already exists. Exit.'.format(save_folder))
17
+
18
+ for i, path in enumerate(file_list):
19
+ base_name = os.path.splitext(os.path.basename(path))[0]
20
+
21
+ new_name = base_name.split('_')[0]
22
+ new_path = os.path.join(save_folder, new_name + '.png')
23
+
24
+ os.system(mode + ' ' + path + ' ' + new_path)
25
+ print(i, base_name)